// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 // Package acm provides a client to make API requests to AWS Certificate Manager. package acm import ( "context" "fmt" "strings" "sync" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/acm" "github.com/dustin/go-humanize/english" "golang.org/x/sync/errgroup" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" ) const ( waitForFindValidAliasesTimeout = 10 * time.Second ) type api interface { DescribeCertificateWithContext(ctx aws.Context, input *acm.DescribeCertificateInput, opts ...request.Option) (*acm.DescribeCertificateOutput, error) } // ACM wraps an AWS Certificate Manager client. type ACM struct { client api } // New returns an ACM struct configured against the input session. func New(s *session.Session) *ACM { return &ACM{ client: acm.New(s), } } // ValidateCertAliases validates if aliases are all valid against the provided ACM certificates. func (a *ACM) ValidateCertAliases(aliases []string, certs []string) error { validAliases := make(map[string]bool) domainsOfCert := make(map[string][]string) ctx, cancelWait := context.WithTimeout(context.Background(), waitForFindValidAliasesTimeout) defer cancelWait() g, ctx := errgroup.WithContext(ctx) var mux sync.Mutex for i := range certs { cert := certs[i] g.Go(func() error { domains, err := a.validDomainsOfCert(ctx, cert) if err != nil { return err } validCertAliases := filterValidAliases(domains, aliases) mux.Lock() defer mux.Unlock() domainsOfCert[cert] = domains for _, alias := range validCertAliases { validAliases[alias] = true } return nil }) } if err := g.Wait(); err != nil { return err } for _, alias := range aliases { if !validAliases[alias] { return &errInValidAliasAgainstCert{ certs: certs, alias: alias, domainsOfCert: domainsOfCert, } } } return nil } func (a *ACM) validDomainsOfCert(ctx context.Context, cert string) ([]string, error) { resp, err := a.client.DescribeCertificateWithContext(ctx, &acm.DescribeCertificateInput{ CertificateArn: aws.String(cert), }) if err != nil { return nil, fmt.Errorf("describe certificate %s: %w", cert, err) } var domainsOfCert []*string domainsOfCert = append(domainsOfCert, resp.Certificate.SubjectAlternativeNames...) return aws.StringValueSlice(domainsOfCert), err } func filterValidAliases(domains []string, aliases []string) []string { domainSet := make(map[string]bool, len(domains)) for _, v := range domains { domainSet[v] = true } var validAliases []string for _, alias := range aliases { // See https://docs.aws.amazon.com/acm/latest/userguide/acm-certificate.html wildCardMatchedAlias := "*" + alias[strings.Index(alias, "."):] if domainSet[alias] || domainSet[wildCardMatchedAlias] { validAliases = append(validAliases, alias) } } return validAliases } type errInValidAliasAgainstCert struct { certs []string alias string domainsOfCert map[string][]string } func (e *errInValidAliasAgainstCert) Error() string { return fmt.Sprintf("%s is not a valid domain against %s", e.alias, strings.Join(e.certs, ",")) } func (e *errInValidAliasAgainstCert) RecommendActions() string { var logMsg string logMsg = fmt.Sprintf("Please use aliases that are protected by %s your imported:\n", english.Plural(len(e.certs), "certificate", "")) for cert, sans := range e.domainsOfCert { logMsg += fmt.Sprintf("%q: %s\n", cert, english.WordSeries(sans, ",")) } return logMsg }