# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

AWSTemplateFormatVersion: 2010-09-09
Description: |
  Create a SageMaker Studio domain, applications and a default user profile

Metadata:
  AWS::CloudFormation::Interface:
    ParameterGroups:
      - Label: 
          default: Data Science Environment
        Parameters:
          - EnvName
          - EnvType
      - Label:
          default: Amazon SageMaker Studio
        Parameters: 
          - DomainName
          - UserProfileName
          - AuthMode
          - StartKernelGatewayApps
      - Label:
          default: Network and Storage Configuration
        Parameters:
          - VPCId
          - SageMakerStudioSubnetIds
          - SageMakerSecurityGroupIds
          - SageMakerStudioStorageKMSKeyId
      - Label:
          default: Permission
        Parameters:
          - SageMakerExecutionRoleArn
          - SetupLambdaExecutionRoleArn

    ParameterLabels:
      EnvName:
        default: Environment name
      EnvType:
        default: Environment stage name (dev, test, prod)
      DomainName:
        default: Domain name
      UserProfileName:
        default: User profile name
      StartKernelGatewayApps:
        default: Select YES to start the Data Science and Data Wrangler kernels on KernelGateway
      AuthMode:
        default: Authentication mode
      VPCId:
        default: VPC
      SageMakerStudioSubnetIds:
        default: Subnet(s)
      SageMakerSecurityGroupIds:
        default: Security group(s)
      SageMakerStudioStorageKMSKeyId:
        default: Storage encryption key
      NetworkAccessType:
        default: Network access for SageMaker Studio
      SageMakerExecutionRoleArn:
        default: Execution role
      SetupLambdaExecutionRoleArn:
        default: Execution role for setup Lambda function

Outputs:
  SageMakerDomainId:
    Description: SageMaker Domain Id
    Value: !GetAtt SageMakerStudioDomain.DomainId
    Export:
      Name: 'ds-sagemaker-domain-id'

Mappings: 
  RegionMap: 
    us-east-1: 
      datascience: "arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:us-east-1:663277389841:image/sagemaker-data-wrangler-1.0"
    us-east-2:
      datascience: "arn:aws:sagemaker:us-east-2:429704687514:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:us-east-2:415577184552:image/sagemaker-data-wrangler-1.0"      
    us-west-1: 
      datascience: "arn:aws:sagemaker:us-west-1:742091327244:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:us-west-1:926135532090:image/sagemaker-data-wrangler-1.0"
    us-west-2: 
      datascience: "arn:aws:sagemaker:us-west-2:236514542706:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:us-west-2:174368400705:image/sagemaker-data-wrangler-1.0"
    af-south-1:
      datascience: "arn:aws:sagemaker:af-south-1:559312083959:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:af-south-1:143210264188:image/sagemaker-data-wrangler-1.0"    
    ap-east-1:
      datascience: "arn:aws:sagemaker:ap-east-1:493642496378:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ap-east-1:707077482487:image/sagemaker-data-wrangler-1.0"
    ap-south-1:
      datascience: "arn:aws:sagemaker:ap-south-1:394103062818:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ap-south-1:089933028263:image/sagemaker-data-wrangler-1.0"  
    ap-northeast-2:
      datascience: "arn:aws:sagemaker:ap-northeast-2:806072073708:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ap-northeast-2:131546521161:image/sagemaker-data-wrangler-1.0"
    ap-southeast-1:
      datascience: "arn:aws:sagemaker:ap-southeast-1:492261229750:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ap-southeast-1:119527597002:image/sagemaker-data-wrangler-1.0"      
    ap-southeast-2:
      datascience: "arn:aws:sagemaker:ap-southeast-2:452832661640:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ap-southeast-2:422173101802:image/sagemaker-data-wrangler-1.0"
    ap-northeast-1: 
      datascience: "arn:aws:sagemaker:ap-northeast-1:102112518831:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ap-northeast-1:649008135260:image/sagemaker-data-wrangler-1.0"
    ca-central-1:
      datascience: "arn:aws:sagemaker:ca-central-1:310906938811:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:ca-central-1:557239378090:image/sagemaker-data-wrangler-1.0"
    eu-central-1: 
      datascience: "arn:aws:sagemaker:eu-central-1:936697816551:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:eu-central-1:024640144536:image/sagemaker-data-wrangler-1.0"
    eu-west-1:
      datascience: "arn:aws:sagemaker:eu-west-1:470317259841:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:eu-west-1:245179582081:image/sagemaker-data-wrangler-1.0"
    eu-west-2:
      datascience: "arn:aws:sagemaker:eu-west-2:712779665605:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:eu-west-2:894491911112:image/sagemaker-data-wrangler-1.0"
    eu-west-3:
      datascience: "arn:aws:sagemaker:eu-west-3:615547856133:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:eu-west-3:807237891255:image/sagemaker-data-wrangler-1.0"
    eu-north-1:
      datascience: "arn:aws:sagemaker:eu-north-1:243637512696:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:eu-north-1:054986407534:image/sagemaker-data-wrangler-1.0"
    eu-south-1:
      datascience: "arn:aws:sagemaker:eu-south-1:488287956546:image/sagemaker-data-wrangler-1.0"
      datawrangler: "arn:aws:sagemaker:eu-south-1:592751261982:image/datascience-1.0"
    sa-east-1:
      datascience: "arn:aws:sagemaker:sa-east-1:782484402741:image/datascience-1.0"
      datawrangler: "arn:aws:sagemaker:sa-east-1:424196993095:image/sagemaker-data-wrangler-1.0"

Parameters:
  EnvName:
    Type: String
    AllowedPattern: '[a-z0-9\-]*'
    Description: Please specify your data science environment name.  Used as a suffix for environment resource names.

  EnvType:
    Description: System Environment (e.g. dev, test, prod). Used as a suffix for environment resource names.
    Type: String
    Default: dev

  DomainName:
    Type: String
    Description: SageMaker Studio domain name. Leave empty to auto generate.
    Default: ''

  UserProfileName:
    Type: String
    Description: The user profile name for the SageMaker Studio. Leave empty to auto generate.
    Default: ''

  StartKernelGatewayApps:
    Type: String
    Description: Start the KernelGateway Apps (Data Science and Data Wrangler)
    AllowedValues:
      - 'YES'
      - 'NO'
    Default: 'NO'

  VPCId:
    Type: AWS::EC2::VPC::Id
    Description: Choose a VPC for SageMaker Studio and SageMaker workloads

  SageMakerStudioSubnetIds:
    Type: List<AWS::EC2::Subnet::Id>
    Description: Choose subnets

  SageMakerSecurityGroupIds:
    Type: List<AWS::EC2::SecurityGroup::Id>
    Description: Choose security groups for SageMaker Studio and SageMaker workloads

  SageMakerStudioStorageKMSKeyId:
    Type: String
    Description: SageMaker uses an AWS managed CMK to encrypt your EFS and EBS file systems by default. To use a customer managed CMK, enter its key Id.
    Default: ''

  NetworkAccessType:
    Type: String
    AllowedValues:
      - 'PublicInternetOnly'
      - 'VpcOnly'
    Description: Choose how SageMaker Studio accesses resources over the Network
    Default: 'VpcOnly'

  AuthMode:
    Type: String
    AllowedValues:
      - 'IAM'
      - 'SSO'
    Description: The mode of authentication that members use to access the domain
    Default: 'IAM'

  SageMakerExecutionRoleArn:
    Type: String
    Description: The ARN of the SageMaker execution role
  
  SetupLambdaExecutionRoleArn:
    Type: String
    Description: The ARN of the execution role for the Lambda function for SageMaker Studio setup


Conditions:
  GenerateDomainNameCondition: !Equals [ !Ref DomainName, '' ]
  GenerateProfileNameCondition: !Equals [ !Ref UserProfileName, '' ]
  SageMakerEFSKMSKeyCondition: !Not [ !Equals [ !Ref SageMakerStudioStorageKMSKeyId, ''] ]
  KernelGatewayAppsCondition: !Equals [ !Ref StartKernelGatewayApps, 'YES' ]
  NoKernelGatewayAppsCondition: !Equals [ !Ref StartKernelGatewayApps, 'NO' ]

Resources:
  SageMakerStudioDomain:
    Type: AWS::SageMaker::Domain
    Properties: 
      AppNetworkAccessType: !Ref NetworkAccessType
      AuthMode: !Ref AuthMode
      DefaultUserSettings: 
          ExecutionRole: !Ref SageMakerExecutionRoleArn
          SecurityGroups: !Ref SageMakerSecurityGroupIds
      DomainName: !If
        - GenerateDomainNameCondition
        - !Sub '${EnvName}-${EnvType}-${AWS::Region}-sagemaker-domain'
        - !Ref DomainName 
      KmsKeyId: !If [ SageMakerEFSKMSKeyCondition, !Ref SageMakerStudioStorageKMSKeyId, !Ref 'AWS::NoValue' ]
      SubnetIds: !Ref SageMakerStudioSubnetIds
      VpcId: !Ref VPCId
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType

  SageMakerUserProfile:
    Type: AWS::SageMaker::UserProfile
    Properties: 
      DomainId: !GetAtt SageMakerStudioDomain.DomainId
      UserProfileName: !If
        - GenerateProfileNameCondition
        - !Sub '${EnvName}-${EnvType}-${AWS::Region}-user-profile'
        - !Ref UserProfileName
      UserSettings:
        ExecutionRole: !Ref SageMakerExecutionRoleArn
        SecurityGroups: !Ref SageMakerSecurityGroupIds
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType

  JupyterApp:
    Type: AWS::SageMaker::App
    DependsOn: SageMakerUserProfile
    Properties: 
      AppName: default
      AppType: JupyterServer
      DomainId: !GetAtt SageMakerStudioDomain.DomainId
      UserProfileName: !If
        - GenerateProfileNameCondition
        - !Sub '${EnvName}-${EnvType}-${AWS::Region}-user-profile'
        - !Ref UserProfileName
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType
  
  DataScienceApp:
    Type: AWS::SageMaker::App
    DependsOn: SageMakerUserProfile
    Condition: KernelGatewayAppsCondition
    Properties: 
      AppName: sm-mlops-datascience-ml-t3-medium
      AppType: KernelGateway
      DomainId: !GetAtt SageMakerStudioDomain.DomainId
      ResourceSpec: 
        InstanceType:  ml.t3.medium
        SageMakerImageArn: !FindInMap
          - RegionMap
          - !Ref 'AWS::Region'
          - datascience
      UserProfileName: !If
        - GenerateProfileNameCondition
        - !Sub '${EnvName}-${EnvType}-${AWS::Region}-user-profile'
        - !Ref UserProfileName
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType

  DataWranglerApp:
    Type: AWS::SageMaker::App
    DependsOn: SageMakerUserProfile
    Condition: KernelGatewayAppsCondition
    Properties: 
      AppName: sm-mlops-datawrangler-ml-m5-4xlarge
      AppType: KernelGateway
      DomainId: !GetAtt SageMakerStudioDomain.DomainId
      ResourceSpec: 
        InstanceType:  ml.m5.4xlarge
        SageMakerImageArn: !FindInMap
          - RegionMap
          - !Ref 'AWS::Region'
          - datawrangler
      UserProfileName: !If
        - GenerateProfileNameCondition
        - !Sub '${EnvName}-${EnvType}-${AWS::Region}-user-profile'
        - !Ref UserProfileName
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType

  EnableSageMakerProjects:
    Type: Custom::ResourceForEnablingSageMakerProjects
    Properties:
      ServiceToken: !GetAtt EnableSageMakerProjectsLambda.Arn
      ExecutionRole: !Ref SageMakerExecutionRoleArn

  DeleteKernelGatewayAppsNoKGCondition:
    Type: Custom::DeleteKernelGatewayApps
    Condition: NoKernelGatewayAppsCondition
    DependsOn: SageMakerUserProfile
    Properties:
      ServiceToken: !GetAtt DeleteKernelGatewayAppsLambda.Arn
      DomainId: !GetAtt SageMakerStudioDomain.DomainId

  DeleteKernelGatewayAppsKGCondition:
    Type: Custom::DeleteKernelGatewayApps
    Condition: KernelGatewayAppsCondition
    DependsOn: 
      - DataScienceApp
      - DataWranglerApp
    Properties:
      ServiceToken: !GetAtt DeleteKernelGatewayAppsLambda.Arn
      DomainId: !GetAtt SageMakerStudioDomain.DomainId

  DeleteKernelGatewayAppsLambda:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |
          # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
          # SPDX-License-Identifier: MIT-0

          import time
          import boto3
          import logging
          import cfnresponse
          from botocore.exceptions import ClientError

          sm_client = boto3.client('sagemaker')
          logger = logging.getLogger(__name__)

          def delete_apps(domain_id):    
              logger.info(f'Start deleting apps for domain id: {domain_id}')

              try:
                  sm_client.describe_domain(DomainId=domain_id)
              except:
                  logger.info(f'Cannot retrieve {domain_id}')
                  return

              for p in sm_client.get_paginator('list_apps').paginate(DomainIdEquals=domain_id):
                  for a in p['Apps']:
                      if a['AppType'] == 'KernelGateway' and a['Status'] != 'Deleted':
                          sm_client.delete_app(DomainId=a['DomainId'], UserProfileName=a['UserProfileName'], AppType=a['AppType'], AppName=a['AppName'])
                  
              apps = 1
              while apps:
                  apps = 0
                  for p in sm_client.get_paginator('list_apps').paginate(DomainIdEquals=domain_id):
                      apps += len([a['AppName'] for a in p['Apps'] if a['AppType'] == 'KernelGateway' and a['Status'] != 'Deleted'])
                  logger.info(f'Number of active KernelGateway apps: {str(apps)}')
                  time.sleep(5)

              logger.info(f'KernelGateway apps for {domain_id} deleted')
              return

          def lambda_handler(event, context):
              response_data = {}
              physicalResourceId = event.get('PhysicalResourceId')

              try:
                  if event['RequestType'] in ['Create', 'Update']:
                      physicalResourceId = event.get('ResourceProperties')['DomainId']
            
                  elif event['RequestType'] == 'Delete':        
                      delete_apps(physicalResourceId)

                  cfnresponse.send(event, context, cfnresponse.SUCCESS, response_data, physicalResourceId=physicalResourceId)

              except (Exception, ClientError) as exception:
                  logger.error(exception)
                  cfnresponse.send(event, context, cfnresponse.FAILED, response_data, physicalResourceId=physicalResourceId, reason=str(exception))
                
      Description: Delete KernelGateway apps to clean up
      Handler: index.lambda_handler
      MemorySize: 128
      Role: !Ref SetupLambdaExecutionRoleArn
      Runtime: python3.8
      Timeout: 900
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType

  EnableSageMakerProjectsLambda:
    Type: AWS::Lambda::Function
    DependsOn: SageMakerStudioDomain
    Properties:
      Code:
        ZipFile: |
          # Function: EnableSagemakerProjects
          # Purpose:  Enables Sagemaker Projects
          import json
          import boto3
          import cfnresponse
          from botocore.exceptions import ClientError
          
          client = boto3.client('sagemaker')
          sc_client = boto3.client('servicecatalog')

          def lambda_handler(event, context):
              try:
                  response_status = cfnresponse.SUCCESS

                  if 'RequestType' in event and event['RequestType'] == 'Create':
                      enable_sm_projects(event['ResourceProperties']['ExecutionRole'])
                  cfnresponse.send(event, context, response_status, {}, '')
              except (Exception, ClientError) as exception:
                  print(exception)
                  cfnresponse.send(event, context, cfnresponse.FAILED, {}, physicalResourceId=event.get('PhysicalResourceId'), reason=str(exception))
            
          def enable_sm_projects(studio_role_arn):
              # enable Project on account level (accepts portfolio share)
              response = client.enable_sagemaker_servicecatalog_portfolio()

              # associate studio role with portfolio
              response = sc_client.list_accepted_portfolio_shares()

              portfolio_id = ''
              for portfolio in response['PortfolioDetails']:
                  if portfolio['ProviderName'] == 'Amazon SageMaker':
                      portfolio_id = portfolio['Id']

              response = sc_client.associate_principal_with_portfolio(
                  PortfolioId=portfolio_id,
                  PrincipalARN=studio_role_arn,
                  PrincipalType='IAM'
              )
      Description: Enable Sagemaker Projects
      Handler: index.lambda_handler
      MemorySize: 128
      Role: !Ref SetupLambdaExecutionRoleArn
      Runtime: python3.8
      Timeout: 60
      Tags:
        - Key: EnvironmentName
          Value: !Ref EnvName
        - Key: EnvironmentType
          Value: !Ref EnvType

  # SSM parameter
  SageMakerDomainIdSSM:
    Type: 'AWS::SSM::Parameter'
    Properties:
      Name: !Sub "${EnvName}-${EnvType}-sagemaker-domain-id"
      Type: String
      Value: !GetAtt SageMakerStudioDomain.DomainId
      Description: !Sub 'SageMaker Studio domain id for ${SageMakerStudioDomain.DomainArn}'