# Copyright (c) 2016, 2022, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # mypy: disable-error-code="return-value" """Implementation of Statements.""" from __future__ import annotations import copy import json import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from .constants import LockContention from .dbdoc import DbDoc from .errors import NotSupportedError, ProgrammingError from .expr import ExprParser from .helpers import deprecated from .protobuf import mysqlxpb_enum from .result import DocResult, Result, RowResult, SqlResult from .types import ( ConnectionType, DatabaseTargetType, MessageType, ProtobufMessageCextType, ProtobufMessageType, SchemaType, ) ERR_INVALID_INDEX_NAME = 'The given index name "{}" is not valid' class Expr: """Expression wrapper.""" def __init__(self, expr: Any) -> None: self.expr: Any = expr def flexible_params(*values: Any) -> Union[List, Tuple]: """Parse flexible parameters.""" if len(values) == 1 and isinstance(values[0], (list, tuple)): return values[0] return values def is_quoted_identifier(identifier: str, sql_mode: str = "") -> bool: """Check if the given identifier is quoted. Args: identifier (string): Identifier to check. sql_mode (Optional[string]): SQL mode. Returns: `True` if the identifier has backtick quotes, and False otherwise. """ if "ANSI_QUOTES" in sql_mode: return (identifier[0] == "`" and identifier[-1] == "`") or ( identifier[0] == '"' and identifier[-1] == '"' ) return identifier[0] == "`" and identifier[-1] == "`" def quote_identifier(identifier: str, sql_mode: str = "") -> str: """Quote the given identifier with backticks, converting backticks (`) in the identifier name with the correct escape sequence (``). Args: identifier (string): Identifier to quote. sql_mode (Optional[string]): SQL mode. Returns: A string with the identifier quoted with backticks. """ if len(identifier) == 0: return "``" if "ANSI_QUOTES" in sql_mode: quoted = identifier.replace('"', '""') return f'"{quoted}"' quoted = identifier.replace("`", "``") return f"`{quoted}`" def quote_multipart_identifier(identifiers: Iterable[str], sql_mode: str = "") -> str: """Quote the given multi-part identifier with backticks. Args: identifiers (iterable): List of identifiers to quote. sql_mode (Optional[string]): SQL mode. Returns: A string with the multi-part identifier quoted with backticks. """ return ".".join( [quote_identifier(identifier, sql_mode) for identifier in identifiers] ) def parse_table_name( default_schema: str, table_name: str, sql_mode: str = "" ) -> Tuple[str, str]: """Parse table name. Args: default_schema (str): The default schema. table_name (str): The table name. sql_mode(Optional[str]): The SQL mode. Returns: str: The parsed table name. """ quote = '"' if "ANSI_QUOTES" in sql_mode else "`" delimiter = f".{quote}" if quote in table_name else "." temp = table_name.split(delimiter, 1) return ( default_schema if len(temp) == 1 else temp[0].strip(quote), temp[-1].strip(quote), ) class Statement: """Provides base functionality for statement objects. Args: target (object): The target database object, it can be :class:`mysqlx.Collection` or :class:`mysqlx.Table`. doc_based (bool): `True` if it is document based. """ def __init__(self, target: DatabaseTargetType, doc_based: bool = True) -> None: self._target: DatabaseTargetType = target self._doc_based: bool = doc_based self._connection: Optional[ConnectionType] = ( target.get_connection() if target else None ) self._stmt_id: Optional[int] = None self._exec_counter: int = 0 self._changed: bool = True self._prepared: bool = False self._deallocate_prepare_execute: bool = False @property def target(self) -> DatabaseTargetType: """object: The database object target.""" return self._target @property def schema(self) -> SchemaType: """:class:`mysqlx.Schema`: The Schema object.""" return self._target.schema @property def stmt_id(self) -> int: """Returns this statement ID. Returns: int: The statement ID. """ return self._stmt_id @stmt_id.setter def stmt_id(self, value: int) -> None: self._stmt_id = value @property def exec_counter(self) -> int: """int: The number of times this statement was executed.""" return self._exec_counter @property def changed(self) -> bool: """bool: `True` if this statement has changes.""" return self._changed @changed.setter def changed(self, value: bool) -> None: self._changed = value @property def prepared(self) -> bool: """bool: `True` if this statement has been prepared.""" return self._prepared @prepared.setter def prepared(self, value: bool) -> None: self._prepared = value @property def repeated(self) -> bool: """bool: `True` if this statement was executed more than once.""" return self._exec_counter > 1 @property def deallocate_prepare_execute(self) -> bool: """bool: `True` to deallocate + prepare + execute statement.""" return self._deallocate_prepare_execute @deallocate_prepare_execute.setter def deallocate_prepare_execute(self, value: bool) -> None: self._deallocate_prepare_execute = value def is_doc_based(self) -> bool: """Check if it is document based. Returns: bool: `True` if it is document based. """ return self._doc_based def increment_exec_counter(self) -> None: """Increments the number of times this statement has been executed.""" self._exec_counter += 1 def reset_exec_counter(self) -> None: """Resets the number of times this statement has been executed.""" self._exec_counter = 0 def execute(self) -> Any: """Execute the statement. Raises: NotImplementedError: This method must be implemented. """ raise NotImplementedError class FilterableStatement(Statement): """A statement to be used with filterable statements. Args: target (object): The target database object, it can be :class:`mysqlx.Collection` or :class:`mysqlx.Table`. doc_based (Optional[bool]): `True` if it is document based (default: `True`). condition (Optional[str]): Sets the search condition to filter documents or records. """ def __init__( self, target: DatabaseTargetType, doc_based: bool = True, condition: Optional[str] = None, ) -> None: super().__init__(target=target, doc_based=doc_based) self._binding_map: Dict[str, Any] = {} self._bindings: Union[Dict[str, Any], List] = {} self._having: Optional[MessageType] = None self._grouping_str: str = "" self._grouping: Optional[ List[Union[ProtobufMessageType, ProtobufMessageCextType]] ] = None self._limit_offset: int = 0 self._limit_row_count: int = None self._projection_str: str = "" self._projection_expr: Optional[ List[Union[ProtobufMessageType, ProtobufMessageCextType]] ] = None self._sort_str: str = "" self._sort_expr: Optional[ List[Union[ProtobufMessageType, ProtobufMessageCextType]] ] = None self._where_str: str = "" self._where_expr: MessageType = None self.has_bindings: bool = False self.has_limit: bool = False self.has_group_by: bool = False self.has_having: bool = False self.has_projection: bool = False self.has_sort: bool = False self.has_where: bool = False if condition: self._set_where(condition) def _bind_single(self, obj: Union[DbDoc, Dict[str, Any], str]) -> None: """Bind single object. Args: obj (:class:`mysqlx.DbDoc` or str): DbDoc or JSON string object. Raises: :class:`mysqlx.ProgrammingError`: If invalid JSON string to bind. ValueError: If JSON loaded is not a dictionary. """ if isinstance(obj, dict): self.bind(DbDoc(obj).as_str()) elif isinstance(obj, DbDoc): self.bind(obj.as_str()) elif isinstance(obj, str): try: res = json.loads(obj) if not isinstance(res, dict): raise ValueError except ValueError as err: raise ProgrammingError("Invalid JSON string to bind") from err for key in res.keys(): self.bind(key, res[key]) else: raise ProgrammingError("Invalid JSON string or object to bind") def _sort(self, *clauses: str) -> FilterableStatement: """Sets the sorting criteria. Args: *clauses: The expression strings defining the sort criteria. Returns: mysqlx.FilterableStatement: FilterableStatement object. """ self.has_sort = True self._sort_str = ",".join(flexible_params(*clauses)) self._sort_expr = ExprParser( self._sort_str, not self._doc_based ).parse_order_spec() self._changed = True return self def _set_where(self, condition: str) -> FilterableStatement: """Sets the search condition to filter. Args: condition (str): Sets the search condition to filter documents or records. Returns: mysqlx.FilterableStatement: FilterableStatement object. """ self.has_where = True self._where_str = condition try: expr = ExprParser(condition, not self._doc_based) self._where_expr = expr.expr() except ValueError as err: raise ProgrammingError("Invalid condition") from err self._binding_map = expr.placeholder_name_to_position self._changed = True return self def _set_group_by(self, *fields: str) -> None: """Set group by. Args: *fields: List of fields. """ fields = flexible_params(*fields) self.has_group_by = True self._grouping_str = ",".join(fields) self._grouping = ExprParser( self._grouping_str, not self._doc_based ).parse_expr_list() self._changed = True def _set_having(self, condition: str) -> None: """Set having. Args: condition (str): The condition. """ self.has_having = True self._having = ExprParser(condition, not self._doc_based).expr() self._changed = True def _set_projection(self, *fields: str) -> FilterableStatement: """Set the projection. Args: *fields: List of fields. Returns: :class:`mysqlx.FilterableStatement`: Returns self. """ fields = flexible_params(*fields) self.has_projection = True self._projection_str = ",".join(fields) self._projection_expr = ExprParser( self._projection_str, not self._doc_based ).parse_table_select_projection() self._changed = True return self def get_binding_map(self) -> Dict[str, Any]: """Returns the binding map dictionary. Returns: dict: The binding map dictionary. """ return self._binding_map def get_bindings(self) -> Union[Dict[str, Any], List]: """Returns the bindings list. Returns: `list`: The bindings list. """ return self._bindings def get_grouping(self) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]: """Returns the grouping expression list. Returns: `list`: The grouping expression list. """ return self._grouping def get_having(self) -> MessageType: """Returns the having expression. Returns: object: The having expression. """ return self._having def get_limit_row_count(self) -> int: """Returns the limit row count. Returns: int: The limit row count. """ return self._limit_row_count def get_limit_offset(self) -> int: """Returns the limit offset. Returns: int: The limit offset. """ return self._limit_offset def get_where_expr(self) -> MessageType: """Returns the where expression. Returns: object: The where expression. """ return self._where_expr def get_projection_expr( self, ) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]: """Returns the projection expression. Returns: object: The projection expression. """ return self._projection_expr def get_sort_expr( self, ) -> List[Union[ProtobufMessageType, ProtobufMessageCextType]]: """Returns the sort expression. Returns: object: The sort expression. """ return self._sort_expr @deprecated("8.0.12") def where(self, condition: str) -> FilterableStatement: """Sets the search condition to filter. Args: condition (str): Sets the search condition to filter documents or records. Returns: mysqlx.FilterableStatement: FilterableStatement object. .. deprecated:: 8.0.12 """ return self._set_where(condition) @deprecated("8.0.12") def sort(self, *clauses: str) -> FilterableStatement: """Sets the sorting criteria. Args: *clauses: The expression strings defining the sort criteria. Returns: mysqlx.FilterableStatement: FilterableStatement object. .. deprecated:: 8.0.12 """ return self._sort(*clauses) def limit( self, row_count: int, offset: Optional[int] = None ) -> FilterableStatement: """Sets the maximum number of items to be returned. Args: row_count (int): The maximum number of items. Returns: mysqlx.FilterableStatement: FilterableStatement object. Raises: ValueError: If ``row_count`` is not a positive integer. .. versionchanged:: 8.0.12 The usage of ``offset`` was deprecated. """ if not isinstance(row_count, int) or row_count < 0: raise ValueError("The 'row_count' value must be a positive integer") if not self.has_limit: self._changed = bool(self._exec_counter == 0) self._deallocate_prepare_execute = bool(not self._exec_counter == 0) self._limit_row_count = row_count self.has_limit = True if offset: self.offset(offset) warnings.warn( "'limit(row_count, offset)' is deprecated, please " "use 'offset(offset)' to set the number of items to " "skip", category=DeprecationWarning, ) return self def offset(self, offset: int) -> FilterableStatement: """Sets the number of items to skip. Args: offset (int): The number of items to skip. Returns: mysqlx.FilterableStatement: FilterableStatement object. Raises: ValueError: If ``offset`` is not a positive integer. .. versionadded:: 8.0.12 """ if not isinstance(offset, int) or offset < 0: raise ValueError("The 'offset' value must be a positive integer") self._limit_offset = offset return self def bind(self, *args: Any) -> FilterableStatement: """Binds value(s) to a specific placeholder(s). Args: *args: The name of the placeholder and the value to bind. A :class:`mysqlx.DbDoc` object or a JSON string representation can be used. Returns: mysqlx.FilterableStatement: FilterableStatement object. Raises: ProgrammingError: If the number of arguments is invalid. """ self.has_bindings = True count = len(args) if count == 1: self._bind_single(args[0]) elif count == 2: self._bindings[args[0]] = args[1] else: raise ProgrammingError("Invalid number of arguments to bind") return self def execute(self) -> Any: """Execute the statement. Raises: NotImplementedError: This method must be implemented. """ raise NotImplementedError class SqlStatement(Statement): """A statement for SQL execution. Args: connection (mysqlx.connection.Connection): Connection object. sql (string): The sql statement to be executed. """ def __init__(self, connection: ConnectionType, sql: str) -> None: super().__init__(target=None, doc_based=False) self._connection: ConnectionType = connection self._sql: str = sql self._binding_map: Optional[Dict[str, Any]] = None self._bindings: Union[List, Tuple] = [] self.has_bindings: bool = False self.has_limit: bool = False @property def sql(self) -> str: """string: The SQL text statement.""" return self._sql def get_binding_map(self) -> Dict[str, Any]: """Returns the binding map dictionary. Returns: dict: The binding map dictionary. """ return self._binding_map def get_bindings(self) -> Union[Tuple, List]: """Returns the bindings list. Returns: `list`: The bindings list. """ return self._bindings def bind(self, *args: Any) -> SqlStatement: """Binds value(s) to a specific placeholder(s). Args: *args: The value(s) to bind. Returns: mysqlx.SqlStatement: SqlStatement object. """ if len(args) == 0: raise ProgrammingError("Invalid number of arguments to bind") self.has_bindings = True bindings = flexible_params(*args) if isinstance(bindings, (list, tuple)): self._bindings = bindings else: self._bindings.append(bindings) return self def execute(self) -> SqlResult: """Execute the statement. Returns: mysqlx.SqlResult: SqlResult object. """ return self._connection.send_sql(self) class WriteStatement(Statement): """Provide common write operation attributes.""" def __init__(self, target: DatabaseTargetType, doc_based: bool) -> None: super().__init__(target, doc_based) self._values: List[ Union[ int, str, DbDoc, Dict[str, Any], List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]], ] ] = [] def get_values( self, ) -> List[ Union[ int, str, DbDoc, Dict[str, Any], List[Optional[Union[str, int, float, ExprParser, Dict[str, Any]]]], ] ]: """Returns the list of values. Returns: `list`: The list of values. """ return self._values def execute(self) -> Any: """Execute the statement. Raises: NotImplementedError: This method must be implemented. """ raise NotImplementedError class AddStatement(WriteStatement): """A statement for document addition on a collection. Args: collection (mysqlx.Collection): The Collection object. """ def __init__(self, collection: DatabaseTargetType) -> None: super().__init__(collection, True) self._upsert: bool = False self.ids: List = [] def is_upsert(self) -> bool: """Returns `True` if it's an upsert. Returns: bool: `True` if it's an upsert. """ return self._upsert def upsert(self, value: bool = True) -> AddStatement: """Sets the upset flag to the boolean of the value provided. Setting of this flag allows updating of the matched rows/documents with the provided value. Args: value (optional[bool]): Set or unset the upsert flag. """ self._upsert = value return self def add(self, *values: DbDoc) -> AddStatement: """Adds a list of documents into a collection. Args: *values: The documents to be added into the collection. Returns: mysqlx.AddStatement: AddStatement object. """ for val in flexible_params(*values): if isinstance(val, DbDoc): self._values.append(val) else: self._values.append(DbDoc(val)) return self def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object. """ if len(self._values) == 0: return Result() return self._connection.send_insert(self) class UpdateSpec: """Update specification class implementation. Args: update_type (int): The update type. source (str): The source. value (Optional[str]): The value. Raises: ProgrammingError: If `source` is invalid. """ def __init__(self, update_type: int, source: str, value: Any = None) -> None: if update_type == mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"): self._table_set(source, value) else: self.update_type: int = update_type try: self.source: Any = ExprParser(source, False).document_field().identifier except ValueError as err: raise ProgrammingError(f"{err}") from err self.value: Any = value def _table_set(self, source: str, value: Any) -> None: """Table set. Args: source (str): The source. value (str): The value. """ self.update_type = mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET") self.source = ExprParser(source, True).parse_table_update_field() self.value = value class ModifyStatement(FilterableStatement): """A statement for document update operations on a Collection. Args: collection (mysqlx.Collection): The Collection object. condition (str): Sets the search condition to identify the documents to be modified. .. versionchanged:: 8.0.12 The ``condition`` parameter is now mandatory. """ def __init__(self, collection: DatabaseTargetType, condition: str) -> None: super().__init__(target=collection, condition=condition) self._update_ops: Dict[str, Any] = {} def sort(self, *clauses: str) -> ModifyStatement: """Sets the sorting criteria. Args: *clauses: The expression strings defining the sort criteria. Returns: mysqlx.ModifyStatement: ModifyStatement object. """ return self._sort(*clauses) def get_update_ops(self) -> Dict[str, Any]: """Returns the list of update operations. Returns: `list`: The list of update operations. """ return self._update_ops def set(self, doc_path: str, value: Any) -> ModifyStatement: """Sets or updates attributes on documents in a collection. Args: doc_path (string): The document path of the item to be set. value (string): The value to be set on the specified attribute. Returns: mysqlx.ModifyStatement: ModifyStatement object. """ self._update_ops[doc_path] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_SET"), doc_path, value, ) self._changed = True return self @deprecated("8.0.12") def change(self, doc_path: str, value: Any) -> ModifyStatement: """Add an update to the statement setting the field, if it exists at the document path, to the given value. Args: doc_path (string): The document path of the item to be set. value (object): The value to be set on the specified attribute. Returns: mysqlx.ModifyStatement: ModifyStatement object. .. deprecated:: 8.0.12 """ self._update_ops[doc_path] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REPLACE"), doc_path, value, ) self._changed = True return self def unset(self, *doc_paths: str) -> ModifyStatement: """Removes attributes from documents in a collection. Args: doc_paths (list): The list of document paths of the attributes to be removed. Returns: mysqlx.ModifyStatement: ModifyStatement object. """ for item in flexible_params(*doc_paths): self._update_ops[item] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REMOVE"), item, ) self._changed = True return self def array_insert(self, field: str, value: Any) -> ModifyStatement: """Insert a value into the specified array in documents of a collection. Args: field (string): A document path that identifies the array attribute and position where the value will be inserted. value (object): The value to be inserted. Returns: mysqlx.ModifyStatement: ModifyStatement object. """ self._update_ops[field] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_INSERT"), field, value, ) self._changed = True return self def array_append(self, doc_path: str, value: Any) -> ModifyStatement: """Inserts a value into a specific position in an array attribute in documents of a collection. Args: doc_path (string): A document path that identifies the array attribute and position where the value will be inserted. value (object): The value to be inserted. Returns: mysqlx.ModifyStatement: ModifyStatement object. """ self._update_ops[doc_path] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_APPEND"), doc_path, value, ) self._changed = True return self def patch(self, doc: Union[Dict, DbDoc, ExprParser, str]) -> ModifyStatement: """Takes a :class:`mysqlx.DbDoc`, string JSON format or a dict with the changes and applies it on all matching documents. Args: doc (object): A generic document (DbDoc), string in JSON format or dict, with the changes to apply to the matching documents. Returns: mysqlx.ModifyStatement: ModifyStatement object. """ if doc is None: doc = "" if not isinstance(doc, (ExprParser, dict, DbDoc, str)): raise ProgrammingError( "Invalid data for update operation on document collection table" ) self._update_ops["patch"] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.MERGE_PATCH"), "$", doc.expr() if isinstance(doc, ExprParser) else doc, ) self._changed = True return self def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object. Raises: ProgrammingError: If condition was not set. """ if not self.has_where: raise ProgrammingError("No condition was found for modify") return self._connection.send_update(self) class ReadStatement(FilterableStatement): """Provide base functionality for Read operations Args: target (object): The target database object, it can be :class:`mysqlx.Collection` or :class:`mysqlx.Table`. doc_based (Optional[bool]): `True` if it is document based (default: `True`). condition (Optional[str]): Sets the search condition to filter documents or records. """ def __init__( self, target: DatabaseTargetType, doc_based: bool = True, condition: Optional[str] = None, ) -> None: super().__init__(target, doc_based, condition) self._lock_exclusive: bool = False self._lock_shared: bool = False self._lock_contention: LockContention = LockContention.DEFAULT @property def lock_contention(self) -> LockContention: """:class:`mysqlx.LockContention`: The lock contention value.""" return self._lock_contention def _set_lock_contention(self, lock_contention: LockContention) -> None: """Set the lock contention. Args: lock_contention (:class:`mysqlx.LockContention`): Lock contention. Raises: ProgrammingError: If is an invalid lock contention value. """ try: # Check if is a valid lock contention value _ = LockContention(lock_contention.value) except ValueError as err: raise ProgrammingError( "Invalid lock contention mode. Use 'NOWAIT' or 'SKIP_LOCKED'" ) from err self._lock_contention = lock_contention def is_lock_exclusive(self) -> bool: """Returns `True` if is `EXCLUSIVE LOCK`. Returns: bool: `True` if is `EXCLUSIVE LOCK`. """ return self._lock_exclusive def is_lock_shared(self) -> bool: """Returns `True` if is `SHARED LOCK`. Returns: bool: `True` if is `SHARED LOCK`. """ return self._lock_shared def lock_shared( self, lock_contention: LockContention = LockContention.DEFAULT ) -> ReadStatement: """Execute a read operation with `SHARED LOCK`. Only one lock can be active at a time. Args: lock_contention (:class:`mysqlx.LockContention`): Lock contention. """ self._lock_exclusive = False self._lock_shared = True self._set_lock_contention(lock_contention) return self def lock_exclusive( self, lock_contention: LockContention = LockContention.DEFAULT ) -> ReadStatement: """Execute a read operation with `EXCLUSIVE LOCK`. Only one lock can be active at a time. Args: lock_contention (:class:`mysqlx.LockContention`): Lock contention. """ self._lock_exclusive = True self._lock_shared = False self._set_lock_contention(lock_contention) return self def group_by(self, *fields: str) -> ReadStatement: """Sets a grouping criteria for the resultset. Args: *fields: The string expressions identifying the grouping criteria. Returns: mysqlx.ReadStatement: ReadStatement object. """ self._set_group_by(*fields) return self def having(self, condition: str) -> ReadStatement: """Sets a condition for records to be considered in agregate function operations. Args: condition (string): A condition on the agregate functions used on the grouping criteria. Returns: mysqlx.ReadStatement: ReadStatement object. """ self._set_having(condition) return self def execute(self) -> Union[DocResult, RowResult]: """Execute the statement. Returns: mysqlx.Result: Result object. """ return self._connection.send_find(self) class FindStatement(ReadStatement): """A statement document selection on a Collection. Args: collection (mysqlx.Collection): The Collection object. condition (Optional[str]): An optional expression to identify the documents to be retrieved. If not specified all the documents will be included on the result unless a limit is set. """ def __init__( self, collection: DatabaseTargetType, condition: Optional[str] = None ) -> None: super().__init__(collection, True, condition) def fields(self, *fields: str) -> FindStatement: """Sets a document field filter. Args: *fields: The string expressions identifying the fields to be extracted. Returns: mysqlx.FindStatement: FindStatement object. """ return self._set_projection(*fields) def sort(self, *clauses: str) -> FindStatement: """Sets the sorting criteria. Args: *clauses: The expression strings defining the sort criteria. Returns: mysqlx.FindStatement: FindStatement object. """ return self._sort(*clauses) class SelectStatement(ReadStatement): """A statement for record retrieval operations on a Table. Args: table (mysqlx.Table): The Table object. *fields: The fields to be retrieved. """ def __init__(self, table: DatabaseTargetType, *fields: str) -> None: super().__init__(table, False) self._set_projection(*fields) def where(self, condition: str) -> SelectStatement: """Sets the search condition to filter. Args: condition (str): Sets the search condition to filter records. Returns: mysqlx.SelectStatement: SelectStatement object. """ return self._set_where(condition) def order_by(self, *clauses: str) -> SelectStatement: """Sets the order by criteria. Args: *clauses: The expression strings defining the order by criteria. Returns: mysqlx.SelectStatement: SelectStatement object. """ return self._sort(*clauses) def get_sql(self) -> str: """Returns the generated SQL. Returns: str: The generated SQL. """ where = f" WHERE {self._where_str}" if self.has_where else "" group_by = f" GROUP BY {self._grouping_str}" if self.has_group_by else "" having = f" HAVING {self._having}" if self.has_having else "" order_by = f" ORDER BY {self._sort_str}" if self.has_sort else "" limit = ( f" LIMIT {self._limit_row_count} OFFSET {self._limit_offset}" if self.has_limit else "" ) stmt = ( f"SELECT {self._projection_str or '*'} " f"FROM {self.schema.name}.{self.target.name}" f"{where}{group_by}{having}{order_by}{limit}" ) return stmt class InsertStatement(WriteStatement): """A statement for insert operations on Table. Args: table (mysqlx.Table): The Table object. *fields: The fields to be inserted. """ def __init__(self, table: DatabaseTargetType, *fields: Any) -> None: super().__init__(table, False) self._fields: Union[List, Tuple] = flexible_params(*fields) def values(self, *values: Any) -> InsertStatement: """Set the values to be inserted. Args: *values: The values of the columns to be inserted. Returns: mysqlx.InsertStatement: InsertStatement object. """ self._values.append(list(flexible_params(*values))) return self def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object. """ return self._connection.send_insert(self) class UpdateStatement(FilterableStatement): """A statement for record update operations on a Table. Args: table (mysqlx.Table): The Table object. .. versionchanged:: 8.0.12 The ``fields`` parameters were removed. """ def __init__(self, table: DatabaseTargetType) -> None: super().__init__(target=table, doc_based=False) self._update_ops: Dict[str, Any] = {} def where(self, condition: str) -> UpdateStatement: """Sets the search condition to filter. Args: condition (str): Sets the search condition to filter records. Returns: mysqlx.UpdateStatement: UpdateStatement object. """ return self._set_where(condition) def order_by(self, *clauses: str) -> UpdateStatement: """Sets the order by criteria. Args: *clauses: The expression strings defining the order by criteria. Returns: mysqlx.UpdateStatement: UpdateStatement object. """ return self._sort(*clauses) def get_update_ops(self) -> Dict[str, Any]: """Returns the list of update operations. Returns: `list`: The list of update operations. """ return self._update_ops def set(self, field: str, value: Any) -> UpdateStatement: """Updates the column value on records in a table. Args: field (string): The column name to be updated. value (object): The value to be set on the specified column. Returns: mysqlx.UpdateStatement: UpdateStatement object. """ self._update_ops[field] = UpdateSpec( mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.SET"), field, value, ) self._changed = True return self def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object Raises: ProgrammingError: If condition was not set. """ if not self.has_where: raise ProgrammingError("No condition was found for update") return self._connection.send_update(self) class RemoveStatement(FilterableStatement): """A statement for document removal from a collection. Args: collection (mysqlx.Collection): The Collection object. condition (str): Sets the search condition to identify the documents to be removed. .. versionchanged:: 8.0.12 The ``condition`` parameter was added. """ def __init__(self, collection: DatabaseTargetType, condition: str) -> None: super().__init__(target=collection, condition=condition) def sort(self, *clauses: str) -> RemoveStatement: """Sets the sorting criteria. Args: *clauses: The expression strings defining the sort criteria. Returns: mysqlx.FindStatement: FindStatement object. """ return self._sort(*clauses) def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object. Raises: ProgrammingError: If condition was not set. """ if not self.has_where: raise ProgrammingError("No condition was found for remove") return self._connection.send_delete(self) class DeleteStatement(FilterableStatement): """A statement that drops a table. Args: table (mysqlx.Table): The Table object. .. versionchanged:: 8.0.12 The ``condition`` parameter was removed. """ def __init__(self, table: DatabaseTargetType) -> None: super().__init__(target=table, doc_based=False) def where(self, condition: str) -> DeleteStatement: """Sets the search condition to filter. Args: condition (str): Sets the search condition to filter records. Returns: mysqlx.DeleteStatement: DeleteStatement object. """ return self._set_where(condition) def order_by(self, *clauses: str) -> DeleteStatement: """Sets the order by criteria. Args: *clauses: The expression strings defining the order by criteria. Returns: mysqlx.DeleteStatement: DeleteStatement object. """ return self._sort(*clauses) def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object. Raises: ProgrammingError: If condition was not set. """ if not self.has_where: raise ProgrammingError("No condition was found for delete") return self._connection.send_delete(self) class CreateCollectionIndexStatement(Statement): """A statement that creates an index on a collection. Args: collection (mysqlx.Collection): Collection. index_name (string): Index name. index_desc (dict): A dictionary containing the fields members that constraints the index to be created. It must have the form as shown in the following:: {"fields": [{"field": member_path, "type": member_type, "required": member_required, "collation": collation, "options": options, "srid": srid}, # {... more members, # repeated as many times # as needed} ], "type": type} """ def __init__( self, collection: DatabaseTargetType, index_name: str, index_desc: Dict[str, Any], ) -> None: super().__init__(target=collection) self._index_desc: Dict[str, Any] = copy.deepcopy(index_desc) self._index_name: str = index_name self._fields_desc: List[Dict[str, Any]] = self._index_desc.pop("fields", []) def execute(self) -> Result: """Execute the statement. Returns: mysqlx.Result: Result object. """ # Validate index name is a valid identifier if self._index_name is None: raise ProgrammingError(ERR_INVALID_INDEX_NAME.format(self._index_name)) try: parsed_ident = ExprParser(self._index_name).expr().get_message() # The message is type dict when the Protobuf cext is used if isinstance(parsed_ident, dict): if parsed_ident["type"] != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"): raise ProgrammingError( ERR_INVALID_INDEX_NAME.format(self._index_name) ) else: if parsed_ident.type != mysqlxpb_enum("Mysqlx.Expr.Expr.Type.IDENT"): raise ProgrammingError( ERR_INVALID_INDEX_NAME.format(self._index_name) ) except (ValueError, AttributeError) as err: raise ProgrammingError( ERR_INVALID_INDEX_NAME.format(self._index_name) ) from err # Validate members that constraint the index if not self._fields_desc: raise ProgrammingError( "Required member 'fields' not found in the given index " f"description: {self._index_desc}" ) if not isinstance(self._fields_desc, list): raise ProgrammingError("Required member 'fields' must contain a list") args: Dict[str, Any] = {} args["name"] = self._index_name args["collection"] = self._target.name args["schema"] = self._target.schema.name if "type" in self._index_desc: args["type"] = self._index_desc.pop("type") else: args["type"] = "INDEX" args["unique"] = self._index_desc.pop("unique", False) # Currently unique indexes are not supported: if args["unique"]: raise NotSupportedError("Unique indexes are not supported.") args["constraint"] = [] if self._index_desc: raise ProgrammingError(f"Unidentified fields: {self._index_desc}") try: for field_desc in self._fields_desc: constraint = {} constraint["member"] = field_desc.pop("field") constraint["type"] = field_desc.pop("type") constraint["required"] = field_desc.pop("required", False) constraint["array"] = field_desc.pop("array", False) if not isinstance(constraint["required"], bool): raise TypeError("Field member 'required' must be Boolean") if not isinstance(constraint["array"], bool): raise TypeError("Field member 'array' must be Boolean") if args["type"].upper() == "SPATIAL" and not constraint["required"]: raise ProgrammingError( "Field member 'required' must be set to 'True' when " "index type is set to 'SPATIAL'" ) if args["type"].upper() == "INDEX" and constraint["type"] == "GEOJSON": raise ProgrammingError( "Index 'type' must be set to 'SPATIAL' when field " "type is set to 'GEOJSON'" ) if "collation" in field_desc: if not constraint["type"].upper().startswith("TEXT"): raise ProgrammingError( "The 'collation' member can only be used when " "field type is set to " f"'{constraint['type'].upper()}'" ) constraint["collation"] = field_desc.pop("collation") # "options" and "srid" fields in IndexField can be # present only if "type" is set to "GEOJSON" if "options" in field_desc: if constraint["type"].upper() != "GEOJSON": raise ProgrammingError( "The 'options' member can only be used when " "index type is set to 'GEOJSON'" ) constraint["options"] = field_desc.pop("options") if "srid" in field_desc: if constraint["type"].upper() != "GEOJSON": raise ProgrammingError( "The 'srid' member can only be used when index " "type is set to 'GEOJSON'" ) constraint["srid"] = field_desc.pop("srid") args["constraint"].append(constraint) except KeyError as err: raise ProgrammingError( f"Required inner member {err} not found in constraint: {field_desc}" ) from err for field_desc in self._fields_desc: if field_desc: raise ProgrammingError(f"Unidentified inner fields: {field_desc}") return self._connection.execute_nonquery( "mysqlx", "create_collection_index", True, args )