# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Functions for checking AST nodes for matches.""" from __future__ import absolute_import import ast from sagemaker.cli.compatibility.v2.modifiers import parsing def matches_any(node, name_to_namespaces_dict): """Determines if the ``ast.Call`` node matches any of the provided names and namespaces. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. name_to_namespaces_dict (dict[str, tuple]): a mapping of names to appropriate namespaces. Returns: bool: if the node matches any of the names and namespaces. """ return any( matches_name_or_namespaces(node, name, namespaces) for name, namespaces in name_to_namespaces_dict.items() ) def matches_name_or_namespaces(node, name, namespaces): """Determines if the ``ast.Call`` node matches the function name in the right namespace. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. name (str): the function name. namespaces (tuple): the possible namespaces to match to. Returns: bool: if the node matches the name and any of the namespaces. """ if matches_name(node, name): return True if not matches_attr(node, name): return False return any(matches_namespace(node, namespace) for namespace in namespaces) def matches_name(node, name): """Determines if the ``ast.Call`` node points to an ``ast.Name`` node with a matching name. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. name (str): the function name. Returns: bool: if ``node.func`` is an ``ast.Name`` node with a matching name. """ return isinstance(node.func, ast.Name) and node.func.id == name def matches_attr(node, name): """Determines if the ``ast.Call`` node points to an ``ast.Attribute`` node with a matching name. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. name (str): the function name. Returns: bool: if ``node.func`` is an ``ast.Attribute`` node with a matching name. """ return isinstance(node.func, ast.Attribute) and node.func.attr == name def matches_namespace(node, namespace): """Determines if the ``ast.Call`` node corresponds to a matching namespace. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. namespace (str): the namespace. Returns: bool: if the node's namespaces matches the given namespace. """ names = namespace.split(".") name, value = names.pop(), node.func.value while isinstance(value, ast.Attribute) and len(names) > 0: if value.attr != name: return False name, value = names.pop(), value.value return isinstance(value, ast.Name) and value.id == name def has_arg(node, arg): """Checks if the call has the given argument. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. arg (str): the name of the argument. Returns: bool: if the node has the given argument. """ try: return parsing.arg_value(node, arg) is not None except KeyError: return False