# Copyright (c) 2012, 2023, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # mypy: disable-error-code="attr-defined" """Module implementing low-level socket communication with MySQL servers. """ import os import socket import struct import warnings import zlib from collections import deque try: import ssl TLS_VERSIONS = { "TLSv1": ssl.PROTOCOL_TLSv1, "TLSv1.1": ssl.PROTOCOL_TLSv1_1, "TLSv1.2": ssl.PROTOCOL_TLSv1_2, } # TLSv1.3 included in PROTOCOL_TLS, but PROTOCOL_TLS is not included on 3.4 TLS_VERSIONS["TLSv1.3"] = ( ssl.PROTOCOL_TLS if hasattr(ssl, "PROTOCOL_TLS") else ssl.PROTOCOL_SSLv23 # Alias of PROTOCOL_TLS ) TLS_V1_3_SUPPORTED = hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3 except ImportError: # If import fails, we don't have SSL support. TLS_V1_3_SUPPORTED = False from typing import Any, Deque, List, Optional, Tuple, Union from .constants import MAX_PACKET_LENGTH from .errors import InterfaceError, NotSupportedError, OperationalError from .types import StrOrBytesPath from .utils import init_bytearray def _strioerror(err: IOError) -> str: """Reformat the IOError error message This function reformats the IOError error message. """ if not err.errno: return str(err) return f"{err.errno} {err.strerror}" def _prepare_packets(buf: bytes, pktnr: int) -> List[bytes]: """Prepare a packet for sending to the MySQL server""" pkts = [] pllen = len(buf) maxpktlen = MAX_PACKET_LENGTH while pllen > maxpktlen: pkts.append(b"\xff\xff\xff" + struct.pack("<B", pktnr) + buf[:maxpktlen]) buf = buf[maxpktlen:] pllen = len(buf) pktnr = pktnr + 1 pkts.append(struct.pack("<I", pllen)[0:3] + struct.pack("<B", pktnr) + buf) return pkts class BaseMySQLSocket: """Base class for MySQL socket communication This class should not be used directly but overloaded, changing the at least the open_connection()-method. Examples of subclasses are mysql.connector.network.MySQLTCPSocket mysql.connector.network.MySQLUnixSocket """ def __init__(self) -> None: # holds the socket connection self.sock: Optional[socket.socket] = None self._connection_timeout: Optional[int] = None self._packet_number: int = -1 self._compressed_packet_number: int = -1 self._packet_queue: Deque[bytearray] = deque() self.server_host: Optional[str] = None self.recvsize: int = 8192 def next_packet_number(self) -> int: """Increments the packet number""" self._packet_number = self._packet_number + 1 if self._packet_number > 255: self._packet_number = 0 return self._packet_number def next_compressed_packet_number(self) -> int: """Increments the compressed packet number""" self._compressed_packet_number = self._compressed_packet_number + 1 if self._compressed_packet_number > 255: self._compressed_packet_number = 0 return self._compressed_packet_number def open_connection(self) -> Any: """Open the socket""" raise NotImplementedError def get_address(self) -> Any: """Get the location of the socket""" raise NotImplementedError def shutdown(self) -> None: """Shut down the socket before closing it""" try: self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() del self._packet_queue except (AttributeError, OSError): pass def close_connection(self) -> None: """Close the socket""" try: self.sock.close() del self._packet_queue except (AttributeError, OSError): pass def __del__(self) -> None: self.shutdown() def send_plain( self, buf: bytes, packet_number: Optional[int] = None, compressed_packet_number: Optional[int] = None, ) -> None: """Send packets to the MySQL server""" # Keep 'compressed_packet_number' for API backward compatibility _ = compressed_packet_number if packet_number is None: self.next_packet_number() else: self._packet_number = packet_number packets = _prepare_packets(buf, self._packet_number) for packet in packets: try: self.sock.sendall(packet) except IOError as err: raise OperationalError( errno=2055, values=(self.get_address(), _strioerror(err)) ) from err except AttributeError as err: raise OperationalError(errno=2006) from err send = send_plain def send_compressed( self, buf: bytes, packet_number: Optional[int] = None, compressed_packet_number: Optional[int] = None, ) -> None: """Send compressed packets to the MySQL server""" if packet_number is None: self.next_packet_number() else: self._packet_number = packet_number if compressed_packet_number is None: self.next_compressed_packet_number() else: self._compressed_packet_number = compressed_packet_number pktnr = self._packet_number pllen = len(buf) zpkts = [] maxpktlen = MAX_PACKET_LENGTH if pllen > maxpktlen: pkts = _prepare_packets(buf, pktnr) tmpbuf = b"".join(pkts) del pkts zbuf = zlib.compress(tmpbuf[:16384]) header = ( struct.pack("<I", len(zbuf))[0:3] + struct.pack("<B", self._compressed_packet_number) + b"\x00\x40\x00" ) zpkts.append(header + zbuf) tmpbuf = tmpbuf[16384:] pllen = len(tmpbuf) self.next_compressed_packet_number() while pllen > maxpktlen: zbuf = zlib.compress(tmpbuf[:maxpktlen]) header = ( struct.pack("<I", len(zbuf))[0:3] + struct.pack("<B", self._compressed_packet_number) + b"\xff\xff\xff" ) zpkts.append(header + zbuf) tmpbuf = tmpbuf[maxpktlen:] pllen = len(tmpbuf) self.next_compressed_packet_number() if tmpbuf: zbuf = zlib.compress(tmpbuf) header = ( struct.pack("<I", len(zbuf))[0:3] + struct.pack("<B", self._compressed_packet_number) + struct.pack("<I", pllen)[0:3] ) zpkts.append(header + zbuf) del tmpbuf else: pkt = struct.pack("<I", pllen)[0:3] + struct.pack("<B", pktnr) + buf pllen = len(pkt) if pllen > 50: zbuf = zlib.compress(pkt) zpkts.append( struct.pack("<I", len(zbuf))[0:3] + struct.pack("<B", self._compressed_packet_number) + struct.pack("<I", pllen)[0:3] + zbuf ) else: header = ( struct.pack("<I", pllen)[0:3] + struct.pack("<B", self._compressed_packet_number) + struct.pack("<I", 0)[0:3] ) zpkts.append(header + pkt) for zip_packet in zpkts: try: self.sock.sendall(zip_packet) except IOError as err: raise OperationalError( errno=2055, values=(self.get_address(), _strioerror(err)) ) from err except AttributeError as err: raise OperationalError(errno=2006) from err def recv_plain(self) -> bytearray: """Receive packets from the MySQL server""" try: # Read the header of the MySQL packet, 4 bytes packet = bytearray(b"") packet_len = 0 while packet_len < 4: chunk = self.sock.recv(4 - packet_len) if not chunk: raise InterfaceError(errno=2013) packet += chunk packet_len = len(packet) # Save the packet number and payload length self._packet_number = packet[3] payload_len = struct.unpack("<I", packet[0:3] + b"\x00")[0] # Read the payload rest = payload_len packet.extend(bytearray(payload_len)) packet_view = memoryview(packet) packet_view = packet_view[4:] while rest: read = self.sock.recv_into(packet_view, rest) if read == 0 and rest > 0: raise InterfaceError(errno=2013) packet_view = packet_view[read:] rest -= read return packet except IOError as err: raise OperationalError( errno=2055, values=(self.get_address(), _strioerror(err)) ) from err recv = recv_plain def _split_zipped_payload(self, packet_bunch: bytearray) -> None: """Split compressed payload""" while packet_bunch: payload_length = struct.unpack("<I", packet_bunch[0:3] + b"\x00")[0] self._packet_queue.append(packet_bunch[0 : payload_length + 4]) packet_bunch = packet_bunch[payload_length + 4 :] def recv_compressed(self) -> Optional[bytearray]: """Receive compressed packets from the MySQL server""" try: pkt = self._packet_queue.popleft() self._packet_number = pkt[3] return pkt except IndexError: pass header = bytearray(b"") packets = [] try: abyte = self.sock.recv(1) while abyte and len(header) < 7: header += abyte abyte = self.sock.recv(1) while header: if len(header) < 7: raise InterfaceError(errno=2013) # Get length of compressed packet zip_payload_length = struct.unpack("<I", header[0:3] + b"\x00")[0] self._compressed_packet_number = header[3] # Get payload length before compression payload_length = struct.unpack("<I", header[4:7] + b"\x00")[0] zip_payload = init_bytearray(abyte) while len(zip_payload) < zip_payload_length: chunk = self.sock.recv(zip_payload_length - len(zip_payload)) if not chunk: raise InterfaceError(errno=2013) zip_payload = zip_payload + chunk # Payload was not compressed if payload_length == 0: self._split_zipped_payload(zip_payload) pkt = self._packet_queue.popleft() self._packet_number = pkt[3] return pkt packets.append((payload_length, zip_payload)) if zip_payload_length <= 16384: # We received the full compressed packet break # Get next compressed packet header = init_bytearray(b"") abyte = self.sock.recv(1) while abyte and len(header) < 7: header += abyte abyte = self.sock.recv(1) except IOError as err: raise OperationalError( errno=2055, values=(self.get_address(), _strioerror(err)) ) from err # Compressed packet can contain more than 1 MySQL packets # We decompress and make one so we can split it up tmp = init_bytearray(b"") for payload_length, payload in packets: # payload_length can not be 0; this was previously handled tmp += zlib.decompress(payload) self._split_zipped_payload(tmp) del tmp try: pkt = self._packet_queue.popleft() self._packet_number = pkt[3] return pkt except IndexError: pass return None def set_connection_timeout(self, timeout: Optional[int]) -> None: """Set the connection timeout""" self._connection_timeout = timeout if self.sock: self.sock.settimeout(timeout) def switch_to_ssl( self, ca: StrOrBytesPath, cert: StrOrBytesPath, key: StrOrBytesPath, verify_cert: bool = False, verify_identity: bool = False, cipher_suites: Optional[str] = None, tls_versions: Optional[List[str]] = None, ) -> None: """Switch the socket to use SSL""" if not self.sock: raise InterfaceError(errno=2048) try: if verify_cert: cert_reqs = ssl.CERT_REQUIRED elif verify_identity: cert_reqs = ssl.CERT_OPTIONAL else: cert_reqs = ssl.CERT_NONE if tls_versions is None or not tls_versions: context = ssl.create_default_context() if not verify_identity: context.check_hostname = False else: tls_versions.sort(reverse=True) tls_version = tls_versions[0] if ( not TLS_V1_3_SUPPORTED and tls_version == "TLSv1.3" and len(tls_versions) > 1 ): tls_version = tls_versions[1] ssl_protocol = TLS_VERSIONS[tls_version] context = ssl.SSLContext(ssl_protocol) if tls_version == "TLSv1.3": if "TLSv1.2" not in tls_versions: context.options |= ssl.OP_NO_TLSv1_2 if "TLSv1.1" not in tls_versions: context.options |= ssl.OP_NO_TLSv1_1 if "TLSv1" not in tls_versions: context.options |= ssl.OP_NO_TLSv1 context.check_hostname = False context.verify_mode = cert_reqs context.load_default_certs() if ca: try: context.load_verify_locations(ca) except (IOError, ssl.SSLError) as err: self.sock.close() raise InterfaceError(f"Invalid CA Certificate: {err}") from err if cert: try: context.load_cert_chain(cert, key) except (IOError, ssl.SSLError) as err: self.sock.close() raise InterfaceError(f"Invalid Certificate/Key: {err}") from err if cipher_suites: context.set_ciphers(cipher_suites) if hasattr(self, "server_host"): self.sock = context.wrap_socket( self.sock, server_hostname=self.server_host ) else: self.sock = context.wrap_socket(self.sock) if verify_identity: context.check_hostname = True hostnames: List[str] = [self.server_host] if self.server_host else [] if os.name == "nt" and self.server_host == "localhost": hostnames = ["localhost", "127.0.0.1"] aliases = socket.gethostbyaddr(self.server_host) hostnames.extend([aliases[0]] + aliases[1]) match_found = False errs = [] for hostname in hostnames: try: # Deprecated in Python 3.7 without a replacement and # should be removed in the future, since OpenSSL now # performs hostname matching # pylint: disable=deprecated-method ssl.match_hostname(self.sock.getpeercert(), hostname) # pylint: enable=deprecated-method except ssl.CertificateError as err: errs.append(str(err)) else: match_found = True break if not match_found: self.sock.close() raise InterfaceError( f"Unable to verify server identity: {', '.join(errs)}" ) except NameError as err: raise NotSupportedError("Python installation has no SSL support") from err except (ssl.SSLError, IOError) as err: raise InterfaceError( errno=2055, values=(self.get_address(), _strioerror(err)) ) from err except ssl.CertificateError as err: raise InterfaceError(str(err)) from err except NotImplementedError as err: raise InterfaceError(str(err)) from err class MySQLUnixSocket(BaseMySQLSocket): """MySQL socket class using UNIX sockets Opens a connection through the UNIX socket of the MySQL Server. """ def __init__(self, unix_socket: str = "/tmp/mysql.sock") -> None: super().__init__() self.unix_socket: str = unix_socket def get_address(self) -> str: return self.unix_socket def open_connection(self) -> None: try: self.sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM # pylint: disable=no-member ) self.sock.settimeout(self._connection_timeout) self.sock.connect(self.unix_socket) except IOError as err: raise InterfaceError( errno=2002, values=(self.get_address(), _strioerror(err)) ) from err except Exception as err: raise InterfaceError(str(err)) from err def switch_to_ssl( self, *args: Any, **kwargs: Any # pylint: disable=unused-argument ) -> None: """Switch the socket to use SSL.""" warnings.warn( "SSL is disabled when using unix socket connections", Warning, ) class MySQLTCPSocket(BaseMySQLSocket): """MySQL socket class using TCP/IP Opens a TCP/IP connection to the MySQL Server. """ def __init__( self, host: str = "127.0.0.1", port: int = 3306, force_ipv6: bool = False ) -> None: super().__init__() self.server_host: str = host self.server_port: int = port self.force_ipv6: bool = force_ipv6 self._family: int = 0 def get_address(self) -> str: return f"{self.server_host}:{self.server_port}" def open_connection(self) -> None: """Open the TCP/IP connection to the MySQL server""" # pylint: disable=no-member # Get address information addrinfo: Union[ Tuple[None, None, None, None, None], Tuple[ socket.AddressFamily, socket.SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]], ], ] = (None, None, None, None, None) try: addrinfos = socket.getaddrinfo( self.server_host, self.server_port, 0, socket.SOCK_STREAM, socket.SOL_TCP, ) # If multiple results we favor IPv4, unless IPv6 was forced. for info in addrinfos: if self.force_ipv6 and info[0] == socket.AF_INET6: addrinfo = info break if info[0] == socket.AF_INET: addrinfo = info break if self.force_ipv6 and addrinfo[0] is None: raise InterfaceError(f"No IPv6 address found for {self.server_host}") if addrinfo[0] is None: addrinfo = addrinfos[0] except IOError as err: raise InterfaceError( errno=2003, values=(self.get_address(), _strioerror(err)) ) from err (self._family, socktype, proto, _, sockaddr) = addrinfo # Instanciate the socket and connect try: self.sock = socket.socket(self._family, socktype, proto) self.sock.settimeout(self._connection_timeout) self.sock.connect(sockaddr) except IOError as err: raise InterfaceError( errno=2003, values=( self.server_host, self.server_port, _strioerror(err), ), ) from err except Exception as err: raise OperationalError(str(err)) from err