//----------------------------------------------------------------------------- // // Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). // You may not use this file except in compliance with the License. // A copy of the License is located at // // http://aws.amazon.com/apache2.0 // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. // //----------------------------------------------------------------------------- using System; using System.Threading.Tasks; using System.Data; using System.Data.Common; using System.Data.SqlClient; using Amazon.XRay.Recorder.Handlers.SqlServer; using Amazon.XRay.Recorder.Core.Internal.Utils; using Amazon.XRay.Recorder.Core.Internal.Entities; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Amazon.XRay.Recorder.Core; using Amazon.XRay.Recorder.UnitTests.Tools; using System.Configuration; namespace Amazon.XRay.Recorder.UnitTests { [TestClass] public class DbCommandInterceptorTests : TestBase { private DbCommandStub _command = new DbCommandStub(); private const string _userId = "admin"; private const string _connectionString = "Data Source=xyz.com,3306;User ID=" + _userId + ";Password=Secret.123;"; private const string _sanitizedConnectionString = "Data Source=xyz.com,3306;User ID=" + _userId; private const string _trustedConnectionString = "Server=myServerAddress;Database=myDataBase;Trusted_Connection=True;"; private const string _collectedTrustedConnectionString = "Data Source=myServerAddress;Initial Catalog=myDataBase;Integrated Security=True"; private const string _collectSqlQueriesKey = "CollectSqlQueries"; [TestInitialize] public void TestInitialize() { var connectionMock = new Mock(); connectionMock.Setup(c => c.DataSource).Returns("xyz.com,3306"); connectionMock.Setup(c => c.Database).Returns("master"); connectionMock.Setup(c => c.ServerVersion).Returns("13.0.5026.0"); connectionMock.Setup(c => c.ConnectionString).Returns(_connectionString); _command.Connection = connectionMock.Object; _command.CommandText = "SELECT a.* FROM dbo.Accounts a ..."; } [TestCleanup] public new void TestCleanup() { base.TestCleanup(); #if NETFRAMEWORK ConfigurationManager.AppSettings[_collectSqlQueriesKey] = string.Empty; AppSettings.Reset(); #endif AWSXRayRecorder.Instance.Dispose(); } [TestMethod] public void Intercept_DoesNot_CollectQueries_When_NotEnabled() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions() }; #else var recorder = new AWSXRayRecorder(); #endif recorder.BeginSegment("test"); var interceptor = new DbCommandInterceptor(recorder); // act interceptor.Intercept(() => 0, _command); // assert var segment = AWSXRayRecorder.Instance.TraceContext.GetEntity(); AssertNotCollected(recorder); recorder.EndSegment(); } [TestMethod] public async Task InterceptAsync_DoesNot_CollectQueries_When_NotEnabled() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions() }; #else var recorder = new AWSXRayRecorder(); #endif recorder.BeginSegment("test"); var interceptor = new DbCommandInterceptor(recorder); // act await interceptor.InterceptAsync(() => Task.FromResult(0), _command); // assert var segment = AWSXRayRecorder.Instance.TraceContext.GetEntity(); AssertNotCollected(recorder); recorder.EndSegment(); } [TestMethod] public void Intercept_CollectsQueries_When_DisabledGlobally_And_EnabledLocally() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions { CollectSqlQueries = false } }; #else ConfigurationManager.AppSettings[_collectSqlQueriesKey] = "false"; AppSettings.Reset(); var recorder = new AWSXRayRecorder(); #endif recorder.BeginSegment("test"); var interceptor = new DbCommandInterceptor(recorder, collectSqlQueries: true); // act interceptor.Intercept(() => 0, _command); // assert AssertCollected(recorder); recorder.EndSegment(); } [TestMethod] public async Task InterceptAsync_CollectsQueries_When_DisabledGlobally_And_EnabledLocally() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions { CollectSqlQueries = false } }; #else ConfigurationManager.AppSettings[_collectSqlQueriesKey] = "false"; AppSettings.Reset(); var recorder = new AWSXRayRecorder(); #endif recorder.BeginSegment("test"); var interceptor = new DbCommandInterceptor(recorder, collectSqlQueries: true); // act await interceptor.InterceptAsync(() => Task.FromResult(0), _command); // assert AssertCollected(recorder); recorder.EndSegment(); } [TestMethod] public void Intercept_CollectsQueries_When_EnabledGlobally() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions { CollectSqlQueries = true } }; #else ConfigurationManager.AppSettings[_collectSqlQueriesKey] = "true"; AppSettings.Reset(); var recorder = new AWSXRayRecorder(); #endif var interceptor = new DbCommandInterceptor(recorder); recorder.BeginSegment("test"); // act interceptor.Intercept(() => 0, _command); // assert AssertCollected(recorder); recorder.EndSegment(); } [TestMethod] public async Task InterceptAsync_CollectsQueries_When_EnabledGlobally() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions { CollectSqlQueries = true } }; #else ConfigurationManager.AppSettings[_collectSqlQueriesKey] = "true"; AppSettings.Reset(); var recorder = new AWSXRayRecorder(); #endif var interceptor = new DbCommandInterceptor(recorder); recorder.BeginSegment("test"); // act await interceptor.InterceptAsync(() => Task.FromResult(0), _command); // assert AssertCollected(recorder); recorder.EndSegment(); } [TestMethod] public void Intercept_DoesNot_CollectQueries_When_EnabledGlobally_And_DisabledLocally() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions { CollectSqlQueries = true } }; #else ConfigurationManager.AppSettings[_collectSqlQueriesKey] = "true"; AppSettings.Reset(); var recorder = new AWSXRayRecorder(); #endif var interceptor = new DbCommandInterceptor(recorder, collectSqlQueries: false); recorder.BeginSegment("test"); // act interceptor.Intercept(() => 0, _command); // assert AssertNotCollected(recorder); recorder.EndSegment(); } [TestMethod] public async Task InterceptAsync_DoesNot_CollectQueries_When_EnabledGlobally_And_DisabledLocally() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions { CollectSqlQueries = true } }; #else ConfigurationManager.AppSettings[_collectSqlQueriesKey] = "true"; AppSettings.Reset(); var recorder = new AWSXRayRecorder(); #endif var interceptor = new DbCommandInterceptor(recorder, collectSqlQueries: false); recorder.BeginSegment("test"); // act await interceptor.InterceptAsync(() => Task.FromResult(0), _command); // assert AssertNotCollected(recorder); recorder.EndSegment(); } [TestMethod] public void TestTrustedConnection_DoesNotCollectUserID() { // arrange #if !NETFRAMEWORK var recorder = new AWSXRayRecorder { XRayOptions = new XRayOptions() }; #else var recorder = new AWSXRayRecorder(); #endif InitializeMockTrustedConnection(); recorder.BeginSegment("test"); var interceptor = new DbCommandInterceptor(recorder); // act interceptor.Intercept(() => 0, _command); // assert var segment = AWSXRayRecorder.Instance.TraceContext.GetEntity(); AssertNotCollected(recorder, true); recorder.EndSegment(); } private void InitializeMockTrustedConnection() { var connectionMock = new Mock(); connectionMock.Setup(c => c.DataSource).Returns("xyz.com,3306"); connectionMock.Setup(c => c.Database).Returns("master"); connectionMock.Setup(c => c.ServerVersion).Returns("13.0.5026.0"); connectionMock.Setup(c => c.ConnectionString).Returns(_trustedConnectionString); _command.Connection = connectionMock.Object; _command.CommandText = "SELECT a.* FROM dbo.Accounts a ..."; } private void AssertNotCollected(AWSXRayRecorder recorder, bool isTrustedConn = false) { var segment = recorder.TraceContext.GetEntity().Subsegments[0]; AssertExpectedSqlInformation(segment, isTrustedConn); if (!isTrustedConn) { Assert.AreEqual(4, segment.Sql.Count); } else { Assert.AreEqual(3, segment.Sql.Count); } Assert.IsFalse(segment.Sql.ContainsKey("sanitized_query")); } private void AssertCollected(AWSXRayRecorder recorder, bool isTrustedConn = false) { var segment = recorder.TraceContext.GetEntity().Subsegments[0]; AssertExpectedSqlInformation(segment, isTrustedConn); if (!isTrustedConn) { Assert.AreEqual(5, segment.Sql.Count); } else { Assert.AreEqual(4, segment.Sql.Count); } Assert.AreEqual(_command.CommandText, segment.Sql["sanitized_query"]); } private void AssertExpectedSqlInformation(Subsegment segment, bool isTrustedConn = false) { Assert.IsNotNull(segment); Assert.IsNotNull(segment.Sql); Assert.AreEqual("sqlserver", segment.Sql["database_type"]); Assert.AreEqual(_command.Connection.ServerVersion, segment.Sql["database_version"]); if (!isTrustedConn) { Assert.AreEqual(_userId, segment.Sql["user"]); Assert.AreEqual(_sanitizedConnectionString, segment.Sql["connection_string"]); } else { Assert.AreEqual(_collectedTrustedConnectionString, segment.Sql["connection_string"]); } } } }