AWSTemplateFormatVersion: 2010-09-09
Description: Creates tags for tgw attachments and propagations in tgw route tables

Parameters:
  TgwId:
    Type: String
  DmzVpcId:
    Type: String
  SpokeVpcIds:
    Type: CommaDelimitedList
    Description: List of spoke Vpc IDs with tgw attachments
  SpokeVpcNames:
    Type: CommaDelimitedList
    Description: List of spoke Vpc names to tag tgw attachments in same order as Spoke Vpc Ids above

Resources:
  LambdaRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: 2012-10-17
        Statement:
          - Effect: Allow
            Principal:
              Service: lambda.amazonaws.com
            Action:
              - sts:AssumeRole
      Path: /
      Policies:
        - PolicyName: route53-access
          PolicyDocument:
            Statement:
              - Effect: Allow
                Action:
                  - sts:*
                  - ec2:*
                  - logs:*
                Resource: '*'
  
  TGWTagsCustomResource:
    Type: AWS::CloudFormation::CustomResource
    Properties:
      ServiceToken: !GetAtt TGWCreateTags.Arn
      SpokeVpcIds: !Ref SpokeVpcIds
      SpokeVpcNames: !Ref SpokeVpcNames
      TGWId: !Ref TgwId 
      TgwDMZRouteDomainId: !ImportValue TgwDMZRouteDomainId
      DmzVpcId: !Ref DmzVpcId
      

  TGWCreateTags:
    Type: AWS::Lambda::Function
    Properties:
      Runtime: python3.7
      Role: !GetAtt LambdaRole.Arn
      Timeout: 360
      Handler: index.handler
      Code:
        ZipFile: |
          import os, boto3, time
          import cfnresponse
          def handler(event, context):
            try:
              request_type = event.get("RequestType")
              if request_type == 'Delete':
                resource_id = event.get("PhysicalResourceId")
                cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, resource_id)
              else:
                resource_id = event.get("PhysicalResourceId")
                properties = event.get("ResourceProperties")
                SpokeVpcIds = properties.get("SpokeVpcIds")
                SpokeVpcNames = properties.get("SpokeVpcNames")
                TGWId = properties.get("TGWId")
                TgwDMZRouteDomainId = properties.get("TgwDMZRouteDomainId")
                DmzVpcId = properties.get("DmzVpcId")
                tags_dct = {SpokeVpcIds[i]: SpokeVpcNames[i] for i in range(0, len(SpokeVpcIds))}
                print(tags_dct)
                ec2_client = boto3.client('ec2')
                response = ec2_client.describe_transit_gateway_attachments(
                  Filters=[{
                  'Name': 'transit-gateway-id',
                  'Values': [TGWId,]
                }])
                attachments = response['TransitGatewayAttachments']
                for attachment in attachments:
                  if attachment['ResourceOwnerId'] in tags_dct.keys():
                    if attachment['ResourceType'].lower() == 'vpc':
                      tags_resp = ec2_client.create_tags(
                      Resources=[attachment['TransitGatewayAttachmentId']],
                      Tags=[{ 'Key': 'Name', 'Value': tags_dct[attachment['ResourceId']] }]
                      )
                      print(tags_resp)
                response = ec2_client.describe_transit_gateway_route_tables(
                Filters=[{'Name': 'transit-gateway-id',
                'Values': [TGWId,]
                }])
                rt_tables = response['TransitGatewayRouteTables']
                DefaultRouteTable = ''
                for table in rt_tables:
                  if table['DefaultAssociationRouteTable']:
                    DefaultRouteTable = table['TransitGatewayRouteTableId']
                for table in rt_tables:
                  rt_id = table['TransitGatewayRouteTableId']
                  if table['DefaultAssociationRouteTable']:
                    tags_resp = ec2_client.create_tags(
                    Resources=[table['TransitGatewayRouteTableId']],
                    Tags=[{ 'Key': 'Name', 'Value': 'TgwSpokesRouteDomain'}]
                    )
                    print(tags_resp)
                  elif (rt_id == TgwDMZRouteDomainId):
                    print('nondefault')
                    for attachment in attachments:
                      if attachment['ResourceType'].lower() == 'vpc':
                        if ((attachment['ResourceId'] == DmzVpcId) and (rt_id == TgwDMZRouteDomainId)):
                          try:
                            disn_resp = ec2_client.disassociate_transit_gateway_route_table(
                            TransitGatewayRouteTableId=DefaultRouteTable,
                            TransitGatewayAttachmentId=attachment['TransitGatewayAttachmentId']
                            )
                            print(disn_resp)
                            time.sleep(60)
                            ascn_resp = ec2_client.associate_transit_gateway_route_table(
                            TransitGatewayRouteTableId=rt_id,
                            TransitGatewayAttachmentId=attachment['TransitGatewayAttachmentId']
                            )
                            print(ascn_resp)
                          except Exception as e:
                            print("already exists")
                        try:
                          propgn_resp = ec2_client.enable_transit_gateway_route_table_propagation(
                          TransitGatewayRouteTableId=rt_id,
                          TransitGatewayAttachmentId=attachment['TransitGatewayAttachmentId']
                          )
                          print(propgn_resp)
                        except Exception as e:
                          print("already exists")
                          cfnresponse.send(event, context, cfnresponse.FAILED, {})
                  cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, resource_id)
            except Exception as e:
              print('Error updating policy: {}'.format(e))
              if 'PhysicalResourceId' in event:
                  cfnresponse.send(event, context, cfnresponse.FAILED, {}, event.get("PhysicalResourceId"))
              else:
                  cfnresponse.send(event, context, cfnresponse.FAILED, {})