//go:build codegen // +build codegen package api import ( "bytes" "encoding/json" "fmt" "os" "sort" "text/template" ) // SmokeTestSuite defines the test suite for smoke tests. type SmokeTestSuite struct { Version int `json:"version"` DefaultRegion string `json:"defaultRegion"` TestCases []SmokeTestCase `json:"testCases"` } // SmokeTestCase provides the definition for a integration smoke test case. type SmokeTestCase struct { OpName string `json:"operationName"` Input map[string]interface{} `json:"input"` ExpectErr bool `json:"errorExpectedFromService"` } var smokeTestsCustomizations = map[string]func(*SmokeTestSuite) error{ "sts": stsSmokeTestCustomization, "waf": wafSmokeTestCustomization, "wafregional": wafRegionalSmokeTestCustomization, "iotdataplane": iotDataPlaneSmokeTestCustomization, } func iotDataPlaneSmokeTestCustomization(suite *SmokeTestSuite) error { suite.TestCases = []SmokeTestCase{} return nil } func wafSmokeTestCustomization(suite *SmokeTestSuite) error { return filterWAFCreateSqlInjectionMatchSet(suite) } func wafRegionalSmokeTestCustomization(suite *SmokeTestSuite) error { return filterWAFCreateSqlInjectionMatchSet(suite) } func filterWAFCreateSqlInjectionMatchSet(suite *SmokeTestSuite) error { const createSqlInjectionMatchSetOp = "CreateSqlInjectionMatchSet" var testCases []SmokeTestCase for _, testCase := range suite.TestCases { if testCase.OpName == createSqlInjectionMatchSetOp { continue } testCases = append(testCases, testCase) } suite.TestCases = testCases return nil } func stsSmokeTestCustomization(suite *SmokeTestSuite) error { const getSessionTokenOp = "GetSessionToken" const getCallerIdentityOp = "GetCallerIdentity" opTestMap := make(map[string][]SmokeTestCase) for _, testCase := range suite.TestCases { opTestMap[testCase.OpName] = append(opTestMap[testCase.OpName], testCase) } if _, ok := opTestMap[getSessionTokenOp]; ok { delete(opTestMap, getSessionTokenOp) } if _, ok := opTestMap[getCallerIdentityOp]; !ok { opTestMap[getCallerIdentityOp] = append(opTestMap[getCallerIdentityOp], SmokeTestCase{ OpName: getCallerIdentityOp, Input: map[string]interface{}{}, ExpectErr: false, }) } var testCases []SmokeTestCase var keys []string for name := range opTestMap { keys = append(keys, name) } sort.Strings(keys) for _, name := range keys { testCases = append(testCases, opTestMap[name]...) } suite.TestCases = testCases return nil } // BuildInputShape returns the Go code as a string for initializing the test // case's input shape. func (c SmokeTestCase) BuildInputShape(ref *ShapeRef) string { b := NewShapeValueBuilder() return fmt.Sprintf("&%s{\n%s\n}", b.GoType(ref, true), b.BuildShape(ref, c.Input, false), ) } // AttachSmokeTests attaches the smoke test cases to the API model. func (a *API) AttachSmokeTests(filename string) error { f, err := os.Open(filename) if err != nil { return fmt.Errorf("failed to open smoke tests %s, err: %v", filename, err) } defer f.Close() if err := json.NewDecoder(f).Decode(&a.SmokeTests); err != nil { return fmt.Errorf("failed to decode smoke tests %s, err: %v", filename, err) } if v := a.SmokeTests.Version; v != 1 { return fmt.Errorf("invalid smoke test version, %d", v) } if fn, ok := smokeTestsCustomizations[a.PackageName()]; ok { if err := fn(&a.SmokeTests); err != nil { return err } } return nil } // APISmokeTestsGoCode returns the Go Code string for the smoke tests. func (a *API) APISmokeTestsGoCode() string { w := bytes.NewBuffer(nil) a.resetImports() a.AddImport("context") a.AddImport("testing") a.AddImport("time") a.AddSDKImport("aws") a.AddSDKImport("aws/request") a.AddSDKImport("aws/awserr") a.AddSDKImport("aws/request") a.AddSDKImport("awstesting/integration") a.AddImport(a.ImportPath()) smokeTests := struct { API *API SmokeTestSuite }{ API: a, SmokeTestSuite: a.SmokeTests, } if err := smokeTestTmpl.Execute(w, smokeTests); err != nil { panic(fmt.Sprintf("failed to create smoke tests, %v", err)) } ignoreImports := ` var _ aws.Config var _ awserr.Error var _ request.Request ` return a.importsGoCode() + ignoreImports + w.String() } var smokeTestTmpl = template.Must(template.New(`smokeTestTmpl`).Parse(` {{- range $i, $testCase := $.TestCases }} {{- $op := index $.API.Operations $testCase.OpName }} {{- if $op }} func TestInteg_{{ printf "%02d" $i }}_{{ $op.ExportedName }}(t *testing.T) { ctx, cancelFn := context.WithTimeout(context.Background(), 5 *time.Second) defer cancelFn() sess := integration.SessionWithDefaultRegion("{{ $.DefaultRegion }}") svc := {{ $.API.PackageName }}.New(sess) params := {{ $testCase.BuildInputShape $op.InputRef }} _, err := svc.{{ $op.ExportedName }}WithContext(ctx, params, func(r *request.Request) { r.Handlers.Validate.RemoveByName("core.ValidateParametersHandler") }) {{- if $testCase.ExpectErr }} if err == nil { t.Fatalf("expect request to fail") } aerr, ok := err.(awserr.RequestFailure) if !ok { t.Fatalf("expect awserr, was %T", err) } if len(aerr.Code()) == 0 { t.Errorf("expect non-empty error code") } if len(aerr.Message()) == 0 { t.Errorf("expect non-empty error message") } if v := aerr.Code(); v == request.ErrCodeSerialization { t.Errorf("expect API error code got serialization failure") } {{- else }} if err != nil { t.Errorf("expect no error, got %v", err) } {{- end }} } {{- end }} {{- end }} `))