# Copyright (c) 2017, 2022, Oracle and/or its affiliates. All rights reserved. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """This module contains helper functions.""" import binascii import decimal import functools import inspect import warnings from typing import Any, Callable, List, Optional, Union from .constants import TLS_CIPHER_SUITES, TLS_VERSIONS from .errors import InterfaceError from .types import EscapeTypes, StrOrBytes BYTE_TYPES = (bytearray, bytes) NUMERIC_TYPES = (int, float, decimal.Decimal) def encode_to_bytes(value: StrOrBytes, encoding: str = "utf-8") -> bytes: """Returns an encoded version of the string as a bytes object. Args: encoding (str): The encoding. Resturns: bytes: The encoded version of the string as a bytes object. """ return value if isinstance(value, bytes) else value.encode(encoding) def decode_from_bytes(value: StrOrBytes, encoding: str = "utf-8") -> str: """Returns a string decoded from the given bytes. Args: value (bytes): The value to be decoded. encoding (str): The encoding. Returns: str: The value decoded from bytes. """ return value.decode(encoding) if isinstance(value, bytes) else value def get_item_or_attr(obj: object, key: str) -> Any: """Get item from dictionary or attribute from object. Args: obj (object): Dictionary or object. key (str): Key. Returns: object: The object for the provided key. """ return obj[key] if isinstance(obj, dict) else getattr(obj, key) def escape(*args: EscapeTypes) -> Union[EscapeTypes, List[EscapeTypes]]: """Escapes special characters as they are expected to be when MySQL receives them. As found in MySQL source mysys/charset.c Args: value (object): Value to be escaped. Returns: str: The value if not a string, or the escaped string. """ def _escape(value: EscapeTypes) -> EscapeTypes: """Escapes special characters.""" if value is None: return value if isinstance(value, NUMERIC_TYPES): return value if isinstance(value, (bytes, bytearray)): value = value.replace(b"\\", b"\\\\") value = value.replace(b"\n", b"\\n") value = value.replace(b"\r", b"\\r") value = value.replace(b"\047", b"\134\047") # single quotes value = value.replace(b"\042", b"\134\042") # double quotes value = value.replace(b"\032", b"\134\032") # for Win32 else: value = value.replace("\\", "\\\\") value = value.replace("\n", "\\n") value = value.replace("\r", "\\r") value = value.replace("\047", "\134\047") # single quotes value = value.replace("\042", "\134\042") # double quotes value = value.replace("\032", "\134\032") # for Win32 return value if len(args) > 1: return [_escape(arg) for arg in args] return _escape(args[0]) def quote_identifier(identifier: str, sql_mode: str = "") -> str: """Quote the given identifier with backticks, converting backticks (`) in the identifier name with the correct escape sequence (``) unless the identifier is quoted (") as in sql_mode set to ANSI_QUOTES. Args: identifier (str): Identifier to quote. Returns: str: Returns string with the identifier quoted with backticks. """ if sql_mode == "ANSI_QUOTES": quoted = identifier.replace('"', '""') return f'"{quoted}"' quoted = identifier.replace("`", "``") return f"`{quoted}`" def deprecated(version: Optional[str] = None, reason: Optional[str] = None) -> Callable: """This is a decorator used to mark functions as deprecated. Args: version (Optional[string]): Version when was deprecated. reason (Optional[string]): Reason or extra information to be shown. Returns: Callable: A decorator used to mark functions as deprecated. Usage: .. code-block:: python from mysqlx.helpers import deprecated @deprecated('8.0.12', 'Please use other_function() instead') def deprecated_function(x, y): return x + y """ def decorate(func: Callable) -> Callable: """Decorate function.""" @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Callable: """Wrapper function. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ message = [f"'{func.__name__}' is deprecated"] if version: message.append(f" since version {version}") if reason: message.append(f". {reason}") frame = inspect.currentframe().f_back warnings.warn_explicit( "".join(message), category=DeprecationWarning, filename=inspect.getfile(frame.f_code), lineno=frame.f_lineno, ) return func(*args, **kwargs) return wrapper return decorate def iani_to_openssl_cs_name( tls_version: str, cipher_suites_names: List[str] ) -> List[str]: """Translates a cipher suites names list; from IANI names to OpenSSL names. Args: TLS_version (str): The TLS version to look at for a translation. cipher_suite_names (list): A list of cipher suites names. Returns: List[str]: List of translated names. """ translated_names = [] cipher_suites = {} # TLS_CIPHER_SUITES[TLS_version] # Find the previews TLS versions of the given on TLS_version for index in range(TLS_VERSIONS.index(tls_version) + 1): cipher_suites.update(TLS_CIPHER_SUITES[TLS_VERSIONS[index]]) for name in cipher_suites_names: if "-" in name: translated_names.append(name) elif name in cipher_suites: translated_names.append(cipher_suites[name]) else: raise InterfaceError( f"The '{name}' in cipher suites is not a valid cipher suite" ) return translated_names def hexlify(data: bytes) -> str: """Return the hexadecimal representation of the binary data. Args: data (bytes): The binary data. Returns: str: The decoded hexadecimal representation of data. """ return binascii.hexlify(data).decode("utf-8")