// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT-0

package aws.proserve.bcs.cem.service;

import aws.proserve.bcs.dr.aws.AwsSecurityGroup;
import aws.proserve.bcs.dr.ce.CloudEndureConstants;
import aws.proserve.bcs.dr.project.Project;
import aws.proserve.bcs.dr.vpc.Cidr;
import aws.proserve.bcs.dr.vpc.Filters;
import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
import com.amazonaws.services.ec2.model.DescribeNetworkInterfacesRequest;
import com.amazonaws.services.ec2.model.DescribeNetworkInterfacesResult;
import com.amazonaws.services.ec2.model.DescribeRouteTablesRequest;
import com.amazonaws.services.ec2.model.DescribeSecurityGroupsRequest;
import com.amazonaws.services.ec2.model.DescribeSecurityGroupsResult;
import com.amazonaws.services.ec2.model.DescribeSubnetsRequest;
import com.amazonaws.services.ec2.model.DescribeSubnetsResult;
import com.amazonaws.services.ec2.model.Subnet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.inject.Named;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;

@Named
class CemNetworkService {
    private final Logger log = LoggerFactory.getLogger(getClass());
    private final Random rnd = new Random(System.currentTimeMillis());

    /**
     * @return A map from machine name to security groups.
     */
    Map<String, List<AwsSecurityGroup>> findSecurityGroups(Project project) {
        final var ec2 = AmazonEC2ClientBuilder.standard()
                .withRegion(project.getTargetRegion().toAwsRegion())
                .build();

        final var pairs = new ArrayList<Object[]>();
        final var vpcId = project.getCemProject().getFirst().getVpcId();
        final var describeRequest = new DescribeSecurityGroupsRequest()
                .withFilters(Filters.vpcId(vpcId));
        DescribeSecurityGroupsResult result;
        do {
            result = ec2.describeSecurityGroups(describeRequest);
            describeRequest.setNextToken(result.getNextToken());

            for (var group : result.getSecurityGroups()) {
                pairs.addAll(group.getTags().stream()
                        .filter(t -> t.getKey().equals(CloudEndureConstants.TAG_MACHINE))
                        .map(t -> t.getValue().split(","))
                        .flatMap(Arrays::stream)
                        .map(s -> new Object[]{s, new AwsSecurityGroup(group.getGroupId(), group.getGroupName())})
                        .collect(Collectors.toList()));
            }
        } while (result.getNextToken() != null);
        return pairs.stream().collect(Collectors.groupingBy(s -> (String) s[0],
                Collectors.mapping(s -> (AwsSecurityGroup) s[1], Collectors.toList())));
    }

    Subnet findSubnet(Project project, boolean publicSubnet) {
        final var ec2 = AmazonEC2ClientBuilder.standard()
                .withRegion(project.getTargetRegion().toAwsRegion())
                .build();

        final var vpcId = project.getCemProject().getFirst().getVpcId();
        final var describeRequest = new DescribeSubnetsRequest().withFilters(Filters.vpcId(vpcId));
        DescribeSubnetsResult result;
        do {
            result = ec2.describeSubnets(describeRequest);
            describeRequest.setNextToken(result.getNextToken());

            for (var subnet : result.getSubnets()) {
                final var tables = ec2.describeRouteTables(new DescribeRouteTablesRequest()
                        .withFilters(Filters.associatedSubnetId(subnet.getSubnetId()))).getRouteTables();
                if (tables.isEmpty()) {
                    if (!publicSubnet) { // no route table association, private subnet
                        log.info("Found private subnet {} with empty route table", subnet.getSubnetId());
                        return subnet;
                    }
                } else {
                    boolean routeIgw = false;
                    for (var table : tables) {
                        for (var route : table.getRoutes()) {
                            if (route.getGatewayId() != null && route.getGatewayId().startsWith("igw-")) { // route to IGW, public subnet
                                routeIgw = true;
                                if (publicSubnet) {
                                    log.info("Found public subnet {} with IGW", subnet.getSubnetId());
                                    return subnet;
                                }
                            }
                        }
                    }

                    if (!routeIgw && !publicSubnet) { // no route to IGW, private subnet
                        log.info("Found private subnet {} with no route to IGW", subnet.getSubnetId());
                        return subnet;
                    }
                }
            }
        } while (result.getNextToken() == null);
        throw new IllegalStateException("Unable to find a subnet [publicSubnet = " + publicSubnet + "]");
    }

    private List<String> findUnusedAddress(Cidr cidr, List<String> addresses, int count) {
        final var unused = new ArrayList<String>();
        final var size = cidr.getSize();
        for (int i = 0; i < size; i++) {
            final var address = cidr.findAddress(rnd.nextInt((int) size));
            if (addresses.contains(address)) {
                continue;
            }

            unused.add(address);

            if (unused.size() > count) {
                return unused;
            }
        }
        return unused;
    }

    List<String> findIpAddress(Project project, Subnet subnet, int count) {
        final var ec2 = AmazonEC2ClientBuilder.standard()
                .withRegion(project.getTargetRegion().toAwsRegion())
                .build();

        final var vpcId = project.getCemProject().getFirst().getVpcId();
        final var addresses = new ArrayList<String>();
        final var request = new DescribeNetworkInterfacesRequest()
                .withFilters(Filters.vpcId(vpcId));
        DescribeNetworkInterfacesResult describeResult;
        do {
            describeResult = ec2.describeNetworkInterfaces(request);
            request.setNextToken(describeResult.getNextToken());

            for (var i : describeResult.getNetworkInterfaces()) {
                addresses.add(i.getPrivateIpAddress());
            }
        } while (describeResult.getNextToken() != null);
        return findUnusedAddress(new Cidr(subnet.getCidrBlock()), addresses, count);
    }
}