//go:build go1.7 // +build go1.7 package ec2metadata_test import ( "bytes" "fmt" "io" "io/ioutil" "net/http" "net/http/httptest" "path" "reflect" "strings" "sync" "sync/atomic" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting/unit" "github.com/aws/aws-sdk-go/internal/sdktesting" ) const instanceIdentityDocument = `{ "devpayProductCodes" : null, "marketplaceProductCodes" : [ "1abc2defghijklm3nopqrs4tu" ], "availabilityZone" : "us-east-1d", "privateIp" : "10.158.112.84", "version" : "2010-08-31", "region" : "us-east-1", "instanceId" : "i-1234567890abcdef0", "billingProducts" : null, "instanceType" : "t1.micro", "accountId" : "123456789012", "pendingTime" : "2015-11-19T16:32:11Z", "imageId" : "ami-5fb8c835", "kernelId" : "aki-919dcaf8", "ramdiskId" : null, "architecture" : "x86_64" }` const validIamInfo = `{ "Code" : "Success", "LastUpdated" : "2016-03-17T12:27:32Z", "InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile", "InstanceProfileId" : "AIPAABCDEFGHIJKLMN123" }` const unsuccessfulIamInfo = `{ "Code" : "Failed", "LastUpdated" : "2016-03-17T12:27:32Z", "InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile", "InstanceProfileId" : "AIPAABCDEFGHIJKLMN123" }` const ( ttlHeader = "x-aws-ec2-metadata-token-ttl-seconds" tokenHeader = "x-aws-ec2-metadata-token" ) type testType int const ( SecureTestType testType = iota InsecureTestType BadRequestTestType NotFoundRequestTestType InvalidTokenRequestTestType ServerErrorForTokenTestType PageNotFoundForTokenTestType PageNotFoundWith401TestType ThrottleErrorForTokenNoFallbackTestType ) type testServer struct { t *testing.T tokens []string activeToken atomic.Value data string } type operationListProvider struct { operationsPerformed []string } func getTokenRequiredParams(t *testing.T, fn http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if e, a := "PUT", r.Method; e != a { t.Errorf("expect %v, http method got %v", e, a) http.Error(w, "wrong method", 400) return } if len(r.Header.Get(ttlHeader)) == 0 { t.Errorf("expect ttl header to be present in the request headers, got none") http.Error(w, "wrong method", 400) return } fn(w, r) } } func newTestServer(t *testing.T, testType testType, testServer *testServer) *httptest.Server { mux := http.NewServeMux() switch testType { case SecureTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.secureGetTokenHandler)) mux.HandleFunc("/", testServer.secureGetLatestHandler) case InsecureTestType: mux.HandleFunc("/latest/api/token", testServer.insecureGetTokenHandler) mux.HandleFunc("/", testServer.insecureGetLatestHandler) case BadRequestTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.badRequestGetTokenHandler)) mux.HandleFunc("/", testServer.badRequestGetLatestHandler) case NotFoundRequestTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.secureGetTokenHandler)) mux.HandleFunc("/", testServer.notFoundRequestGetLatestHandler) case InvalidTokenRequestTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.secureGetTokenHandler)) mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler) case ServerErrorForTokenTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.serverErrorGetTokenHandler)) mux.HandleFunc("/", testServer.insecureGetLatestHandler) case PageNotFoundForTokenTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler)) mux.HandleFunc("/", testServer.insecureGetLatestHandler) case PageNotFoundWith401TestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler)) mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler) case ThrottleErrorForTokenNoFallbackTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.throtleErrorGetTokenHandler)) mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler) } return httptest.NewServer(mux) } func (s *testServer) secureGetTokenHandler(w http.ResponseWriter, r *http.Request) { token := s.tokens[0] // set the active token s.activeToken.Store(token) // rotate the token if len(s.tokens) > 1 { s.tokens = s.tokens[1:] } // set the header and response body w.Header().Set(ttlHeader, r.Header.Get(ttlHeader)) if activeToken, ok := s.activeToken.Load().(string); ok { w.Write([]byte(activeToken)) } else { s.t.Fatalf("Expected activeToken to be of type string, got %v", activeToken) } } func (s *testServer) secureGetLatestHandler(w http.ResponseWriter, r *http.Request) { if s.activeToken.Load() == nil { s.t.Errorf("expect token to have been requested, was not") http.Error(w, "", 401) return } if e, a := s.activeToken.Load(), r.Header.Get(tokenHeader); e != a { s.t.Errorf("expect %v token, got %v", e, a) http.Error(w, "", 401) return } w.Header().Set(ttlHeader, r.Header.Get(ttlHeader)) w.Write([]byte(s.data)) } func (s *testServer) insecureGetTokenHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "", 404) } func (s *testServer) insecureGetLatestHandler(w http.ResponseWriter, r *http.Request) { if len(r.Header.Get(tokenHeader)) != 0 { s.t.Errorf("Request token found, expected none") http.Error(w, "", 400) return } w.Write([]byte(s.data)) } func (s *testServer) badRequestGetTokenHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "", 400) } func (s *testServer) badRequestGetLatestHandler(w http.ResponseWriter, r *http.Request) { s.t.Errorf("Expected no call to this handler, incorrect behavior found") } func (s *testServer) notFoundRequestGetLatestHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "not found error", 404) } func (s *testServer) serverErrorGetTokenHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "", 403) } func (s *testServer) pageNotFoundGetTokenHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Page not found error", 404) } func (s *testServer) unauthorizedGetLatestHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "", 401) } func (s *testServer) throtleErrorGetTokenHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "", 429) } func (opListProvider *operationListProvider) addToOperationPerformedList(r *request.Request) { opListProvider.operationsPerformed = append(opListProvider.operationsPerformed, r.Operation.Name) } func TestEndpoint(t *testing.T) { restoreEnvFn := sdktesting.StashEnv() defer restoreEnvFn() c := ec2metadata.New(unit.Session) op := &request.Operation{ Name: "GetMetadata", HTTPMethod: "GET", HTTPPath: path.Join("/latest", "meta-data", "testpath"), } req := c.NewRequest(op, nil, nil) if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a { t.Errorf("expect %v, got %v", e, a) } } func TestGetMetadata(t *testing.T) { cases := map[string]struct { tokens []string NewServer func(t *testing.T, tokens []string) *httptest.Server expectedData string expectedError string expectedOperationsAttempted []string enableImdsFallback *bool }{ "Insecure server success case": { NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := InsecureTestType Ts := &testServer{ t: t, tokens: tokens, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedData: "IMDSProfileForGoSDK", expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetMetadata"}, }, "Secure server success case": { tokens: []string{"firstToken", "secondToken", "thirdToken"}, NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := SecureTestType Ts := &testServer{ t: t, tokens: tokens, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedData: "IMDSProfileForGoSDK", expectedError: "", expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetMetadata"}, }, "Bad token request case": { tokens: []string{"firstToken", "secondToken", "thirdToken"}, NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := BadRequestTestType Ts := &testServer{ t: t, tokens: tokens, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedError: "400", expectedOperationsAttempted: []string{"GetToken", "GetToken"}, }, "Not found no retry request case": { tokens: []string{"firstToken", "secondToken", "thirdToken"}, NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := NotFoundRequestTestType Ts := &testServer{ t: t, tokens: tokens, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedError: "404", expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetMetadata"}, }, "invalid token request case": { tokens: []string{"firstToken", "secondToken", "thirdToken"}, NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := InvalidTokenRequestTestType Ts := &testServer{ t: t, tokens: tokens, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedError: "401", expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetToken", "GetMetadata"}, }, "ServerErrorForTokenTestType": { NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := ServerErrorForTokenTestType Ts := &testServer{ t: t, tokens: []string{}, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedData: "IMDSProfileForGoSDK", expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetMetadata"}, }, "No fallback to IMDSv1": { NewServer: func(t *testing.T, tokens []string) *httptest.Server { testType := ThrottleErrorForTokenNoFallbackTestType Ts := &testServer{ t: t, tokens: []string{}, data: "IMDSProfileForGoSDK", } return newTestServer(t, testType, Ts) }, expectedError: "failed to get IMDSv2 token and fallback to IMDSv1 is disabled", // 2 attempts + 2 retries per/attempt expectedOperationsAttempted: []string{"GetToken", "GetToken", "GetToken", "GetToken", "GetToken", "GetToken"}, enableImdsFallback: aws.Bool(false), }, } for name, x := range cases { t.Run(name, func(t *testing.T) { server := x.NewServer(t, x.tokens) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), EC2MetadataEnableFallback: x.enableImdsFallback, }) c.Handlers.CompleteAttempt.PushBack(op.addToOperationPerformedList) tokenCounter := -1 c.Handlers.Send.PushBack(func(r *request.Request) { switch r.Operation.Name { case "GetToken": tokenCounter++ case "GetMetadata": curToken := r.HTTPRequest.Header.Get("x-aws-ec2-metadata-token") if len(curToken) != 0 && curToken != x.tokens[tokenCounter] { t.Errorf("expect %v token, got %v", x.tokens[tokenCounter], curToken) } } }) resp, err := c.GetMetadata("some/path") // token should stay alive, since default duration is 26000 seconds resp, err = c.GetMetadata("some/path") if len(x.expectedError) != 0 { if err == nil { t.Fatalf("expect %v error, got none", x.expectedError) } if e, a := x.expectedError, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect %v error, got %v", e, a) } } else if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := x.expectedData, resp; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := x.expectedOperationsAttempted, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Errorf("expect %v operations, got %v", e, a) } }) } } func TestGetUserData_Error(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reader := strings.NewReader(` 404 - Not Found

404 - Not Found

`) w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Length", fmt.Sprintf("%d", reader.Len())) w.WriteHeader(http.StatusNotFound) io.Copy(w, reader) })) defer server.Close() c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) resp, err := c.GetUserData() if err == nil { t.Fatalf("expect error") } if len(resp) != 0 { t.Fatalf("expect empty, got %v", resp) } if requestFailedError, ok := err.(awserr.RequestFailure); ok { if e, a := http.StatusNotFound, requestFailedError.StatusCode(); e != a { t.Fatalf("expect %v, got %v", e, a) } } } func TestGetRegion(t *testing.T) { cases := map[string]struct { NewServer func(t *testing.T) *httptest.Server expectedData string expectedError string expectedOperationsPerformed []string }{ "Insecure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := InsecureTestType Ts := &testServer{ t: t, data: instanceIdentityDocument, } return newTestServer(t, testType, Ts) }, expectedData: "us-east-1", expectedOperationsPerformed: []string{"GetToken", "GetDynamicData"}, }, "Secure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := SecureTestType Ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken", "thirdToken"}, data: instanceIdentityDocument, } return newTestServer(t, testType, Ts) }, expectedData: "us-east-1", expectedOperationsPerformed: []string{"GetToken", "GetDynamicData"}, }, "Bad request case": { NewServer: func(t *testing.T) *httptest.Server { testType := BadRequestTestType Ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken", "thirdToken"}, data: instanceIdentityDocument, } return newTestServer(t, testType, Ts) }, expectedError: "400", expectedOperationsPerformed: []string{"GetToken", "GetDynamicData"}, }, "ServerErrorForTokenTestType": { NewServer: func(t *testing.T) *httptest.Server { testType := ServerErrorForTokenTestType Ts := &testServer{ t: t, tokens: []string{}, data: instanceIdentityDocument, } return newTestServer(t, testType, Ts) }, expectedData: "us-east-1", expectedOperationsPerformed: []string{"GetToken", "GetDynamicData"}, }, } for name, x := range cases { t.Run(name, func(t *testing.T) { server := x.NewServer(t) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.Region() if len(x.expectedError) != 0 { if err == nil { t.Fatalf("expect %v error, got none", x.expectedError) } if e, a := x.expectedError, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect %v error, got %v", e, a) } } else if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := x.expectedData, resp; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := x.expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } }) } } func TestMetadataIAMInfo_success(t *testing.T) { cases := map[string]struct { NewServer func(t *testing.T) *httptest.Server expectedData string expectedError string expectedOperationsPerformed []string }{ "Insecure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := InsecureTestType Ts := &testServer{ t: t, data: validIamInfo, } return newTestServer(t, testType, Ts) }, expectedData: validIamInfo, expectedOperationsPerformed: []string{"GetToken", "GetMetadata"}, }, "Secure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := SecureTestType Ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken", "thirdToken"}, data: validIamInfo, } return newTestServer(t, testType, Ts) }, expectedData: validIamInfo, expectedOperationsPerformed: []string{"GetToken", "GetMetadata"}, }, } for name, x := range cases { t.Run(name, func(t *testing.T) { server := x.NewServer(t) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) iamInfo, err := c.IAMInfo() if len(x.expectedError) != 0 { if err == nil { t.Fatalf("expect %v error, got none", x.expectedError) } if e, a := x.expectedError, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect %v error, got %v", e, a) } } else if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := "Success", iamInfo.Code; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := x.expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } }) } } func TestMetadataIAMInfo_failure(t *testing.T) { cases := map[string]struct { NewServer func(t *testing.T) *httptest.Server expectedData string expectedError string expectedOperationsPerformed []string }{ "Insecure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := InsecureTestType Ts := &testServer{ t: t, tokens: nil, data: unsuccessfulIamInfo, } return newTestServer(t, testType, Ts) }, expectedData: unsuccessfulIamInfo, expectedOperationsPerformed: []string{"GetToken", "GetMetadata"}, }, "Secure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := SecureTestType Ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken", "thirdToken"}, data: unsuccessfulIamInfo, } return newTestServer(t, testType, Ts) }, expectedData: unsuccessfulIamInfo, expectedOperationsPerformed: []string{"GetToken", "GetMetadata"}, }, } for name, x := range cases { t.Run(name, func(t *testing.T) { server := x.NewServer(t) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) iamInfo, err := c.IAMInfo() if err == nil { t.Fatalf("expect error") } if e, a := "", iamInfo.Code; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := "", iamInfo.InstanceProfileArn; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := "", iamInfo.InstanceProfileID; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := x.expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } }) } } func TestMetadataNotAvailable(t *testing.T) { c := ec2metadata.New(unit.Session) c.Handlers.Send.Clear() c.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &http.Response{ StatusCode: int(0), Status: http.StatusText(int(0)), Body: ioutil.NopCloser(bytes.NewReader([]byte{})), } r.Error = awserr.New(request.ErrCodeRequestError, "send request failed", nil) r.Retryable = aws.Bool(true) // network errors are retryable }) if c.Available() { t.Fatalf("expect not available") } } func TestMetadataErrorResponse(t *testing.T) { c := ec2metadata.New(unit.Session) c.Handlers.Send.Clear() c.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &http.Response{ StatusCode: http.StatusBadRequest, Status: http.StatusText(http.StatusBadRequest), Body: ioutil.NopCloser(strings.NewReader("error message text")), } r.Retryable = aws.Bool(false) // network errors are retryable }) data, err := c.GetMetadata("uri/path") if e, a := "error message text", err.Error(); !strings.Contains(a, e) { t.Fatalf("expect %v to be in %v", e, a) } if len(data) != 0 { t.Fatalf("expect empty, got %v", data) } } func TestEC2RoleProviderInstanceIdentity(t *testing.T) { cases := map[string]struct { NewServer func(t *testing.T) *httptest.Server expectedData string expectedOperationsPerformed []string }{ "Insecure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := InsecureTestType Ts := &testServer{ t: t, tokens: nil, data: instanceIdentityDocument, } return newTestServer(t, testType, Ts) }, expectedData: instanceIdentityDocument, expectedOperationsPerformed: []string{"GetToken", "GetDynamicData"}, }, "Secure server success case": { NewServer: func(t *testing.T) *httptest.Server { testType := SecureTestType Ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken", "thirdToken"}, data: instanceIdentityDocument, } return newTestServer(t, testType, Ts) }, expectedData: instanceIdentityDocument, expectedOperationsPerformed: []string{"GetToken", "GetDynamicData"}, }, } for name, x := range cases { t.Run(name, func(t *testing.T) { server := x.NewServer(t) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) doc, err := c.GetInstanceIdentityDocument() if err != nil { t.Fatalf("expected no error, got %v", err) } if e, a := doc.AccountID, "123456789012"; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := doc.AvailabilityZone, "us-east-1d"; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := doc.Region, "us-east-1"; e != a { t.Fatalf("expect %v, got %v", e, a) } if e, a := x.expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } }) } } func TestEC2MetadataRetryFailure(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/latest/api/token", func(w http.ResponseWriter, r *http.Request) { if r.Method == "PUT" && r.Header.Get(ttlHeader) != "" { w.Header().Set(ttlHeader, "200") http.Error(w, "service unavailable", http.StatusServiceUnavailable) return } http.Error(w, "bad request", http.StatusBadRequest) }) // meta-data endpoint for this test, just returns the token mux.HandleFunc("/latest/meta-data/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("profile_name")) }) server := httptest.NewServer(mux) defer server.Close() c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.AfterRetry.PushBack(func(i *request.Request) { t.Logf("%v received, retrying operation %v", i.HTTPResponse.StatusCode, i.Operation.Name) }) c.Handlers.Complete.PushBack(func(i *request.Request) { t.Logf("%v operation exited with status %v", i.Operation.Name, i.HTTPResponse.StatusCode) }) resp, err := c.GetMetadata("some/path") if err != nil { t.Fatalf("Expected none, got error %v", err) } if resp != "profile_name" { t.Fatalf("Expected response to be profile_name, got %v", resp) } resp, err = c.GetMetadata("some/path") if err != nil { t.Fatalf("Expected none, got error %v", err) } if resp != "profile_name" { t.Fatalf("Expected response to be profile_name, got %v", resp) } } func TestEC2MetadataRetryOnce(t *testing.T) { var secureDataFlow bool var retry = true mux := http.NewServeMux() mux.HandleFunc("/latest/api/token", func(w http.ResponseWriter, r *http.Request) { if r.Method == "PUT" && r.Header.Get(ttlHeader) != "" { w.Header().Set(ttlHeader, "200") for retry { retry = false http.Error(w, "service unavailable", http.StatusServiceUnavailable) return } w.Write([]byte("token")) secureDataFlow = true return } http.Error(w, "bad request", http.StatusBadRequest) }) // meta-data endpoint for this test, just returns the token mux.HandleFunc("/latest/meta-data/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(r.Header.Get(tokenHeader))) }) var tokenRetryCount int server := httptest.NewServer(mux) defer server.Close() c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) // Handler on client that logs if retried c.Handlers.AfterRetry.PushBack(func(i *request.Request) { t.Logf("%v received, retrying operation %v", i.HTTPResponse.StatusCode, i.Operation.Name) tokenRetryCount++ }) _, err := c.GetMetadata("some/path") if tokenRetryCount != 1 { t.Fatalf("Expected number of retries for fetching token to be 1, got %v", tokenRetryCount) } if !secureDataFlow { t.Fatalf("Expected secure data flow to be %v, got %v", secureDataFlow, !secureDataFlow) } if err != nil { t.Fatalf("Expected none, got error %v", err) } } func TestEC2Metadata_Concurrency(t *testing.T) { ts := &testServer{ t: t, tokens: []string{"firstToken"}, data: "IMDSProfileForSDKGo", } server := newTestServer(t, SecureTestType, ts) defer server.Close() c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) var wg sync.WaitGroup wg.Add(10) for i := 0; i < 10; i++ { go func() { defer wg.Done() for j := 0; j < 10; j++ { resp, err := c.GetMetadata("some/data") if err != nil { t.Errorf("expect no error, got %v", err) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Errorf("expect %v, got %v", e, a) } } }() } wg.Wait() } func TestRequestOnMetadata(t *testing.T) { ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken"}, data: "profile_name", } server := newTestServer(t, SecureTestType, ts) defer server.Close() c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) req := c.NewRequest(&request.Operation{ Name: "Ec2Metadata request", HTTPMethod: "GET", HTTPPath: "/latest/foo", Paginator: nil, BeforePresignFn: nil, }, nil, nil) op := &operationListProvider{} c.Handlers.Complete.PushBack(op.addToOperationPerformedList) err := req.Send() if err != nil { t.Fatalf("expect no error, got %v", err) } if len(op.operationsPerformed) < 1 { t.Fatalf("Expected atleast one operation GetToken to be called on EC2Metadata client") return } if op.operationsPerformed[0] != "GetToken" { t.Fatalf("Expected GetToken operation to be called") } } func TestExhaustiveRetryToFetchToken(t *testing.T) { ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken"}, data: "IMDSProfileForSDKGo", } server := newTestServer(t, PageNotFoundForTokenTestType, ts) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("/some/path") if err != nil { t.Fatalf("Expected no error, got %v", err) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } resp, err = c.GetMetadata("/some/path") if err != nil { t.Fatalf("Expected no error, got %v", err) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } resp, err = c.GetMetadata("/some/path") if err != nil { t.Fatalf("Expected no error, got %v", err) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } resp, err = c.GetMetadata("/some/path") expectedOperationsPerformed := []string{"GetToken", "GetMetadata", "GetMetadata", "GetMetadata", "GetMetadata"} if err != nil { t.Fatalf("Expected no error, got %v", err) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } } func TestExhaustiveRetryWith401(t *testing.T) { ts := &testServer{ t: t, tokens: []string{"firstToken", "secondToken"}, data: "IMDSProfileForSDKGo", } server := newTestServer(t, PageNotFoundWith401TestType, ts) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("/some/path") if err == nil { t.Fatalf("Expected %v error, got none", err) } if e, a := "", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } resp, err = c.GetMetadata("/some/path") if err == nil { t.Fatalf("Expected %v error, got none", err) } if e, a := "", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } resp, err = c.GetMetadata("/some/path") if err == nil { t.Fatalf("Expected %v error, got none", err) } if e, a := "", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } resp, err = c.GetMetadata("/some/path") expectedOperationsPerformed := []string{"GetToken", "GetMetadata", "GetToken", "GetMetadata", "GetToken", "GetMetadata", "GetToken", "GetMetadata"} if err == nil { t.Fatalf("Expected %v error, got none", err) } if e, a := "", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } } func TestRequestTimeOut(t *testing.T) { mux := http.NewServeMux() done := make(chan bool) mux.HandleFunc("/latest/api/token", func(w http.ResponseWriter, r *http.Request) { // wait to read from channel done <-done }) mux.HandleFunc("/latest/", func(w http.ResponseWriter, r *http.Request) { if len(r.Header.Get(tokenHeader)) != 0 { http.Error(w, "", 400) return } w.Write([]byte("IMDSProfileForSDKGo")) }) server := httptest.NewServer(mux) defer server.Close() defer close(done) op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) // for test, change the timeout to 100 ms c.Config.HTTPClient.Timeout = 100 * time.Millisecond c.Handlers.Complete.PushBack(op.addToOperationPerformedList) start := time.Now() resp, err := c.GetMetadata("/some/path") if e, a := 1*time.Second, time.Since(start); e < a { t.Fatalf("expected duration of test to be less than %v, got %v", e, a) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } if err != nil { t.Fatalf("Expected no error, got %v", err) } expectedOperationsPerformed := []string{"GetToken", "GetMetadata"} if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } start = time.Now() resp, err = c.GetMetadata("/some/path") if e, a := 1*time.Second, time.Since(start); e < a { t.Fatalf("expected duration of test to be less than %v, got %v", e, a) } if e, a := "IMDSProfileForSDKGo", resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } if err != nil { t.Fatalf("Expected no error, got %v", err) } expectedOperationsPerformed = []string{"GetToken", "GetMetadata", "GetToken", "GetMetadata"} if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } } func TestTokenExpiredBehavior(t *testing.T) { tokens := []string{"firstToken", "secondToken", "thirdToken"} var activeToken string mux := http.NewServeMux() mux.HandleFunc("/latest/api/token", func(w http.ResponseWriter, r *http.Request) { if r.Method == "PUT" && r.Header.Get(ttlHeader) != "" { // set ttl to 0, so TTL is expired. w.Header().Set(ttlHeader, "0") activeToken = tokens[0] if len(tokens) > 1 { tokens = tokens[1:] } w.Write([]byte(activeToken)) return } http.Error(w, "bad request", http.StatusBadRequest) }) // meta-data endpoint for this test, just returns the token mux.HandleFunc("/latest/meta-data/", func(w http.ResponseWriter, r *http.Request) { w.Header().Set(ttlHeader, r.Header.Get(ttlHeader)) w.Write([]byte(r.Header.Get(tokenHeader))) }) server := httptest.NewServer(mux) defer server.Close() op := &operationListProvider{} c := ec2metadata.New(unit.Session, &aws.Config{ Endpoint: aws.String(server.URL), }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("/some/path") if err != nil { t.Fatalf("Expected no error, got %v", err) } if e, a := activeToken, resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } // store the token received before var firstToken = activeToken resp, err = c.GetMetadata("/some/path") if err != nil { t.Fatalf("Expected no error, got %v", err) } if e, a := activeToken, resp; e != a { t.Fatalf("Expected %v, got %v", e, a) } // Since TTL is 0, we should have received a new token if firstToken == activeToken { t.Fatalf("Expected token should have expired, and not the same") } expectedOperationsPerformed := []string{"GetToken", "GetMetadata", "GetToken", "GetMetadata"} if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) { t.Fatalf("expect %v operations, got %v", e, a) } }