// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may // not use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. package server import ( "encoding/json" "fmt" "log" "net/http" "strings" "github.com/gorilla/mux" ) const notFoundResponse = ` 404 - Not Found

404 - Not Found

` // BadRequestResponse represents the IMDSv2 response in the event of missing or invalid parameters in the request const BadRequestResponse = ` 400 - Bad Request

400 - Bad Request

` // UnauthorizedResponse represents the IMDSv2 response in the event of unauthorized access const UnauthorizedResponse = ` 401 - Unauthorized

401 - Unauthorized

` var ( // Routes represents the list of routes served by the http server Routes []string router = mux.NewRouter() ) // HandlerType represents the function passed as an argument to HandleFunc type HandlerType func(http.ResponseWriter, *http.Request) // HandleFunc registers the handler function for the given pattern func HandleFunc(pattern string, requestHandler HandlerType) { router.HandleFunc(pattern, requestHandler) } // HandleFuncPrefix registers the handler function for the given prefix pattern func HandleFuncPrefix(pattern string, requestHandler HandlerType) { router.PathPrefix(pattern).HandlerFunc(requestHandler) } func listRoutes() { router.Walk(func(route *mux.Route, r *mux.Router, ancestors []*mux.Route) error { t, err := route.GetPathTemplate() if err != nil { return err } Routes = append(Routes, t) return nil }) } // ListenAndServe serves all patterns setup via their respective handlers func ListenAndServe(hostname string, port string) { listRoutes() host := fmt.Sprint(hostname, ":", port) if err := http.ListenAndServe(host, trailingSlashMiddleware(router)); err != nil { panic(err) } } // FormatAndReturnJSONResponse formats the given data into JSON and returns the response func FormatAndReturnJSONResponse(res http.ResponseWriter, data interface{}) { res.Header().Set("Content-Type", "application/json") var err error var metadataPrettyJSON []byte if metadataPrettyJSON, err = json.MarshalIndent(data, "", "\t"); err != nil { log.Fatalf("Error while attempting to format data %s for response: %s", data, err) } // In order to align with IMDS formatting, it is necessary to indent the response // EXCEPT FOR values of type list, ex: marketplaceProductCodes metadataPrettyJSON = removeIndentFromLists(metadataPrettyJSON) res.Write(metadataPrettyJSON) log.Println("Returned JSON mock response successfully.") return } // FormatAndReturnTextResponse formats the given data as plaintext and returns the response func FormatAndReturnTextResponse(res http.ResponseWriter, data string) { res.Header().Set("Content-Type", "text/plain") res.Write([]byte(data)) log.Println("Returned text mock response successfully.") return } // FormatAndReturnOctetResponse formats the given data into an octet stream and returns the response func FormatAndReturnOctetResponse(res http.ResponseWriter, data string) { res.Header().Set("Content-Type", "application/octet-stream") res.Write([]byte(data)) log.Println("Returned octet stream response successfully.") return } // FormatAndReturnJSONTextResponse formats the given data into JSON and returns a plaintext response func FormatAndReturnJSONTextResponse(res http.ResponseWriter, data interface{}) { res.Header().Set("Content-Type", "text/plain") var err error var metadataPrettyJSON []byte if metadataPrettyJSON, err = json.Marshal(data); err != nil { log.Fatalf("Error while attempting to format data %s for response: %s", data, err) } res.Write(metadataPrettyJSON) log.Println("Returned JSON text/plain mock response successfully.") return } // ReturnNotFoundResponse returns response with 404 Not Found func ReturnNotFoundResponse(w http.ResponseWriter) { http.Error(w, notFoundResponse, http.StatusNotFound) return } // ReturnBadRequestResponse returns response with 400 Bad Request func ReturnBadRequestResponse(w http.ResponseWriter) { http.Error(w, BadRequestResponse, http.StatusBadRequest) return } // ReturnUnauthorizedResponse returns response with 401 Unauthorized func ReturnUnauthorizedResponse(w http.ResponseWriter) { http.Error(w, UnauthorizedResponse, http.StatusUnauthorized) return } // trailingSlashMiddleware will remove trailing slashes and forward the request to the path's handler func trailingSlashMiddleware(pathHandler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // support "/" as a valid path if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") } pathHandler.ServeHTTP(w, r) }) } // removeIndentFromLists takes a JSON encoding and // removes indentation from list elements func removeIndentFromLists(bytes []byte) []byte { strInput := string(bytes) i := strings.Index(strInput, "[") j := strings.Index(strInput, "]") if i == 0 { // the JSON encoding is a list itself, ex: scheduled maintenance events // do not process unless the list is an element // WITHIN the JSON blob return bytes } for i != -1 && j != -1 { // ex: [ // "4i20ezfza3p7xx2kt2g8weu2u" // ] listVal := strInput[i : j+1] listValNoFormat := strings.ReplaceAll(listVal, "\t", "") listValNoFormat = strings.ReplaceAll(listValNoFormat, "\n", "") // replace indented value with unformatted list, ex: ["4i20ezfza3p7xx2kt2g8weu2u"] strInput = strings.Replace(strInput, listVal, "", 1) strInput = strInput[:i] + listValNoFormat + "," + strInput[i+1:] // find the next list element listValIndex := strings.Index(strInput, listValNoFormat) remainingString := strInput[listValIndex+len(listValNoFormat)+1:] i = strings.Index(remainingString, "[") j = strings.Index(remainingString, "]") } return []byte(strInput) }