//-----------------------------------------------------------------------------
//
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License").
// You may not use this file except in compliance with the License.
// A copy of the License is located at
//
// http://aws.amazon.com/apache2.0
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
//
//-----------------------------------------------------------------------------
#if !NET45
using Amazon.Runtime.Internal.Util;
using Amazon.XRay.Recorder.Core;
using Amazon.XRay.Recorder.Core.Internal.Entities;
using Amazon.XRay.Recorder.Core.Sampling;
using Amazon.XRay.Recorder.Core.Strategies;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.HttpOverrides.Internal;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
using System;
using System.Collections.Generic;
using System.Net;
using System.Text;
namespace Amazon.XRay.Recorder.AutoInstrumentation.Utils
{
///
/// This class provides methods to set up segment naming strategy, process Asp.Net Core incoming
/// request, response and exception.
///
public class AspNetCoreRequestUtil
{
private static AWSXRayRecorder _recorder;
private static readonly Logger _logger = Logger.GetLogger(typeof(AspNetCoreRequestUtil));
private static readonly string SchemeDelimiter = "://";
private static readonly string X_FORWARDED_FOR = "X-Forwarded-For";
private static SegmentNamingStrategy SegmentNamingStrategy { get; set; }
///
/// Set up segment naming strategy
///
internal static void SetSegmentNamingStrategy(SegmentNamingStrategy segmentNamingStrategy)
{
SegmentNamingStrategy = segmentNamingStrategy ?? throw new ArgumentNullException("segmentNamingStrategy");
}
internal static void SetAWSXRayRecorder(AWSXRayRecorder recorder)
{
_recorder = recorder ?? throw new ArgumentNullException("recorder");
}
///
/// Process http request.
///
internal static void ProcessRequest(HttpContext httpContext)
{
HttpRequest request = httpContext.Request;
string headerString = null;
if (request.Headers.TryGetValue(TraceHeader.HeaderKey, out StringValues headerValue))
{
if (headerValue.Count >= 1)
headerString = headerValue[0];
}
if (!TraceHeader.TryParse(headerString, out TraceHeader traceHeader))
{
_logger.DebugFormat("Trace header doesn't exist or not valid : ({0}). Injecting a new one.", headerString);
traceHeader = new TraceHeader
{
RootTraceId = TraceId.NewId(),
ParentId = null,
Sampled = SampleDecision.Unknown
};
}
var segmentName = SegmentNamingStrategy.GetSegmentName(request);
bool isSampleDecisionRequested = traceHeader.Sampled == SampleDecision.Requested;
string ruleName = null;
// Make sample decision
if (traceHeader.Sampled == SampleDecision.Unknown || traceHeader.Sampled == SampleDecision.Requested)
{
string host = request.Host.Host;
string url = request.Path;
string method = request.Method;
SamplingInput samplingInput = new SamplingInput(host, url, method, segmentName, _recorder.Origin);
SamplingResponse sampleResponse = _recorder.SamplingStrategy.ShouldTrace(samplingInput);
traceHeader.Sampled = sampleResponse.SampleDecision;
ruleName = sampleResponse.RuleName;
}
if (AWSXRayRecorder.IsLambda())
{
_recorder.BeginSubsegment(segmentName);
}
else
{
SamplingResponse samplingResponse = new SamplingResponse(ruleName, traceHeader.Sampled); // get final ruleName and SampleDecision
_recorder.BeginSegment(SegmentNamingStrategy.GetSegmentName(request), traceHeader.RootTraceId, traceHeader.ParentId, samplingResponse);
}
if (!AWSXRayRecorder.Instance.IsTracingDisabled())
{
var requestAttributes = PopulateRequestAttributes(request);
_recorder.AddHttpInformation("request", requestAttributes);
}
// Mark the segment as auto-instrumented
AgentUtil.AddAutoInstrumentationMark();
if (isSampleDecisionRequested)
{
httpContext.Response.Headers.Add(TraceHeader.HeaderKey, traceHeader.ToString()); // Its recommended not to modify response header after _next.Invoke() call
}
}
private static Dictionary PopulateRequestAttributes(HttpRequest request)
{
var requestAttributes = new Dictionary();
requestAttributes["url"] = GetUrl(request);
requestAttributes["method"] = request.Method;
string xForwardedFor = GetXForwardedFor(request);
if (xForwardedFor == null)
{
requestAttributes["client_ip"] = GetClientIpAddress(request);
}
else
{
requestAttributes["client_ip"] = xForwardedFor;
// If it's outer Proxy, add "X-Forwarded-For: true" in the trace context.
if (IsOuterProxy(request))
{
requestAttributes["x_forwarded_for"] = true;
}
}
if (request.Headers.ContainsKey(HeaderNames.UserAgent))
{
requestAttributes["user_agent"] = request.Headers[HeaderNames.UserAgent].ToString();
}
return requestAttributes;
}
private static bool IsOuterProxy(HttpRequest request)
{
if (request.HttpContext.Request.Headers.TryGetValue(X_FORWARDED_FOR, out StringValues headerValue))
{
return headerValue.ToString().IndexOf(',') >= 0;
}
return false;
}
private static string GetClientIpAddress(HttpRequest request)
{
return request.HttpContext.Connection.RemoteIpAddress?.ToString();
}
///
/// Get X-Forwarded-For header.
///
private static string GetXForwardedFor(HttpRequest request)
{
String clientIp = null;
if (request.HttpContext.Request.Headers.TryGetValue(X_FORWARDED_FOR, out StringValues headerValue))
{
string[] ipEndPoints = headerValue.ToString().Split(',');
// parse the IP address from "IP:port number" end point
clientIp = ExtractIpAddress(ipEndPoints[0].Trim());
}
return string.IsNullOrEmpty(clientIp) ? null : clientIp.Split(',')[0].Trim();
}
///
/// IP end point format: "IP:Port number".
/// IPV6 formats: [xx:xx:xx:xx:xx:xx:xx:xx]:port number, [xx:xx:xx:xx:xx:xx:xx:xx], xx:xx:xx:xx:xx:xx:xx:xx.
/// IPV4 formats: x.x.x.x:port number, x.x.x.x.
/// Extract IP address from "IP:Port number" end point format.
///
private static string ExtractIpAddress(string endPoint)
{
IPAddress ipAddress = null;
if (IPEndPointParser.TryParse(endPoint, out IPEndPoint ipEndPoint))
{
ipAddress = ipEndPoint.Address;
}
return ipAddress?.ToString();
}
private static string GetUrl(HttpRequest request)
{
if (request == null)
{
_logger.DebugFormat("HTTPRequest instance is null. Cannot get URL from the request, Setting url to null");
return null;
}
var scheme = request.Scheme ?? string.Empty;
var host = request.Host.Value ?? string.Empty;
var pathBase = request.PathBase.Value ?? string.Empty;
var path = request.Path.Value ?? string.Empty;
var queryString = request.QueryString.Value ?? string.Empty;
// PERF: Calculate string length to allocate correct buffer size for StringBuilder.
var length = scheme.Length + SchemeDelimiter.Length + host.Length
+ pathBase.Length + path.Length + queryString.Length;
return new StringBuilder(length)
.Append(scheme)
.Append(SchemeDelimiter)
.Append(host)
.Append(pathBase)
.Append(path)
.Append(queryString)
.ToString();
}
///
/// Process http response.
///
internal static void ProcessResponse(HttpContext httpContext)
{
HttpResponse response = httpContext.Response;
if (!AWSXRayRecorder.Instance.IsTracingDisabled())
{
var responseAttributes = PopulateResponseAttributes(response);
_recorder.AddHttpInformation("response", responseAttributes);
}
if (AWSXRayRecorder.IsLambda())
{
_recorder.EndSubsegment();
}
else
{
_recorder.EndSegment();
}
}
private static Dictionary PopulateResponseAttributes(HttpResponse response)
{
var responseAttributes = new Dictionary();
int statusCode = (int)response.StatusCode;
AgentUtil.MarkEntityFromStatus(statusCode);
responseAttributes["status"] = statusCode;
if (response.Headers.ContentLength != null)
{
responseAttributes["content_length"] = response.Headers.ContentLength;
}
return responseAttributes;
}
internal static void ProcessException(Exception exception)
{
_recorder.AddException(exception);
}
}
}
#endif