# Copyright (c) 2016, 2022, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """Implementation of the CRUD database objects.""" from __future__ import annotations import json import warnings from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from .dbdoc import DbDoc from .errorcode import ( ER_NO_SUCH_TABLE, ER_TABLE_EXISTS_ERROR, ER_X_CMD_NUM_ARGUMENTS, ER_X_INVALID_ADMIN_COMMAND, ) from .errors import NotSupportedError, OperationalError, ProgrammingError from .helpers import deprecated, escape, quote_identifier from .statement import ( AddStatement, CreateCollectionIndexStatement, DeleteStatement, FindStatement, InsertStatement, ModifyStatement, RemoveStatement, SelectStatement, UpdateStatement, ) from .types import ConnectionType, SchemaType, SessionType, StrOrBytes if TYPE_CHECKING: from .result import Result _COUNT_VIEWS_QUERY = ( "SELECT COUNT(*) FROM information_schema.views " "WHERE table_schema = '{0}' AND table_name = '{1}'" ) _COUNT_TABLES_QUERY = ( "SELECT COUNT(*) FROM information_schema.tables " "WHERE table_schema = '{0}' AND table_name = '{1}'" ) _COUNT_SCHEMAS_QUERY = ( "SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = '{0}'" ) _COUNT_QUERY = "SELECT COUNT(*) FROM {0}.{1}" _DROP_TABLE_QUERY = "DROP TABLE IF EXISTS {0}.{1}" class DatabaseObject: """Provides base functionality for database objects. Args: schema (mysqlx.Schema): The Schema object. name (str): The database object name. """ def __init__(self, schema: SchemaType, name: StrOrBytes) -> None: self._schema: SchemaType = schema self._name: str = name.decode() if isinstance(name, bytes) else name self._session: SessionType = self._schema.get_session() self._connection: ConnectionType = self._session.get_connection() @property def session(self) -> SessionType: """:class:`mysqlx.Session`: The Session object.""" return self._session @property def schema(self) -> SchemaType: """:class:`mysqlx.Schema`: The Schema object.""" return self._schema @property def name(self) -> str: """str: The name of this database object.""" return self._name def get_connection(self) -> ConnectionType: """Returns the underlying connection. Returns: mysqlx.connection.Connection: The connection object. """ return self._connection def get_session(self) -> SessionType: """Returns the session of this database object. Returns: mysqlx.Session: The Session object. """ return self._session def get_schema(self) -> SchemaType: """Returns the Schema object of this database object. Returns: mysqlx.Schema: The Schema object. """ return self._schema def get_name(self) -> str: """Returns the name of this database object. Returns: str: The name of this database object. """ return self._name def exists_in_database(self) -> Any: """Verifies if this object exists in the database. Returns: bool: `True` if object exists in database. Raises: NotImplementedError: This method must be implemented. """ raise NotImplementedError @deprecated("8.0.12", "Use 'exists_in_database()' method instead") def am_i_real(self) -> Any: """Verifies if this object exists in the database. Returns: bool: `True` if object exists in database. Raises: NotImplementedError: This method must be implemented. .. deprecated:: 8.0.12 Use ``exists_in_database()`` method instead. """ return self.exists_in_database() @deprecated("8.0.12", "Use 'get_name()' method instead") def who_am_i(self) -> str: """Returns the name of this database object. Returns: str: The name of this database object. .. deprecated:: 8.0.12 Use ``get_name()`` method instead. """ return self.get_name() class Schema(DatabaseObject): """A client-side representation of a database schema. Provides access to the schema contents. Args: session (mysqlx.XSession): Session object. name (str): The Schema name. """ def __init__(self, session: SessionType, name: str) -> None: self._session: SessionType = session super().__init__(self, name) def exists_in_database(self) -> bool: """Verifies if this object exists in the database. Returns: bool: `True` if object exists in database. """ sql = _COUNT_SCHEMAS_QUERY.format(escape(self._name)) return self._connection.execute_sql_scalar(sql) == 1 def get_collections(self) -> List[Collection]: """Returns a list of collections for this schema. Returns: `list`: List of Collection objects. """ rows = self._connection.get_row_result("list_objects", {"schema": self._name}) rows.fetch_all() collections = [] for row in rows: if row["type"] != "COLLECTION": continue try: collection = Collection(self, row["TABLE_NAME"]) except ValueError: collection = Collection(self, row["name"]) collections.append(collection) return collections def get_collection_as_table( self, name: str, check_existence: bool = False ) -> Table: """Returns a a table object for the given collection Returns: mysqlx.Table: Table object. """ return self.get_table(name, check_existence) def get_tables(self) -> List[Table]: """Returns a list of tables for this schema. Returns: `list`: List of Table objects. """ rows = self._connection.get_row_result("list_objects", {"schema": self._name}) rows.fetch_all() tables = [] object_types = ( "TABLE", "VIEW", ) for row in rows: if row["type"] in object_types: try: table = Table(self, row["TABLE_NAME"]) except ValueError: table = Table(self, row["name"]) tables.append(table) return tables def get_table(self, name: str, check_existence: bool = False) -> Table: """Returns the table of the given name for this schema. Returns: mysqlx.Table: Table object. """ table = Table(self, name) if check_existence: if not table.exists_in_database(): raise ProgrammingError("Table does not exist") return table def get_view(self, name: str, check_existence: bool = False) -> View: """Returns the view of the given name for this schema. Returns: mysqlx.View: View object. """ view = View(self, name) if check_existence: if not view.exists_in_database(): raise ProgrammingError("View does not exist") return view def get_collection(self, name: str, check_existence: bool = False) -> Collection: """Returns the collection of the given name for this schema. Returns: mysqlx.Collection: Collection object. """ collection = Collection(self, name) if check_existence: if not collection.exists_in_database(): raise ProgrammingError("Collection does not exist") return collection def drop_collection(self, name: str) -> None: """Drops a collection. Args: name (str): The name of the collection to be dropped. """ self._connection.execute_nonquery( "sql", _DROP_TABLE_QUERY.format( quote_identifier(self._name), quote_identifier(name) ), False, ) def create_collection( self, name: str, reuse_existing: bool = False, validation: Optional[Dict[str, Union[str, Dict]]] = None, **kwargs: Any, ) -> Collection: """Creates in the current schema a new collection with the specified name and retrieves an object representing the new collection created. Args: name (str): The name of the collection. reuse_existing (bool): `True` to reuse an existing collection. validation (Optional[dict]): A dict, containing the keys `level` with the validation level and `schema` with a dict or a string representation of a JSON schema specification. Returns: mysqlx.Collection: Collection object. Raises: :class:`mysqlx.ProgrammingError`: If ``reuse_existing`` is False and collection exists or the collection name is invalid. :class:`mysqlx.NotSupportedError`: If schema validation is not supported by the server. .. versionchanged:: 8.0.21 """ if not name: raise ProgrammingError("Collection name is invalid") if "reuse" in kwargs: warnings.warn( "'reuse' is deprecated since 8.0.21. " "Please use 'reuse_existing' instead", DeprecationWarning, ) reuse_existing = kwargs["reuse"] collection = Collection(self, name) fields: Dict[str, Any] = {"schema": self._name, "name": name} if validation is not None: if not isinstance(validation, dict) or not validation: raise ProgrammingError("Invalid value for 'validation'") valid_options = ("level", "schema") for option in validation: if option not in valid_options: raise ProgrammingError(f"Invalid option in 'validation': {option}") options = [] if "level" in validation: level = validation["level"] if not isinstance(level, str): raise ProgrammingError("Invalid value for 'level'") options.append(("level", level)) if "schema" in validation: schema = validation["schema"] if not isinstance(schema, (str, dict)): raise ProgrammingError("Invalid value for 'schema'") options.append( ( "schema", json.dumps(schema) if isinstance(schema, dict) else schema, ) ) fields["options"] = ("validation", options) try: self._connection.execute_nonquery( "mysqlx", "create_collection", True, fields ) except OperationalError as err: if err.errno == ER_X_CMD_NUM_ARGUMENTS: raise NotSupportedError( "Your MySQL server does not support the requested " "operation. Please update to MySQL 8.0.19 or a later " "version" ) from err if err.errno == ER_TABLE_EXISTS_ERROR: if not reuse_existing: raise ProgrammingError( f"Collection '{name}' already exists" ) from err else: raise ProgrammingError(err.msg, err.errno) from err return collection def modify_collection( self, name: str, validation: Optional[Dict[str, Union[str, Dict]]] = None ) -> None: """Modifies a collection using a JSON schema validation. Args: name (str): The name of the collection. validation (Optional[dict]): A dict, containing the keys `level` with the validation level and `schema` with a dict or a string representation of a JSON schema specification. Raises: :class:`mysqlx.ProgrammingError`: If the collection name or validation is invalid. :class:`mysqlx.NotSupportedError`: If schema validation is not supported by the server. .. versionadded:: 8.0.21 """ if not name: raise ProgrammingError("Collection name is invalid") if not isinstance(validation, dict) or not validation: raise ProgrammingError("Invalid value for 'validation'") valid_options = ("level", "schema") for option in validation: if option not in valid_options: raise ProgrammingError(f"Invalid option in 'validation': {option}") options = [] if "level" in validation: level = validation["level"] if not isinstance(level, str): raise ProgrammingError("Invalid value for 'level'") options.append(("level", level)) if "schema" in validation: schema = validation["schema"] if not isinstance(schema, (str, dict)): raise ProgrammingError("Invalid value for 'schema'") options.append( ( "schema", json.dumps(schema) if isinstance(schema, dict) else schema, ) ) fields = { "schema": self._name, "name": name, "options": ("validation", options), } try: self._connection.execute_nonquery( "mysqlx", "modify_collection_options", True, fields ) except OperationalError as err: if err.errno == ER_X_INVALID_ADMIN_COMMAND: raise NotSupportedError( "Your MySQL server does not support the requested " "operation. Please update to MySQL 8.0.19 or a later " "version" ) from err raise ProgrammingError(err.msg, err.errno) from err class Collection(DatabaseObject): """Represents a collection of documents on a schema. Args: schema (mysqlx.Schema): The Schema object. name (str): The collection name. """ def exists_in_database(self) -> bool: """Verifies if this object exists in the database. Returns: bool: `True` if object exists in database. """ sql = _COUNT_TABLES_QUERY.format(escape(self._schema.name), escape(self._name)) return self._connection.execute_sql_scalar(sql) == 1 def find(self, condition: Optional[str] = None) -> FindStatement: """Retrieves documents from a collection. Args: condition (Optional[str]): The string with the filter expression of the documents to be retrieved. """ stmt = FindStatement(self, condition) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def add(self, *values: DbDoc) -> AddStatement: """Adds a list of documents to a collection. Args: *values: The document list to be added into the collection. Returns: mysqlx.AddStatement: AddStatement object. """ return AddStatement(self).add(*values) def remove(self, condition: str) -> RemoveStatement: """Removes documents based on the ``condition``. Args: condition (str): The string with the filter expression of the documents to be removed. Returns: mysqlx.RemoveStatement: RemoveStatement object. .. versionchanged:: 8.0.12 The ``condition`` parameter is now mandatory. """ stmt = RemoveStatement(self, condition) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def modify(self, condition: str) -> ModifyStatement: """Modifies documents based on the ``condition``. Args: condition (str): The string with the filter expression of the documents to be modified. Returns: mysqlx.ModifyStatement: ModifyStatement object. .. versionchanged:: 8.0.12 The ``condition`` parameter is now mandatory. """ stmt = ModifyStatement(self, condition) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def count(self) -> int: """Counts the documents in the collection. Returns: int: The total of documents in the collection. """ sql = _COUNT_QUERY.format( quote_identifier(self._schema.name), quote_identifier(self._name) ) try: res = self._connection.execute_sql_scalar(sql) except OperationalError as err: if err.errno == ER_NO_SUCH_TABLE: raise OperationalError( f"Collection '{self._name}' does not exist in schema " f"'{self._schema.name}'" ) from err raise return res def create_index( self, index_name: str, fields_desc: Dict[str, Any] ) -> CreateCollectionIndexStatement: """Creates a collection index. Args: index_name (str): Index name. fields_desc (dict): A dictionary containing the fields members that constraints the index to be created. It must have the form as shown in the following:: {"fields": [{"field": member_path, "type": member_type, "required": member_required, "array": array, "collation": collation, "options": options, "srid": srid}, # {... more members, # repeated as many times # as needed} ], "type": type} """ return CreateCollectionIndexStatement(self, index_name, fields_desc) def drop_index(self, index_name: str) -> None: """Drops a collection index. Args: index_name (str): Index name. """ self._connection.execute_nonquery( "mysqlx", "drop_collection_index", False, { "schema": self._schema.name, "collection": self._name, "name": index_name, }, ) def replace_one(self, doc_id: str, doc: Union[Dict, DbDoc]) -> "Result": """Replaces the Document matching the document ID with a new document provided. Args: doc_id (str): Document ID doc (:class:`mysqlx.DbDoc` or `dict`): New Document """ if "_id" in doc and doc["_id"] != doc_id: raise ProgrammingError( "Replacement document has an _id that is different than the " "matched document" ) return self.modify("_id = :id").set("$", doc).bind("id", doc_id).execute() def add_or_replace_one(self, doc_id: str, doc: Union[Dict, DbDoc]) -> "Result": """Upserts the Document matching the document ID with a new document provided. Args: doc_id (str): Document ID doc (:class:`mysqlx.DbDoc` or dict): New Document """ if "_id" in doc and doc["_id"] != doc_id: raise ProgrammingError( "Replacement document has an _id that is different than the " "matched document" ) if not isinstance(doc, DbDoc): doc = DbDoc(doc) return self.add(doc.copy(doc_id)).upsert(True).execute() def get_one(self, doc_id: str) -> DbDoc: """Returns a Document matching the Document ID. Args: doc_id (str): Document ID Returns: mysqlx.DbDoc: The Document matching the Document ID. """ result = self.find("_id = :id").bind("id", doc_id).execute() doc = result.fetch_one() self._connection.fetch_active_result() return doc def remove_one(self, doc_id: str) -> "Result": """Removes a Document matching the Document ID. Args: doc_id (str): Document ID Returns: mysqlx.Result: Result object. """ return self.remove("_id = :id").bind("id", doc_id).execute() class Table(DatabaseObject): """Represents a database table on a schema. Provides access to the table through standard INSERT/SELECT/UPDATE/DELETE statements. Args: schema (mysqlx.Schema): The Schema object. name (str): The table name. """ def exists_in_database(self) -> bool: """Verifies if this object exists in the database. Returns: bool: `True` if object exists in database. """ sql = _COUNT_TABLES_QUERY.format(escape(self._schema.name), escape(self._name)) return self._connection.execute_sql_scalar(sql) == 1 def select(self, *fields: str) -> SelectStatement: """Creates a new :class:`mysqlx.SelectStatement` object. Args: *fields: The fields to be retrieved. Returns: mysqlx.SelectStatement: SelectStatement object """ stmt = SelectStatement(self, *fields) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def insert(self, *fields: Any) -> InsertStatement: """Creates a new :class:`mysqlx.InsertStatement` object. Args: *fields: The fields to be inserted. Returns: mysqlx.InsertStatement: InsertStatement object """ stmt = InsertStatement(self, *fields) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def update(self) -> UpdateStatement: """Creates a new :class:`mysqlx.UpdateStatement` object. Returns: mysqlx.UpdateStatement: UpdateStatement object """ stmt = UpdateStatement(self) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def delete(self) -> DeleteStatement: """Creates a new :class:`mysqlx.DeleteStatement` object. Returns: mysqlx.DeleteStatement: DeleteStatement object .. versionchanged:: 8.0.12 The ``condition`` parameter was removed. """ stmt = DeleteStatement(self) stmt.stmt_id = self._connection.get_next_statement_id() return stmt def count(self) -> int: """Counts the rows in the table. Returns: int: The total of rows in the table. """ sql = _COUNT_QUERY.format( quote_identifier(self._schema.name), quote_identifier(self._name) ) try: res = self._connection.execute_sql_scalar(sql) except OperationalError as err: if err.errno == ER_NO_SUCH_TABLE: raise OperationalError( f"Table '{self._name}' does not exist in schema " f"'{self._schema.name}'" ) from err raise return res def is_view(self) -> bool: """Determine if the underlying object is a view or not. Returns: bool: `True` if the underlying object is a view. """ sql = _COUNT_VIEWS_QUERY.format(escape(self._schema.name), escape(self._name)) return self._connection.execute_sql_scalar(sql) == 1 class View(Table): """Represents a database view on a schema. Provides a mechanism for creating, alter and drop views. Args: schema (mysqlx.Schema): The Schema object. name (str): The table name. """ def exists_in_database(self) -> bool: """Verifies if this object exists in the database. Returns: bool: `True` if object exists in database. """ sql = _COUNT_VIEWS_QUERY.format(escape(self._schema.name), escape(self._name)) return self._connection.execute_sql_scalar(sql) == 1