// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT-0 package extension import ( "bytes" "context" "encoding/json" "fmt" "io/ioutil" "net/http" ) // RegisterResponse is the body of the response for /register type RegisterResponse struct { FunctionName string `json:"functionName"` FunctionVersion string `json:"functionVersion"` Handler string `json:"handler"` } // NextEventResponse is the response for /event/next type NextEventResponse struct { EventType EventType `json:"eventType"` DeadlineMs int64 `json:"deadlineMs"` RequestID string `json:"requestId"` InvokedFunctionArn string `json:"invokedFunctionArn"` Tracing Tracing `json:"tracing"` } // Tracing is part of the response for /event/next type Tracing struct { Type string `json:"type"` Value string `json:"value"` } // StatusResponse is the body of the response for /init/error and /exit/error type StatusResponse struct { Status string `json:"status"` } // EventType represents the type of events recieved from /event/next type EventType string const ( // Invoke is a lambda invoke Invoke EventType = "INVOKE" // Shutdown is a shutdown event for the environment Shutdown EventType = "SHUTDOWN" extensionNameHeader = "Lambda-Extension-Name" extensionIdentiferHeader = "Lambda-Extension-Identifier" extensionErrorType = "Lambda-Extension-Function-Error-Type" ) // Client is a simple client for the Lambda Extensions API type Client struct { baseURL string httpClient *http.Client ExtensionID string } // NewClient returns a Lambda Extensions API client func NewClient(awsLambdaRuntimeAPI string) *Client { baseURL := fmt.Sprintf("http://%s/2020-01-01/extension", awsLambdaRuntimeAPI) return &Client{ baseURL: baseURL, httpClient: &http.Client{}, } } // Register will register the extension with the Extensions API func (e *Client) Register(ctx context.Context, filename string) (*RegisterResponse, error) { const action = "/register" url := e.baseURL + action reqBody, err := json.Marshal(map[string]interface{}{ "events": []EventType{Invoke, Shutdown}, }) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody)) if err != nil { return nil, err } httpReq.Header.Set(extensionNameHeader, filename) httpRes, err := e.httpClient.Do(httpReq) if err != nil { return nil, err } if httpRes.StatusCode != 200 { return nil, fmt.Errorf("request failed with status %s", httpRes.Status) } defer httpRes.Body.Close() body, err := ioutil.ReadAll(httpRes.Body) if err != nil { return nil, err } res := RegisterResponse{} err = json.Unmarshal(body, &res) if err != nil { return nil, err } e.ExtensionID = httpRes.Header.Get(extensionIdentiferHeader) fmt.Println("Extensionid:", e.ExtensionID) return &res, nil } // NextEvent blocks while long polling for the next lambda invoke or shutdown func (e *Client) NextEvent(ctx context.Context) (*NextEventResponse, error) { const action = "/event/next" url := e.baseURL + action httpReq, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } httpReq.Header.Set(extensionIdentiferHeader, e.ExtensionID) httpRes, err := e.httpClient.Do(httpReq) if err != nil { return nil, err } if httpRes.StatusCode != 200 { return nil, fmt.Errorf("request failed with status %s", httpRes.Status) } defer httpRes.Body.Close() body, err := ioutil.ReadAll(httpRes.Body) if err != nil { return nil, err } res := NextEventResponse{} err = json.Unmarshal(body, &res) if err != nil { return nil, err } return &res, nil } // InitError reports an initialization error to the platform. Call it when you registered but failed to initialize func (e *Client) InitError(ctx context.Context, errorType string) (*StatusResponse, error) { const action = "/init/error" url := e.baseURL + action httpReq, err := http.NewRequestWithContext(ctx, "POST", url, nil) if err != nil { return nil, err } httpReq.Header.Set(extensionIdentiferHeader, e.ExtensionID) httpReq.Header.Set(extensionErrorType, errorType) httpRes, err := e.httpClient.Do(httpReq) if err != nil { return nil, err } if httpRes.StatusCode != 200 { return nil, fmt.Errorf("request failed with status %s", httpRes.Status) } defer httpRes.Body.Close() body, err := ioutil.ReadAll(httpRes.Body) if err != nil { return nil, err } res := StatusResponse{} err = json.Unmarshal(body, &res) if err != nil { return nil, err } return &res, nil } // ExitError reports an error to the platform before exiting. Call it when you encounter an unexpected failure func (e *Client) ExitError(ctx context.Context, errorType string) (*StatusResponse, error) { const action = "/exit/error" url := e.baseURL + action httpReq, err := http.NewRequestWithContext(ctx, "POST", url, nil) if err != nil { return nil, err } httpReq.Header.Set(extensionIdentiferHeader, e.ExtensionID) httpReq.Header.Set(extensionErrorType, errorType) httpRes, err := e.httpClient.Do(httpReq) if err != nil { return nil, err } if httpRes.StatusCode != 200 { return nil, fmt.Errorf("request failed with status %s", httpRes.Status) } defer httpRes.Body.Close() body, err := ioutil.ReadAll(httpRes.Body) if err != nil { return nil, err } res := StatusResponse{} err = json.Unmarshal(body, &res) if err != nil { return nil, err } return &res, nil }