# * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# * SPDX-License-Identifier: MIT-0
# *
# * Permission is hereby granted, free of charge, to any person obtaining a copy of this
# * software and associated documentation files (the "Software"), to deal in the Software
# * without restriction, including without limitation the rights to use, copy, modify,
# * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
# * permit persons to whom the Software is furnished to do so.
# *
# * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

AWSTemplateFormatVersion: 2010-09-09
Description: SageMaker 1P container Blacklist
Parameters:
  ProcessingContainers:
    Type: List<String>
    Description: List of blacklist containers for processing
    Default: ""
  TrainingContainers:
    Type: List<String>
    Description: List of blacklist containers for training
    Default: ""

Resources:
  ### IAM Section
  LambdaPolicy:
    Type: 'AWS::IAM::ManagedPolicy'
    Properties:
      PolicyDocument:
        Version: 2012-10-17
        Statement:
          - Effect: Allow
            Action:
              - 'logs:CreateLogGroup'
            Resource:
              - !Sub 'arn:aws:logs:${AWS::Region}:${AWS::AccountId}:*'
          - Effect: Allow
            Action:
              - 'logs:CreateLogStream'
              - 'logs:PutLogEvents'
            Resource:
              - !Sub 'arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:/aws/lambda/*'
          - Effect: Allow
            Action:
              - 'sagemaker:StopTrainingJob'
              - 'sagemaker:StopProcessingJob'
            Resource:
              - !Sub 'arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:processing-job/*'
              - !Sub 'arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:training-job/*'

  LambdaRole:
    Type: 'AWS::IAM::Role'
    Properties:
      AssumeRolePolicyDocument:
        Version: 2012-10-17
        Statement:
          - Effect: Allow
            Action: 'sts:AssumeRole'
            Principal:
              Service:
                - lambda.amazonaws.com
      Path: /
      ManagedPolicyArns:
        - !Ref LambdaPolicy

  ### Lambda Section
  LambdaStopProcessingJob:
      Type: 'AWS::Lambda::Function'
      Properties:
        Code:
          ZipFile: |
            # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
            # SPDX-License-Identifier: MIT-0
            import boto3
            import json
            import os
  
            sagemaker_client = boto3.client("sagemaker")
            
            def lambda_handler(event, context):
                blacklist_containers = os.environ.get("containers", "").split(",")
            
                print("Blacklisted containers: {}".format(blacklist_containers))
                
                message = ""
                
                if "detail" in event:
                    if "AppSpecification" in event["detail"] and "ImageUri" in event["detail"]["AppSpecification"]:
                        container = event["detail"]["AppSpecification"]["ImageUri"].split("/")[-1]
                        
                        if container in blacklist_containers:
                            if "ProcessingJobStatus" in event["detail"]:
                                if event["detail"]["ProcessingJobStatus"] == "InProgress":
                                    response = sagemaker_client.stop_processing_job(
                                        ProcessingJobName=event["detail"]["ProcessingJobName"]
                                    )
                                    
                                    print(response)
                                    message = "Job with image {} stopped".format(container)
                        else:
                            print("Image {} allowed".format(container))
                            message = "Job not stopped"
                return {
                    'statusCode': 200,
                    'body': json.dumps(message)
                }
        Description: Stop Processing Job
        Handler: index.lambda_handler
        MemorySize: 128
        Role: !GetAtt
          - LambdaRole
          - Arn
        Runtime: python3.8
        Timeout: 900
        Environment:
          Variables:
            containers: !Join [",", !Ref ProcessingContainers]

  LambdaStopTrainingJob:
      Type: 'AWS::Lambda::Function'
      Properties:
        Code:
          ZipFile: |
            # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
            # SPDX-License-Identifier: MIT-0
            import boto3
            import json
            import os
  
            sagemaker_client = boto3.client("sagemaker")
            
            def lambda_handler(event, context):
                blacklist_containers = os.environ.get("containers", "").split(",")
            
                print("Blacklisted containers: {}".format(blacklist_containers))
            
                message = ""
                
                if "detail" in event:
                    if "AlgorithmSpecification" in event["detail"] and "TrainingImage" in event["detail"]["AlgorithmSpecification"]:
                        container = event["detail"]["AlgorithmSpecification"]["TrainingImage"].split("/")[-1]
                        
                        if container in blacklist_containers:
                            if "TrainingJobStatus" in event["detail"] and "SecondaryStatus" in event["detail"]:
                                if event["detail"]["TrainingJobStatus"] == "InProgress" and event["detail"]["SecondaryStatus"] == "Starting":
                                    response = sagemaker_client.stop_training_job(
                                        TrainingJobName=event["detail"]["TrainingJobName"]
                                    )
          
                                    print(response)
                                    message = "Job with image {} stopped".format(container)
                        else:
                            print("Image {} allowed".format(container))
                            message = "Job not stopped"
                
                return {
                    'statusCode': 200,
                    'body': json.dumps(message)
                }
        Description: Stop Training Job
        Handler: index.lambda_handler
        MemorySize: 128
        Role: !GetAtt
          - LambdaRole
          - Arn
        Runtime: python3.8
        Timeout: 900
        Environment:
          Variables:
            containers: !Join [",", !Ref TrainingContainers]

  ### EventBridge Section
  EventBridgeRuleStopProcessing:
    Type: 'AWS::Events::Rule'
    Properties:
      Description: Stop SageMaker Processing Job
      EventPattern:
        source:
          - aws.sagemaker
        detail-type:
          - SageMaker Processing Job State Change
        detail:
          ProcessingJobStatus:
            - InProgress
      Name: ProcessingContainerRule
      State: ENABLED
      Targets:
        - Arn: !GetAtt LambdaStopProcessingJob.Arn
          Id: Target0

  EventBridgeRuleStopTraining:
    Type: 'AWS::Events::Rule'
    Properties:
      Description: Stop SageMaker Training Job
      EventBusName: default
      EventPattern:
        source:
          - aws.sagemaker
        detail-type:
          - SageMaker Training Job State Change
        detail:
          TrainingJobStatus:
            - InProgress
          SecondaryStatus:
            - Starting
      Name: TrainingContainerRule
      State: ENABLED
      Targets:
        - Id: Target1
          Arn: !GetAtt LambdaStopTrainingJob.Arn

  LambdaPermissionStopProcessingJob:
    Type: 'AWS::Lambda::Permission'
    Properties:
      Action: lambda:InvokeFunction
      FunctionName: !Ref LambdaStopProcessingJob
      Principal: events.amazonaws.com
      SourceArn: !GetAtt EventBridgeRuleStopProcessing.Arn

  LambdaPermissionStopTrainingJob:
    Type: 'AWS::Lambda::Permission'
    Properties:
      Action: lambda:InvokeFunction
      FunctionName: !Ref LambdaStopTrainingJob
      Principal: events.amazonaws.com
      SourceArn: !GetAtt EventBridgeRuleStopTraining.Arn