import logging import os import socket import typing from collections import deque from copy import deepcopy from datetime import datetime as Datetime from datetime import timedelta as Timedelta from decimal import Decimal from hashlib import md5 from itertools import count from os import getpid from struct import pack from typing import TYPE_CHECKING from warnings import warn from packaging import version from scramp import ScramClient # type: ignore from redshift_connector.config import ( DEFAULT_PROTOCOL_VERSION, ClientProtocolVersion, DbApiParamstyle, _client_encoding, max_int2, max_int4, max_int8, min_int2, min_int4, min_int8, pg_array_types, pg_to_py_encodings, ) from redshift_connector.cursor import Cursor from redshift_connector.error import ( ArrayContentNotHomogenousError, ArrayContentNotSupportedError, DatabaseError, Error, IntegrityError, InterfaceError, InternalError, NotSupportedError, OperationalError, ProgrammingError, Warning, ) from redshift_connector.utils import ( FC_BINARY, FC_TEXT, NULL, NULL_BYTE, DriverInfo, array_check_dimensions, array_dim_lengths, array_find_first_element, array_flatten, array_has_null, array_recv_binary, array_recv_text, bh_unpack, cccc_unpack, ci_unpack, date_in, date_recv_binary, float_array_recv, geographyhex_recv, h_pack, h_unpack, i_pack, i_unpack, ihihih_unpack, ii_pack, iii_pack, int_array_recv, make_divider_block, numeric_in, numeric_in_binary, numeric_to_float_binary, numeric_to_float_in, ) from redshift_connector.utils import py_types as PY_TYPES from redshift_connector.utils import q_pack from redshift_connector.utils import redshift_types as REDSHIFT_TYPES from redshift_connector.utils import ( text_recv, time_in, time_recv_binary, timetz_in, timetz_recv_binary, varbytehex_recv, walk_array, ) from redshift_connector.utils.oids import RedshiftOID if TYPE_CHECKING: from ssl import SSLSocket # Copyright (c) 2007-2009, Mathieu Fenniak # Copyright (c) The Contributors # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # * The name of the author may not be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. __author__ = "Mathieu Fenniak" _logger: logging.Logger = logging.getLogger(__name__) ZERO: Timedelta = Timedelta(0) BINARY: type = bytes # The purpose of this function is to change the placeholder of original query into $1, $2 # in order to be identified by database # example: INSERT INTO book (title) VALUES (:title) -> INSERT INTO book (title) VALUES ($1) # also return the function: make_args() def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]: # I don't see any way to avoid scanning the query string char by char, # so we might as well take that careful approach and create a # state-based scanner. We'll use int variables for the state. OUTSIDE: int = 0 # outside quoted string INSIDE_SQ: int = 1 # inside single-quote string '...' INSIDE_QI: int = 2 # inside quoted identifier "..." INSIDE_ES: int = 3 # inside escaped single-quote string, E'...' INSIDE_PN: int = 4 # inside parameter name eg. :name INSIDE_CO: int = 5 # inside inline comment eg. -- INSIDE_MC: int = 6 # inside multiline comment eg. /* in_quote_escape: bool = False in_param_escape: bool = False placeholders: typing.List[str] = [] output_query: typing.List[str] = [] param_idx: typing.Iterator[str] = map(lambda x: "$" + str(x), count(1)) state: int = OUTSIDE prev_c: typing.Optional[str] = None for i, c in enumerate(query): if i + 1 < len(query): next_c = query[i + 1] else: next_c = None if state == OUTSIDE: if c == "'": output_query.append(c) if prev_c == "E": state = INSIDE_ES else: state = INSIDE_SQ elif c == '"': output_query.append(c) state = INSIDE_QI elif c == "-": output_query.append(c) if prev_c == "-": state = INSIDE_CO elif c == "*": output_query.append(c) if prev_c == "/": state = INSIDE_MC elif style == DbApiParamstyle.QMARK.value and c == "?": output_query.append(next(param_idx)) elif style == DbApiParamstyle.NUMERIC.value and c == ":" and next_c not in ":=" and prev_c != ":": # Treat : as beginning of parameter name if and only # if it's the only : around # Needed to properly process type conversions # i.e. sum(x)::float output_query.append("$") elif style == DbApiParamstyle.NAMED.value and c == ":" and next_c not in ":=" and prev_c != ":": # Same logic for : as in numeric parameters state = INSIDE_PN placeholders.append("") elif style == DbApiParamstyle.PYFORMAT.value and c == "%" and next_c == "(": state = INSIDE_PN placeholders.append("") elif style in (DbApiParamstyle.FORMAT.value, DbApiParamstyle.PYFORMAT.value) and c == "%": style = DbApiParamstyle.FORMAT.value if in_param_escape: in_param_escape = False output_query.append(c) else: if next_c == "%": in_param_escape = True elif next_c == "s": state = INSIDE_PN output_query.append(next(param_idx)) else: raise InterfaceError("Only %s and %% are supported in the query.") else: output_query.append(c) elif state == INSIDE_SQ: if c == "'": if in_quote_escape: in_quote_escape = False else: if next_c == "'": in_quote_escape = True else: state = OUTSIDE output_query.append(c) elif state == INSIDE_QI: if c == '"': state = OUTSIDE output_query.append(c) elif state == INSIDE_ES: if c == "'" and prev_c != "\\": # check for escaped single-quote state = OUTSIDE output_query.append(c) elif state == INSIDE_PN: if style == DbApiParamstyle.NAMED.value: placeholders[-1] += c if next_c is None or (not next_c.isalnum() and next_c != "_"): state = OUTSIDE try: pidx: int = placeholders.index(placeholders[-1], 0, -1) output_query.append("$" + str(pidx + 1)) del placeholders[-1] except ValueError: output_query.append("$" + str(len(placeholders))) elif style == DbApiParamstyle.PYFORMAT.value: if prev_c == ")" and c == "s": state = OUTSIDE try: pidx = placeholders.index(placeholders[-1], 0, -1) output_query.append("$" + str(pidx + 1)) del placeholders[-1] except ValueError: output_query.append("$" + str(len(placeholders))) elif c in "()": pass else: placeholders[-1] += c elif style == DbApiParamstyle.FORMAT.value: state = OUTSIDE elif state == INSIDE_CO: output_query.append(c) if c == "\n": state = OUTSIDE elif state == INSIDE_MC: output_query.append(c) if c == "/" and prev_c == "*": state = OUTSIDE prev_c = c if style in (DbApiParamstyle.NUMERIC.value, DbApiParamstyle.QMARK.value, DbApiParamstyle.FORMAT.value): def make_args(vals): return vals else: def make_args(vals): return tuple(vals[p] for p in placeholders) return "".join(output_query), make_args # Message codes # ALl communication is through a stream of messages # Driver will send one or more messages to database, # and database will respond one or more messages # The first byte of a message specify the type of the message NOTICE_RESPONSE: bytes = b"N" AUTHENTICATION_REQUEST: bytes = b"R" PARAMETER_STATUS: bytes = b"S" BACKEND_KEY_DATA: bytes = b"K" READY_FOR_QUERY: bytes = b"Z" ROW_DESCRIPTION: bytes = b"T" ERROR_RESPONSE: bytes = b"E" DATA_ROW: bytes = b"D" COMMAND_COMPLETE: bytes = b"C" PARSE_COMPLETE: bytes = b"1" BIND_COMPLETE: bytes = b"2" CLOSE_COMPLETE: bytes = b"3" PORTAL_SUSPENDED: bytes = b"s" NO_DATA: bytes = b"n" PARAMETER_DESCRIPTION: bytes = b"t" NOTIFICATION_RESPONSE: bytes = b"A" COPY_DONE: bytes = b"c" COPY_DATA: bytes = b"d" COPY_IN_RESPONSE: bytes = b"G" COPY_OUT_RESPONSE: bytes = b"H" EMPTY_QUERY_RESPONSE: bytes = b"I" BIND: bytes = b"B" PARSE: bytes = b"P" EXECUTE: bytes = b"E" FLUSH: bytes = b"H" SYNC: bytes = b"S" PASSWORD: bytes = b"p" DESCRIBE: bytes = b"D" TERMINATE: bytes = b"X" CLOSE: bytes = b"C" # This inform the format of a message # the first byte, the code, will be the type of the message # then add the 4 bytes to inform the length of rest of message # then add the real data we want to send def create_message(code: bytes, data: bytes = b"") -> bytes: return code + typing.cast(bytes, i_pack(len(data) + 4)) + data FLUSH_MSG: bytes = create_message(FLUSH) SYNC_MSG: bytes = create_message(SYNC) TERMINATE_MSG: bytes = create_message(TERMINATE) COPY_DONE_MSG: bytes = create_message(COPY_DONE) EXECUTE_MSG: bytes = create_message(EXECUTE, NULL_BYTE + i_pack(0)) # DESCRIBE constants STATEMENT: bytes = b"S" PORTAL: bytes = b"P" # ErrorResponse codes RESPONSE_SEVERITY: str = "S" # always present RESPONSE_SEVERITY = "V" # always present RESPONSE_CODE: str = "C" # always present RESPONSE_MSG: str = "M" # always present RESPONSE_DETAIL: str = "D" RESPONSE_HINT: str = "H" RESPONSE_POSITION: str = "P" RESPONSE__POSITION: str = "p" RESPONSE__QUERY: str = "q" RESPONSE_WHERE: str = "W" RESPONSE_FILE: str = "F" RESPONSE_LINE: str = "L" RESPONSE_ROUTINE: str = "R" IDLE: bytes = b"I" IDLE_IN_TRANSACTION: bytes = b"T" IDLE_IN_FAILED_TRANSACTION: bytes = b"E" arr_trans: typing.Mapping[int, typing.Optional[str]] = dict(zip(map(ord, "[] 'u"), ["{", "}", None, None, None])) class Connection: # DBAPI Extension: supply exceptions as attributes on the connection Warning = property(lambda self: self._getError(Warning)) Error = property(lambda self: self._getError(Error)) InterfaceError = property(lambda self: self._getError(InterfaceError)) DatabaseError = property(lambda self: self._getError(DatabaseError)) OperationalError = property(lambda self: self._getError(OperationalError)) IntegrityError = property(lambda self: self._getError(IntegrityError)) InternalError = property(lambda self: self._getError(InternalError)) ProgrammingError = property(lambda self: self._getError(ProgrammingError)) NotSupportedError = property(lambda self: self._getError(NotSupportedError)) def __enter__(self: "Connection") -> "Connection": return self def __exit__(self: "Connection", exc_type, exc_value, traceback) -> None: self.close() def _getError(self: "Connection", error): warn("DB-API extension connection.%s used" % error.__name__, stacklevel=3) return error @property def client_os_version(self: "Connection") -> str: from platform import platform as CLIENT_PLATFORM try: os_version: str = CLIENT_PLATFORM() except: os_version = "unknown" return os_version @staticmethod def __get_host_address_info(host: str, port: int): """ Returns IPv4 address and port given a host name and port """ # https://docs.python.org/3/library/socket.html#socket.getaddrinfo response = socket.getaddrinfo(host=host, port=port, family=socket.AF_INET) _logger.debug("getaddrinfo response {}".format(response)) if not response: raise InterfaceError("Unable to determine ip for host {} port {}".format(host, port)) return response[0][4] def __init__( self: "Connection", user: str, password: str, database: str, host: str = "localhost", port: int = 5439, source_address: typing.Optional[str] = None, unix_sock: typing.Optional[str] = None, ssl: bool = True, sslmode: str = "verify-ca", timeout: typing.Optional[int] = None, max_prepared_statements: int = 1000, tcp_keepalive: typing.Optional[bool] = True, application_name: typing.Optional[str] = None, replication: typing.Optional[str] = None, client_protocol_version: int = DEFAULT_PROTOCOL_VERSION, database_metadata_current_db_only: bool = True, credentials_provider: typing.Optional[str] = None, provider_name: typing.Optional[str] = None, web_identity_token: typing.Optional[str] = None, numeric_to_float: bool = False, ): """ Creates a :class:`Connection` to an Amazon Redshift cluster. For more information on establishing a connection to an Amazon Redshift cluster using `federated API access `_ see our examples page. This is the underlying :class:`Connection` constructor called from :func:`redshift_connector.connect`. Parameters ---------- user : str The username to use for authentication with the Amazon Redshift cluster. password : str The password to use for authentication with the Amazon Redshift cluster. database : str The name of the database instance to connect to. host : str The hostname of the Amazon Redshift cluster. port : int The port number of the Amazon Redshift cluster. Default value is 5439. source_address : Optional[str] unix_sock : Optional[str] ssl : bool Is SSL enabled. Default value is ``True``. SSL must be enabled when authenticating using IAM. sslmode : str The security of the connection to the Amazon Redshift cluster. 'verify-ca' and 'verify-full' are supported. timeout : Optional[int] The number of seconds before the connection to the server will timeout. By default there is no timeout. max_prepared_statements : int tcp_keepalive : Optional[bool] Is `TCP keepalive `_ used. The default value is ``True``. application_name : Optional[str] Sets the application name. The default value is None. replication : Optional[str] Used to run in `streaming replication mode `_. client_protocol_version : int The requested server protocol version. The default value is 1 representing `EXTENDED_RESULT_METADATA`. If the requested server protocol cannot be satisfied, a warning will be displayed to the user. database_metadata_current_db_only : bool Is `datashare `_ disabled. Default value is True, implying datasharing will not be used. credentials_provider : Optional[str] The class-path of the IdP plugin used for authentication with Amazon Redshift. provider_name : Optional[str] The name of the Redshift Native Auth Provider. web_identity_token: Optional[str] A web identity token used for authentication via Redshift Native IDP Integration numeric_to_float: bool Specifies if NUMERIC datatype values will be converted from ``decimal.Decimal`` to ``float``. By default NUMERIC values are received as ``decimal.Decimal``. """ self.merge_socket_read = True _client_encoding = "utf8" self._commands_with_count: typing.Tuple[bytes, ...] = ( b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY", b"SELECT", ) self.notifications: deque = deque(maxlen=100) self.notices: deque = deque(maxlen=100) self.parameter_statuses: deque = deque(maxlen=100) self.max_prepared_statements: int = int(max_prepared_statements) self._run_cursor: Cursor = Cursor(self, paramstyle=DbApiParamstyle.NAMED.value) self._client_protocol_version: int = client_protocol_version self._database = database self.py_types = deepcopy(PY_TYPES) self.redshift_types = deepcopy(REDSHIFT_TYPES) self._database_metadata_current_db_only: bool = database_metadata_current_db_only self.numeric_to_float: bool = numeric_to_float # based on _client_protocol_version value, we must use different conversion functions # for receiving some datatypes self._enable_protocol_based_conversion_funcs() self.web_identity_token = web_identity_token if user is None: raise InterfaceError("The 'user' connection parameter cannot be None") redshift_native_auth: bool = False if application_name is None or application_name == "": def get_calling_module() -> str: import inspect module_name: str = "" stack: typing.List[inspect.FrameInfo] = inspect.stack() try: # get_calling_module -> init -> connect -> init -> calling module start: int = min(4, len(stack) - 1) parent = stack[start][0] calling_module = inspect.getmodule(parent) if calling_module: module_name = calling_module.__name__ except: pass finally: del parent del stack return module_name application_name = get_calling_module() init_params: typing.Dict[str, typing.Optional[typing.Union[str, bytes]]] = { "user": "", "database": database, "application_name": application_name, "replication": replication, "client_protocol_version": str(self._client_protocol_version), "driver_version": DriverInfo.driver_full_name(), "os_version": self.client_os_version, } if credentials_provider: init_params["plugin_name"] = credentials_provider if credentials_provider.split(".")[-1] in ( "BasicJwtCredentialsProvider", "BrowserAzureOAuth2CredentialsProvider", ): redshift_native_auth = True init_params["idp_type"] = "AzureAD" if provider_name: init_params["provider_name"] = provider_name if not redshift_native_auth or user: init_params["user"] = user _logger.debug(make_divider_block()) _logger.debug("Establishing a connection") _logger.debug(init_params) _logger.debug(make_divider_block()) for k, v in tuple(init_params.items()): if isinstance(v, str): init_params[k] = v.encode("utf8") elif v is None: del init_params[k] elif not isinstance(v, (bytes, bytearray)): raise InterfaceError("The parameter " + k + " can't be of type " + str(type(v)) + ".") if "user" in init_params: self.user: bytes = typing.cast(bytes, init_params["user"]) else: self.user = b"" if isinstance(password, str): self.password: bytes = password.encode("utf8") else: self.password = password self.autocommit: bool = False self._xid = None self._caches: typing.Dict = {} # Create the TCP/Ip socket and connect to specific database # if there already has a socket, it will not create new connection when run connect again try: if unix_sock is None and host is not None: self._usock: typing.Union[socket.socket, "SSLSocket"] = socket.socket( socket.AF_INET, socket.SOCK_STREAM ) if source_address is not None: self._usock.bind((source_address, 0)) elif unix_sock is not None: if not hasattr(socket, "AF_UNIX"): raise InterfaceError("attempt to connect to unix socket on unsupported " "platform") self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: raise ProgrammingError("one of host or unix_sock must be provided") if timeout is not None: self._usock.settimeout(timeout) if unix_sock is None and host is not None: hostport: typing.Tuple[str, int] = Connection.__get_host_address_info(host, port) _logger.debug( "Attempting to create connection socket with address {} {}".format(hostport[0], str(hostport[1])) ) self._usock.connect(hostport) elif unix_sock is not None: self._usock.connect(unix_sock) # For Redshift, we the default ssl approve is True # create ssl connection with Redshift CA certificates and check the hostname if ssl is True: try: from ssl import CERT_REQUIRED, SSLContext # ssl_context = ssl.create_default_context() path = os.path.abspath(__file__) if os.name == "nt": path = "\\".join(path.split("\\")[:-1]) + "\\files\\redshift-ca-bundle.crt" else: path = "/".join(path.split("/")[:-1]) + "/files/redshift-ca-bundle.crt" ssl_context: SSLContext = SSLContext() ssl_context.verify_mode = CERT_REQUIRED ssl_context.load_default_certs() ssl_context.load_verify_locations(path) # Int32(8) - Message length, including self. # Int32(80877103) - The SSL request code. self._usock.sendall(ii_pack(8, 80877103)) resp: bytes = self._usock.recv(1) if resp != b"S": _logger.debug( "Server response code when attempting to establish ssl connection: {!r}".format(resp) ) raise InterfaceError("Server refuses SSL") if sslmode == "verify-ca": self._usock = ssl_context.wrap_socket(self._usock) elif sslmode == "verify-full": ssl_context.check_hostname = True self._usock = ssl_context.wrap_socket(self._usock, server_hostname=host) except ImportError: raise InterfaceError("SSL required but ssl module not available in " "this python installation") self._sock: typing.Optional[typing.BinaryIO] = self._usock.makefile(mode="rwb") if tcp_keepalive: self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) except socket.error as e: self._usock.close() raise InterfaceError("communication error", e) self._flush: typing.Callable = self._sock.flush self._read: typing.Callable = self._sock.read self._write: typing.Callable = self._sock.write self._backend_key_data: typing.Optional[bytes] = None trans_tab = dict(zip(map(ord, "{}"), "[]")) glbls = {"Decimal": Decimal} self.inspect_funcs: typing.Dict[type, typing.Callable] = { Datetime: self.inspect_datetime, list: self.array_inspect, tuple: self.array_inspect, int: self.inspect_int, } # it's a dictionary whose key is type of message, # value is the corresponding function to process message self.message_types: typing.Dict[bytes, typing.Callable] = { NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE, AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST, PARAMETER_STATUS: self.handle_PARAMETER_STATUS, BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA, READY_FOR_QUERY: self.handle_READY_FOR_QUERY, ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION, ERROR_RESPONSE: self.handle_ERROR_RESPONSE, EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE, DATA_ROW: self.handle_DATA_ROW, COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE, PARSE_COMPLETE: self.handle_PARSE_COMPLETE, BIND_COMPLETE: self.handle_BIND_COMPLETE, CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE, PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED, NO_DATA: self.handle_NO_DATA, PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION, NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE, COPY_DONE: self.handle_COPY_DONE, COPY_DATA: self.handle_COPY_DATA, COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE, COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE, } # Int32 - Message length, including self. # Int32(196608) - Protocol version number. Version 3.0. # Any number of key/value pairs, terminated by a zero byte: # String - A parameter name (user, database, or options) # String - Parameter value # Conduct start-up communication with database # Message's first part is the protocol version - Int32(196608) protocol: int = 196608 val: bytearray = bytearray(i_pack(protocol)) # Message include parameters name and value (user, database, application_name, replication) for k, v in init_params.items(): val.extend(k.encode("ascii") + NULL_BYTE + typing.cast(bytes, v) + NULL_BYTE) val.append(0) # Use write and flush function to write the content of the buffer # and then send the message to the database self._write(i_pack(len(val) + 4)) self._write(val) self._flush() self._cursor: Cursor = self.cursor() code = None self.error: typing.Optional[Exception] = None _logger.debug("Sending start-up message") # When driver send the start-up message to database, DB will respond multi messages to driver # whose format is same with the message that driver send to DB. while code not in (READY_FOR_QUERY, ERROR_RESPONSE): # Thus use a loop to process each message # Each time will read 5 bytes, the first byte, the code, inform the type of message # following 4 bytes inform the message's length # then can use this length to minus 4 to get the real data. code, data_len = ci_unpack(self._read(5)) self.message_types[code](self._read(data_len - 4), None) if self.error is not None: raise self.error # if we didn't receive a server_protocol_version from the server, default to # using BASE_SERVER as the server is likely lacking this functionality due to # being out of date if ( self._client_protocol_version > ClientProtocolVersion.BASE_SERVER and not (b"server_protocol_version", str(self._client_protocol_version).encode()) in self.parameter_statuses ): _logger.debug("Server_protocol_version not received from server") self._client_protocol_version = ClientProtocolVersion.BASE_SERVER self._enable_protocol_based_conversion_funcs() self.in_transaction = False def _enable_protocol_based_conversion_funcs(self: "Connection"): if self._client_protocol_version >= ClientProtocolVersion.BINARY.value: self.redshift_types[RedshiftOID.NUMERIC] = (FC_BINARY, numeric_in_binary) self.redshift_types[RedshiftOID.DATE] = (FC_BINARY, date_recv_binary) self.redshift_types[RedshiftOID.GEOGRAPHY] = (FC_BINARY, geographyhex_recv) # GEOGRAPHY self.redshift_types[RedshiftOID.TIME] = (FC_BINARY, time_recv_binary) self.redshift_types[RedshiftOID.TIMETZ] = (FC_BINARY, timetz_recv_binary) self.redshift_types[RedshiftOID.CHAR_ARRAY] = (FC_BINARY, array_recv_binary) # CHAR[] self.redshift_types[RedshiftOID.SMALLINT_ARRAY] = (FC_BINARY, array_recv_binary) # INT2[] self.redshift_types[RedshiftOID.INTEGER_ARRAY] = (FC_BINARY, array_recv_binary) # INT4[] self.redshift_types[RedshiftOID.TEXT_ARRAY] = (FC_BINARY, array_recv_binary) # TEXT[] self.redshift_types[RedshiftOID.VARCHAR_ARRAY] = (FC_BINARY, array_recv_binary) # VARCHAR[] self.redshift_types[RedshiftOID.REAL_ARRAY] = (FC_BINARY, array_recv_binary) # FLOAT4[] self.redshift_types[RedshiftOID.OID_ARRAY] = (FC_BINARY, array_recv_binary) # OID[] self.redshift_types[RedshiftOID.ACLITEM_ARRAY] = (FC_BINARY, array_recv_binary) # ACLITEM[] self.redshift_types[RedshiftOID.VARBYTE] = (FC_TEXT, text_recv) # VARBYTE if self.numeric_to_float: self.redshift_types[RedshiftOID.NUMERIC] = (FC_BINARY, numeric_to_float_binary) else: # text protocol self.redshift_types[RedshiftOID.NUMERIC] = (FC_TEXT, numeric_in) self.redshift_types[RedshiftOID.TIME] = (FC_TEXT, time_in) self.redshift_types[RedshiftOID.DATE] = (FC_TEXT, date_in) self.redshift_types[RedshiftOID.GEOGRAPHY] = (FC_TEXT, text_recv) # GEOGRAPHY self.redshift_types[RedshiftOID.TIMETZ] = (FC_BINARY, timetz_recv_binary) self.redshift_types[RedshiftOID.CHAR_ARRAY] = (FC_TEXT, array_recv_text) # CHAR[] self.redshift_types[RedshiftOID.SMALLINT_ARRAY] = (FC_TEXT, int_array_recv) # INT2[] self.redshift_types[RedshiftOID.INTEGER_ARRAY] = (FC_TEXT, int_array_recv) # INT4[] self.redshift_types[RedshiftOID.TEXT_ARRAY] = (FC_TEXT, array_recv_text) # TEXT[] self.redshift_types[RedshiftOID.VARCHAR_ARRAY] = (FC_TEXT, array_recv_text) # VARCHAR[] self.redshift_types[RedshiftOID.REAL_ARRAY] = (FC_TEXT, float_array_recv) # FLOAT4[] self.redshift_types[RedshiftOID.OID_ARRAY] = (FC_TEXT, int_array_recv) # OID[] self.redshift_types[RedshiftOID.ACLITEM_ARRAY] = (FC_TEXT, array_recv_text) # ACLITEM[] self.redshift_types[RedshiftOID.VARBYTE] = (FC_TEXT, varbytehex_recv) # VARBYTE if self.numeric_to_float: self.redshift_types[RedshiftOID.NUMERIC] = (FC_TEXT, numeric_to_float_in) @property def _is_multi_databases_catalog_enable_in_server(self: "Connection") -> bool: if (b"datashare_enabled", str("on").encode()) in self.parameter_statuses: return True else: # if we don't receive this param from the server, we do not support return False @property def is_single_database_metadata(self): return self._database_metadata_current_db_only or not self._is_multi_databases_catalog_enable_in_server def handle_ERROR_RESPONSE(self: "Connection", data, ps): """ Handler for ErrorResponse message received via Amazon Redshift wire protocol, represented by b'E' code. ErrorResponse (B) Byte1('E') Identifies the message as an error. Int32 Length of message contents in bytes, including self. The message body consists of one or more identified fields, followed by a zero byte as a terminator. Fields may appear in any order. For each field there is the following: Byte1 A code identifying the field type; if zero, this is the message terminator and no string follows. The presently defined field types are listed in Section 42.5. Since more field types may be added in future, frontends should silently ignore fields of unrecognized type. String The field value. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ msg: typing.Dict[str, str] = dict( (s[:1].decode(_client_encoding), s[1:].decode(_client_encoding)) for s in data.split(NULL_BYTE) if s != b"" ) response_code: str = msg[RESPONSE_CODE] if response_code == "28000": cls: type = InterfaceError elif response_code == "23505": cls = IntegrityError else: cls = ProgrammingError self.error = cls(msg) def handle_EMPTY_QUERY_RESPONSE(self: "Connection", data, ps): """ Handler for EmptyQueryResponse message received via Amazon Redshift wire protocol, represented by b'I' code. EmptyQueryResponse (B) Byte1('I') Identifies the message as a response to an empty query string. (This substitutes for CommandComplete.) Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ self.error = ProgrammingError("query was empty") def handle_CLOSE_COMPLETE(self: "Connection", data, ps): """ Handler for CloseComplete message received via Amazon Redshift wire protocol, represented by b'3' code. Currently a no-op. CloseComplete (B) Byte1('3') Identifies the message as a Close-complete indicator. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ pass def handle_PARSE_COMPLETE(self: "Connection", data, ps): """ Handler for ParseComplete message received via Amazon Redshift wire protocol, represented by b'1' code. Currently a no-op. ParseComplete (B) Byte1('1') Identifies the message as a Parse-complete indicator. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ pass def handle_BIND_COMPLETE(self: "Connection", data, ps): """ Handler for BindComplete message received via Amazon Redshift wire protocol, represented by b'2' code. Currently a no-op. BindComplete (B) Byte1('2') Identifies the message as a Bind-complete indicator. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ pass def handle_PORTAL_SUSPENDED(self: "Connection", data, cursor: Cursor): """ Handler for PortalSuspend message received via Amazon Redshift wire protocol, represented by b's' code. Currently a no-op. PortalSuspended (B) Byte1('s') Identifies the message as a portal-suspended indicator. Note this only appears if an Execute message's row-count limit was reached. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param cursor: `Cursor` The `Cursor` object associated with the given statements execution. Returns ------- None:None """ pass def handle_PARAMETER_DESCRIPTION(self: "Connection", data, ps): """ Handler for ParameterDescription message received via Amazon Redshift wire protocol, represented by b't' code. ParameterDescription (B) Byte1('t') Identifies the message as a parameter description. Int32 Length of message contents in bytes, including self. Int16 The number of parameters used by the statement (may be zero). Then, for each parameter, there is the following: Int32 Specifies the object ID of the parameter data type. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ # Well, we don't really care -- we're going to send whatever we # want and let the database deal with it. But thanks anyways! # count = h_unpack(data)[0] # type_oids = unpack_from("!" + "i" * count, data, 2) pass def handle_COPY_DONE(self: "Connection", data, ps): """ Handler for CopyDone message received via Amazon Redshift wire protocol, represented by b'c' code. CopyDone (F & B) Byte1('c') Identifies the message as a COPY-complete indicator. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ self._copy_done = True def handle_COPY_OUT_RESPONSE(self: "Connection", data, ps): """ Handler for CopyOutResponse message received via Amazon Redshift wire protocol, represented by b'H' code. CopyOutResponse (B) Byte1('H') Identifies the message as a Start Copy Out response. This message will be followed by copy-out data. Int32 Length of message contents in bytes, including self. Int8 0 indicates the overall COPY format is textual (rows separated by newlines, columns separated by separator characters, etc). 1 indicates the overall copy format is binary (similar to DataRow format). See COPY for more information. Int16 The number of columns in the data to be copied (denoted N below). Int16[N] The format codes to be used for each column. Each must presently be zero (text) or one (binary). All must be zero if the overall copy format is textual. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ is_binary, num_cols = bh_unpack(data) # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) if ps.stream is None: raise InterfaceError("An output stream is required for the COPY OUT response.") def handle_COPY_DATA(self: "Connection", data, ps) -> None: """ Handler for CopyData message received via Amazon Redshift wire protocol, represented by b'd' code. CopyData (F & B) Byte1('d') Identifies the message as COPY data. Int32 Length of message contents in bytes, including self. Byten Data that forms part of a COPY data stream. Messages sent from the backend will always correspond to single data rows, but messages sent by frontends may divide the data stream arbitrarily. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ ps.stream.write(data) def handle_COPY_IN_RESPONSE(self: "Connection", data, ps): """ Handler for CopyInResponse message received via Amazon Redshift wire protocol, represented by b'G' code. CopyInResponse (B) Byte1('G') Identifies the message as a Start Copy In response. The frontend must now send copy-in data (if not prepared to do so, send a CopyFail message). Int32 Length of message contents in bytes, including self. Int8 0 indicates the overall COPY format is textual (rows separated by newlines, columns separated by separator characters, etc). 1 indicates the overall copy format is binary (similar to DataRow format). See COPY for more information. Int16 The number of columns in the data to be copied (denoted N below). Int16[N] The format codes to be used for each column. Each must presently be zero (text) or one (binary). All must be zero if the overall copy format is textual. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ # Int16(2) - Number of columns # Int16(N) - Format codes for each column (0 text, 1 binary) is_binary, num_cols = bh_unpack(data) # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) if ps.stream is None: raise InterfaceError("An input stream is required for the COPY IN response.") bffr: bytearray = bytearray(8192) while True: bytes_read = ps.stream.readinto(bffr) if bytes_read == 0: break self._write(COPY_DATA + i_pack(bytes_read + 4)) self._write(bffr[:bytes_read]) self._flush() # Send CopyDone # Byte1('c') - Identifier. # Int32(4) - Message length, including self. self._write(COPY_DONE_MSG) self._write(SYNC_MSG) self._flush() def handle_NOTIFICATION_RESPONSE(self: "Connection", data, ps): """ Handler for NotificationResponse message received via Amazon Redshift wire protocol, represented by b'A' code. A message sent if this connection receives a NOTIFY that it was listening for. NotificationResponse (B) Byte1('A') Identifies the message as a notification response. Int32 Length of message contents in bytes, including self. Int32 The process ID of the notifying backend process. String The name of the condition that the notify has been raised on. String Additional information passed from the notifying process. (Currently, this feature is unimplemented so the field is always an empty string.) Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ backend_pid = i_unpack(data)[0] idx: int = 4 null: int = data.find(NULL_BYTE, idx) - idx condition: str = data[idx : idx + null].decode("ascii") idx += null + 1 null = data.find(NULL_BYTE, idx) - idx # additional_info = data[idx:idx + null] self.notifications.append((backend_pid, condition)) def cursor(self: "Connection") -> Cursor: """Creates a :class:`Cursor` object bound to this connection. This function is part of the `DBAPI 2.0 specification `_. Returns ------- A Cursor object associated with the current Connection: :class:`Cursor` """ return Cursor(self) @property def description(self: "Connection") -> typing.Optional[typing.List]: return self._run_cursor._getDescription() def run(self: "Connection", sql, stream=None, **params) -> typing.Tuple[typing.Any, ...]: """ Executes an sql statement, and returns the results as a `tuple`. Returns ------- Result of executing an sql statement:tuple[Any, ...] """ self._run_cursor.execute(sql, params, stream=stream) return tuple(self._run_cursor._cached_rows) def commit(self: "Connection") -> None: """Commits the current database transaction. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ self.execute(self._cursor, "commit", None) def rollback(self: "Connection") -> None: """Rolls back the current database transaction. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ if not self.in_transaction: return self.execute(self._cursor, "rollback", None) def close(self: "Connection") -> None: """Closes the database connection. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ try: # Byte1('X') - Identifies the message as a terminate message. # Int32(4) - Message length, including self. self._write(TERMINATE_MSG) self._flush() if self._sock is not None: self._sock.close() except AttributeError: raise InterfaceError("connection is closed") except ValueError: raise InterfaceError("connection is closed") except socket.error: pass finally: self._usock.close() self._sock = None def handle_AUTHENTICATION_REQUEST(self: "Connection", data: bytes, cursor: Cursor) -> None: """ Handler for AuthenticationRequest message received via Amazon Redshift wire protocol, represented by b'R' code. AuthenticationRequest (B) Byte1('R') Identifies the message as an authentication request. Int32(8) Length of message contents in bytes, including self. Int32(1) An authentication code that represents different authentication messages: 0 = AuthenticationOk 5 = MD5 pwd 2 = Kerberos v5 (not supported) 3 = Cleartext pwd 4 = crypt() pwd (not supported) 6 = SCM credential (not supported) 7 = GSSAPI (not supported) 8 = GSSAPI data (not supported) 9 = SSPI (not supported) 14 = Redshift Native IDP Integration Please note that some authentication messages have additional data following the authentication code. That data is documented in the appropriate conditional branch below. Parameters ---------- :param data: bytes: Message content :param cursor: `Cursor` The `Cursor` object associated with the given statements execution. Returns ------- None:None """ auth_code: int = i_unpack(data)[0] if auth_code == 0: pass elif auth_code == 3: if self.password is None: raise InterfaceError("server requesting password authentication, but no " "password was provided") self._send_message(PASSWORD, self.password + NULL_BYTE) self._flush() elif auth_code == 5: ## # A message representing the backend requesting an MD5 hashed # password response. The response will be sent as # md5(md5(pwd + login) + salt). # Additional message data: # Byte4 - Hash salt. salt: bytes = b"".join(cccc_unpack(data, 4)) if self.password is None: raise InterfaceError("server requesting MD5 password authentication, but no " "password was provided") pwd: bytes = b"md5" + md5( md5(self.password + self.user).hexdigest().encode("ascii") + salt ).hexdigest().encode("ascii") # Byte1('p') - Identifies the message as a password message. # Int32 - Message length including self. # String - The password. Password may be encrypted. self._send_message(PASSWORD, pwd + NULL_BYTE) self._flush() elif auth_code == 10: # AuthenticationSASL mechanisms: typing.List[str] = [m.decode("ascii") for m in data[4:-1].split(NULL_BYTE)] self.auth: ScramClient = ScramClient(mechanisms, self.user.decode("utf8"), self.password.decode("utf8")) init: bytes = self.auth.get_client_first().encode("utf8") # SASLInitialResponse self._write(create_message(PASSWORD, b"SCRAM-SHA-256" + NULL_BYTE + i_pack(len(init)) + init)) self._flush() elif auth_code == 11: # AuthenticationSASLContinue self.auth.set_server_first(data[4:].decode("utf8")) # SASLResponse msg: bytes = self.auth.get_client_final().encode("utf8") self._write(create_message(PASSWORD, msg)) self._flush() elif auth_code == 12: # AuthenticationSASLFinal self.auth.set_server_final(data[4:].decode("utf8")) elif auth_code == 14: # Redshift Native IDP Integration aad_token: str = typing.cast(str, self.web_identity_token) _logger.debug("<=BE Authentication request IDP") if not aad_token: raise ConnectionAbortedError( "The server requested AAD token-based authentication, but no token was provided." ) _logger.debug("FE=> IDP(AAD Token)") token: bytes = aad_token.encode(encoding="utf-8") self._write(create_message(b"i", token)) # self._write(NULL_BYTE) self._flush() elif auth_code == 13: # AUTH_REQ_DIGEST offset: int = 4 algo: int = i_unpack(data, offset)[0] algo_names: typing.Tuple[str] = ("SHA256",) offset += 4 salt_len: int = i_unpack(data, offset)[0] offset += 4 salt = data[offset : offset + salt_len] offset += salt_len server_nonce_len: int = i_unpack(data, offset)[0] offset += 4 server_nonce: bytes = data[offset : offset + server_nonce_len] offset += server_nonce_len ms_since_epoch: int = int((Datetime.utcnow() - Datetime.utcfromtimestamp(0)).total_seconds() * 1000.0) client_nonce: bytes = str(ms_since_epoch).encode("utf-8") _logger.debug("handle_AUTHENTICATION_REQUEST: AUTH_REQ_DIGEST") _logger.debug("Algo:{}".format(algo)) if self.password is None: raise InterfaceError( "The server requested password-based authentication, but no password was provided." ) if algo > len(algo_names): raise InterfaceError( "The server requested password-based authentication, " "but requested algorithm {} is not supported.".format(algo) ) from redshift_connector.utils.extensible_digest import ExtensibleDigest digest: bytes = ExtensibleDigest.encode( client_nonce=client_nonce, password=typing.cast(bytes, self.password), salt=salt, algo_name=algo_names[algo], server_nonce=server_nonce, ) _logger.debug("Password(extensible digest)") self._write(b"d") self._write(i_pack(4 + 4 + len(digest) + 4 + len(client_nonce))) self._write(i_pack(len(digest))) self._write(digest) self._write(i_pack(len(client_nonce))) self._write(client_nonce) self._flush() elif auth_code in (2, 4, 6, 7, 8, 9): raise InterfaceError("Authentication method " + str(auth_code) + " not supported by redshift_connector.") else: raise InterfaceError("Authentication method " + str(auth_code) + " not recognized by redshift_connector.") def handle_READY_FOR_QUERY(self: "Connection", data: bytes, ps) -> None: """ Handler for ReadyForQuery message received via Amazon Redshift wire protocol, represented by b'Z' code. ReadyForQuery (B) Byte1('Z') Identifies the message type. ReadyForQuery is sent whenever the backend is ready for a new query cycle. Int32(5) Length of message contents in bytes, including self. Byte1 Current backend transaction status indicator. Possible values are 'I' if idle (not in a transaction block); 'T' if in a transaction block; or 'E' if in a failed transaction block (queries will be rejected until block is ended). Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ # Byte1 - Status indicator. self.in_transaction = data != IDLE def handle_BACKEND_KEY_DATA(self: "Connection", data: bytes, ps) -> None: self._backend_key_data = data def inspect_datetime(self: "Connection", value: Datetime): if value.tzinfo is None: return self.py_types[RedshiftOID.TIMESTAMP] # timestamp else: return self.py_types[RedshiftOID.TIMESTAMPTZ] # send as timestamptz def inspect_int(self: "Connection", value: int): if min_int2 < value < max_int2: return self.py_types[RedshiftOID.SMALLINT] if min_int4 < value < max_int4: return self.py_types[RedshiftOID.INTEGER] if min_int8 < value < max_int8: return self.py_types[RedshiftOID.BIGINT] return self.py_types[Decimal] def make_params(self: "Connection", values) -> typing.Tuple[typing.Tuple[int, int, typing.Callable], ...]: params: typing.List[typing.Tuple[int, int, typing.Callable]] = [] for value in values: typ: typing.Type = type(value) try: params.append(self.py_types[typ]) except KeyError: try: params.append(self.inspect_funcs[typ](value)) except KeyError as e: param: typing.Optional[typing.Tuple[int, int, typing.Callable]] = None for k, v in self.py_types.items(): try: if isinstance(value, typing.cast(type, k)): param = v break except TypeError: pass if param is None: for k, v in self.inspect_funcs.items(): # type: ignore try: if isinstance(value, k): v_func: typing.Callable = typing.cast(typing.Callable, v) param = v_func(value) break except TypeError: pass except KeyError: pass if param is None: raise NotSupportedError("type " + str(e) + " not mapped to pg type") else: params.append(param) return tuple(params) def handle_ROW_DESCRIPTION(self: "Connection", data, cursor: Cursor) -> None: """ Handler for RowDescription message received via Amazon Redshift wire protocol, represented by b'T' code. Sets ``Connection.ps`` to store metadata. RowDescription (B) Byte1('T') Identifies the message as a row description. Int32 Length of message contents in bytes, including self. Int16 Specifies the number of fields in a row (may be zero). Then, for each field, there is the following: String The field name. Int32 If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. Int16 If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. Int32 The object ID of the field's data type. Int16 The data type size (see pg_type.typlen). Note that negative values denote variable-width types. Int32 The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. Int16 The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. Parameters ---------- :param data: bytes: Message content :param cursor: `Cursor` The `Cursor` object associated with the given statements execution. Returns ------- None:None """ if cursor.ps is None: raise InterfaceError("Cursor is missing prepared statement") elif "row_desc" not in cursor.ps: raise InterfaceError("Prepared Statement is missing row description") count: int = h_unpack(data)[0] _logger.debug("field count={}".format(count)) idx = 2 for i in range(count): column_label = data[idx : data.find(NULL_BYTE, idx)] idx += len(column_label) + 1 field: typing.Dict = dict( zip( ("table_oid", "column_attrnum", "type_oid", "type_size", "type_modifier", "format"), ihihih_unpack(data, idx), ) ) field["label"] = column_label idx += 18 if self._client_protocol_version >= ClientProtocolVersion.EXTENDED_RESULT_METADATA: for entry in ("schema_name", "table_name", "column_name", "catalog_name"): field[entry] = data[idx : data.find(NULL_BYTE, idx)] idx += len(field[entry]) + 1 temp: int = h_unpack(data, idx)[0] field["nullable"] = temp & 0x1 field["autoincrement"] = (temp >> 4) & 0x1 field["read_only"] = (temp >> 8) & 0x1 field["searchable"] = (temp >> 12) & 0x1 idx += 2 cursor.ps["row_desc"].append(field) field["redshift_connector_fc"], field["func"] = self.redshift_types[field["type_oid"]] _logger.debug(cursor.ps["row_desc"]) def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None: """ Executes a database operation. Parameters may be provided as a sequence, or as a mapping, depending upon the value of `redshift_connector.paramstyle`. Parameters ---------- cursor : :class:`Cursor` operation : str The SQL statement to execute. vals : If `redshift_connector.paramstyle` is `qmark`, `numeric`, or `format` this argument should be an array of parameters to bind into the statement. If `redshift_connector.paramstyle` is `named` the argument should be a `dict` mapping of parameters. If `redshift_connector.paramstyle` is `pyformat`, the argument value may be either an array or mapping. Returns ------- None:None """ # get the process ID of the calling process. pid: int = getpid() args: typing.Tuple[typing.Optional[typing.Tuple[str, typing.Any]], ...] = () # transforms user provided bind parameters to server friendly bind parameters params: typing.Tuple[typing.Optional[typing.Tuple[int, int, typing.Callable]], ...] = () has_bind_parameters: bool = False if vals is None else True # multi dimensional dictionary to store the data # cache = self._caches[cursor.paramstyle][pid] # cache = {'statement': {}, 'ps': {}} # statement stores the data of the statement, ps store the data of the prepared statement # statement = {operation(query): tuple from 'convert_paramstyle'(statement, make_args)} try: cache = self._caches[cursor.paramstyle][pid] except KeyError: try: param_cache = self._caches[cursor.paramstyle] except KeyError: param_cache = self._caches[cursor.paramstyle] = {} try: cache = param_cache[pid] except KeyError: cache = param_cache[pid] = {"statement": {}, "ps": {}} try: statement, make_args = cache["statement"][operation] except KeyError: if has_bind_parameters: statement, make_args = cache["statement"][operation] = convert_paramstyle(cursor.paramstyle, operation) else: # use a no-op make_args in lieu of parsing the sql statement statement, make_args = cache["statement"][operation] = operation, lambda p: () if has_bind_parameters: args = make_args(vals) # change the args to the format that the DB will identify # take reference from self.py_types params = self.make_params(args) key = operation, params try: ps = cache["ps"][key] cursor.ps = ps except KeyError: statement_nums: typing.List[int] = [0] for style_cache in self._caches.values(): try: pid_cache = style_cache[pid] for csh in pid_cache["ps"].values(): statement_nums.append(csh["statement_num"]) except KeyError: pass # statement_num is the id of statement increasing from 1 statement_num: int = sorted(statement_nums)[-1] + 1 # consist of "redshift_connector", statement, process id and statement number. # e.g redshift_connector_statement_11432_2 statement_name: str = "_".join(("redshift_connector", "statement", str(pid), str(statement_num))) statement_name_bin: bytes = statement_name.encode("ascii") + NULL_BYTE # row_desc: list that used to store metadata of rows from DB # param_funcs: type transform function ps = { "statement_name_bin": statement_name_bin, "pid": pid, "statement_num": statement_num, "row_desc": [], "param_funcs": tuple(x[2] for x in params), # type: ignore } cursor.ps = ps param_fcs: typing.Tuple[typing.Optional[int], ...] = tuple(x[1] for x in params) # type: ignore # Byte1('P') - Identifies the message as a Parse command. # Int32 - Message length, including self. # String - Prepared statement name. An empty string selects the # unnamed prepared statement. # String - The query string. # Int16 - Number of parameter data types specified (can be zero). # For each parameter: # Int32 - The OID of the parameter data type. val: typing.Union[bytes, bytearray] = bytearray(statement_name_bin) typing.cast(bytearray, val).extend(statement.encode(_client_encoding) + NULL_BYTE) typing.cast(bytearray, val).extend(h_pack(len(params))) for oid, fc, send_func in params: # type: ignore # Parse message doesn't seem to handle the -1 type_oid for NULL # values that other messages handle. So we'll provide type_oid # 705, the PG "unknown" type. typing.cast(bytearray, val).extend(i_pack(705 if oid == -1 else oid)) # Byte1('D') - Identifies the message as a describe command. # Int32 - Message length, including self. # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to describe. # PARSE message will notify database to create a prepared statement object self._send_message(PARSE, val) # DESCRIBE message will specify the name of the existing prepared statement # the response will be a parameterDescribing message describe the parameters needed # and a RowDescription message describe the rows will be return(nodata message when no return rows) self._send_message(DESCRIBE, STATEMENT + statement_name_bin) # at completion of query message, driver issue a sync message self._write(SYNC_MSG) try: self._flush() except AttributeError as e: if self._sock is None: raise InterfaceError("connection is closed") else: raise e self.handle_messages(cursor) # We've got row_desc that allows us to identify what we're # going to get back from this statement. output_fc = tuple(self.redshift_types[f["type_oid"]][0] for f in ps["row_desc"]) ps["input_funcs"] = tuple(f["func"] for f in ps["row_desc"]) # Byte1('B') - Identifies the Bind command. # Int32 - Message length, including self. # String - Name of the destination portal. # String - Name of the source prepared statement. # Int16 - Number of parameter format codes. # For each parameter format code: # Int16 - The parameter format code. # Int16 - Number of parameter values. # For each parameter value: # Int32 - The length of the parameter value, in bytes, not # including this length. -1 indicates a NULL parameter # value, in which no value bytes follow. # Byte[n] - Value of the parameter. # Int16 - The number of result-column format codes. # For each result-column format code: # Int16 - The format code. ps["bind_1"] = ( NULL_BYTE + statement_name_bin + h_pack(len(params)) + pack("!" + "h" * len(param_fcs), *param_fcs) + h_pack(len(params)) ) ps["bind_2"] = h_pack(len(output_fc)) + pack("!" + "h" * len(output_fc), *output_fc) if len(cache["ps"]) > self.max_prepared_statements: for p in cache["ps"].values(): self.close_prepared_statement(p["statement_name_bin"]) cache["ps"].clear() cache["ps"][key] = ps cursor._cached_rows.clear() cursor._row_count = -1 cursor._redshift_row_count = -1 # Byte1('B') - Identifies the Bind command. # Int32 - Message length, including self. # String - Name of the destination portal. # String - Name of the source prepared statement. # Int16 - Number of parameter format codes. # For each parameter format code: # Int16 - The parameter format code. # Int16 - Number of parameter values. # For each parameter value: # Int32 - The length of the parameter value, in bytes, not # including this length. -1 indicates a NULL parameter # value, in which no value bytes follow. # Byte[n] - Value of the parameter. # Int16 - The number of result-column format codes. # For each result-column format code: # Int16 - The format code. retval: bytearray = bytearray(ps["bind_1"]) for value, send_func in zip(args, ps["param_funcs"]): if value is None: val = NULL else: val = send_func(value) retval.extend(i_pack(len(val))) retval.extend(val) retval.extend(ps["bind_2"]) # send BIND message which includes name of parepared statement, # name of destination portal and the value of placeholders in prepared statement. # these parameters need to match the prepared statements self._send_message(BIND, retval) self.send_EXECUTE(cursor) self._write(SYNC_MSG) self._flush() # handle multi messages including BIND_COMPLETE, DATA_ROW, COMMAND_COMPLETE # READY_FOR_QUERY if self.merge_socket_read: self.handle_messages_merge_socket_read(cursor) else: self.handle_messages(cursor) def _send_message(self: "Connection", code: bytes, data: bytes) -> None: try: self._write(code) self._write(i_pack(len(data) + 4)) self._write(data) self._write(FLUSH_MSG) except ValueError as e: if str(e) == "write to closed file": raise InterfaceError("connection is closed") else: raise e except AttributeError: raise InterfaceError("connection is closed") def send_EXECUTE(self: "Connection", cursor: Cursor) -> None: """ Sends an Execute message in ordinance with Amazon Redshift wire protocol. Execute (F) Byte1('E') Identifies the message as an Execute command. Int32 Length of message contents in bytes, including self. String The name of the portal to execute (an empty string selects the unnamed portal). Int32 Maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes "no limit". Parameters ---------- :param cursor: `Cursor` The `Cursor` object associated with the given statements execution. Returns ------- None:None """ self._write(EXECUTE_MSG) self._write(FLUSH_MSG) def handle_NO_DATA(self: "Connection", msg, ps) -> None: """ Handler for NoData message received via Amazon Redshift wire protocol, represented by b'B' code. Currently a no-op. NoData (B) Byte1('n') Identifies the message as a no-data indicator. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param msg: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ pass def handle_COMMAND_COMPLETE(self: "Connection", data: bytes, cursor: Cursor) -> None: """ Handler for CommandComplete message received via Amazon Redshift wire protocol, represented by b'C' code. Modifies the cursor object and prepared statement. CommandComplete (B) Byte1('C') Identifies the message as a command-completed response. Int32 Length of message contents in bytes, including self. String The command tag. This is usually a single word that identifies which SQL command was completed. For an INSERT command, the tag is INSERT oid rows, where rows is the number of rows inserted. oid is the object ID of the inserted row if rows is 1 and the target table has OIDs; otherwise oid is 0. For a DELETE command, the tag is DELETE rows where rows is the number of rows deleted. For an UPDATE command, the tag is UPDATE rows where rows is the number of rows updated. For a MOVE command, the tag is MOVE rows where rows is the number of rows the cursor's position has been changed by. For a FETCH command, the tag is FETCH rows where rows is the number of rows that have been retrieved from the cursor. Parameters ---------- :param data: bytes: Message content :param cursor: `Cursor` The `Cursor` object associated with the given statements execution. Returns ------- None:None """ values: typing.List[bytes] = data[:-1].split(b" ") command = values[0] if command in self._commands_with_count: row_count: int = int(values[-1]) if cursor._row_count == -1: cursor._row_count = row_count else: cursor._row_count += row_count cursor._redshift_row_count = cursor._row_count elif command == b"SELECT": # Redshift server does not support row count for SELECT statement # so we derive this from the size of the rows associated with the # cursor object cursor._redshift_row_count = len(cursor._cached_rows) if command in (b"ALTER", b"CREATE"): for scache in self._caches.values(): for pcache in scache.values(): for ps in pcache["ps"].values(): self.close_prepared_statement(ps["statement_name_bin"]) pcache["ps"].clear() def handle_DATA_ROW(self: "Connection", data: bytes, cursor: Cursor) -> None: """ Handler for DataRow message received via Amazon Redshift wire protocol, represented by b'D' code. Processes incoming data rows from Amazon Redshift into Python data types, storing the transformed row in the cursor object's `_cached_rows`. NoData (B) Byte1('n') Identifies the message as a no-data indicator. Int32(4) Length of message contents in bytes, including self. Parameters ---------- :param data: bytes: Message content :param cursor: `Cursor` The `Cursor` object associated with the given statements execution. Returns ------- None:None """ data_idx: int = 2 row: typing.List = [] for desc in cursor.truncated_row_desc(): vlen: int = i_unpack(data, data_idx)[0] data_idx += 4 if vlen == -1: row.append(None) elif desc[0] in (numeric_in_binary, numeric_to_float_binary): row.append(desc[0](data, data_idx, vlen, desc[1])) data_idx += vlen else: row.append(desc[0](data, data_idx, vlen)) data_idx += vlen cursor._cached_rows.append(row) def handle_messages(self: "Connection", cursor: Cursor) -> None: """ Reads messages formatted in ordinance with Amazon Redshift wire protocol, modifying the connection and cursor. Parameters ---------- :param cursor: `Cursor` The `Cursor` object associated with the given connection object. Returns ------- None:None """ code = self.error = None while code != READY_FOR_QUERY: code, data_len = ci_unpack(self._read(5)) self.message_types[code](self._read(data_len - 4), cursor) if self.error is not None: raise self.error def handle_messages_merge_socket_read(self: "Connection", cursor: Cursor): """ An optimized version of :func:`Connection.handle_messages` which reduces reads. Parameters ---------- :param cursor: `Cursor` The `Cursor` object associated with the given connection object. Returns ------- None:None """ code = self.error = None # read 5 bytes of message firstly code, data_len = ci_unpack(self._read(5)) while True: if code == READY_FOR_QUERY: # for last message self.message_types[code](self._read(data_len - 4), cursor) break else: # read data body of last message and read next 5 bytes of next message data = self._read(data_len - 4 + 5) last_message_body = data[0:-5] self.message_types[code](last_message_body, cursor) code, data_len = ci_unpack(data[-5:]) if self.error is not None: raise self.error def close_prepared_statement(self: "Connection", statement_name_bin: bytes) -> None: """ Handler for Close message received via Amazon Redshift wire protocol, represented by b'C' code. Clears attributes associated with the prepared statement from the current connection object. Close (F) Byte1('C') Identifies the message as a Close command. Int32 Length of message contents in bytes, including self. Byte1 'S' to close a prepared statement; or 'P' to close a portal. String The name of the prepared statement or portal to close (an empty string selects the unnamed prepared statement or portal). Parameters ---------- :param statement_name_bin: bytes: Message content Returns ------- None:None """ self._send_message(CLOSE, STATEMENT + statement_name_bin) self._write(SYNC_MSG) self._flush() self.handle_messages(self._cursor) def handle_NOTICE_RESPONSE(self: "Connection", data: bytes, ps) -> None: """ Handler for NoticeResponse message received via Amazon Redshift wire protocol, represented by b'N' code. Adds the received notice to ``Connection.notices``. NoticeResponse (B) Byte1('N') Identifies the message as a notice. Int32 Length of message contents in bytes, including self. The message body consists of one or more identified fields, followed by a zero byte as a terminator. Fields may appear in any order. For each field there is the following: Byte1 A code identifying the field type; if zero, this is the message terminator and no string follows. The presently defined field types are listed in Section 42.5. Since more field types may be added in future, frontends should silently ignore fields of unrecognized type. String The field value. Parameters ---------- :param data: bytes: Message content :param ps: typing.Optional[typing.Dict[str, typing.Any]]: Prepared Statement from associated Cursor Returns ------- None:None """ self.notices.append(dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE))) def handle_PARAMETER_STATUS(self: "Connection", data: bytes, ps) -> None: """ Handler for ParameterStatus message received via Amazon Redshift wire protocol, represented by b'S' code. Modifies the connection object inline with parameter values received in preperation for statment execution. ParameterStatus (B) Byte1('S') Identifies the message as a run-time parameter status report. Int32 Length of message contents in bytes, including self. String The name of the run-time parameter being reported. String The current value of the parameter. Parameters ---------- :param statement_name_bin: bytes: Message content Returns ------- None:None """ pos: int = data.find(NULL_BYTE) key, value = data[:pos], data[pos + 1 : -1] self.parameter_statuses.append((key, value)) if key == b"client_encoding": encoding = value.decode("ascii").lower() _client_encoding = pg_to_py_encodings.get(encoding, encoding) elif key == b"server_protocol_version": # when a mismatch occurs between the client's requested protocol version, and the server's response, # warn the user and follow server if self._client_protocol_version != int(value): _logger.debug( "Server indicated {} transfer protocol will be used rather than protocol requested by client: {}".format( ClientProtocolVersion.get_name(int(value)), ClientProtocolVersion.get_name(self._client_protocol_version), ) ) self._client_protocol_version = int(value) self._enable_protocol_based_conversion_funcs() elif key == b"server_version": self._server_version: typing.Union[version.LegacyVersion, version.Version] = version.parse( (value.decode("ascii")) ) if self._server_version < version.parse("8.2.0"): self._commands_with_count = (b"INSERT", b"DELETE", b"UPDATE", b"MOVE") elif self._server_version < version.parse("9.0.0"): self._commands_with_count = (b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY") def array_inspect(self: "Connection", value): # Check if array has any values. If empty, we can just assume it's an # array of strings first_element = array_find_first_element(value) if first_element is None: oid: int = 25 # Use binary ARRAY format to avoid having to properly # escape text in the array literals fc: int = FC_BINARY array_oid: int = pg_array_types[oid] else: # supported array output typ: type = type(first_element) if issubclass(typ, int): # special int array support -- send as smallest possible array # type typ = int int2_ok, int4_ok, int8_ok = True, True, True for v in array_flatten(value): if v is None: continue if min_int2 < v < max_int2: continue int2_ok = False if min_int4 < v < max_int4: continue int4_ok = False if min_int8 < v < max_int8: continue int8_ok = False if int2_ok: array_oid = 1005 # INT2[] oid, fc, send_func = (21, FC_BINARY, h_pack) elif int4_ok: array_oid = 1007 # INT4[] oid, fc, send_func = (23, FC_BINARY, i_pack) elif int8_ok: array_oid = 1016 # INT8[] oid, fc, send_func = (20, FC_BINARY, q_pack) else: raise ArrayContentNotSupportedError("numeric not supported as array contents") else: try: oid, fc, send_func = self.make_params((first_element,))[0] # If unknown or string, assume it's a string array if oid in (705, 1043, 25): oid = 25 # Use binary ARRAY format to avoid having to properly # escape text in the array literals fc = FC_BINARY array_oid = pg_array_types[oid] except KeyError: raise ArrayContentNotSupportedError("oid " + str(oid) + " not supported as array contents") except NotSupportedError: raise ArrayContentNotSupportedError("type " + str(typ) + " not supported as array contents") if fc == FC_BINARY: def send_array(arr: typing.List) -> typing.Union[bytes, bytearray]: # check that all array dimensions are consistent array_check_dimensions(arr) has_null: bool = array_has_null(arr) dim_lengths: typing.List[int] = array_dim_lengths(arr) data: bytearray = bytearray(iii_pack(len(dim_lengths), has_null, oid)) for i in dim_lengths: data.extend(ii_pack(i, 1)) for v in array_flatten(arr): if v is None: data += i_pack(-1) elif isinstance(v, typ): inner_data = send_func(v) data += i_pack(len(inner_data)) data += inner_data else: raise ArrayContentNotHomogenousError("not all array elements are of type " + str(typ)) return data else: def send_array(arr: typing.List) -> typing.Union[bytes, bytearray]: array_check_dimensions(arr) ar: typing.List = deepcopy(arr) for a, i, v in walk_array(ar): if v is None: a[i] = "NULL" elif isinstance(v, typ): a[i] = send_func(v).decode("ascii") else: raise ArrayContentNotHomogenousError("not all array elements are of type " + str(typ)) return str(ar).translate(arr_trans).encode("ascii") return (array_oid, fc, send_array) def xid(self: "Connection", format_id, global_transaction_id, branch_qualifier) -> typing.Tuple: """Create a Transaction IDs (only global_transaction_id is used in pg) format_id and branch_qualifier are not used in Amazon Redshift global_transaction_id may be any string identifier supported by Amazon Redshift. Returns ------- (format_id, global_transaction_id, branch_qualifier):typing.Tuple """ return (format_id, global_transaction_id, branch_qualifier) def tpc_begin(self: "Connection", xid) -> None: """Begins a TPC transaction with the given transaction ID xid. This method should be called outside of a transaction (i.e. nothing may have executed since the last .commit() or .rollback()). Furthermore, it is an error to call .commit() or .rollback() within the TPC transaction. A ProgrammingError is raised, if the application calls .commit() or .rollback() during an active TPC transaction. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ self._xid = xid if self.autocommit: self.execute(self._cursor, "begin transaction", None) def tpc_prepare(self: "Connection") -> None: """Performs the first phase of a transaction started with .tpc_begin(). A ProgrammingError is be raised if this method is called outside of a TPC transaction. After calling .tpc_prepare(), no statements can be executed until .tpc_commit() or .tpc_rollback() have been called. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ if self._xid is None or len(self._xid) < 2: raise InterfaceError("Malformed Transaction Id") q: str = "PREPARE TRANSACTION '%s';" % (self._xid[1],) self.execute(self._cursor, q, None) def tpc_commit(self: "Connection", xid=None) -> None: """When called with no arguments, .tpc_commit() commits a TPC transaction previously prepared with .tpc_prepare(). If .tpc_commit() is called prior to .tpc_prepare(), a single phase commit is performed. A transaction manager may choose to do this if only a single resource is participating in the global transaction. When called with a transaction ID xid, the database commits the given transaction. If an invalid transaction ID is provided, a ProgrammingError will be raised. This form should be called outside of a transaction, and is intended for use in recovery. On return, the TPC transaction is ended. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ if xid is None: xid = self._xid if xid is None: raise ProgrammingError("Cannot tpc_commit() without a TPC transaction!") try: previous_autocommit_mode: bool = self.autocommit self.autocommit = True if xid in self.tpc_recover(): self.execute(self._cursor, "COMMIT PREPARED '%s';" % (xid[1],), None) else: # a single-phase commit self.commit() finally: self.autocommit = previous_autocommit_mode self._xid = None def tpc_rollback(self: "Connection", xid=None) -> None: """When called with no arguments, .tpc_rollback() rolls back a TPC transaction. It may be called before or after .tpc_prepare(). When called with a transaction ID xid, it rolls back the given transaction. If an invalid transaction ID is provided, a ProgrammingError is raised. This form should be called outside of a transaction, and is intended for use in recovery. On return, the TPC transaction is ended. This function is part of the `DBAPI 2.0 specification `_. Returns ------- None:None """ if xid is None: xid = self._xid if xid is None: raise ProgrammingError("Cannot tpc_rollback() without a TPC prepared transaction!") try: previous_autocommit_mode: bool = self.autocommit self.autocommit = True if xid in self.tpc_recover(): # a two-phase rollback self.execute(self._cursor, "ROLLBACK PREPARED '%s';" % (xid[1],), None) else: # a single-phase rollback self.rollback() finally: self.autocommit = previous_autocommit_mode self._xid = None def tpc_recover(self: "Connection") -> typing.List[typing.Tuple[typing.Any, ...]]: """Returns a list of pending transaction IDs suitable for use with .tpc_commit(xid) or .tpc_rollback(xid). This function is part of the `DBAPI 2.0 specification `_. Returns ------- List of pending transaction IDs:List[tuple[Any, ...]] """ try: previous_autocommit_mode: bool = self.autocommit self.autocommit = True curs = self.cursor() curs.execute("select xact_id FROM stl_undone") return [self.xid(0, row[0], "") for row in curs] finally: self.autocommit = previous_autocommit_mode