# Copyright (c) 2014, 2023, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # mypy: disable-error-code="assignment,attr-defined" """Module gathering all abstract base classes.""" from __future__ import annotations import importlib import os import re import weakref from abc import ABC, abstractmethod from datetime import date, datetime, time, timedelta from decimal import Decimal from inspect import signature from time import sleep from types import TracebackType from typing import ( Any, BinaryIO, Callable, Dict, Generator, List, Mapping, Optional, Sequence, Tuple, Type, Union, ) TLS_V1_3_SUPPORTED = False try: import ssl if hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3: TLS_V1_3_SUPPORTED = True except ImportError: # If import fails, we don't have SSL support. pass from .constants import ( CONN_ATTRS_DN, DEFAULT_CONFIGURATION, DEPRECATED_TLS_VERSIONS, OPENSSL_CS_NAMES, TLS_CIPHER_SUITES, TLS_VERSIONS, CharacterSet, ClientFlag, ) from .conversion import MySQLConverter, MySQLConverterBase from .errors import ( Error, InterfaceError, NotSupportedError, OperationalError, ProgrammingError, ) from .optionfiles import read_option_files from .types import ( ConnAttrsType, DescriptionType, HandShakeType, QueryAttrType, StrOrBytes, SupportedMysqlBinaryProtocolTypes, WarningType, ) NAMED_TUPLE_CACHE: weakref.WeakValueDictionary[Any, Any] = weakref.WeakValueDictionary() DUPLICATED_IN_LIST_ERROR = ( "The '{list}' list must not contain repeated values, the value " "'{value}' is duplicated." ) TLS_VERSION_ERROR = ( "The given tls_version: '{}' is not recognized as a valid " "TLS protocol version (should be one of {})." ) TLS_VERSION_DEPRECATED_ERROR = ( "The given tls_version: '{}' are no longer allowed (should be one of {})." ) TLS_VER_NO_SUPPORTED = ( "No supported TLS protocol version found in the 'tls-versions' list '{}'. " ) KRB_SERVICE_PINCIPAL_ERROR = ( 'Option "krb_service_principal" {error}, must be a string in the form ' '"primary/instance@realm" e.g "ldap/ldapauth@MYSQL.COM" where "@realm" ' "is optional and if it is not given will be assumed to belong to the " "default realm, as configured in the krb5.conf file." ) MYSQL_PY_TYPES = ( Decimal, bytes, date, datetime, float, int, str, time, timedelta, ) class MySQLConnectionAbstract(ABC): """Abstract class for classes connecting to a MySQL server""" def __init__(self) -> None: """Initialize""" self._client_flags: int = ClientFlag.get_default() self._charset_id: int = 45 self._sql_mode: Optional[str] = None self._time_zone: Optional[str] = None self._autocommit: bool = False self._server_version: Optional[Tuple[int, ...]] = None self._handshake: Optional[HandShakeType] = None self._conn_attrs: ConnAttrsType = {} self._user: str = "" self._password: str = "" self._password1: str = "" self._password2: str = "" self._password3: str = "" self._database: str = "" self._host: str = "127.0.0.1" self._port: int = 3306 self._unix_socket: Optional[str] = None self._client_host: str = "" self._client_port: int = 0 self._ssl: Dict[str, Optional[Union[str, bool, List[str]]]] = {} self._ssl_disabled: bool = DEFAULT_CONFIGURATION["ssl_disabled"] self._force_ipv6: bool = False self._oci_config_file: Optional[str] = None self._oci_config_profile: Optional[str] = None self._fido_callback: Optional[Union[str, Callable]] = None self._krb_service_principal: Optional[str] = None self._use_unicode: bool = True self._get_warnings: bool = False self._raise_on_warnings: bool = False self._connection_timeout: Optional[int] = DEFAULT_CONFIGURATION[ "connect_timeout" ] self._buffered: bool = False self._unread_result: bool = False self._have_next_result: bool = False self._raw: bool = False self._in_transaction: bool = False self._allow_local_infile: bool = DEFAULT_CONFIGURATION["allow_local_infile"] self._allow_local_infile_in_path: Optional[str] = DEFAULT_CONFIGURATION[ "allow_local_infile_in_path" ] self._prepared_statements: Any = None self._query_attrs: QueryAttrType = [] self._ssl_active: bool = False self._auth_plugin: Optional[str] = None self._auth_plugin_class: Optional[str] = None self._pool_config_version: Any = None self.converter: Optional[MySQLConverter] = None self._converter_class: Optional[Type[MySQLConverter]] = None self._converter_str_fallback: bool = False self._compress: bool = False self._consume_results: bool = False self._init_command: Optional[str] = None def __enter__(self) -> MySQLConnectionAbstract: return self def __exit__( self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType, ) -> None: self.close() def get_self(self) -> MySQLConnectionAbstract: """Return self for weakref.proxy This method is used when the original object is needed when using weakref.proxy. """ return self @property def is_secure(self) -> bool: """Return True if is a secure connection.""" return self._ssl_active or ( self._unix_socket is not None and os.name == "posix" ) @property def have_next_result(self) -> bool: """Return if have next result.""" return self._have_next_result @property def query_attrs(self) -> QueryAttrType: """Return query attributes list.""" return self._query_attrs def query_attrs_append( self, value: Tuple[str, SupportedMysqlBinaryProtocolTypes] ) -> None: """Add element to the query attributes list.""" self._query_attrs.append(value) def query_attrs_clear(self) -> None: """Clear query attributes list.""" del self._query_attrs[:] def _validate_tls_ciphersuites(self) -> None: """Validates the tls_ciphersuites option.""" tls_ciphersuites = [] tls_cs = self._ssl["tls_ciphersuites"] if isinstance(tls_cs, str): if not (tls_cs.startswith("[") and tls_cs.endswith("]")): raise AttributeError( f"tls_ciphersuites must be a list, found: '{tls_cs}'" ) tls_css = tls_cs[1:-1].split(",") if not tls_css: raise AttributeError( "No valid cipher suite found in 'tls_ciphersuites' list" ) for _tls_cs in tls_css: _tls_cs = tls_cs.strip().upper() if _tls_cs: tls_ciphersuites.append(_tls_cs) elif isinstance(tls_cs, (list, set)): tls_ciphersuites = [tls_cs for tls_cs in tls_cs if tls_cs] else: raise AttributeError( "tls_ciphersuites should be a list with one or more " f"ciphersuites. Found: '{tls_cs}'" ) tls_versions = ( TLS_VERSIONS[:] if self._ssl.get("tls_versions", None) is None else self._ssl["tls_versions"][:] # type: ignore[index] ) # A newer TLS version can use a cipher introduced on # an older version. tls_versions.sort(reverse=True) # type: ignore[union-attr] newer_tls_ver = tls_versions[0] # translated_names[0] belongs to TLSv1, TLSv1.1 and TLSv1.2 # translated_names[1] are TLSv1.3 only translated_names: List[List[str]] = [[], []] iani_cipher_suites_names = {} ossl_cipher_suites_names: List[str] = [] # Old ciphers can work with new TLS versions. # Find all the ciphers introduced on previous TLS versions. for tls_ver in TLS_VERSIONS[: TLS_VERSIONS.index(newer_tls_ver) + 1]: iani_cipher_suites_names.update(TLS_CIPHER_SUITES[tls_ver]) ossl_cipher_suites_names.extend(OPENSSL_CS_NAMES[tls_ver]) for name in tls_ciphersuites: if "-" in name and name in ossl_cipher_suites_names: if name in OPENSSL_CS_NAMES["TLSv1.3"]: translated_names[1].append(name) else: translated_names[0].append(name) elif name in iani_cipher_suites_names: translated_name = iani_cipher_suites_names[name] if translated_name in translated_names: raise AttributeError( DUPLICATED_IN_LIST_ERROR.format( list="tls_ciphersuites", value=translated_name ) ) if name in TLS_CIPHER_SUITES["TLSv1.3"]: translated_names[1].append(iani_cipher_suites_names[name]) else: translated_names[0].append(iani_cipher_suites_names[name]) else: raise AttributeError( f"The value '{name}' in tls_ciphersuites is not a valid " "cipher suite" ) if not translated_names[0] and not translated_names[1]: raise AttributeError( "No valid cipher suite found in the 'tls_ciphersuites' list" ) self._ssl["tls_ciphersuites"] = [ ":".join(translated_names[0]), ":".join(translated_names[1]), ] def _validate_tls_versions(self) -> None: """Validates the tls_versions option.""" tls_versions = [] tls_version = self._ssl["tls_versions"] if isinstance(tls_version, str): if not (tls_version.startswith("[") and tls_version.endswith("]")): raise AttributeError( f"tls_versions must be a list, found: '{tls_version}'" ) tls_vers = tls_version[1:-1].split(",") for tls_ver in tls_vers: tls_version = tls_ver.strip() if tls_version == "": continue if tls_version in tls_versions: raise AttributeError( DUPLICATED_IN_LIST_ERROR.format( list="tls_versions", value=tls_version ) ) tls_versions.append(tls_version) if tls_vers == ["TLSv1.3"] and not TLS_V1_3_SUPPORTED: raise AttributeError( TLS_VER_NO_SUPPORTED.format(tls_version, TLS_VERSIONS) ) elif isinstance(tls_version, list): if not tls_version: raise AttributeError( "At least one TLS protocol version must be specified in " "'tls_versions' list" ) for tls_ver in tls_version: if tls_ver in tls_versions: raise AttributeError( DUPLICATED_IN_LIST_ERROR.format( list="tls_versions", value=tls_ver ) ) tls_versions.append(tls_ver) elif isinstance(tls_version, set): for tls_ver in tls_version: tls_versions.append(tls_ver) else: raise AttributeError( "tls_versions should be a list with one or more of versions " f"in {', '.join(TLS_VERSIONS)}. found: '{tls_versions}'" ) if not tls_versions: raise AttributeError( "At least one TLS protocol version must be specified " "in 'tls_versions' list when this option is given" ) use_tls_versions = [] deprecated_tls_versions = [] invalid_tls_versions = [] for tls_ver in tls_versions: if tls_ver in TLS_VERSIONS: use_tls_versions.append(tls_ver) if tls_ver in DEPRECATED_TLS_VERSIONS: deprecated_tls_versions.append(tls_ver) else: invalid_tls_versions.append(tls_ver) if use_tls_versions: if use_tls_versions == ["TLSv1.3"] and not TLS_V1_3_SUPPORTED: raise NotSupportedError( TLS_VER_NO_SUPPORTED.format(tls_version, TLS_VERSIONS) ) use_tls_versions.sort() self._ssl["tls_versions"] = use_tls_versions elif deprecated_tls_versions: raise NotSupportedError( TLS_VERSION_DEPRECATED_ERROR.format( deprecated_tls_versions, TLS_VERSIONS ) ) elif invalid_tls_versions: raise AttributeError(TLS_VERSION_ERROR.format(tls_ver, TLS_VERSIONS)) @property def user(self) -> str: """User used while connecting to MySQL""" return self._user @property def server_host(self) -> str: """MySQL server IP address or name""" return self._host @property def server_port(self) -> int: "MySQL server TCP/IP port" return self._port @property def unix_socket(self) -> Optional[str]: "MySQL Unix socket file location" return self._unix_socket @property @abstractmethod def database(self) -> str: """Get the current database""" @database.setter def database(self, value: str) -> None: """Set the current database""" self.cmd_query(f"USE {value}") @property def can_consume_results(self) -> bool: """Returns whether to consume results""" return self._consume_results @can_consume_results.setter def can_consume_results(self, value: bool) -> None: """Set if can consume results.""" assert isinstance(value, bool) self._consume_results = value @property def pool_config_version(self) -> Any: """Return the pool configuration version""" return self._pool_config_version @pool_config_version.setter def pool_config_version(self, value: Any) -> None: """Set the pool configuration version""" self._pool_config_version = value def config(self, **kwargs: Any) -> None: """Configure the MySQL Connection This method allows you to configure the MySQLConnection instance. Raises on errors. """ config = kwargs.copy() if "dsn" in config: raise NotSupportedError("Data source name is not supported") # Read option files config = read_option_files(**config) # Configure how we handle MySQL warnings try: self.get_warnings = config["get_warnings"] del config["get_warnings"] except KeyError: pass # Leave what was set or default try: self.raise_on_warnings = config["raise_on_warnings"] del config["raise_on_warnings"] except KeyError: pass # Leave what was set or default # Configure client flags try: default = ClientFlag.get_default() self.set_client_flags(config["client_flags"] or default) del config["client_flags"] except KeyError: pass # Missing client_flags-argument is OK try: if config["compress"]: self._compress = True self.set_client_flags([ClientFlag.COMPRESS]) except KeyError: pass # Missing compress argument is OK self._allow_local_infile = config.get( "allow_local_infile", DEFAULT_CONFIGURATION["allow_local_infile"] ) self._allow_local_infile_in_path = config.get( "allow_local_infile_in_path", DEFAULT_CONFIGURATION["allow_local_infile_in_path"], ) infile_in_path = None if self._allow_local_infile_in_path: infile_in_path = os.path.abspath(self._allow_local_infile_in_path) if ( infile_in_path and os.path.exists(infile_in_path) and not os.path.isdir(infile_in_path) or os.path.islink(infile_in_path) ): raise AttributeError("allow_local_infile_in_path must be a directory") if self._allow_local_infile or self._allow_local_infile_in_path: self.set_client_flags([ClientFlag.LOCAL_FILES]) else: self.set_client_flags([-ClientFlag.LOCAL_FILES]) try: if not config["consume_results"]: self._consume_results = False else: self._consume_results = True except KeyError: self._consume_results = False # Configure auth_plugin try: self._auth_plugin = config["auth_plugin"] del config["auth_plugin"] except KeyError: self._auth_plugin = "" # Configure character set and collation if "charset" in config or "collation" in config: try: charset = config["charset"] del config["charset"] except KeyError: charset = None try: collation = config["collation"] del config["collation"] except KeyError: collation = None self._charset_id = CharacterSet.get_charset_info(charset, collation)[0] # Set converter class try: self.set_converter_class(config["converter_class"]) except KeyError: pass # Using default converter class except TypeError as err: raise AttributeError( "Converter class should be a subclass of " "conversion.MySQLConverterBase" ) from err # Compatible configuration with other drivers compat_map = [ # (,) ("db", "database"), ("username", "user"), ("passwd", "password"), ("connect_timeout", "connection_timeout"), ("read_default_file", "option_files"), ] for compat, translate in compat_map: try: if translate not in config: config[translate] = config[compat] del config[compat] except KeyError: pass # Missing compat argument is OK # Configure login information if "user" in config or "password" in config: try: user = config["user"] del config["user"] except KeyError: user = self._user try: password = config["password"] del config["password"] except KeyError: password = self._password self.set_login(user, password) # Configure host information if "host" in config and config["host"]: self._host = config["host"] # Check network locations try: self._port = int(config["port"]) del config["port"] except KeyError: pass # Missing port argument is OK except ValueError as err: raise InterfaceError("TCP/IP port number should be an integer") from err if "ssl_disabled" in config: self._ssl_disabled = config.pop("ssl_disabled") # If an init_command is set, keep it, so we can execute it in _post_connection if "init_command" in config: self._init_command = config["init_command"] del config["init_command"] # Other configuration set_ssl_flag = False for key, value in config.items(): try: DEFAULT_CONFIGURATION[key] except KeyError: raise AttributeError(f"Unsupported argument '{key}'") from None # SSL Configuration if key.startswith("ssl_"): set_ssl_flag = True self._ssl.update({key.replace("ssl_", ""): value}) elif key.startswith("tls_"): set_ssl_flag = True self._ssl.update({key: value}) else: attribute = "_" + key try: setattr(self, attribute, value.strip()) except AttributeError: setattr(self, attribute, value) # Disable SSL for unix socket connections if self._unix_socket and os.name == "posix": self._ssl_disabled = True if self._ssl_disabled and self._auth_plugin == "mysql_clear_password": raise InterfaceError( "Clear password authentication is not supported over insecure channels" ) if set_ssl_flag: if "verify_cert" not in self._ssl: self._ssl["verify_cert"] = DEFAULT_CONFIGURATION["ssl_verify_cert"] if "verify_identity" not in self._ssl: self._ssl["verify_identity"] = DEFAULT_CONFIGURATION[ "ssl_verify_identity" ] # Make sure both ssl_key/ssl_cert are set, or neither (XOR) if "ca" not in self._ssl or self._ssl["ca"] is None: self._ssl["ca"] = "" if bool("key" in self._ssl) != bool("cert" in self._ssl): raise AttributeError( "ssl_key and ssl_cert need to be both specified, or neither" ) # Make sure key/cert are set to None if not set(("key", "cert")) <= set(self._ssl): self._ssl["key"] = None self._ssl["cert"] = None elif (self._ssl["key"] is None) != (self._ssl["cert"] is None): raise AttributeError( "ssl_key and ssl_cert need to be both set, or neither" ) if "tls_versions" in self._ssl and self._ssl["tls_versions"] is not None: self._validate_tls_versions() if ( "tls_ciphersuites" in self._ssl and self._ssl["tls_ciphersuites"] is not None ): self._validate_tls_ciphersuites() if self._conn_attrs is None: self._conn_attrs = {} elif not isinstance(self._conn_attrs, dict): raise InterfaceError("conn_attrs must be of type dict") else: for attr_name, attr_value in self._conn_attrs.items(): if attr_name in CONN_ATTRS_DN: continue # Validate name type if not isinstance(attr_name, str): raise InterfaceError( "Attribute name should be a string, found: " f"'{attr_name}' in '{self._conn_attrs}'" ) # Validate attribute name limit 32 characters if len(attr_name) > 32: raise InterfaceError( f"Attribute name '{attr_name}' exceeds 32 characters limit size" ) # Validate names in connection attributes cannot start with "_" if attr_name.startswith("_"): raise InterfaceError( "Key names in connection attributes cannot start with " "'_', found: '{attr_name}'" ) # Validate value type if not isinstance(attr_value, str): raise InterfaceError( f"Attribute '{attr_name}' value: '{attr_value}' must " "be a string type" ) # Validate attribute value limit 1024 characters if len(attr_value) > 1024: raise InterfaceError( f"Attribute '{attr_name}' value: '{attr_value}' " "exceeds 1024 characters limit size" ) if self._client_flags & ClientFlag.CONNECT_ARGS: self._add_default_conn_attrs() if "kerberos_auth_mode" in config and config["kerberos_auth_mode"] is not None: if not isinstance(config["kerberos_auth_mode"], str): raise InterfaceError("'kerberos_auth_mode' must be of type str") kerberos_auth_mode = config["kerberos_auth_mode"].lower() if kerberos_auth_mode == "sspi": if os.name != "nt": raise InterfaceError( "'kerberos_auth_mode=SSPI' is only available on Windows" ) self._auth_plugin_class = "MySQLSSPIKerberosAuthPlugin" elif kerberos_auth_mode == "gssapi": self._auth_plugin_class = "MySQLKerberosAuthPlugin" else: raise InterfaceError( "Invalid 'kerberos_auth_mode' mode. Please use 'SSPI' or 'GSSAPI'" ) if ( "krb_service_principal" in config and config["krb_service_principal"] is not None ): self._krb_service_principal = config["krb_service_principal"] if not isinstance(self._krb_service_principal, str): raise InterfaceError( KRB_SERVICE_PINCIPAL_ERROR.format(error="is not a string") ) if self._krb_service_principal == "": raise InterfaceError( KRB_SERVICE_PINCIPAL_ERROR.format( error="can not be an empty string" ) ) if "/" not in self._krb_service_principal: raise InterfaceError( KRB_SERVICE_PINCIPAL_ERROR.format(error="is incorrectly formatted") ) if self._fido_callback: # Import the callable if it's a str if isinstance(self._fido_callback, str): try: module, callback = self._fido_callback.rsplit(".", 1) except ValueError: raise ProgrammingError( f"No callable named '{self._fido_callback}'" ) from None try: module = importlib.import_module(module) self._fido_callback = getattr(module, callback) except (AttributeError, ModuleNotFoundError) as err: raise ProgrammingError(f"{err}") from err # Check if it's a callable if not callable(self._fido_callback): raise ProgrammingError("Expected a callable for 'fido_callback'") # Check the callable signature if has only 1 positional argument params = len(signature(self._fido_callback).parameters) if params != 1: raise ProgrammingError( "'fido_callback' requires 1 positional argument, but the " f"callback provided has {params}" ) def _add_default_conn_attrs(self) -> Any: """Add the default connection attributes.""" @staticmethod def _check_server_version(server_version: StrOrBytes) -> Tuple[int, ...]: """Check the MySQL version This method will check the MySQL version and raise an InterfaceError when it is not supported or invalid. It will return the version as a tuple with major, minor and patch. Raises InterfaceError if invalid server version. Returns tuple """ if isinstance(server_version, (bytearray, bytes)): server_version = server_version.decode() regex_ver = re.compile(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)") match = regex_ver.match(server_version) if not match: raise InterfaceError("Failed parsing MySQL version") version = tuple(int(v) for v in match.groups()[0:3]) if version < (4, 1): raise InterfaceError(f"MySQL Version '{server_version}' is not supported") return version def get_server_version(self) -> Tuple[int, ...]: """Get the MySQL version This method returns the MySQL server version as a tuple. If not previously connected, it will return None. Returns a tuple or None. """ return self._server_version def get_server_info(self) -> Optional[str]: """Get the original MySQL version information This method returns the original MySQL server as text. If not previously connected, it will return None. Returns a string or None. """ try: return self._handshake["server_version_original"] # type: ignore[return-value] except (TypeError, KeyError): return None @property @abstractmethod def in_transaction(self) -> Any: """MySQL session has started a transaction""" def set_client_flags(self, flags: Union[int, Sequence[int]]) -> int: """Set the client flags The flags-argument can be either an int or a list (or tuple) of ClientFlag-values. If it is an integer, it will set client_flags to flags as is. If flags is a list (or tuple), each flag will be set or unset when it's negative. set_client_flags([ClientFlag.FOUND_ROWS,-ClientFlag.LONG_FLAG]) Raises ProgrammingError when the flags argument is not a set or an integer bigger than 0. Returns self.client_flags """ if isinstance(flags, int) and flags > 0: self._client_flags = flags elif isinstance(flags, (tuple, list)): for flag in flags: if flag < 0: self._client_flags &= ~abs(flag) else: self._client_flags |= flag else: raise ProgrammingError("set_client_flags expect integer (>0) or set") return self._client_flags def isset_client_flag(self, flag: int) -> bool: """Check if a client flag is set""" if (self._client_flags & flag) > 0: return True return False @property def time_zone(self) -> str: """Get the current time zone""" return self.info_query("SELECT @@session.time_zone")[0] @time_zone.setter def time_zone(self, value: str) -> None: """Set the time zone""" self.cmd_query(f"SET @@session.time_zone = '{value}'") self._time_zone = value @property def sql_mode(self) -> str: """Get the SQL mode""" if self._sql_mode is None: self._sql_mode = self.info_query("SELECT @@session.sql_mode")[0] return self._sql_mode @sql_mode.setter def sql_mode(self, value: Union[str, Sequence[int]]) -> None: """Set the SQL mode This method sets the SQL Mode for the current connection. The value argument can be either a string with comma separate mode names, or a sequence of mode names. It is good practice to use the constants class SQLMode: from mysql.connector.constants import SQLMode cnx.sql_mode = [SQLMode.NO_ZERO_DATE, SQLMode.REAL_AS_FLOAT] """ if isinstance(value, (list, tuple)): value = ",".join(value) self.cmd_query(f"SET @@session.sql_mode = '{value}'") self._sql_mode = value @abstractmethod def info_query(self, query: Any) -> Any: """Send a query which only returns 1 row""" def set_login( self, username: Optional[str] = None, password: Optional[str] = None ) -> None: """Set login information for MySQL Set the username and/or password for the user connecting to the MySQL Server. """ if username is not None: self._user = username.strip() else: self._user = "" if password is not None: self._password = password else: self._password = "" def set_unicode(self, value: bool = True) -> None: """Toggle unicode mode Set whether we return string fields as unicode or not. Default is True. """ self._use_unicode = value if self.converter: self.converter.set_unicode(value) @property def autocommit(self) -> bool: """Get whether autocommit is on or off""" value = self.info_query("SELECT @@session.autocommit")[0] return value == 1 @autocommit.setter def autocommit(self, value: bool) -> None: """Toggle autocommit""" switch = "ON" if value else "OFF" self.cmd_query(f"SET @@session.autocommit = {switch}") self._autocommit = value @property def get_warnings(self) -> bool: """Get whether this connection retrieves warnings automatically This method returns whether this connection retrieves warnings automatically. Returns True, or False when warnings are not retrieved. """ return self._get_warnings @get_warnings.setter def get_warnings(self, value: bool) -> None: """Set whether warnings should be automatically retrieved The toggle-argument must be a boolean. When True, cursors for this connection will retrieve information about warnings (if any). Raises ValueError on error. """ if not isinstance(value, bool): raise ValueError("Expected a boolean type") self._get_warnings = value @property def raise_on_warnings(self) -> bool: """Get whether this connection raises an error on warnings This method returns whether this connection will raise errors when MySQL reports warnings. Returns True or False. """ return self._raise_on_warnings @raise_on_warnings.setter def raise_on_warnings(self, value: bool) -> None: """Set whether warnings raise an error The toggle-argument must be a boolean. When True, cursors for this connection will raise an error when MySQL reports warnings. Raising on warnings implies retrieving warnings automatically. In other words: warnings will be set to True. If set to False, warnings will be also set to False. Raises ValueError on error. """ if not isinstance(value, bool): raise ValueError("Expected a boolean type") self._raise_on_warnings = value # Don't disable warning retrieval if raising explicitly disabled if value: self._get_warnings = value @property def unread_result(self) -> bool: """Get whether there is an unread result This method is used by cursors to check whether another cursor still needs to retrieve its result set. Returns True, or False when there is no unread result. """ return self._unread_result @unread_result.setter def unread_result(self, value: bool) -> None: """Set whether there is an unread result This method is used by cursors to let other cursors know there is still a result set that needs to be retrieved. Raises ValueError on errors. """ if not isinstance(value, bool): raise ValueError("Expected a boolean type") self._unread_result = value @property def charset(self) -> str: """Returns the character set for current connection This property returns the character set name of the current connection. The server is queried when the connection is active. If not connected, the configured character set name is returned. Returns a string. """ return CharacterSet.get_info(self._charset_id)[0] @property def python_charset(self) -> str: """Returns the Python character set for current connection This property returns the character set name of the current connection. Note that, unlike property charset, this checks if the previously set character set is supported by Python and if not, it returns the equivalent character set that Python supports. Returns a string. """ encoding = CharacterSet.get_info(self._charset_id)[0] if encoding in ("utf8mb4", "utf8mb3", "binary"): return "utf8" return encoding def set_charset_collation( self, charset: Optional[Union[int, str]] = None, collation: Optional[str] = None ) -> None: """Sets the character set and collation for the current connection This method sets the character set and collation to be used for the current connection. The charset argument can be either the name of a character set as a string, or the numerical equivalent as defined in constants.CharacterSet. When the collation is not given, the default will be looked up and used. For example, the following will set the collation for the latin1 character set to latin1_general_ci: set_charset('latin1','latin1_general_ci') """ err_msg = "{} should be either integer, string or None" if not isinstance(charset, (int, str)) and charset is not None: raise ValueError(err_msg.format("charset")) if not isinstance(collation, str) and collation is not None: raise ValueError("collation should be either string or None") if charset: if isinstance(charset, int): ( self._charset_id, charset_name, collation_name, ) = CharacterSet.get_charset_info(charset) elif isinstance(charset, str): ( self._charset_id, charset_name, collation_name, ) = CharacterSet.get_charset_info(charset, collation) else: raise ValueError(err_msg.format("charset")) elif collation: ( self._charset_id, charset_name, collation_name, ) = CharacterSet.get_charset_info(collation=collation) else: charset = DEFAULT_CONFIGURATION["charset"] ( self._charset_id, charset_name, collation_name, ) = CharacterSet.get_charset_info(charset, collation=None) self._execute_query(f"SET NAMES '{charset_name}' COLLATE '{collation_name}'") try: # Required for C Extension self.set_character_set_name(charset_name) except AttributeError: # Not required for pure Python connection pass if self.converter: self.converter.set_charset(charset_name) @property def collation(self) -> str: """Returns the collation for current connection This property returns the collation name of the current connection. The server is queried when the connection is active. If not connected, the configured collation name is returned. Returns a string. """ return CharacterSet.get_charset_info(self._charset_id)[2] @abstractmethod def _do_handshake(self) -> Any: """Gather information of the MySQL server before authentication""" @abstractmethod def _open_connection(self) -> Any: """Open the connection to the MySQL server""" def _post_connection(self) -> None: """Executes commands after connection has been established This method executes commands after the connection has been established. Some setting like autocommit, character set, and SQL mode are set using this method. """ self.set_charset_collation(self._charset_id) self.autocommit = self._autocommit if self._time_zone: self.time_zone = self._time_zone if self._sql_mode: self.sql_mode = self._sql_mode if self._init_command: self._execute_query(self._init_command) @abstractmethod def disconnect(self) -> Any: """Disconnect from the MySQL server""" close: Callable[[], Any] = disconnect def connect(self, **kwargs: Any) -> None: """Connect to the MySQL server This method sets up the connection to the MySQL server. If no arguments are given, it will use the already configured or default values. """ if kwargs: self.config(**kwargs) self.disconnect() self._open_connection() # Server does not allow to run any other statement different from ALTER # when user's password has been expired. if not self._client_flags & ClientFlag.CAN_HANDLE_EXPIRED_PASSWORDS: self._post_connection() def reconnect(self, attempts: int = 1, delay: int = 0) -> None: """Attempt to reconnect to the MySQL server The argument attempts should be the number of times a reconnect is tried. The delay argument is the number of seconds to wait between each retry. You may want to set the number of attempts higher and use delay when you expect the MySQL server to be down for maintenance or when you expect the network to be temporary unavailable. Raises InterfaceError on errors. """ counter = 0 while counter != attempts: counter = counter + 1 try: self.disconnect() self.connect() if self.is_connected(): break except (Error, IOError) as err: if counter == attempts: msg = ( f"Can not reconnect to MySQL after {attempts} " f"attempt(s): {err}" ) raise InterfaceError(msg) from err if delay > 0: sleep(delay) @abstractmethod def is_connected(self) -> Any: """Reports whether the connection to MySQL Server is available""" @abstractmethod def ping(self, reconnect: bool = False, attempts: int = 1, delay: int = 0) -> Any: """Check availability of the MySQL server""" @abstractmethod def commit(self) -> Any: """Commit current transaction""" @abstractmethod def cursor( self, buffered: Optional[bool] = None, raw: Optional[bool] = None, prepared: Optional[bool] = None, cursor_class: Optional[type] = None, dictionary: Optional[bool] = None, named_tuple: Optional[bool] = None, ) -> "MySQLCursorAbstract": """Instantiates and returns a cursor""" @abstractmethod def _execute_query(self, query: Any) -> Any: """Execute a query""" @abstractmethod def rollback(self) -> Any: """Rollback current transaction""" def start_transaction( self, consistent_snapshot: bool = False, isolation_level: Optional[str] = None, readonly: Optional[bool] = None, ) -> None: """Start a transaction This method explicitly starts a transaction sending the START TRANSACTION statement to the MySQL server. You can optionally set whether there should be a consistent snapshot, which isolation level you need or which access mode i.e. READ ONLY or READ WRITE. For example, to start a transaction with isolation level SERIALIZABLE, you would do the following: >>> cnx = mysql.connector.connect(..) >>> cnx.start_transaction(isolation_level='SERIALIZABLE') Raises ProgrammingError when a transaction is already in progress and when ValueError when isolation_level specifies an Unknown level. """ if self.in_transaction: raise ProgrammingError("Transaction already in progress") if isolation_level: level = isolation_level.strip().replace("-", " ").upper() levels = [ "READ UNCOMMITTED", "READ COMMITTED", "REPEATABLE READ", "SERIALIZABLE", ] if level not in levels: raise ValueError(f'Unknown isolation level "{isolation_level}"') self._execute_query(f"SET TRANSACTION ISOLATION LEVEL {level}") if readonly is not None: if self._server_version < (5, 6, 5): raise ValueError( f"MySQL server version {self._server_version} does not " "support this feature" ) if readonly: access_mode = "READ ONLY" else: access_mode = "READ WRITE" self._execute_query(f"SET TRANSACTION {access_mode}") query = "START TRANSACTION" if consistent_snapshot: query += " WITH CONSISTENT SNAPSHOT" self.cmd_query(query) def reset_session( self, user_variables: Optional[Dict[str, Any]] = None, session_variables: Optional[Dict[str, Any]] = None, ) -> None: """Clears the current active session This method resets the session state, if the MySQL server is 5.7.3 or later active session will be reset without re-authenticating. For other server versions session will be reset by re-authenticating. It is possible to provide a sequence of variables and their values to be set after clearing the session. This is possible for both user defined variables and session variables. This method takes two arguments user_variables and session_variables which are dictionaries. Raises OperationalError if not connected, InternalError if there are unread results and InterfaceError on errors. """ if not self.is_connected(): raise OperationalError("MySQL Connection not available") try: self.cmd_reset_connection() except (NotSupportedError, NotImplementedError): if self._compress: raise NotSupportedError( "Reset session is not supported with compression for " "MySQL server version 5.7.2 or earlier" ) from None self.cmd_change_user( self._user, self._password, self._database, self._charset_id, ) if user_variables or session_variables: cur = self.cursor() if user_variables: for key, value in user_variables.items(): cur.execute(f"SET @`{key}` = {value}") if session_variables: for key, value in session_variables.items(): cur.execute(f"SET SESSION `{key}` = {value}") cur.close() def set_converter_class(self, convclass: Optional[Type[MySQLConverter]]) -> None: """ Set the converter class to be used. This should be a class overloading methods and members of conversion.MySQLConverter. """ if convclass and issubclass(convclass, MySQLConverterBase): charset_name = CharacterSet.get_info(self._charset_id)[0] self._converter_class = convclass self.converter = convclass(charset_name, self._use_unicode) self.converter.str_fallback = self._converter_str_fallback else: raise TypeError( "Converter class should be a subclass of conversion.MySQLConverterBase." ) @abstractmethod def get_rows( self, count: Optional[int] = None, binary: bool = False, columns: Optional[List[DescriptionType]] = None, raw: Optional[bool] = None, prep_stmt: Any = None, ) -> Tuple[List[Any], Optional[Mapping[str, Any]]]: """Get all rows returned by the MySQL server""" def cmd_init_db(self, database: str) -> Optional[Mapping[str, Any]]: """Change the current database""" raise NotImplementedError def cmd_query( self, query: Any, raw: bool = False, buffered: bool = False, raw_as_string: bool = False, ) -> Optional[Mapping[str, Any]]: """Send a query to the MySQL server""" raise NotImplementedError def cmd_query_iter( self, statements: Any ) -> Generator[Mapping[str, Any], None, None]: """Send one or more statements to the MySQL server""" raise NotImplementedError def cmd_refresh(self, options: int) -> Optional[Mapping[str, Any]]: """Send the Refresh command to the MySQL server""" raise NotImplementedError def cmd_quit(self) -> Any: """Close the current connection with the server""" raise NotImplementedError def cmd_shutdown( self, shutdown_type: Optional[int] = None ) -> Optional[Mapping[str, Any]]: """Shut down the MySQL Server""" raise NotImplementedError def cmd_statistics(self) -> Optional[Mapping[str, Any]]: """Send the statistics command to the MySQL Server""" raise NotImplementedError @staticmethod def cmd_process_info() -> Any: """Get the process list of the MySQL Server This method is a placeholder to notify that the PROCESS_INFO command is not supported by raising the NotSupportedError. The command "SHOW PROCESSLIST" should be send using the cmd_query()-method or using the INFORMATION_SCHEMA database. Raises NotSupportedError exception """ raise NotSupportedError( "Not implemented. Use SHOW PROCESSLIST or INFORMATION_SCHEMA" ) def cmd_process_kill(self, mysql_pid: int) -> Optional[Mapping[str, Any]]: """Kill a MySQL process""" raise NotImplementedError def cmd_debug(self) -> Optional[Mapping[str, Any]]: """Send the DEBUG command""" raise NotImplementedError def cmd_ping(self) -> Optional[Mapping[str, Any]]: """Send the PING command""" raise NotImplementedError def cmd_change_user( self, username: str = "", password: str = "", database: str = "", charset: int = 45, password1: str = "", password2: str = "", password3: str = "", oci_config_file: str = "", ) -> Optional[Mapping[str, Any]]: """Change the current logged in user""" raise NotImplementedError def cmd_stmt_prepare(self, statement: Any) -> Optional[Mapping[str, Any]]: """Prepare a MySQL statement""" raise NotImplementedError def cmd_stmt_execute( self, statement_id: Any, data: Sequence[Any] = (), parameters: Sequence[Any] = (), flags: int = 0, ) -> Any: """Execute a prepared MySQL statement""" raise NotImplementedError def cmd_stmt_close(self, statement_id: Any) -> Any: """Deallocate a prepared MySQL statement""" raise NotImplementedError def cmd_stmt_send_long_data( self, statement_id: Any, param_id: int, data: BinaryIO ) -> Any: """Send data for a column""" raise NotImplementedError def cmd_stmt_reset(self, statement_id: Any) -> Any: """Reset data for prepared statement sent as long data""" raise NotImplementedError def cmd_reset_connection(self) -> Any: """Resets the session state without re-authenticating""" raise NotImplementedError class MySQLCursorAbstract(ABC): """Abstract cursor class Abstract class defining cursor class with method and members required by the Python Database API Specification v2.0. """ def __init__(self) -> None: """Initialization""" self._description: Optional[List[DescriptionType]] = None self._rowcount: int = -1 self._last_insert_id: Optional[int] = None self._warnings: Optional[List[WarningType]] = None self._warning_count: int = 0 self._executed: Optional[StrOrBytes] = None self._executed_list: List[StrOrBytes] = [] self._stored_results: List[Any] = [] self.arraysize: int = 1 def __enter__(self) -> MySQLCursorAbstract: return self def __exit__( self, exc_type: Type[BaseException], exc_value: BaseException, traceback: TracebackType, ) -> None: self.close() @abstractmethod def callproc(self, procname: str, args: Sequence[Any] = ()) -> Any: """Calls a stored procedure with the given arguments The arguments will be set during this session, meaning they will be called like ___arg where is an enumeration (+1) of the arguments. Coding Example: 1) Defining the Stored Routine in MySQL: CREATE PROCEDURE multiply(IN pFac1 INT, IN pFac2 INT, OUT pProd INT) BEGIN SET pProd := pFac1 * pFac2; END 2) Executing in Python: args = (5,5,0) # 0 is to hold pprod cursor.callproc('multiply', args) print(cursor.fetchone()) Does not return a value, but a result set will be available when the CALL-statement execute successfully. Raises exceptions when something is wrong. """ @abstractmethod def close(self) -> Any: """Close the cursor.""" @abstractmethod def execute( self, operation: Any, params: Union[Sequence[Any], Dict[str, Any]] = (), multi: bool = False, ) -> Any: """Executes the given operation Executes the given operation substituting any markers with the given parameters. For example, getting all rows where id is 5: cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,)) The multi argument should be set to True when executing multiple statements in one operation. If not set and multiple results are found, an InterfaceError will be raised. If warnings where generated, and connection.get_warnings is True, then self._warnings will be a list containing these warnings. Returns an iterator when multi is True, otherwise None. """ @abstractmethod def executemany( self, operation: Any, seq_params: Sequence[Union[Sequence[Any], Dict[str, Any]]] ) -> Any: """Execute the given operation multiple times The executemany() method will execute the operation iterating over the list of parameters in seq_params. Example: Inserting 3 new employees and their phone number data = [ ('Jane','555-001'), ('Joe', '555-001'), ('John', '555-003') ] stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')" cursor.executemany(stmt, data) INSERT statements are optimized by batching the data, that is using the MySQL multiple rows syntax. Results are discarded. If they are needed, consider looping over data using the execute() method. """ @abstractmethod def fetchone(self) -> Optional[Sequence[Any]]: """Returns next row of a query result set Returns a tuple or None. """ @abstractmethod def fetchmany(self, size: int = 1) -> List[Sequence[Any]]: """Returns the next set of rows of a query result, returning a list of tuples. When no more rows are available, it returns an empty list. The number of rows returned can be specified using the size argument, which defaults to one """ @abstractmethod def fetchall(self) -> Sequence[Any]: """Returns all rows of a query result set Returns a list of tuples. """ def nextset(self) -> Any: """Not Implemented.""" def setinputsizes(self, sizes: Any) -> Any: """Not Implemented.""" def setoutputsize(self, size: Any, column: Any = None) -> Any: """Not Implemented.""" def reset(self, free: bool = True) -> Any: """Reset the cursor to default""" @property @abstractmethod def description( self, ) -> Optional[List[DescriptionType]]: """Returns description of columns in a result This property returns a list of tuples describing the columns in in a result set. A tuple is described as follows:: (column_name, type, None, None, None, None, null_ok, column_flags) # Addition to PEP-249 specs Returns a list of tuples. """ return self._description @property @abstractmethod def rowcount(self) -> int: """Returns the number of rows produced or affected This property returns the number of rows produced by queries such as a SELECT, or affected rows when executing DML statements like INSERT or UPDATE. Note that for non-buffered cursors it is impossible to know the number of rows produced before having fetched them all. For those, the number of rows will be -1 right after execution, and incremented when fetching rows. Returns an integer. """ return self._rowcount @property def lastrowid(self) -> Optional[int]: """Returns the value generated for an AUTO_INCREMENT column Returns the value generated for an AUTO_INCREMENT column by the previous INSERT or UPDATE statement or None when there is no such value available. Returns a long value or None. """ return self._last_insert_id @property def warnings(self) -> Optional[List[WarningType]]: """Return warnings.""" return self._warnings @property def warning_count(self) -> int: """Returns the number of warnings This property returns the number of warnings generated by the previously executed operation. Returns an integer value. """ return self._warning_count def fetchwarnings(self) -> Optional[List[WarningType]]: """Returns Warnings.""" return self._warnings def get_attributes(self) -> Optional[List[Tuple[Any, Any]]]: """Get the added query attributes so far.""" if hasattr(self, "_cnx"): return self._cnx.query_attrs if hasattr(self, "_connection"): return self._connection.query_attrs return None def add_attribute(self, name: str, value: Any) -> None: """Add a query attribute and his value.""" if not isinstance(name, str): raise ProgrammingError("Parameter `name` must be a string type") if value is not None and not isinstance(value, MYSQL_PY_TYPES): raise ProgrammingError( f"Object {value} cannot be converted to a MySQL type" ) if hasattr(self, "_cnx"): self._cnx.query_attrs_append((name, value)) elif hasattr(self, "_connection"): self._connection.query_attrs_append((name, value)) def clear_attributes(self) -> None: """Remove all the query attributes.""" if hasattr(self, "_cnx"): self._cnx.query_attrs_clear() elif hasattr(self, "_connection"): self._connection.query_attrs_clear()