// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
package server
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"github.com/gorilla/mux"
)
const notFoundResponse = `
404 - Not Found
404 - Not Found
`
// BadRequestResponse represents the IMDSv2 response in the event of missing or invalid parameters in the request
const BadRequestResponse = `
400 - Bad Request
400 - Bad Request
`
// UnauthorizedResponse represents the IMDSv2 response in the event of unauthorized access
const UnauthorizedResponse = `
401 - Unauthorized
401 - Unauthorized
`
var (
// Routes represents the list of routes served by the http server
Routes []string
router = mux.NewRouter()
)
// HandlerType represents the function passed as an argument to HandleFunc
type HandlerType func(http.ResponseWriter, *http.Request)
// HandleFunc registers the handler function for the given pattern
func HandleFunc(pattern string, requestHandler HandlerType) {
router.HandleFunc(pattern, requestHandler)
}
// HandleFuncPrefix registers the handler function for the given prefix pattern
func HandleFuncPrefix(pattern string, requestHandler HandlerType) {
router.PathPrefix(pattern).HandlerFunc(requestHandler)
}
func listRoutes() {
router.Walk(func(route *mux.Route, r *mux.Router, ancestors []*mux.Route) error {
t, err := route.GetPathTemplate()
if err != nil {
return err
}
Routes = append(Routes, t)
return nil
})
}
// ListenAndServe serves all patterns setup via their respective handlers
func ListenAndServe(hostname string, port string) {
listRoutes()
host := fmt.Sprint(hostname, ":", port)
if err := http.ListenAndServe(host, trailingSlashMiddleware(router)); err != nil {
panic(err)
}
}
// FormatAndReturnJSONResponse formats the given data into JSON and returns the response
func FormatAndReturnJSONResponse(res http.ResponseWriter, data interface{}) {
res.Header().Set("Content-Type", "application/json")
var err error
var metadataPrettyJSON []byte
if metadataPrettyJSON, err = json.MarshalIndent(data, "", "\t"); err != nil {
log.Fatalf("Error while attempting to format data %s for response: %s", data, err)
}
// In order to align with IMDS formatting, it is necessary to indent the response
// EXCEPT FOR values of type list, ex: marketplaceProductCodes
metadataPrettyJSON = removeIndentFromLists(metadataPrettyJSON)
res.Write(metadataPrettyJSON)
log.Println("Returned JSON mock response successfully.")
return
}
// FormatAndReturnTextResponse formats the given data as plaintext and returns the response
func FormatAndReturnTextResponse(res http.ResponseWriter, data string) {
res.Header().Set("Content-Type", "text/plain")
res.Write([]byte(data))
log.Println("Returned text mock response successfully.")
return
}
// FormatAndReturnOctetResponse formats the given data into an octet stream and returns the response
func FormatAndReturnOctetResponse(res http.ResponseWriter, data string) {
res.Header().Set("Content-Type", "application/octet-stream")
res.Write([]byte(data))
log.Println("Returned octet stream response successfully.")
return
}
// FormatAndReturnJSONTextResponse formats the given data into JSON and returns a plaintext response
func FormatAndReturnJSONTextResponse(res http.ResponseWriter, data interface{}) {
res.Header().Set("Content-Type", "text/plain")
var err error
var metadataPrettyJSON []byte
if metadataPrettyJSON, err = json.Marshal(data); err != nil {
log.Fatalf("Error while attempting to format data %s for response: %s", data, err)
}
res.Write(metadataPrettyJSON)
log.Println("Returned JSON text/plain mock response successfully.")
return
}
// ReturnNotFoundResponse returns response with 404 Not Found
func ReturnNotFoundResponse(w http.ResponseWriter) {
http.Error(w, notFoundResponse, http.StatusNotFound)
return
}
// ReturnBadRequestResponse returns response with 400 Bad Request
func ReturnBadRequestResponse(w http.ResponseWriter) {
http.Error(w, BadRequestResponse, http.StatusBadRequest)
return
}
// ReturnUnauthorizedResponse returns response with 401 Unauthorized
func ReturnUnauthorizedResponse(w http.ResponseWriter) {
http.Error(w, UnauthorizedResponse, http.StatusUnauthorized)
return
}
// trailingSlashMiddleware will remove trailing slashes and forward the request to the path's handler
func trailingSlashMiddleware(pathHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// support "/" as a valid path
if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") {
r.URL.Path = strings.TrimSuffix(r.URL.Path, "/")
}
pathHandler.ServeHTTP(w, r)
})
}
// removeIndentFromLists takes a JSON encoding and
// removes indentation from list elements
func removeIndentFromLists(bytes []byte) []byte {
strInput := string(bytes)
i := strings.Index(strInput, "[")
j := strings.Index(strInput, "]")
if i == 0 {
// the JSON encoding is a list itself, ex: scheduled maintenance events
// do not process unless the list is an element
// WITHIN the JSON blob
return bytes
}
for i != -1 && j != -1 {
// ex: [
// "4i20ezfza3p7xx2kt2g8weu2u"
// ]
listVal := strInput[i : j+1]
listValNoFormat := strings.ReplaceAll(listVal, "\t", "")
listValNoFormat = strings.ReplaceAll(listValNoFormat, "\n", "")
// replace indented value with unformatted list, ex: ["4i20ezfza3p7xx2kt2g8weu2u"]
strInput = strings.Replace(strInput, listVal, "", 1)
strInput = strInput[:i] + listValNoFormat + "," + strInput[i+1:]
// find the next list element
listValIndex := strings.Index(strInput, listValNoFormat)
remainingString := strInput[listValIndex+len(listValNoFormat)+1:]
i = strings.Index(remainingString, "[")
j = strings.Index(remainingString, "]")
}
return []byte(strInput)
}