//go:build example // +build example package rdsutils_test import ( "database/sql" "flag" "fmt" "net/url" "os" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/rds/rdsutils" ) // ExampleConnectionStringBuilder contains usage of assuming a role and using // that to build the auth token. // // For MySQL you many need to configure your MySQL driver with the TLS // certificate to ensure verified TLS handshake, // https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem // // Usage: // // ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306 func ExampleConnectionStringBuilder() { userPtr := flag.String("user", "", "user of the credentials") regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds") roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds") endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to") portPtr := flag.Int("port", 3306, "DB port to be connected to") tablePtr := flag.String("table", "test_table", "DB table to query against") dbNamePtr := flag.String("dbname", "", "DB name to query against") flag.Parse() // Check required flags. Will exit with status code 1 if // required field isn't set. if err := requiredFlags( userPtr, regionPtr, roleArnPtr, endpointPtr, portPtr, dbNamePtr, ); err != nil { fmt.Printf("Error: %v\n\n", err) flag.PrintDefaults() os.Exit(1) } sess := session.Must(session.NewSession()) creds := stscreds.NewCredentials(sess, *roleArnPtr) v := url.Values{} // required fields for DB connection v.Add("tls", "rds") v.Add("allowCleartextPasswords", "true") endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr) b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds) connectStr, err := b.WithTCPFormat().WithParams(v).Build() const dbType = "mysql" db, err := sql.Open(dbType, connectStr) // if an error is encountered here, then most likely security groups are incorrect // in the database. if err != nil { panic(fmt.Errorf("failed to open connection to the database")) } rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr)) if err != nil { panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err)) } for rows.Next() { columns, err := rows.Columns() if err != nil { panic(fmt.Errorf("failed to read columns from row: %v", err)) } fmt.Printf("rows colums:\n%d\n", len(columns)) } } func requiredFlags(flags ...interface{}) error { for _, f := range flags { switch f.(type) { case nil: return fmt.Errorf("one or more required flags were not set") } } return nil }