//go:build !func_test // Package aws contains functionality that wraps the AWS SDK package aws import ( "context" "errors" "fmt" "os" "time" "github.com/aws-cloudformation/rain/internal/config" "github.com/aws-cloudformation/rain/internal/console" "github.com/aws-cloudformation/rain/internal/console/spinner" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/middleware" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" smithymiddleware "github.com/aws/smithy-go/middleware" ) // MFAProvider is called by the AWS SDK when an MFA token number // is required during authentication func MFAProvider() (string, error) { spinner.Pause() defer func() { fmt.Println() spinner.Resume() }() return console.Ask("MFA Token:"), nil } var awsCfg *aws.Config var creds aws.Credentials var defaultSessionName = fmt.Sprintf("%s-%s", config.NAME, config.VERSION) var lastSessionName = defaultSessionName func loadConfig(ctx context.Context, sessionName string) *aws.Config { // Credential configs var configs = make([]func(*awsconfig.LoadOptions) error, 0) // Uncomment for testing against a local endpoint //configs = append(configs, awsconfig.WithEndpointResolver(aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { // return aws.Endpoint{URL: "http://localhost:8000"}, nil //}))) // Add user-agent configs = append(configs, awsconfig.WithAPIOptions( []func(*smithymiddleware.Stack) error{ middleware.AddUserAgentKeyValue(config.NAME, config.VERSION), middleware.AddSDKAgentKeyValue(middleware.ApplicationIdentifier, config.NAME, config.VERSION), }, )) // Add MFA provider and Rain session name configs = append(configs, awsconfig.WithAssumeRoleCredentialOptions(func(options *stscreds.AssumeRoleOptions) { options.RoleSessionName = sessionName options.TokenProvider = MFAProvider })) // Supplied profile if config.Profile != "" { configs = append(configs, awsconfig.WithSharedConfigProfile(config.Profile)) } else if p := os.Getenv("AWS_PROFILE"); p != "" { config.Profile = p } // Supplied region if config.Region != "" { configs = append(configs, awsconfig.WithRegion(config.Region)) } else if r := os.Getenv("AWS_DEFAULT_REGION"); r != "" { config.Region = r } cfg, err := awsconfig.LoadDefaultConfig(context.Background(), configs...) if err != nil { panic(errors.New("unable to find valid credentials")) } if cfg.Region == "" { panic(errors.New("a region was not specified. You can run 'aws configure' or choose a profile with a region")) } // Check for validity creds, err = cfg.Credentials.Retrieve(context.Background()) if err != nil { config.Debugf("Error retreiving creds: %s", err.Error()) panic(errors.New("could not establish AWS credentials; please run 'aws configure' or choose a profile")) } return &cfg } // Config loads an aws.Config based on current settings func Config() aws.Config { return NamedConfig(defaultSessionName) } // NamedConfig loads an aws.Config based on current settings // with configurable session name func NamedConfig(sessionName string) aws.Config { message := "Loading AWS config" if creds.CanExpire && time.Until(creds.Expires) < time.Minute { // Check for expiry message = "Refreshing AWS credentials" awsCfg = nil } else if lastSessionName != sessionName { message = "Reloading AWS credentials" awsCfg = nil } if awsCfg == nil { spinner.Push(message) awsCfg = loadConfig(context.Background(), sessionName) spinner.Pop() } return *awsCfg } // SetRegion is used to set the current AWS region func SetRegion(region string) { awsCfg.Region = region }