# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Provides utilities for SageMaker Pipeline CLI."""
from __future__ import absolute_import

import ast


def get_pipeline_driver(module_name, passed_args=None):
    """Gets the driver for generating your pipeline definition.

    Pipeline modules must define a get_pipeline() module-level method.

    Args:
        module_name: The module name of your pipeline.
        passed_args: Optional passed arguments that your pipeline may be templated by.

    Returns:
        The SageMaker Workflow pipeline.
    """
    _imports = __import__(module_name, fromlist=["get_pipeline"])
    kwargs = convert_struct(passed_args)
    return _imports.get_pipeline(**kwargs)


def convert_struct(str_struct=None):
    return ast.literal_eval(str_struct) if str_struct else {}

# def get_pipeline_custom_tags(module_name, args, tags):
#     """Gets the custom tags for pipeline

#     Returns:
#         Custom tags to be added to the pipeline
#     """
#     try:
#         _imports = __import__(module_name, fromlist=["get_pipeline_custom_tags"])
#         kwargs = convert_struct(args)
#         return _imports.get_pipeline_custom_tags(tags, kwargs['region'], kwargs['sagemaker_project_arn'])
#     except Exception as e:
#         print(f"Error getting project tags: {e}")
#     return tags