//-----------------------------------------------------------------------------
//
// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License").
// You may not use this file except in compliance with the License.
// A copy of the License is located at
//
// http://aws.amazon.com/apache2.0
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
//
//-----------------------------------------------------------------------------
using System;
using System.Data.Common;
using System.Text.RegularExpressions;
using Amazon.Runtime.Internal.Util;
using Amazon.XRay.Recorder.Core;
#if NETFRAMEWORK
using Amazon.XRay.Recorder.Core.Internal.Utils;
#endif
namespace Amazon.XRay.Recorder.Handlers.EntityFramework
{
///
/// Utilities for EFInterceptor
///
internal static class EFUtil
{
private static readonly string DefaultDbTypeEntityFramework = "entityframework";
private static readonly string SqlServerCompact35 = "sqlservercompact35";
private static readonly string SqlServerCompact40 = "sqlservercompact40";
private static readonly string MicrosoftSqlClient = "microsoft.data.sqlclient";
private static readonly string SystemSqlClient = "system.data.sqlclient";
private static readonly string SqlServer = "sqlserver";
private static readonly string[] UserIdFormatOptions = { "user id", "username", "user", "userid" }; // case insensitive
private static readonly string[] DatabaseTypes = { "sqlserver", "sqlite", "postgresql", "mysql", "firebirdsql",
"inmemory" , "cosmosdb" , "oracle" , "filecontextcore" ,
"jet" , "teradata" , "openedge" , "ibm" , "mycat" , "vfp"};
private static readonly Regex _portNumberRegex = new Regex(@"[,|:]\d+$");
private static readonly AWSXRayRecorder _recorder = AWSXRayRecorder.Instance;
private static readonly Logger _logger = Logger.GetLogger(typeof(EFUtil));
///
/// Process command to begin subsegment.
///
/// Instance of .
/// Nullable to indicate whether to collect sql query text or not.
internal static void ProcessBeginCommand(DbCommand command, bool? collectSqlQueriesOverride)
{
_recorder.BeginSubsegment(BuildSubsegmentName(command));
_recorder.SetNamespace("remote");
CollectSqlInformation(command, collectSqlQueriesOverride);
}
///
/// Process to end subsegment
///
internal static void ProcessEndCommand()
{
_recorder.EndSubsegment();
}
///
/// Process exception.
///
/// Instance of .
internal static void ProcessCommandError(Exception exception)
{
_recorder.AddException(exception);
_recorder.EndSubsegment();
}
///
/// Builds the name of the subsegment in the format database@datasource
///
/// Instance of .
/// Returns the formed subsegment name as a string.
private static string BuildSubsegmentName(DbCommand command)
=> command.Connection.Database + "@" + RemovePortNumberFromDataSource(command.Connection.DataSource);
///
/// Records the SQL information on the current subsegment,
///
private static void CollectSqlInformation(DbCommand command, bool? collectSqlQueriesOverride)
{
// Get database type from DbCommand
string databaseType = GetDataBaseType(command);
_recorder.AddSqlInformation("database_type", databaseType);
DbConnectionStringBuilder connectionStringBuilder = new DbConnectionStringBuilder
{
ConnectionString = command.Connection.ConnectionString
};
// Remove sensitive information from connection string
connectionStringBuilder.Remove("Password");
_recorder.AddSqlInformation("connection_string", connectionStringBuilder.ToString());
// Do a pre-check for UserID since in the case of TrustedConnection, a UserID may not be available.
var user_id = GetUserId(connectionStringBuilder);
if (user_id != null)
{
_recorder.AddSqlInformation("user", user_id.ToString());
}
if (ShouldCollectSqlText(collectSqlQueriesOverride))
{
_recorder.AddSqlInformation("sanitized_query", command.CommandText);
}
_recorder.AddSqlInformation("database_version", command.Connection.ServerVersion);
}
///
/// Extract database_type from .
///
/// Instance of .
/// Type of database.
internal static string GetDataBaseType(DbCommand command)
{
var typeString = command?.Connection?.GetType()?.FullName?.ToLower();
if (string.IsNullOrEmpty(typeString))
{
_logger.DebugFormat("Can't extract database type from connection, setting it as default: ({0})", DefaultDbTypeEntityFramework);
return DefaultDbTypeEntityFramework;
}
if (typeString.Contains(MicrosoftSqlClient) || typeString.Contains(SystemSqlClient))
{
return SqlServer;
}
if (typeString.Contains(SqlServerCompact35))
{
return SqlServerCompact35;
}
if (typeString.Contains(SqlServerCompact40))
{
return SqlServerCompact40;
}
foreach (var databaseType in DatabaseTypes)
{
if (typeString.Contains(databaseType))
{
return databaseType;
}
}
return typeString;
}
///
/// Extract user id from .
///
/// Instance of .
///
internal static string GetUserId(DbConnectionStringBuilder builder)
{
foreach (string key in UserIdFormatOptions)
{
if (builder.TryGetValue(key, out object value))
{
return (string)value;
}
}
return null;
}
///
/// Removes the port number from data source.
///
/// The data source.
/// The data source string with port number removed.
private static string RemovePortNumberFromDataSource(string dataSource)
{
return _portNumberRegex.Replace(dataSource, string.Empty);
}
#if !NETFRAMEWORK
private static bool ShouldCollectSqlText(bool? collectSqlQueriesOverride)
=> collectSqlQueriesOverride ?? _recorder.XRayOptions.CollectSqlQueries;
#else
private static bool ShouldCollectSqlText(bool? collectSqlQueriesOverride)
=> collectSqlQueriesOverride ?? AppSettings.CollectSqlQueries;
#endif
}
}