# Copyright (c) 2009, 2022, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # mypy: disable-error-code="assignment" """Implements the MySQL Client/Server protocol.""" import datetime import struct from decimal import Decimal, DecimalException from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from . import utils from .authentication import get_auth_plugin from .constants import ( PARAMETER_COUNT_AVAILABLE, ClientFlag, FieldFlag, FieldType, ServerCmd, ) from .errors import DatabaseError, InterfaceError, ProgrammingError, get_exception from .types import ( ConnAttrsType, DescriptionType, EofPacketType, HandShakeType, OkPacketType, ParseValueFromBinaryResultPacketTypes, QueryAttrType, SocketType, StatsPacketType, StrOrBytes, SupportedMysqlBinaryProtocolTypes, ) PROTOCOL_VERSION: int = 10 class MySQLProtocol: """Implements MySQL client/server protocol Create and parses MySQL packets. """ @staticmethod def _connect_with_db(client_flags: int, database: Optional[str]) -> bytes: """Prepare database string for handshake response""" if client_flags & ClientFlag.CONNECT_WITH_DB and database: return database.encode("utf8") + b"\x00" return b"\x00" @staticmethod def _auth_response( client_flags: int, username: Optional[StrOrBytes], password: Optional[str], database: Optional[str], auth_plugin_class: str, auth_plugin: str, auth_data: Optional[bytes], ssl_enabled: bool, ) -> bytes: """Prepare the authentication response""" if not password: return b"\x00" try: auth = get_auth_plugin(auth_plugin, auth_plugin_class)( auth_data, username=username, password=password, database=database, ssl_enabled=ssl_enabled, ) plugin_auth_response = auth.auth_response() # # We could receive a NULL response, an Interface error is called when so # if plugin_auth_response is None: # raise InterfaceError except (TypeError, InterfaceError) as err: raise InterfaceError(f"Failed authentication: {err}") from err if client_flags & ClientFlag.SECURE_CONNECTION: resplen = len(plugin_auth_response) auth_response = struct.pack(" bytes: """Make a MySQL Authentication packet""" if handshake is None: handshake = {} try: auth_data: Optional[bytes] = handshake["auth_data"] auth_plugin = auth_plugin or handshake["auth_plugin"] except (TypeError, KeyError) as err: raise ProgrammingError( f"Handshake misses authentication info ({err})" ) from None if not username: username = b"" try: username_bytes = username.encode("utf8") # type: ignore[union-attr] except AttributeError: # Username is already bytes username_bytes = username filler = "x" * 22 username_len = len(username_bytes) packet = struct.pack( f" bytes: """Encode the connection attributes""" for attr_name in conn_attrs: if conn_attrs[attr_name] is None: conn_attrs[attr_name] = "" conn_attrs_len = ( sum(len(x) + len(conn_attrs[x]) for x in conn_attrs) + len(conn_attrs.keys()) + len(conn_attrs.values()) ) conn_attrs_packet = struct.pack(" bytearray: """Make a SSL authentication packet""" return ( utils.int4store(client_flags) + utils.int4store(max_allowed_packet) + utils.int2store(charset) + b"\x00" * 22 ) @staticmethod def make_command(command: int, argument: Optional[bytes] = None) -> bytearray: """Make a MySQL packet containing a command""" data = utils.int1store(command) if argument is not None: data += argument return data @staticmethod def make_stmt_fetch(statement_id: int, rows: int = 1) -> bytearray: """Make a MySQL packet with Fetch Statement command""" return utils.int4store(statement_id) + utils.int4store(rows) def make_change_user( self, handshake: HandShakeType, username: Optional[StrOrBytes] = None, password: Optional[str] = None, database: Optional[str] = None, charset: int = 45, client_flags: int = 0, ssl_enabled: bool = False, auth_plugin: Optional[str] = None, conn_attrs: Optional[ConnAttrsType] = None, auth_plugin_class: Optional[str] = None, ) -> bytes: """Make a MySQL packet with the Change User command""" try: auth_data: Optional[bytes] = handshake["auth_data"] auth_plugin = auth_plugin or handshake["auth_plugin"] except (TypeError, KeyError) as err: raise ProgrammingError( f"Handshake misses authentication info ({err})" ) from None if not username: username = b"" try: username_bytes = username.encode("utf8") # type: ignore[union-attr] except AttributeError: # Username is already bytes username_bytes = username username_len = len(username_bytes) packet = struct.pack( f" HandShakeType: """Parse a MySQL Handshake-packet""" res = {} res["protocol"] = struct.unpack(" Tuple[bytes, str]: """Parse a MySQL AuthNextFactor packet.""" packet, status = utils.read_int(packet, 1) if status != 2: raise InterfaceError("Failed parsing AuthNextFactor packet (invalid)") packet, auth_plugin = utils.read_string(packet, end=b"\x00") return packet, auth_plugin.decode("utf-8") @staticmethod def parse_ok(packet: bytes) -> OkPacketType: """Parse a MySQL OK-packet""" if not packet[4] == 0: raise InterfaceError("Failed parsing OK packet (invalid).") ok_packet = {} try: ok_packet["field_count"] = struct.unpack(" Optional[int]: """Parse a MySQL packet with the number of columns in result set""" try: count = utils.read_lc_int(packet[4:])[1] return count except (struct.error, ValueError) as err: raise InterfaceError("Failed parsing column count") from err @staticmethod def parse_column(packet: bytes, encoding: str = "utf-8") -> DescriptionType: """Parse a MySQL column-packet""" packet, _ = utils.read_lc_string(packet[4:]) # catalog packet, _ = utils.read_lc_string(packet) # db packet, _ = utils.read_lc_string(packet) # table packet, _ = utils.read_lc_string(packet) # org_table packet, name = utils.read_lc_string(packet) # name packet, _ = utils.read_lc_string(packet) # org_name try: ( charset, _, column_type, flags, _, ) = struct.unpack(" EofPacketType: """Parse a MySQL EOF-packet""" if packet[4] == 0: # EOF packet deprecation return self.parse_ok(packet) err_msg = "Failed parsing EOF packet." res = {} try: unpacked = struct.unpack(" StatsPacketType: """Parse the statistics packet""" errmsg = "Failed getting COM_STATISTICS information" res: Dict[str, Union[int, Decimal]] = {} # Information is separated by 2 spaces pairs: List[bytes] = [b""] lbl: StrOrBytes = b"" if with_header: pairs = packet[4:].split(b"\x20\x20") else: pairs = packet.split(b"\x20\x20") for pair in pairs: try: lbl, val = [v.strip() for v in pair.split(b":", 2)] except ValueError as err: raise InterfaceError(errmsg) from err # It's either an integer or a decimal lbl = lbl.decode("utf-8") try: res[lbl] = int(val) except (KeyError, ValueError): try: res[lbl] = Decimal(val.decode("utf-8")) except DecimalException as err: raise InterfaceError(f"{errmsg} ({lbl}:{repr(val)})") from err return res def read_text_result( self, sock: SocketType, version: Tuple[int, ...], count: int = 1 ) -> Tuple[List[Tuple[Optional[bytes], ...]], Optional[EofPacketType],]: """Read MySQL text result Reads all or given number of rows from the socket. Returns a tuple with 2 elements: a list with all rows and the EOF packet. """ # Keep unused 'version' for API backward compatibility _ = version rows = [] eof = None rowdata = None i = 0 while True: if eof or i == count: break packet = sock.recv() if packet.startswith(b"\xff\xff\xff"): datas = [packet[4:]] packet = sock.recv() while packet.startswith(b"\xff\xff\xff"): datas.append(packet[4:]) packet = sock.recv() datas.append(packet[4:]) rowdata = utils.read_lc_string_list(bytearray(b"").join(datas)) elif packet[4] == 254 and packet[0] < 7: eof = self.parse_eof(packet) rowdata = None else: eof = None rowdata = utils.read_lc_string_list(packet[4:]) if eof is None and rowdata is not None: rows.append(rowdata) elif eof is None and rowdata is None: raise get_exception(packet) i += 1 return rows, eof @staticmethod def _parse_binary_integer( packet: bytes, field: DescriptionType ) -> Tuple[bytes, int]: """Parse an integer from a binary packet""" if field[1] == FieldType.TINY: format_ = " Tuple[bytes, float]: """Parse a float/double from a binary packet""" if field[1] == FieldType.DOUBLE: length = 8 format_ = " Tuple[bytes, Decimal]: """Parse a New Decimal from a binary packet""" (packet, value) = utils.read_lc_string(packet) return (packet, Decimal(value.decode(charset))) @staticmethod def _parse_binary_timestamp( packet: bytes, field_type: int, ) -> Tuple[bytes, Optional[Union[datetime.date, datetime.datetime]]]: """Parse a timestamp from a binary packet""" length = packet[0] value = None if length == 4: year = struct.unpack("= 7: mcs = 0 if length == 11: mcs = struct.unpack(" Tuple[bytes, datetime.timedelta]: """Parse a time value from a binary packet""" length = packet[0] if not length: return (packet[1:], datetime.timedelta()) data = packet[1 : length + 1] mcs = 0 if length > 8: mcs = struct.unpack(" Tuple[ParseValueFromBinaryResultPacketTypes, ...]: """Parse values from a binary result packet""" null_bitmap_length = (len(fields) + 7 + 2) // 8 null_bitmap = [int(i) for i in packet[0:null_bitmap_length]] packet = packet[null_bitmap_length:] values: List[Any] = [] value: Any = None for pos, field in enumerate(fields): if null_bitmap[int((pos + 2) / 8)] & (1 << (pos + 2) % 8): values.append(None) continue if field[1] in ( FieldType.TINY, FieldType.SHORT, FieldType.INT24, FieldType.LONG, FieldType.LONGLONG, ): packet, value = self._parse_binary_integer(packet, field) values.append(value) elif field[1] in (FieldType.DOUBLE, FieldType.FLOAT): packet, value = self._parse_binary_float(packet, field) values.append(value) elif field[1] in (FieldType.DECIMAL, FieldType.NEWDECIMAL): packet, value = self._parse_binary_new_decimal(packet, charset) values.append(value) elif field[1] in ( FieldType.DATETIME, FieldType.DATE, FieldType.TIMESTAMP, ): (packet, value) = self._parse_binary_timestamp(packet, field[1]) values.append(value) elif field[1] == FieldType.TIME: (packet, value) = self._parse_binary_time(packet) values.append(value) else: (packet, value) = utils.read_lc_string(packet) try: values.append(value.decode(charset)) except UnicodeDecodeError: values.append(value) return tuple(values) def read_binary_result( self, sock: SocketType, columns: List[DescriptionType], count: int = 1, charset: str = "utf-8", ) -> Tuple[ List[Tuple[ParseValueFromBinaryResultPacketTypes, ...]], Optional[EofPacketType], ]: """Read MySQL binary protocol result Reads all or given number of binary resultset rows from the socket. """ rows = [] eof = None values = None i = 0 while True: if eof is not None: break if i == count: break packet = sock.recv() if packet[4] == 254: eof = self.parse_eof(packet) values = None elif packet[4] == 0: eof = None values = self._parse_binary_values(columns, packet[5:], charset) if eof is None and values is not None: rows.append(values) elif eof is None and values is None: raise get_exception(packet) i += 1 return (rows, eof) @staticmethod def parse_binary_prepare_ok(packet: bytes) -> Dict[str, int]: """Parse a MySQL Binary Protocol OK packet""" if not packet[4] == 0: raise InterfaceError("Failed parsing Binary OK packet") ok_pkt = {} try: packet, ok_pkt["statement_id"] = utils.read_int(packet[5:], 4) packet, ok_pkt["num_columns"] = utils.read_int(packet, 2) packet, ok_pkt["num_params"] = utils.read_int(packet, 2) packet = packet[1:] # Filler 1 * \x00 packet, ok_pkt["warning_count"] = utils.read_int(packet, 2) except ValueError as err: raise InterfaceError("Failed parsing Binary OK packet") from err return ok_pkt @staticmethod def prepare_binary_integer(value: int) -> Tuple[bytes, int, int]: """Prepare an integer for the MySQL binary protocol""" field_type = None flags = 0 if value < 0: if value >= -128: format_ = "= -32768: format_ = "= -2147483648: format_ = " Tuple[bytearray, int]: """Prepare a timestamp object for the MySQL binary protocol This method prepares a timestamp of type datetime.datetime or datetime.date for sending over the MySQL binary protocol. A tuple is returned with the prepared value and field type as elements. Raises ValueError when the argument value is of invalid type. Returns a tuple. """ if isinstance(value, datetime.datetime): field_type = FieldType.DATETIME elif isinstance(value, datetime.date): field_type = FieldType.DATE else: raise ValueError("Argument must a datetime.datetime or datetime.date") packed = ( utils.int2store(value.year) + utils.int1store(value.month) + utils.int1store(value.day) ) if isinstance(value, datetime.datetime): packed = ( packed + utils.int1store(value.hour) + utils.int1store(value.minute) + utils.int1store(value.second) ) if value.microsecond > 0: packed += utils.int4store(value.microsecond) packed = utils.int1store(len(packed)) + packed return (packed, field_type) @staticmethod def prepare_binary_time( value: Union[datetime.timedelta, datetime.time] ) -> Tuple[bytearray, int]: """Prepare a time object for the MySQL binary protocol This method prepares a time object of type datetime.timedelta or datetime.time for sending over the MySQL binary protocol. A tuple is returned with the prepared value and field type as elements. Raises ValueError when the argument value is of invalid type. Returns a tuple. """ if not isinstance(value, (datetime.timedelta, datetime.time)): raise ValueError("Argument must a datetime.timedelta or datetime.time") field_type = FieldType.TIME negative = 0 mcs = None packed = b"" if isinstance(value, datetime.timedelta): if value.days < 0: negative = 1 (hours, remainder) = divmod(value.seconds, 3600) (mins, secs) = divmod(remainder, 60) packed += ( utils.int4store(abs(value.days)) + utils.int1store(hours) + utils.int1store(mins) + utils.int1store(secs) ) mcs = value.microseconds else: packed += ( utils.int4store(0) + utils.int1store(value.hour) + utils.int1store(value.minute) + utils.int1store(value.second) ) mcs = value.microsecond if mcs: packed += utils.int4store(mcs) packed = utils.int1store(negative) + packed packed = utils.int1store(len(packed)) + packed return (packed, field_type) @staticmethod def prepare_stmt_send_long_data( statement: int, param: int, data: bytes ) -> bytearray: """Prepare long data for prepared statements Returns a string. """ packet: bytearray = utils.int4store(statement) + utils.int2store(param) + data return packet def make_stmt_execute( self, statement_id: int, data: Sequence[SupportedMysqlBinaryProtocolTypes] = (), parameters: Sequence[Any] = (), flags: int = 0, long_data_used: Optional[Dict[int, Tuple[bool]]] = None, charset: str = "utf8", query_attrs: Optional[QueryAttrType] = None, converter_str_fallback: bool = False, ) -> bytearray: """Make a MySQL packet with the Statement Execute command""" iteration_count = 1 null_bitmap = [0] * ((len(data) + 7) // 8) values = [] types = [] packed = b"" data_len = len(data) query_attr_names = [] flags = flags if not query_attrs else flags + PARAMETER_COUNT_AVAILABLE if charset == "utf8mb4": charset = "utf8" if long_data_used is None: long_data_used = {} if query_attrs: data = list(data) for _, attr_val in query_attrs: data.append(attr_val) null_bitmap = [0] * ((len(data) + 7) // 8) if parameters or data: if data_len != len(parameters): raise InterfaceError( "Failed executing prepared statement: data values does not" " match number of parameters" ) for pos, value in enumerate(data): _flags = 0 if value is None: null_bitmap[(pos // 8)] |= 1 << (pos % 8) types.append( utils.int1store(FieldType.NULL) + utils.int1store(_flags) ) continue if pos in long_data_used: if long_data_used[pos][0]: # We suppose binary data field_type = FieldType.BLOB else: # We suppose text data field_type = FieldType.STRING elif isinstance(value, int): ( packed, field_type, _flags, ) = self.prepare_binary_integer(value) values.append(packed) elif isinstance(value, str): value = value.encode(charset) values.append(utils.lc_int(len(value)) + value) field_type = FieldType.VARCHAR elif isinstance(value, bytes): values.append(utils.lc_int(len(value)) + value) field_type = FieldType.BLOB elif isinstance(value, Decimal): values.append( utils.lc_int(len(str(value).encode(charset))) + str(value).encode(charset) ) field_type = FieldType.DECIMAL elif isinstance(value, float): values.append(struct.pack(" data_len: name = query_attrs[pos - data_len][0].encode(charset) query_attr_names.append(utils.lc_int(len(name)) + name) packet = ( utils.int4store(statement_id) + utils.int1store(flags) + utils.int4store(iteration_count) ) # if (num_params > 0 || (CLIENT_QUERY_ATTRIBUTES \ # && (flags & PARAMETER_COUNT_AVAILABLE)) { if query_attrs is not None: parameter_count = data_len + len(query_attrs) else: parameter_count = data_len if parameter_count: # if CLIENT_QUERY_ATTRIBUTES is on if query_attrs is not None: packet += utils.lc_int(parameter_count) packet += b"".join( [struct.pack("B", bit) for bit in null_bitmap] ) + utils.int1store(1) count = 0 for a_type in types: packet += a_type # if CLIENT_QUERY_ATTRIBUTES is on { # string parameter_name Name of the parameter # or empty if not present # } if CLIENT_QUERY_ATTRIBUTES is on if query_attrs is not None: if count + 1 > data_len: packet += query_attr_names[count - data_len] else: packet += b"\x00" count += 1 for a_value in values: packet += a_value return packet @staticmethod def parse_auth_switch_request(packet: bytes) -> Tuple[str, bytes]: """Parse a MySQL AuthSwitchRequest-packet""" if not packet[4] == 254: raise InterfaceError("Failed parsing AuthSwitchRequest packet") packet, plugin_name = utils.read_string(packet[5:], end=b"\x00") if packet and packet[-1] == 0: packet = packet[:-1] return plugin_name.decode("utf8"), packet @staticmethod def parse_auth_more_data(packet: bytes) -> bytes: """Parse a MySQL AuthMoreData-packet""" if not packet[4] == 1: raise InterfaceError("Failed parsing AuthMoreData packet") return packet[5:]