/*
Copyright 2019 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package banner
import (
"bytes"
"context"
"net/http"
"net/url"
"strings"
"text/template"
"time"
"github.com/aws-samples/inverting-proxy/agent/metrics"
)
const (
acceptHeader = "Accept"
cacheControlHeader = "Cache-Control"
contentDispositionHeader = "Content-Disposition"
contentEncodingHeader = "Content-Encoding"
contentTypeHeader = "Content-Type"
dateHeader = "Date"
expiresHeader = "Expires"
refererHeader = "Referer"
pragmaHeader = "Pragma"
secFetchDestHeader = "Sec-Fetch-Dest"
secFetchModeHeader = "Sec-Fetch-Mode"
xFrameOptionsHeader = "X-Frame-Options"
frameWrapperTemplate = `
{{.FavIconLink}}
{{.Banner}}
`
favIconLinkTemplate = ``
)
var frameWrapperTmpl = template.Must(template.New("frame-wrapper").Parse(frameWrapperTemplate))
var favIconLinkTmpl = template.Must(template.New("fav-icon").Parse(favIconLinkTemplate))
func isHTMLRequest(r *http.Request) bool {
// We want to err on the side of not injecting a banner in case that might
// interfere with the semantics of the app. Since we can't know the expected
// semantics for the response to a POST reqest, we play it safe and only
// inject the banner in responses to GET requests.
if r.Method != http.MethodGet {
return false
}
accept := r.Header.Get(acceptHeader)
// Our injected response will be HTML, so we don't want to inject it unless
// the client explicitly stated it will accept HTML responses.
return strings.Contains(accept, "text/html")
}
func isFrameableHTMLResponse(statusCode int, responseHeader http.Header) bool {
if statusCode != http.StatusOK {
return false
}
for _, contentDisposition := range responseHeader[contentDispositionHeader] {
if strings.Contains(contentDisposition, "attachment") {
return false
}
}
for _, contentType := range responseHeader[contentTypeHeader] {
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
}
return false
}
func isAlreadyFramed(r *http.Request) bool {
if r.Header.Get(secFetchModeHeader) == "nested-navigate" || r.Header.Get(secFetchDestHeader) == "iframe" {
// If the browser told us the page is already framed, then believe it.
return true
}
if referer := r.Header.Get(refererHeader); referer != "" {
refererURL, err := url.Parse(referer)
if err == nil && refererURL.Host == r.Host && refererURL.Path == r.URL.Path {
return true
}
}
return false
}
type bannerResponseWriter struct {
wrapped http.ResponseWriter
bannerHTML string
bannerHeight string
targetURL *url.URL
favIconURL string
metricHandler *metrics.MetricHandler
wroteHeader bool
writeBytes bool
isAlreadyFramed bool
}
func (w *bannerResponseWriter) Header() http.Header {
return w.wrapped.Header()
}
func setNotCacheable(h http.Header) {
h.Set(cacheControlHeader, "no-cache, no-store, max-age=0, must-revalidate")
h.Set(dateHeader, time.Now().UTC().Format(http.TimeFormat))
h.Set(expiresHeader, time.Time{}.UTC().Format(http.TimeFormat))
h.Set(pragmaHeader, "no-cache")
}
func setXFrameOptionsSameOrigin(h http.Header) {
h.Del(xFrameOptionsHeader)
h.Set(xFrameOptionsHeader, "sameorigin")
}
func (w *bannerResponseWriter) getFavIconLink() (string, error) {
if w.favIconURL == "" {
return "", nil
}
var favIconLinkBuf bytes.Buffer
templateVals := &struct {
FavIconURL string
}{
FavIconURL: w.favIconURL,
}
if err := favIconLinkTmpl.Execute(&favIconLinkBuf, templateVals); err != nil {
return "", err
}
return favIconLinkBuf.String(), nil
}
func (w *bannerResponseWriter) getBanner(favIconLink string) ([]byte, error) {
var templateBuf bytes.Buffer
templateVals := &struct {
TargetURL string
Banner string
BannerHeight string
FavIconLink string
}{
TargetURL: w.targetURL.String(),
Banner: w.bannerHTML,
BannerHeight: w.bannerHeight,
FavIconLink: favIconLink,
}
if err := frameWrapperTmpl.Execute(&templateBuf, templateVals); err != nil {
return []byte{}, err
}
return templateBuf.Bytes(), nil
}
func (w *bannerResponseWriter) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
if !isFrameableHTMLResponse(statusCode, w.Header()) {
w.wrapped.WriteHeader(statusCode)
w.writeBytes = true
return
}
setNotCacheable(w.Header())
setXFrameOptionsSameOrigin(w.Header())
if w.isAlreadyFramed {
w.wrapped.WriteHeader(statusCode)
w.writeBytes = true
return
}
w.Header().Del(contentEncodingHeader)
favIconLink, err := w.getFavIconLink()
if err != nil {
sc := http.StatusInternalServerError
http.Error(w, err.Error(), sc)
w.metricHandler.WriteResponseCodeMetric(sc)
return
}
banner, e := w.getBanner(favIconLink)
if e != nil {
sc := http.StatusInternalServerError
http.Error(w, e.Error(), sc)
w.metricHandler.WriteResponseCodeMetric(sc)
return
}
w.metricHandler.WriteResponseCodeMetric(statusCode)
w.wrapped.WriteHeader(statusCode)
w.wrapped.Write(banner)
}
func (w *bannerResponseWriter) Write(bs []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
if !w.writeBytes {
return len(bs), nil
}
return w.wrapped.Write(bs)
}
// Proxy builds an HTTP handler that proxies to a wrapped handler but injects the given HTML banner into every HTML response.
func Proxy(ctx context.Context, wrapped http.Handler, bannerHTML, bannerHeight, favIconURL string, metricHandler *metrics.MetricHandler) (http.Handler, error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if !isHTMLRequest(r) {
wrapped.ServeHTTP(w, r)
return
}
w = &bannerResponseWriter{
wrapped: w,
bannerHTML: bannerHTML,
bannerHeight: bannerHeight,
targetURL: r.URL,
isAlreadyFramed: isAlreadyFramed(r),
favIconURL: favIconURL,
metricHandler: metricHandler,
}
wrapped.ServeHTTP(w, r)
})
return mux, nil
}