// Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. //go:build !lambda.norpc // +build !lambda.norpc package lambda import ( "context" "encoding/json" "errors" "io" "os" "strconv" "strings" "testing" "time" "github.com/aws/aws-lambda-go/lambda/messages" "github.com/aws/aws-lambda-go/lambdacontext" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testWrapperHandler func(ctx context.Context, input []byte) (interface{}, error) func (h testWrapperHandler) Invoke(ctx context.Context, payload []byte) ([]byte, error) { response, err := h(ctx, payload) if err != nil { return nil, err } responseBytes, err := json.Marshal(response) if err != nil { return nil, err } return responseBytes, nil } // verify testWrapperHandler implements Handler var _ Handler = (testWrapperHandler)(nil) func TestInvoke(t *testing.T) { srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { if deadline, ok := ctx.Deadline(); ok { return deadline.UnixNano(), nil } return nil, errors.New("!?!?!?!?!") }, )) deadline := time.Now() var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{ Deadline: messages.InvokeRequest_Timestamp{ Seconds: deadline.Unix(), Nanos: int64(deadline.Nanosecond()), }}, &response) assert.NoError(t, err) var responseValue int64 assert.NoError(t, json.Unmarshal(response.Payload, &responseValue)) assert.Equal(t, deadline.UnixNano(), responseValue) } func TestInvokeWithContext(t *testing.T) { key := struct{}{} srv := NewFunction(&handlerOptions{ handlerFunc: func(ctx context.Context, _ []byte) (io.Reader, error) { assert.Equal(t, "dummy", ctx.Value(key)) if deadline, ok := ctx.Deadline(); ok { return strings.NewReader(strconv.FormatInt(deadline.UnixNano(), 10)), nil } return nil, errors.New("!?!?!?!?!") }, baseContext: context.WithValue(context.Background(), key, "dummy"), }) deadline := time.Now() var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{ Deadline: messages.InvokeRequest_Timestamp{ Seconds: deadline.Unix(), Nanos: int64(deadline.Nanosecond()), }}, &response) assert.NoError(t, err) var responseValue int64 assert.NoError(t, json.Unmarshal(response.Payload, &responseValue)) assert.Equal(t, deadline.UnixNano(), responseValue) } type CustomError struct{} func (e CustomError) Error() string { return "Something bad happened!" } func TestCustomError(t *testing.T) { srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { return nil, CustomError{} }, )) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{}, &response) assert.NoError(t, err) assert.Nil(t, response.Payload) assert.Equal(t, "Something bad happened!", response.Error.Message) assert.Equal(t, "CustomError", response.Error.Type) } type CustomError2 struct{} func (e *CustomError2) Error() string { return "Something bad happened!" } func TestCustomErrorRef(t *testing.T) { srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { return nil, &CustomError2{} }, )) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{}, &response) assert.NoError(t, err) assert.Nil(t, response.Payload) assert.Equal(t, "Something bad happened!", response.Error.Message) assert.Equal(t, "CustomError2", response.Error.Type) } func TestContextPlumbing(t *testing.T) { srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { lc, _ := lambdacontext.FromContext(ctx) return lc, nil }, )) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{ CognitoIdentityId: "dummyident", CognitoIdentityPoolId: "dummypool", ClientContext: []byte(`{ "Client": { "app_title": "dummytitle", "installation_id": "dummyinstallid", "app_version_code": "dummycode", "app_package_name": "dummyname" } }`), RequestId: "dummyid", InvokedFunctionArn: "dummyarn", }, &response) assert.NoError(t, err) assert.NotNil(t, response.Payload) expected := ` { "AwsRequestID": "dummyid", "InvokedFunctionArn": "dummyarn", "Identity": { "CognitoIdentityID": "dummyident", "CognitoIdentityPoolID": "dummypool" }, "ClientContext": { "Client": { "installation_id": "dummyinstallid", "app_title": "dummytitle", "app_version_code": "dummycode", "app_package_name": "dummyname" }, "env": null, "custom": null } } ` assert.JSONEq(t, expected, string(response.Payload)) } func TestXAmznTraceID(t *testing.T) { type XRayResponse struct { Env string Ctx string } srv := NewFunction(testWrapperHandler( func(ctx context.Context, input []byte) (interface{}, error) { return &XRayResponse{ Env: os.Getenv("_X_AMZN_TRACE_ID"), Ctx: ctx.Value("x-amzn-trace-id").(string), }, nil }, )) sequence := []struct { Input string Expected string }{ { "", `{"Env": "", "Ctx": ""}`, }, { "dummyid", `{"Env": "dummyid", "Ctx": "dummyid"}`, }, { "", `{"Env": "", "Ctx": ""}`, }, { "123dummyid", `{"Env": "123dummyid", "Ctx": "123dummyid"}`, }, { "", `{"Env": "", "Ctx": ""}`, }, { "", `{"Env": "", "Ctx": ""}`, }, { "567", `{"Env": "567", "Ctx": "567"}`, }, { "hihihi", `{"Env": "hihihi", "Ctx": "hihihi"}`, }, } for i, test := range sequence { var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{XAmznTraceId: test.Input}, &response) require.NoError(t, err, "failed test sequence[%d]", i) assert.JSONEq(t, test.Expected, string(response.Payload), "failed test sequence[%d]", i) } } type closeableResponse struct { reader io.Reader closed bool } func (c *closeableResponse) Read(p []byte) (int, error) { return c.reader.Read(p) } func (c *closeableResponse) Close() error { c.closed = true return nil } type readerError struct { err error } func (r *readerError) Read(_ []byte) (int, error) { return 0, r.err } func TestRPCModeInvokeClosesCloserIfResponseIsCloser(t *testing.T) { handlerResource := &closeableResponse{ reader: strings.NewReader(""), closed: false, } srv := NewFunction(newHandler(func() (interface{}, error) { return handlerResource, nil })) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{}, &response) require.NoError(t, err) assert.Equal(t, "", string(response.Payload)) assert.True(t, handlerResource.closed) } func TestRPCModeInvokeReaderErrorPropogated(t *testing.T) { handlerResource := &closeableResponse{ reader: &readerError{errors.New("yolo")}, closed: false, } srv := NewFunction(newHandler(func() (interface{}, error) { return handlerResource, nil })) var response messages.InvokeResponse err := srv.Invoke(&messages.InvokeRequest{}, &response) require.NoError(t, err) assert.Equal(t, "", string(response.Payload)) assert.Equal(t, "yolo", response.Error.Message) assert.True(t, handlerResource.closed) }