# Copyright (c) 2016, 2022, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """Implementation of the X protocol for MySQL servers.""" import struct import zlib from io import BytesIO from typing import Any, Dict, List, Optional, Tuple, Union try: import lz4.frame HAVE_LZ4 = True except ImportError: HAVE_LZ4 = False try: import zstandard as zstd HAVE_ZSTD = True except ImportError: HAVE_ZSTD = False from .errors import ( InterfaceError, NotSupportedError, OperationalError, ProgrammingError, ) from .expr import ( ExprParser, build_bool_scalar, build_expr, build_int_scalar, build_scalar, build_unsigned_int_scalar, ) from .helpers import encode_to_bytes, get_item_or_attr from .logger import logger from .protobuf import ( CRUD_PREPARE_MAPPING, PROTOBUF_REPEATED_TYPES, SERVER_MESSAGES, Message, mysqlxpb_enum, ) from .result import Column from .statement import ( AddStatement, DeleteStatement, FilterableStatement, FindStatement, InsertStatement, ModifyStatement, ReadStatement, RemoveStatement, SqlStatement, UpdateStatement, ) from .types import ( ColumnType, MessageType, ProtobufMessageCextType, ProtobufMessageType, ResultBaseType, SocketType, StatementType, StrOrBytes, ) _COMPRESSION_THRESHOLD = 1000 class Compressor: """Implements compression/decompression using `zstd_stream`, `lz4_message` and `deflate_stream` algorithms. Args: algorithm (str): Compression algorithm. .. versionadded:: 8.0.21 """ def __init__(self, algorithm: str) -> None: self._algorithm: str = algorithm self._compressobj: Any = None self._decompressobj: Any = None if algorithm == "zstd_stream": self._compressobj = zstd.ZstdCompressor() self._decompressobj = zstd.ZstdDecompressor() elif algorithm == "deflate_stream": self._compressobj = zlib.compressobj() self._decompressobj = zlib.decompressobj() def compress(self, data: StrOrBytes) -> bytes: """Compresses data and returns it. Args: data (str, bytes or buffer object): Data to be compressed. Returns: bytes: Compressed data. """ if self._algorithm == "zstd_stream": return self._compressobj.compress(data) if self._algorithm == "lz4_message": with lz4.frame.LZ4FrameCompressor() as compressor: compressed = compressor.begin() compressed += compressor.compress(data) compressed += compressor.flush() return compressed # Using 'deflate_stream' algorithm compressed = self._compressobj.compress(data) compressed += self._compressobj.flush(zlib.Z_SYNC_FLUSH) return compressed def decompress(self, data: StrOrBytes) -> bytes: """Decompresses a frame of data and returns it as a string of bytes. Args: data (str, bytes or buffer object): Data to be compressed. Returns: bytes: Decompresssed data. """ if self._algorithm == "zstd_stream": return self._decompressobj.decompress(data) if self._algorithm == "lz4_message": with lz4.frame.LZ4FrameDecompressor() as decompressor: decompressed = decompressor.decompress(data) return decompressed # Using 'deflate' algorithm decompressed = self._decompressobj.decompress(data) decompressed += self._decompressobj.flush(zlib.Z_SYNC_FLUSH) return decompressed class MessageReader: """Implements a Message Reader. Args: socket_stream (mysqlx.connection.SocketStream): `SocketStream` object. .. versionadded:: 8.0.21 """ def __init__(self, socket_stream: SocketType) -> None: self._stream: SocketType = socket_stream self._compressor: Optional[Compressor] = None self._msg: MessageType = None self._msg_queue: List[Message] = [] def _read_message(self) -> MessageType: """Reads X Protocol messages from the stream and returns a :class:`mysqlx.protobuf.Message` object. Raises: :class:`mysqlx.ProgrammingError`: If e connected server does not have the MySQL X protocol plugin enabled. Returns: mysqlx.protobuf.Message: MySQL X Protobuf Message. """ if self._msg_queue: return self._msg_queue.pop(0) frame_size, frame_type = struct.unpack(" MessageType: """Read message. Returns: mysqlx.protobuf.Message: MySQL X Protobuf Message. """ if self._msg is not None: msg = self._msg self._msg = None return msg return self._read_message() def push_message(self, msg: MessageType) -> None: """Push message. Args: msg (mysqlx.protobuf.Message): MySQL X Protobuf Message. Raises: :class:`mysqlx.OperationalError`: If message push slot is full. """ if self._msg is not None: raise OperationalError("Message push slot is full") self._msg = msg def set_compression(self, algorithm: str) -> None: """Creates a :class:`mysqlx.protocol.Compressor` object based on the compression algorithm. Args: algorithm (str): Compression algorithm. .. versionadded:: 8.0.21 """ self._compressor = Compressor(algorithm) if algorithm else None class MessageWriter: """Implements a Message Writer. Args: socket_stream (mysqlx.connection.SocketStream): `SocketStream` object. .. versionadded:: 8.0.21 """ def __init__(self, socket_stream: SocketType) -> None: self._stream: SocketType = socket_stream self._compressor: Optional[Compressor] = None def write_message(self, msg_type: int, msg: MessageType) -> None: """Write message. Args: msg_type (int): The message type. msg (mysqlx.protobuf.Message): MySQL X Protobuf Message. """ msg_size = msg.byte_size(msg) if self._compressor and msg_size > _COMPRESSION_THRESHOLD: msg_str = encode_to_bytes(msg.serialize_to_string()) header = struct.pack(" None: """Creates a :class:`mysqlx.protocol.Compressor` object based on the compression algorithm. Args: algorithm (str): Compression algorithm. """ self._compressor = Compressor(algorithm) if algorithm else None class Protocol: """Implements the MySQL X Protocol. Args: read (mysqlx.protocol.MessageReader): A Message Reader object. writer (mysqlx.protocol.MessageWriter): A Message Writer object. .. versionchanged:: 8.0.21 """ def __init__(self, reader: MessageReader, writer: MessageWriter) -> None: self._reader: MessageReader = reader self._writer: MessageWriter = writer self._compression_algorithm: Optional[str] = None self._warnings: List[str] = [] @property def compression_algorithm(self) -> Optional[str]: """str: The compresion algorithm.""" return self._compression_algorithm @staticmethod def _apply_filter(msg: MessageType, stmt: FilterableStatement) -> None: """Apply filter. Args: msg (mysqlx.protobuf.Message): The MySQL X Protobuf Message. stmt (Statement): A `Statement` based type object. """ if stmt.has_where: msg["criteria"] = stmt.get_where_expr() if stmt.has_sort: msg["order"].extend(stmt.get_sort_expr()) if stmt.has_group_by: msg["grouping"].extend(stmt.get_grouping()) if stmt.has_having: msg["grouping_criteria"] = stmt.get_having() def _create_any(self, arg: Any) -> Optional[MessageType]: """Create any. Args: arg (object): Arbitrary object. Returns: mysqlx.protobuf.Message: MySQL X Protobuf Message. """ if isinstance(arg, str): value = Message("Mysqlx.Datatypes.Scalar.String", value=arg) scalar = Message("Mysqlx.Datatypes.Scalar", type=8, v_string=value) return Message("Mysqlx.Datatypes.Any", type=1, scalar=scalar) if isinstance(arg, bool): return Message( "Mysqlx.Datatypes.Any", type=1, scalar=build_bool_scalar(arg) ) if isinstance(arg, int): if arg < 0: return Message( "Mysqlx.Datatypes.Any", type=1, scalar=build_int_scalar(arg), ) return Message( "Mysqlx.Datatypes.Any", type=1, scalar=build_unsigned_int_scalar(arg), ) if isinstance(arg, tuple) and len(arg) == 2: arg_key, arg_value = arg obj_fld = Message( "Mysqlx.Datatypes.Object.ObjectField", key=arg_key, value=self._create_any(arg_value), ) obj = Message("Mysqlx.Datatypes.Object", fld=[obj_fld.get_message()]) return Message("Mysqlx.Datatypes.Any", type=2, obj=obj) if isinstance(arg, dict) or ( isinstance(arg, (list, tuple)) and isinstance(arg[0], dict) ): array_values = [] for items in arg: obj_flds = [] for key, value in items.items(): # Array can only handle Any types, Mysqlx.Datatypes.Any.obj obj_fld = Message( "Mysqlx.Datatypes.Object.ObjectField", key=key, value=self._create_any(value), ) obj_flds.append(obj_fld.get_message()) msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds) msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj) array_values.append(msg_any.get_message()) msg = Message("Mysqlx.Datatypes.Array") msg["value"] = array_values return Message("Mysqlx.Datatypes.Any", type=3, array=msg) if isinstance(arg, list): obj_flds = [] for key, value in arg: obj_fld = Message( "Mysqlx.Datatypes.Object.ObjectField", key=key, value=self._create_any(value), ) obj_flds.append(obj_fld.get_message()) msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds) msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj) return msg_any return None def _get_binding_args( self, stmt: Union[FilterableStatement, SqlStatement], is_scalar: bool = True ) -> Union[List[None], List[Union[ProtobufMessageType, ProtobufMessageCextType]]]: """Returns the binding any/scalar. Args: stmt (Statement): A `Statement` based type object. is_scalar (bool): `True` to return scalar values. Raises: :class:`mysqlx.ProgrammingError`: If unable to find placeholder for parameter. Returns: list: A list of ``Any`` or ``Scalar`` objects. """ def build_value( value: Any, ) -> Union[ProtobufMessageType, ProtobufMessageCextType]: if is_scalar: return build_scalar(value).get_message() return self._create_any(value).get_message() bindings = stmt.get_bindings() binding_map = stmt.get_binding_map() # If binding_map is None it's a SqlStatement object if binding_map is None: return [build_value(value) for value in bindings] count = len(binding_map) args: List[Any] = count * [None] if count != len(bindings): raise ProgrammingError( "The number of bind parameters and placeholders do not match" ) for name, value in bindings.items(): # type: ignore[union-attr] if name not in binding_map: raise ProgrammingError( f"Unable to find placeholder for parameter: {name}" ) pos = binding_map[name] args[pos] = build_value(value) return args def _process_frame(self, msg: MessageType, result: ResultBaseType) -> None: """Process frame. Args: msg (mysqlx.protobuf.Message): A MySQL X Protobuf Message. result (Result): A `Result` based type object. """ if msg["type"] == 1: warn_msg = Message.from_message("Mysqlx.Notice.Warning", msg["payload"]) self._warnings.append(warn_msg.msg) logger.warning( "Protocol.process_frame Received Warning Notice code %s: %s", warn_msg.code, warn_msg.msg, ) result.append_warning(warn_msg.level, warn_msg.code, warn_msg.msg) elif msg["type"] == 2: Message.from_message("Mysqlx.Notice.SessionVariableChanged", msg["payload"]) elif msg["type"] == 3: sess_state_msg = Message.from_message( "Mysqlx.Notice.SessionStateChanged", msg["payload"] ) if sess_state_msg["param"] == mysqlxpb_enum( "Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_DOCUMENT_IDS" ): result.set_generated_ids( [ get_item_or_attr( get_item_or_attr(value, "v_octets"), "value" ).decode() for value in sess_state_msg["value"] ] ) else: # Following results are unitary and not a list sess_state_value = ( sess_state_msg["value"][0] if isinstance( sess_state_msg["value"], tuple(PROTOBUF_REPEATED_TYPES) ) else sess_state_msg["value"] ) if sess_state_msg["param"] == mysqlxpb_enum( "Mysqlx.Notice.SessionStateChanged.Parameter.ROWS_AFFECTED" ): result.set_rows_affected( get_item_or_attr(sess_state_value, "v_unsigned_int") ) elif sess_state_msg["param"] == mysqlxpb_enum( "Mysqlx.Notice.SessionStateChanged.Parameter.GENERATED_INSERT_ID" ): result.set_generated_insert_id( get_item_or_attr(sess_state_value, "v_unsigned_int") ) def _read_message(self, result: ResultBaseType) -> Optional[MessageType]: """Read message. Args: result (Result): A `Result` based type object. """ while True: try: msg = self._reader.read_message() except RuntimeError as err: warnings = repr(result.get_warnings()) if warnings: raise RuntimeError(f"{err} reason: {warnings}") from err if msg.type == "Mysqlx.Error": raise OperationalError(msg["msg"], msg["code"]) if msg.type == "Mysqlx.Notice.Frame": try: self._process_frame(msg, result) except (AttributeError, KeyError): continue elif msg.type == "Mysqlx.Sql.StmtExecuteOk": return None elif msg.type == "Mysqlx.Resultset.FetchDone": result.set_closed(True) elif msg.type == "Mysqlx.Resultset.FetchDoneMoreResultsets": result.set_has_more_results(True) elif msg.type == "Mysqlx.Resultset.Row": result.set_has_data(True) break else: break return msg def set_compression(self, algorithm: str) -> None: """Sets the compression algorithm to be used by the compression object, for uplink and downlink. Args: algorithm (str): Algorithm to be used in compression/decompression. .. versionadded:: 8.0.21 """ self._compression_algorithm = algorithm self._reader.set_compression(algorithm) self._writer.set_compression(algorithm) def get_capabilites(self) -> MessageType: """Get capabilities. Returns: mysqlx.protobuf.Message: MySQL X Protobuf Message. """ msg = Message("Mysqlx.Connection.CapabilitiesGet") self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_GET"), msg, ) msg = self._reader.read_message() while msg.type == "Mysqlx.Notice.Frame": msg = self._reader.read_message() if msg.type == "Mysqlx.Error": raise OperationalError(msg["msg"], msg["code"]) return msg def set_capabilities(self, **kwargs: Any) -> None: """Set capabilities. Args: **kwargs: Arbitrary keyword arguments. Returns: mysqlx.protobuf.Message: MySQL X Protobuf Message. """ if not kwargs: return None capabilities = Message("Mysqlx.Connection.Capabilities") for key, value in kwargs.items(): capability = Message("Mysqlx.Connection.Capability") capability["name"] = key if isinstance(value, dict): items = value obj_flds = [] for item in items: obj_fld = Message( "Mysqlx.Datatypes.Object.ObjectField", key=item, value=self._create_any(items[item]), ) obj_flds.append(obj_fld.get_message()) msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds) msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj) capability["value"] = msg_any.get_message() else: capability["value"] = self._create_any(value) capabilities["capabilities"].extend([capability.get_message()]) msg = Message("Mysqlx.Connection.CapabilitiesSet") msg["capabilities"] = capabilities self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CAPABILITIES_SET"), msg, ) try: return self.read_ok() except InterfaceError as err: # Skip capability "session_connect_attrs" error since # is only available on version >= 8.0.16 if err.errno != 5002: raise return None def send_auth_start( self, method: str, auth_data: Optional[str] = None, initial_response: Optional[str] = None, ) -> None: """Send authenticate start. Args: method (str): Message method. auth_data (Optional[str]): Authentication data. initial_response (Optional[str]): Initial response. """ msg = Message("Mysqlx.Session.AuthenticateStart") msg["mech_name"] = method if auth_data is not None: msg["auth_data"] = auth_data if initial_response is not None: msg["initial_response"] = initial_response self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_START"), msg, ) def read_auth_continue(self) -> bytes: """Read authenticate continue. Raises: :class:`InterfaceError`: If the message type is not `Mysqlx.Session.AuthenticateContinue` Returns: str: The authentication data. """ msg = self._reader.read_message() while msg.type == "Mysqlx.Notice.Frame": msg = self._reader.read_message() if msg.type != "Mysqlx.Session.AuthenticateContinue": raise InterfaceError( "Unexpected message encountered during authentication handshake" ) return msg["auth_data"] def send_auth_continue(self, auth_data: str) -> None: """Send authenticate continue. Args: auth_data (str): Authentication data. """ msg = Message("Mysqlx.Session.AuthenticateContinue", auth_data=auth_data) self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_AUTHENTICATE_CONTINUE"), msg, ) def read_auth_ok(self) -> None: """Read authenticate OK. Raises: :class:`mysqlx.InterfaceError`: If message type is `Mysqlx.Error`. """ while True: msg = self._reader.read_message() if msg.type == "Mysqlx.Session.AuthenticateOk": break if msg.type == "Mysqlx.Error": raise InterfaceError(msg.msg) def send_prepare_prepare( self, msg_type: str, msg: MessageType, stmt: Union[ FindStatement, DeleteStatement, ModifyStatement, ReadStatement, RemoveStatement, UpdateStatement, ], ) -> None: """ Send prepare statement. Args: msg_type (str): Message ID string. msg (mysqlx.protobuf.Message): MySQL X Protobuf Message. stmt (Statement): A `Statement` based type object. Raises: :class:`mysqlx.NotSupportedError`: If prepared statements are not supported. .. versionadded:: 8.0.16 """ if stmt.has_limit and msg.type != "Mysqlx.Crud.Insert": # Remove 'limit' from message by building a new one if msg.type == "Mysqlx.Crud.Find": _, msg = self.build_find(stmt) # type: ignore[arg-type] elif msg.type == "Mysqlx.Crud.Update": _, msg = self.build_update(stmt) # type: ignore[arg-type] elif msg.type == "Mysqlx.Crud.Delete": _, msg = self.build_delete(stmt) # type: ignore[arg-type] else: raise ValueError(f"Invalid message type: {msg_type}") # Build 'limit_expr' message position = len(stmt.get_bindings()) placeholder = mysqlxpb_enum("Mysqlx.Expr.Expr.Type.PLACEHOLDER") msg_limit_expr = Message("Mysqlx.Crud.LimitExpr") msg_limit_expr["row_count"] = Message( "Mysqlx.Expr.Expr", type=placeholder, position=position ) if msg.type == "Mysqlx.Crud.Find": msg_limit_expr["offset"] = Message( "Mysqlx.Expr.Expr", type=placeholder, position=position + 1 ) msg["limit_expr"] = msg_limit_expr oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type] msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage") msg_oneof["type"] = mysqlxpb_enum(oneof_type) msg_oneof[oneof_op] = msg msg_prepare = Message("Mysqlx.Prepare.Prepare") msg_prepare["stmt_id"] = stmt.stmt_id msg_prepare["stmt"] = msg_oneof self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_PREPARE"), msg_prepare, ) try: self.read_ok() except InterfaceError as err: raise NotSupportedError from err def send_prepare_execute( self, msg_type: str, msg: MessageType, stmt: FilterableStatement ) -> None: """ Send execute statement. Args: msg_type (str): Message ID string. msg (mysqlx.protobuf.Message): MySQL X Protobuf Message. stmt (Statement): A `Statement` based type object. .. versionadded:: 8.0.16 """ oneof_type, oneof_op = CRUD_PREPARE_MAPPING[msg_type] msg_oneof = Message("Mysqlx.Prepare.Prepare.OneOfMessage") msg_oneof["type"] = mysqlxpb_enum(oneof_type) msg_oneof[oneof_op] = msg msg_execute = Message("Mysqlx.Prepare.Execute") msg_execute["stmt_id"] = stmt.stmt_id args = self._get_binding_args(stmt, is_scalar=False) if args: msg_execute["args"].extend(args) if stmt.has_limit: msg_execute["args"].extend( [ self._create_any(stmt.get_limit_row_count()).get_message(), self._create_any(stmt.get_limit_offset()).get_message(), ] ) self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_EXECUTE"), msg_execute, ) def send_prepare_deallocate(self, stmt_id: int) -> None: """ Send prepare deallocate statement. Args: stmt_id (int): Statement ID. .. versionadded:: 8.0.16 """ msg_dealloc = Message("Mysqlx.Prepare.Deallocate") msg_dealloc["stmt_id"] = stmt_id self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.PREPARE_DEALLOCATE"), msg_dealloc, ) self.read_ok() def send_msg_without_ps( self, msg_type: str, msg: MessageType, stmt: Union[FilterableStatement, SqlStatement], ) -> None: """ Send a message without prepared statements support. Args: msg_type (str): Message ID string. msg (mysqlx.protobuf.Message): MySQL X Protobuf Message. stmt (Statement): A `Statement` based type object. .. versionadded:: 8.0.16 """ if stmt.has_limit: msg_limit = Message("Mysqlx.Crud.Limit") msg_limit["row_count"] = stmt.get_limit_row_count() # type: ignore[union-attr] if msg.type == "Mysqlx.Crud.Find": msg_limit["offset"] = stmt.get_limit_offset() # type: ignore[union-attr] msg["limit"] = msg_limit is_scalar = msg_type != "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE" args = self._get_binding_args(stmt, is_scalar=is_scalar) if args: msg["args"].extend(args) self.send_msg(msg_type, msg) def send_msg(self, msg_type: str, msg: MessageType) -> None: """ Send a message. Args: msg_type (str): Message ID string. msg (mysqlx.protobuf.Message): MySQL X Protobuf Message. .. versionadded:: 8.0.16 """ self._writer.write_message(mysqlxpb_enum(msg_type), msg) def build_find( self, stmt: Union[FindStatement, ReadStatement] ) -> Tuple[str, MessageType]: """Build find/read message. Args: stmt (Statement): A :class:`mysqlx.ReadStatement` or :class:`mysqlx.FindStatement` object. Returns: (tuple): Tuple containing: * `str`: Message ID string. * :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message. .. versionadded:: 8.0.16 """ data_model = mysqlxpb_enum( "Mysqlx.Crud.DataModel.DOCUMENT" if stmt.is_doc_based() else "Mysqlx.Crud.DataModel.TABLE" ) collection = Message( "Mysqlx.Crud.Collection", name=stmt.target.name, schema=stmt.schema.name, ) msg = Message("Mysqlx.Crud.Find", data_model=data_model, collection=collection) if stmt.has_projection: msg["projection"] = stmt.get_projection_expr() self._apply_filter(msg, stmt) if stmt.is_lock_exclusive(): msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.EXCLUSIVE_LOCK") elif stmt.is_lock_shared(): msg["locking"] = mysqlxpb_enum("Mysqlx.Crud.Find.RowLock.SHARED_LOCK") if stmt.lock_contention.value > 0: msg["locking_options"] = stmt.lock_contention.value return "Mysqlx.ClientMessages.Type.CRUD_FIND", msg def build_update( self, stmt: Union[ModifyStatement, UpdateStatement] ) -> Tuple[str, MessageType]: """Build update message. Args: stmt (Statement): A :class:`mysqlx.ModifyStatement` or :class:`mysqlx.UpdateStatement` object. Returns: (tuple): Tuple containing: * `str`: Message ID string. * :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message. .. versionadded:: 8.0.16 """ data_model = mysqlxpb_enum( "Mysqlx.Crud.DataModel.DOCUMENT" if stmt.is_doc_based() else "Mysqlx.Crud.DataModel.TABLE" ) collection = Message( "Mysqlx.Crud.Collection", name=stmt.target.name, schema=stmt.schema.name, ) msg = Message( "Mysqlx.Crud.Update", data_model=data_model, collection=collection ) self._apply_filter(msg, stmt) for _, update_op in stmt.get_update_ops().items(): operation = Message("Mysqlx.Crud.UpdateOperation") operation["operation"] = update_op.update_type operation["source"] = update_op.source if update_op.value is not None: operation["value"] = build_expr(update_op.value) msg["operation"].extend([operation.get_message()]) return "Mysqlx.ClientMessages.Type.CRUD_UPDATE", msg def build_delete( self, stmt: Union[DeleteStatement, RemoveStatement] ) -> Tuple[str, MessageType]: """Build delete message. Args: stmt (Statement): A :class:`mysqlx.DeleteStatement` or :class:`mysqlx.RemoveStatement` object. Returns: (tuple): Tuple containing: * `str`: Message ID string. * :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message. .. versionadded:: 8.0.16 """ data_model = mysqlxpb_enum( "Mysqlx.Crud.DataModel.DOCUMENT" if stmt.is_doc_based() else "Mysqlx.Crud.DataModel.TABLE" ) collection = Message( "Mysqlx.Crud.Collection", name=stmt.target.name, schema=stmt.schema.name, ) msg = Message( "Mysqlx.Crud.Delete", data_model=data_model, collection=collection ) self._apply_filter(msg, stmt) return "Mysqlx.ClientMessages.Type.CRUD_DELETE", msg def build_execute_statement( self, namespace: str, stmt: Union[str, StatementType], fields: Optional[Dict[str, Any]] = None, ) -> Tuple[str, MessageType]: """Build execute statement. Args: namespace (str): The namespace. stmt (Statement): A `Statement` based type object. fields (Optional[dict]): The message fields. Returns: (tuple): Tuple containing: * `str`: Message ID string. * :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message. .. versionadded:: 8.0.16 """ msg = Message( "Mysqlx.Sql.StmtExecute", namespace=namespace, stmt=stmt, compact_metadata=False, ) if fields: obj_flds = [] for key, value in fields.items(): obj_fld = Message( "Mysqlx.Datatypes.Object.ObjectField", key=key, value=self._create_any(value), ) obj_flds.append(obj_fld.get_message()) msg_obj = Message("Mysqlx.Datatypes.Object", fld=obj_flds) msg_any = Message("Mysqlx.Datatypes.Any", type=2, obj=msg_obj) msg["args"] = [msg_any.get_message()] return "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE", msg @staticmethod def build_insert( stmt: Union[AddStatement, InsertStatement] ) -> Tuple[str, MessageType]: """Build insert statement. Args: stmt (Statement): A :class:`mysqlx.AddStatement` or :class:`mysqlx.InsertStatement` object. Returns: (tuple): Tuple containing: * `str`: Message ID string. * :class:`mysqlx.protobuf.Message`: MySQL X Protobuf Message. .. versionadded:: 8.0.16 """ data_model = mysqlxpb_enum( "Mysqlx.Crud.DataModel.DOCUMENT" if stmt.is_doc_based() else "Mysqlx.Crud.DataModel.TABLE" ) collection = Message( "Mysqlx.Crud.Collection", name=stmt.target.name, schema=stmt.schema.name, ) msg = Message( "Mysqlx.Crud.Insert", data_model=data_model, collection=collection ) if hasattr(stmt, "_fields"): for field in stmt._fields: expr = ExprParser( field, not stmt.is_doc_based() ).parse_table_insert_field() msg["projection"].extend([expr.get_message()]) for value in stmt.get_values(): row = Message("Mysqlx.Crud.Insert.TypedRow") if isinstance(value, list): for val in value: row["field"].extend([build_expr(val).get_message()]) else: row["field"].extend([build_expr(value).get_message()]) msg["row"].extend([row.get_message()]) if hasattr(stmt, "is_upsert"): msg["upsert"] = stmt.is_upsert() return "Mysqlx.ClientMessages.Type.CRUD_INSERT", msg def close_result(self, result: ResultBaseType) -> None: """Close the result. Args: result (Result): A `Result` based type object. Raises: :class:`mysqlx.OperationalError`: If message read is None. """ msg = self._read_message(result) if msg is not None: raise OperationalError("Expected to close the result") def read_row(self, result: ResultBaseType) -> Optional[MessageType]: """Read row. Args: result (Result): A `Result` based type object. """ msg = self._read_message(result) if msg is None: return None if msg.type == "Mysqlx.Resultset.Row": return msg self._reader.push_message(msg) return None def get_column_metadata(self, result: ResultBaseType) -> List[ColumnType]: """Returns column metadata. Args: result (Result): A `Result` based type object. Raises: :class:`mysqlx.InterfaceError`: If unexpected message. """ columns = [] while True: msg = self._read_message(result) if msg is None: break if msg.type == "Mysqlx.Resultset.Row": self._reader.push_message(msg) break if msg.type != "Mysqlx.Resultset.ColumnMetaData": raise InterfaceError("Unexpected msg type") col = Column( msg["type"], msg["catalog"], msg["schema"], msg["table"], msg["original_table"], msg["name"], msg["original_name"], msg.get("length", 21), msg.get("collation", 0), msg.get("fractional_digits", 0), msg.get("flags", 16), msg.get("content_type"), ) columns.append(col) return columns def read_ok(self) -> None: """Read OK. Raises: :class:`mysqlx.InterfaceError`: If unexpected message. """ msg = self._reader.read_message() if msg.type == "Mysqlx.Error": raise InterfaceError(f"Mysqlx.Error: {msg['msg']}", errno=msg["code"]) if msg.type != "Mysqlx.Ok": raise InterfaceError("Unexpected message encountered") def send_connection_close(self) -> None: """Send connection close.""" msg = Message("Mysqlx.Connection.Close") self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.CON_CLOSE"), msg ) def send_close(self) -> None: """Send close.""" msg = Message("Mysqlx.Session.Close") self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_CLOSE"), msg ) def send_expect_open(self) -> None: """Send expectation.""" cond_key = mysqlxpb_enum("Mysqlx.Expect.Open.Condition.Key.EXPECT_FIELD_EXIST") msg_oc = Message("Mysqlx.Expect.Open.Condition") msg_oc["condition_key"] = cond_key msg_oc["condition_value"] = "6.1" msg_eo = Message("Mysqlx.Expect.Open") msg_eo["cond"] = [msg_oc.get_message()] self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.EXPECT_OPEN"), msg_eo ) def send_reset(self, keep_open: Optional[bool] = None) -> bool: """Send reset session message. Returns: boolean: ``True`` if the server will keep the session open, otherwise ``False``. """ msg = Message("Mysqlx.Session.Reset") if keep_open is None: try: # Send expectation: keep connection open self.send_expect_open() self.read_ok() keep_open = True except InterfaceError: # Expectation is unkown by this version of the server keep_open = False if keep_open: msg["keep_open"] = True self._writer.write_message( mysqlxpb_enum("Mysqlx.ClientMessages.Type.SESS_RESET"), msg ) self.read_ok() if keep_open: return True return False