# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging import os import time import boto3 import ipaddr from botocore import config from botocore.exceptions import ClientError from util.exception import APIException logger = logging.getLogger() logger.setLevel(logging.INFO) solution_version = os.environ.get("SOLUTION_VERSION", "v1.0.0") solution_id = os.environ.get("SOLUTION_ID", "SO8025") user_agent_config = { "user_agent_extra": f"AwsSolution/{solution_id}/{solution_version}" } default_config = config.Config(**user_agent_config) default_region = os.environ.get("AWS_REGION") loghub_vpc_id = os.environ.get("DEFAULT_VPC_ID") loghub_sg_id = os.environ.get("DEFAULT_SG_ID") loghub_private_subnet_ids_str = os.environ.get("DEFAULT_PRIVATE_SUBNET_IDS") ec2 = boto3.client("ec2", config=default_config) def handle_error(func): """Decorator for exception handling""" def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except APIException as e: logger.error(e) raise e except Exception as e: logger.error(e) raise RuntimeError( "Unknown exception, please check Lambda log for more details" ) return wrapper class ClusterAutoImportManager: """ Currently do not consider processing ipv6 scenarios """ Anywhere_ipv4 = "0.0.0.0/0" def __init__( self, tags: None, ec2: boto3.Session.client, es_resp, loghub_vpc_id: str, loghub_sg_id: str, loghub_private_subnet_ids_str: str, ): self.tags = [] if tags: for tag in tags: self.tags.append( { "Key": tag["key"], "Value": tag["value"], } ) self.ec2 = ec2 """ obtain aos_vpc_id,aos_subnet_ids,aos_sg_ids from 'es.describe_elasticsearch_domain' API """ es_vpc = es_resp["DomainStatus"]["VPCOptions"] self.aos_vpc_id = es_vpc["VPCId"] self.aos_subnet_ids = es_vpc["SubnetIds"] self.aos_sg_ids = es_vpc["SecurityGroupIds"] """ the network_acl_id will be set value when calling 'validate_nacl()' """ self.aos_network_acl_id = None self.loghub_vpc_id = loghub_vpc_id self.loghub_sg_id = loghub_sg_id self.loghub_private_subnet_ids_str = loghub_private_subnet_ids_str self.loghub_private_subnet_ids = self.loghub_private_subnet_ids_str.split(",") """ By default, it is assumed that the vpcs are the same. Later, it will be verified whether the ids of the two vpcs are the same. If they are not the same, is_same_vpc will be changed to false. """ self.is_same_vpc = True """ The cidr of aos vpc does not conflict with the cidr of loghub vpc by default. When the ids of the two vpcs are different, the following code starts to verify whether the cidr is the same. """ self.vpc_peering_connection_id = None vpc_ids = [self.aos_vpc_id] if self.aos_vpc_id != self.loghub_vpc_id: self.is_same_vpc = False self.vpc_peering_connection_status = None self.loghub_vpc_subnet_ids = self.get_vpc_subnets(vpc_id=self.loghub_vpc_id) vpc_ids.append(self.loghub_vpc_id) for vpc_id in vpc_ids: try: response = self.ec2.describe_vpcs(VpcIds=[vpc_id], DryRun=False) except ClientError as e: logger.error(e) raise e # init cidr if "Vpcs" not in response or response["Vpcs"] is False: raise APIException(f"the VPC is Not Found, id is {vpc_id}") if vpc_id == self.aos_vpc_id: self.aos_cidr_block = response["Vpcs"][0]["CidrBlock"] else: self.loghub_cidr_block = response["Vpcs"][0]["CidrBlock"] if not self.check_cidr_overlaps( self.aos_cidr_block, self.loghub_cidr_block ): raise APIException("VPC CIDR is conflict with AOS VPC!") self.vpc_peering_retry = 0 else: # Obtain cidr in the same vpc scenario. try: response = self.ec2.describe_vpcs( VpcIds=[self.aos_vpc_id], DryRun=False ) except ClientError as e: logger.error(e) raise e if "Vpcs" not in response: raise APIException(f"the VPC is Not Found, id is {self.aos_vpc_id}") self.aos_cidr_block = response["Vpcs"][0]["CidrBlock"] self.loghub_cidr_block = self.aos_cidr_block # obtain aos_vpc_subnet_ids from 'ec2.describe_subnets' API self.aos_vpc_subnet_ids = self.get_vpc_subnets(vpc_id=self.aos_vpc_id) def check_cidr_overlaps(self, aos_cidr: str, loghub_cidr: str): """ Using ipaddr lib to check: True is Pass """ aos_network = ipaddr.IPNetwork(aos_cidr) loghub_network = ipaddr.IPNetwork(loghub_cidr) if not aos_network.overlaps(loghub_network) and not loghub_network.overlaps( aos_network ): return True else: return False def validate_same_vpc(self): """ Check if it is the same vpc, return True if the same """ return self.is_same_vpc def validate_sg(self): """ Using 'ec2.describe_security_group_rules' API to check whether the rules of the aos security group allow access from the loghub process security group. Rules: 1.sg_rule['IsEgress']: False is inbound rule 2.sg_rule['IpProtocol']: tcp or -1 to specify all protocols 3.sg_rule['CidrIpv4']: 3.1. check 0.0.0.0/16 in same VPC scenarios 3.2. check loghub_cidr_block or 0.0.0.0/16 in different VPC scenarios 4.check sg_rule['ReferencedGroupInfo']['GroupId'] in same VPC scenarios: loghub_sg_id 5.sg_rule['FromPort'] <= 443 and sg_rule['ToPort'] >= 443 """ success = False response = self.ec2.describe_security_group_rules( Filters=[ { "Name": "group-id", "Values": self.aos_sg_ids, }, ], DryRun=False, ) if "SecurityGroupRules" in response: sg_rules = response["SecurityGroupRules"] for sg_rule in sg_rules: if self.validate_same_vpc(): if ( sg_rule["IsEgress"] is False and ( sg_rule["IpProtocol"] == "tcp" or sg_rule["IpProtocol"] == "-1" ) and (sg_rule["FromPort"] <= 443) and (sg_rule["ToPort"] >= 443) ): if "ReferencedGroupInfo" in sg_rule: if ( sg_rule["ReferencedGroupInfo"]["GroupId"] == self.loghub_sg_id ): success = True break elif "CidrIpv4" in sg_rule: if ( sg_rule["CidrIpv4"] == ClusterAutoImportManager.Anywhere_ipv4 ): success = True break else: if "CidrIpv4" in sg_rule: if ( sg_rule["IsEgress"] is False and ( sg_rule["CidrIpv4"] == self.loghub_cidr_block or sg_rule["CidrIpv4"] == ClusterAutoImportManager.Anywhere_ipv4 ) and sg_rule["IpProtocol"] == "tcp" and sg_rule["FromPort"] <= 443 and sg_rule["ToPort"] >= 443 ): success = True break return success def validate_nacl(self): """ Using 'ec2.describe_network_acls' API to check whether the rules of the aos nacl """ success = False response = self.ec2.describe_network_acls( Filters=[ { "Name": "association.subnet-id", "Values": self.aos_vpc_subnet_ids, }, { "Name": "vpc-id", "Values": [ self.aos_vpc_id, ], }, ], DryRun=False, ) if "NetworkAcls" in response: nacls = response["NetworkAcls"] for nacl in nacls: if "NetworkAclId" in nacl: self.aos_network_acl_id = nacl["NetworkAclId"] if "Entries" in nacl: entries = nacl["Entries"] for entry in entries: if ( entry["Egress"] is False and ( entry["CidrBlock"] == ClusterAutoImportManager.Anywhere_ipv4 or entry["CidrBlock"] == self.loghub_cidr_block ) and (entry["Protocol"] == "-1" or entry["Protocol"] == "6") and entry["RuleAction"] == "allow" ): if "PortRange" in entry: if (entry["PortRange"]["From"] <= 443) and ( entry["PortRange"]["To"] >= 443 ): success = True break elif entry["Protocol"] == "-1": success = True break if success: break return success def validate_aos_vpc_routing(self): return self.validate_routing( self.aos_vpc_id, self.aos_subnet_ids, self.loghub_cidr_block ) def validate_loghub_vpc_routing(self): return self.validate_routing( self.loghub_vpc_id, self.loghub_private_subnet_ids, self.aos_cidr_block ) def validate_routing(self, vpc_id, vpc_subnet_ids, cidr_block): """ Check if routing table contains vpc_peering_connection_id. 1.No need to check in the same vpc scenario, it returns True. 2.Calling 'get_vpc_peering_connections' API to vpc_peering_connection_id. 3.Using 'ec2.describe_route_tables' API to check, if it returns True, the verification passes. 4.In different VPC scenarios, you not only need to call 'validate_aos_vpc_routing()' to verify the routing table of aos vpc, but also call 'validate_loghub_vpc_routing()' to verify the routing table of the vpc of loghub. """ success = False if not self.validate_same_vpc(): if self.get_vpc_peering_connections() is None: return False response = self.ec2.describe_route_tables( Filters=[ {"Name": "vpc-id", "Values": [vpc_id]}, { "Name": "association.subnet-id", "Values": vpc_subnet_ids, }, { "Name": "route.vpc-peering-connection-id", "Values": [self.vpc_peering_connection_id], }, { "Name": "route.destination-cidr-block", "Values": [cidr_block], }, ], DryRun=False, ) count = 0 if "RouteTables" in response: route_tables = response["RouteTables"] for route_table in route_tables: if "Routes" in route_table: routes = route_table["Routes"] for route in routes: if ( "VpcPeeringConnectionId" in route and "State" in route and route["VpcPeeringConnectionId"] == self.vpc_peering_connection_id and route["State"] == "active" ): count = count + 1 # success = True # break # if success: # break if count == len(vpc_subnet_ids): success = True else: success = True return success def get_vpc_subnets(self, vpc_id: str): """ Obtain the subnet_ids of vpc from "ec2.describe_subnets" API. """ filters = [ { "Name": "vpc-id", "Values": [ vpc_id, ], }, ] response = self.ec2.describe_subnets(Filters=filters, DryRun=False) subnet_ids = [] if "Subnets" in response: subnets = response["Subnets"] for subnet in subnets: subnet_ids.append(subnet["SubnetId"]) else: raise APIException(f"Please check the subnets of vpc, vpc id is {vpc_id}") return subnet_ids def get_vpc_peering_connections(self): """ Obtain the vpc_peering_connection_id from "ec2.describe_vpc_peering_connections" API by aos_vpc_id and loghub_vpc_id """ if not self.validate_same_vpc(): response = self.ec2.describe_vpc_peering_connections( Filters=[ { "Name": "requester-vpc-info.vpc-id", "Values": [ self.loghub_vpc_id, ], }, { "Name": "accepter-vpc-info.vpc-id", "Values": [ self.aos_vpc_id, ], }, { "Name": "status-code", "Values": ["active", "provisioning", "pending-acceptance"], }, ], DryRun=False, ) if ( "VpcPeeringConnections" in response and response["VpcPeeringConnections"] ): self.vpc_peering_connection_id = response["VpcPeeringConnections"][0][ "VpcPeeringConnectionId" ] self.vpc_peering_connection_status = response["VpcPeeringConnections"][ 0 ]["Status"] return self.vpc_peering_connection_id def create_sg_rule(self): """ add sg ingress rule to allow members in the loghub processing security group to access port 443 1.same vpc: allow loghub_sg_id 2.not the same vpc: allow loghub_cidr_block """ if self.validate_same_vpc(): ip_permissions = [ { "FromPort": 443, "IpProtocol": "tcp", "UserIdGroupPairs": [ { "Description": "Loghub Processing Rule", "GroupId": self.loghub_sg_id, "VpcId": self.aos_vpc_id, }, ], "ToPort": 443, }, ] else: ip_permissions = [ { "FromPort": 443, "IpProtocol": "tcp", "IpRanges": [ { "CidrIp": self.loghub_cidr_block, "Description": "Loghub Processing Rule", }, ], "ToPort": 443, }, ] response = self.ec2.authorize_security_group_ingress( GroupId=self.aos_sg_ids[0], IpPermissions=ip_permissions, DryRun=False, ) if "Return" in response and response["Return"] is True: return True else: return False def create_nacl_entry(self): """ Add an inbound rule to NACL to allow Loghub CIDR to allow access to port 443. This method only needs to be called when the vpc of loghub is different from that of aos """ self.ec2.create_network_acl_entry( CidrBlock=self.loghub_cidr_block, DryRun=False, Egress=False, NetworkAclId=self.aos_network_acl_id, PortRange={"From": 443, "To": 443}, Protocol="6", RuleAction="allow", RuleNumber=666, ) def create_vpc_peering_connection(self): """ Create vpc peering between Loghub vpc and aos vpc. This method only needs to be called when the vpc of loghub is different from that of AOS. """ if not self.validate_same_vpc(): tag_specifications = [] if self.tags: tag_specifications.append( { "ResourceType": "vpc-peering-connection", "Tags": self.tags, } ) response = self.ec2.create_vpc_peering_connection( DryRun=False, PeerVpcId=self.aos_vpc_id, VpcId=self.loghub_vpc_id, TagSpecifications=tag_specifications, ) if "VpcPeeringConnection" in response: vpc_peering_connection_id = response["VpcPeeringConnection"][ "VpcPeeringConnectionId" ] self.accept_vpc_peering_connection(vpc_peering_connection_id) def accept_vpc_peering_connection(self, vpc_peering_connection_id: str): try: response = self.ec2.accept_vpc_peering_connection( DryRun=False, VpcPeeringConnectionId=vpc_peering_connection_id ) if "VpcPeeringConnection" in response: status = response["VpcPeeringConnection"]["Status"] logger.info( f"accept vpc peering, status is {status},VpcPeeringConnectionId is {vpc_peering_connection_id}" ) except ClientError as ex: if ex.response["Error"]["Code"] == "InvalidVpcPeeringConnectionID.NotFound": time.sleep(1) self.vpc_peering_retry = self.vpc_peering_retry + 1 if self.vpc_peering_retry == 4: raise ex return self.accept_vpc_peering_connection(vpc_peering_connection_id) else: raise ex def create_aos_route(self): """ Create route in the ids of aos. """ aos_route_table_ids = self.get_route_table_ids( self.aos_vpc_id, self.aos_subnet_ids ) for aos_route_table_id in aos_route_table_ids: self.create_route(self.loghub_cidr_block, aos_route_table_id) def create_loghub_route(self): """ Create route in the vpc of loghub """ loghub_route_table_ids = self.get_route_table_ids( self.loghub_vpc_id, self.loghub_private_subnet_ids ) for loghub_route_table_id in loghub_route_table_ids: self.create_route(self.aos_cidr_block, loghub_route_table_id) def create_route(self, cidr_block, route_table_id): """ create a route in the route table for vpc peering connection """ try: response = self.ec2.create_route( DestinationCidrBlock=cidr_block, DryRun=False, RouteTableId=route_table_id, VpcPeeringConnectionId=self.vpc_peering_connection_id, ) if response and "Return" in response and response["Return"] is True: return True else: return False except ClientError as ex: if ex.response["Error"]["Code"] == "RouteAlreadyExists": return True else: raise ex def get_route_table_ids(self, vpc_id: str, subnet_ids: list): """ Obtain the route table id from calling 'describe_route_tables' API by vpc_id, subnet_ids """ "" response = self.ec2.describe_route_tables( Filters=[ {"Name": "vpc-id", "Values": [vpc_id]}, { "Name": "association.subnet-id", "Values": subnet_ids, }, ], DryRun=False, ) route_table_ids = [] if "RouteTables" in response: route_tables = response["RouteTables"] for route_table in route_tables: if "Associations" in route_table: associations = route_table["Associations"] for association in associations: route_table_ids.append(association["RouteTableId"]) if not route_table_ids: return self.get_main_route_table_id(vpc_id) return route_table_ids def get_main_route_table_id(self, vpc_id: str): """ Obtain the route table id from calling 'describe_route_tables' API by vpc_id, subnet_ids """ "" response = self.ec2.describe_route_tables( Filters=[ {"Name": "vpc-id", "Values": [vpc_id]}, { "Name": "association.main", "Values": ["true"], }, ], DryRun=False, ) route_table_ids = [] if "RouteTables" in response: route_tables = response["RouteTables"] for route_table in route_tables: if "Associations" in route_table: associations = route_table["Associations"] for association in associations: route_table_ids.append(association["RouteTableId"]) if not route_table_ids: raise APIException( f"Failed to import the AOS, please check the main route table of VPC, we can't do peering. VPC id is {vpc_id}" ) return route_table_ids def check_all(self): if not self.validate_sg(): self.create_sg_rule() if not self.validate_nacl(): self.create_nacl_entry() if not self.get_vpc_peering_connections(): self.create_vpc_peering_connection() elif self.vpc_peering_connection_status == "pending-acceptance": self.accept_vpc_peering_connection(self.vpc_peering_connection_id) if not self.validate_aos_vpc_routing(): self.create_aos_route() if not self.validate_loghub_vpc_routing(): self.create_loghub_route() def check_all_aos_cidr_overlaps( self, region=default_region, existed_aos_list=list() ) -> bool: """ Using ipaddr lib to check: True is Pass """ not_conflict = True import_in_same_vpc = False if self.aos_vpc_id != self.loghub_vpc_id: if not existed_aos_list or len(existed_aos_list) == 0: return True es = boto3.client("es", region_name=region, config=default_config) vpc_ids = [] for aos in existed_aos_list: domain_name = aos["domainName"] existed_aos_region = aos["region"] try: if region != existed_aos_region: region = existed_aos_region es = boto3.client( "es", region_name=region, config=default_config ) es_resp = es.describe_elasticsearch_domain(DomainName=domain_name) es_vpc = es_resp["DomainStatus"]["VPCOptions"] if self.aos_vpc_id == es_vpc["VPCId"]: import_in_same_vpc = True break vpc_ids.append(es_vpc["VPCId"]) except ClientError as e: if e.response["Error"]["Code"] == "ResourceNotFoundException": raise APIException("OpenSearch Domain Not Found") else: raise e if not import_in_same_vpc: try: response = self.ec2.describe_vpcs(VpcIds=vpc_ids, DryRun=False) except ClientError as e: logger.error(e) raise e # init cidr if "Vpcs" not in response or response["Vpcs"] is False: raise APIException(f"the VPC is Not Found, id list is {vpc_ids}") vpcs = response["Vpcs"] for vpc in vpcs: existed_aos_cidr_block = vpc["CidrBlock"] if not self.check_cidr_overlaps( self.aos_cidr_block, existed_aos_cidr_block ): raise APIException( "We can't import AOS, its CIDR conflicts with imported AOS!" ) return not_conflict