AWSTemplateFormatVersion: 2010-09-09
Description: CloudFormation to provision a VPN
Parameters:
  OnPremIP:
    Description: Enter IP Address for On-Premises Firewall
    Type: String
    Default: '1.2.3.4'
  OnPremBGP:
    Description: Enter BGP for On-Premises Network
    Type: Number
    Default: 65000
  OnPremDevice:
    Description: Enter Name for On-Premises for Firewall
    Type: String
    Default: myFirewall
  PrivCert:
    Description: Enter Private Certificate ARN
    Type: String
    AllowedPattern: arn:aws:acm:.*
    ConstraintDescription: Private Certificate ARN must be of the form arn:aws:acm:<region>:<account id>:certificate/<certificate id>
  AWSBGP:
    Description: Enter BGP ASN for AWS' side
    Type: Number
    Default: 64512
  DHGroupOne:
    Description: Enter Phase 1 Diffie-Hellman Group
    Type: Number
    Default: 19
    AllowedValues: [19]
  DHGroupTwo:
    Description: Enter Phase 1 Diffie-Hellman Group
    Type: Number
    Default: 19
    AllowedValues: [19]
  Phase1EncryptionAlgorithms:
    Description: Enter Phase 1 Encryption Algorithms (AES128-GCM-16 | AES256-GCM-16)
    Type: String
    Default: AES256-GCM-16
    AllowedValues: [AES128-GCM-16, AES256-GCM-16]
  Phase2EncryptionAlgorithms:
    Description: Enter Phase 2 Encryption Algorithms (AES128-GCM-16 | AES256-GCM-16)
    Type: String
    Default: AES256-GCM-16
    AllowedValues: [AES128-GCM-16, AES256-GCM-16]
  Phase1IntegrityAlgorithms:
    Description: Enter Phase 1 Integrity Algorithms (SHA2-256 | SHA2-384 | SHA2-512)
    Type: String
    Default: SHA2-512
    AllowedValues: [SHA2-256, SHA2-384, SHA2-512]
  Phase2IntegrityAlgorithms:
    Description: Enter Phase 2 Integrity Algorithms (SHA2-256 | SHA2-384 | SHA2-512)
    Type: String
    Default: SHA2-512
    AllowedValues: [SHA2-256, SHA2-384, SHA2-512]
  VPC:
    Description: Enter a VPC ID to use with the VPN (for Virtual Private Gateway)
    Type: String
Resources:
# Setup VPN
  VPN:
    Type: AWS::EC2::VPNConnection
    Properties: 
      CustomerGatewayId: !GetAtt CGWInvoke.CGWID
      StaticRoutesOnly: 'false'
      Tags: 
        - Key: Framework
          Value: PALZ
      Type: ipsec.1
      VpnGatewayId: !Ref VGW
  CGWLambda:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |
          import boto3
          import time
          import cfnresponse as cf
          import string
          import random

          c = boto3.client('ec2')
          iam_c = boto3.client('iam')

          def rndStr():
              l = string.ascii_lowercase
              return ''.join(random.choice(l) for i in range(8))

          def create_cgw(bgp,ip,cert,dev_name):
              response = c.create_customer_gateway(
              BgpAsn=bgp,
              PublicIp=ip,
              CertificateArn=cert,
              Type='ipsec.1',
              DeviceName=dev_name,
              DryRun=False
              )
              return response
          def get_cgw():
              response = c.describe_customer_gateways(
              Filters=[],
              DryRun=False
              )
              return response
          def delete_cgw(cgw_id):
              response = c.delete_customer_gateway(
              CustomerGatewayId=cgw_id,
              DryRun=False
              )
              return response
          def cgw_status(cgw_id):
              response = c.describe_customer_gateways(
              CustomerGatewayIds=[
                  cgw_id,
              ],
              DryRun=False
              )
              status = response['CustomerGateways'][0]['State']
              return status
          def check_role():
              try:
                  response = iam_c.get_role(
                  RoleName='AWSServiceRoleForVPCS2SVPN'
                  )
                  return response
              except:
                  return False
          def create_role():
              response = iam_c.create_service_linked_role(
              AWSServiceName='s2svpn.amazonaws.com',
              )
          def handler(event, context):
              print(event)

              try:
                  responseData = {}

                  bgp = int(event['ResourceProperties']['bgpASN'])
                  ip = event['ResourceProperties']['ipADDR']
                  cert = event['ResourceProperties']['certARN']
                  dev_name = event['ResourceProperties']['deviceNAME'] + '-' + rndStr()

                  existing_cgw = get_cgw()
                  gateways = existing_cgw['CustomerGateways']

                  gw_exists = False
                  gw_id = ''

                  for gw in gateways:
                      if gw['DeviceName'] == dev_name:
                          if gw['State'] == 'available':
                              gw_exists = True
                              gw_id = gateway['CustomerGatewayId']

                  if event['RequestType'] == 'Create':
                      if gw_exists:
                          cf.send(event, context, cf.FAILED, responseData, 'created')

                      try:
                          role_status = check_role()
                      except Exception:
                          create_role()
                      print('Creating CGW')
                      try:
                          new_cgw = create_cgw(bgp,ip,cert,dev_name)
                          cgw_id = new_cgw['CustomerGateway']['CustomerGatewayId']
                          while cgw_status(cgw_id) != 'available':
                              time.sleep(10)
                          responseData['CGWID'] = cgw_id
                          cf.send(event, context, cf.SUCCESS, responseData, 'createdCGW')
                      except Exception as e:
                          responseData['error'] = str(e)
                          cf.send(event, context, cf.FAILED, responseData, 'errorCGW')
                  elif event['RequestType'] == 'Delete':
                      if not gw_exists:
                          cf.send(event, context, cf.SUCCESS, responseData, 'deletedCGW')
                      print('Deleting CGW')
                      try:
                          delete_cgw(gw_id)
                          while cgw_status(cgw_id) != 'deleted':
                              print('CGW deleting')
                              time.sleep(10)
                          print('CGW deleted')
                          responseData['CGWID'] = cgw_id
                          cf.send(event, context, cf.SUCCESS, responseData, 'deletedCGW')
                      except Exception as e:
                          responseData['error'] = str(e)
                          cf.send(event, context, cf.SUCCESS, responseData, 'errorCGW')
                  else:
                      cf.send(event, context, cf.SUCCESS, responseData, 'NoChangeCGW')
              except Exception as e:
                  responseData['error'] = str(e)
                  cf.send(event, context, cf.FAILED, responseData, 'errorCGW')
      Description: Lambda to create CGW
      Handler: index.handler
      MemorySize: 128
      Role: !GetAtt CGWLambdaRole.Arn
      Runtime: python3.7
      Timeout: '900'
      Environment:
        Variables:
          Region: !Sub ${AWS::Region}
  CGWLambdaRole:
    Type: 'AWS::IAM::Role'
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
        - Effect: Allow
          Principal:
            Service: 'lambda.amazonaws.com'
          Action:
          - 'sts:AssumeRole'
      Path: '/'
      ManagedPolicyArns:
      - 'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'
      Policies:
      - PolicyName: CGWPermissions
        PolicyDocument:
          Version: 2012-10-17
          Statement:
          - Sid: CGWFullControl
            Effect: Allow
            Action:
            - ec2:CreateCustomerGateway
            - ec2:DescribeCustomerGateways
            - ec2:DeleteCustomerGateway
            - iam:GetRole
            - iam:CreateServiceLinkedRole
            - iam:GetRolePolicy
            Resource:
            - '*'
  CGWInvoke:
    Type: Custom::CGWInvoke
    Version: '1.0'
    Properties: 
        ServiceToken: !GetAtt CGWLambda.Arn
        FunctionName: !Ref CGWLambda
        Region: !Ref 'AWS::Region'
        bgpASN: !Ref OnPremBGP
        ipADDR: !Ref OnPremIP
        certARN: !Ref PrivCert
        deviceNAME: !Ref OnPremDevice
# Modify VPN
  VPNLambda:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |
          import boto3
          import time
          import cfnresponse

          client = boto3.client('ec2')

          def get_vpn(vpn_id):
              print('Getting VPN Details')
              response = client.describe_vpn_connections(
              VpnConnectionIds=[
                  vpn_id,
              ],
              DryRun=False
              )
              return response
          def vpn_status(vpn_id):
              print('Checking VPN Status')
              response = client.describe_vpn_connections(
              VpnConnectionIds=[
                  vpn_id,
              ],
              DryRun=False
              )
              status = response['VpnConnections'][0]['State']
              return status
          def update_vpn(vpn_id, tunnel_ip, dh1, dh2, c1, c2, i1, i2):
              response = client.modify_vpn_tunnel_options(
                  VpnConnectionId=vpn_id,
                  VpnTunnelOutsideIpAddress=tunnel_ip,
                  TunnelOptions={
                      'Phase1EncryptionAlgorithms': [
                          {
                              'Value': c1
                          },
                      ],
                      'Phase2EncryptionAlgorithms': [
                          {
                              'Value': c2
                          },
                      ],
                      'Phase1IntegrityAlgorithms': [
                          {
                              'Value': i1
                          },
                      ],
                      'Phase2IntegrityAlgorithms': [
                          {
                              'Value': i2
                          },
                      ],
                      'Phase1DHGroupNumbers': [
                          {
                              'Value': dh1
                          },
                      ],
                      'Phase2DHGroupNumbers': [
                          {
                              'Value': dh2
                          },
                      ],
                      'IKEVersions': [
                          {
                              'Value': 'ikev2'
                          },
                      ]
                  },
                  DryRun=False
              )

          def handler(event, context):

              print(event)
              try:
                  vpn_id = event['ResourceProperties']['vpnID']
                  dh_group_one = int(event['ResourceProperties']['dhGROUPONE'])
                  dh_group_two = int(event['ResourceProperties']['dhGROUPTWO'])
                  CryptoP1 = event['ResourceProperties']['CryptoP1']
                  CryptoP2 = event['ResourceProperties']['CryptoP2']
                  IntP1 = event['ResourceProperties']['IntP1']
                  IntP2 = event['ResourceProperties']['IntP2']
                  vpn_details = get_vpn(vpn_id)
                  print('Modifying VPN for ' + vpn_id)
                  tunnel_one_ip = vpn_details['VpnConnections'][0]['VgwTelemetry'][0]['OutsideIpAddress']
                  tunnel_two_ip = vpn_details['VpnConnections'][0]['VgwTelemetry'][1]['OutsideIpAddress']
                  
                  responseData = {}
                  responseData['IPOne'] = tunnel_one_ip
                  responseData['IPTwo'] = tunnel_two_ip

                  ev = event['RequestType']

                  if ev == 'Create' or ev == 'Update':
                      while vpn_status(vpn_id) != 'available':
                          print('Waiting for VPN to be in Available state')
                          time.sleep(10)
                      print('Updating Tunnel One - ' + tunnel_one_ip)
                      update_vpn(vpn_id, tunnel_one_ip, dh_group_one, dh_group_two, CryptoP1, CryptoP2, IntP1, IntP2)
                      while vpn_status(vpn_id) != 'available':
                          print('Waiting for Tunnel update to be completed')
                          time.sleep(10)
                      print('Successfully updated Tunnel One - ' + tunnel_one_ip)
                      print('Updating Tunnel Two - ' + tunnel_two_ip)
                      update_vpn(vpn_id, tunnel_two_ip, dh_group_one, dh_group_two, CryptoP1, CryptoP2, IntP1, IntP2)
                      while vpn_status(vpn_id) != 'available':
                          print('Waiting for Tunnel update to be completed')
                          time.sleep(10)
                      print('Successfully updated Tunnel Two - ' + tunnel_two_ip)
                      cfnresponse.send(event, context, cfnresponse.SUCCESS, responseData, 'updatedVPN')
                  else:
                      cfnresponse.send(event, context, cfnresponse.SUCCESS, responseData, 'NoChangeVPN')
              except Exception as e:
                  responseData = {}
                  responseData['error'] = str(e)
                  cfnresponse.send(event, context, cfnresponse.FAILED, responseData, 'updatedVPN')
      Description: Lambda to update VPN
      Handler: index.handler
      MemorySize: 128
      Role: !GetAtt VPNLambdaRole.Arn
      Runtime: python3.7
      Timeout: '900'
      Environment:
        Variables:
          Region: !Sub ${AWS::Region}
  VPNLambdaRole:
    Type: 'AWS::IAM::Role'
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
        - Effect: Allow
          Principal:
            Service: 'lambda.amazonaws.com'
          Action:
          - 'sts:AssumeRole'
      Path: '/'
      ManagedPolicyArns:
      - 'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'
      Policies:
      - PolicyName: VPNModifyPermissions
        PolicyDocument:
          Version: 2012-10-17
          Statement:
          - Sid: ModifyVPN
            Effect: Allow
            Action:
            - ec2:ModifyVpnTunnelOptions
            - ec2:ModifyVpnTunnelCertificate
            - ec2:ModifyVpnConnection
            - ec2:DescribeVpnConnections
            Resource:
            - '*'
  VPNInvoke:
    Type: Custom::VPNInvoke
    Version: '1.0'
    DependsOn: VPN
    Properties: 
        ServiceToken: !GetAtt VPNLambda.Arn
        FunctionName: !Ref VPNLambda
        Region: !Ref 'AWS::Region'
        vpnID: !Ref VPN
        dhGROUPONE: !Ref DHGroupOne
        dhGROUPTWO: !Ref DHGroupTwo
        CryptoP1: !Ref Phase1EncryptionAlgorithms
        CryptoP2: !Ref Phase2EncryptionAlgorithms
        IntP1: !Ref Phase1IntegrityAlgorithms
        IntP2: !Ref Phase2IntegrityAlgorithms
  VGW:
    Type: AWS::EC2::VPNGateway
    Properties: 
      AmazonSideAsn: !Ref AWSBGP
      Tags: 
        - Key: Framework
          Value: PALZ
      Type: ipsec.1
  VGWAttach:
    Type: AWS::EC2::VPCGatewayAttachment
    Properties:
      VpcId: !Ref VPC
      VpnGatewayId: !Ref VGW
Outputs:
  VPNID:
    Description: ID of the VPN Connection that has been established
    Value: !Ref VPN