// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may not // use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, // either express or implied. See the License for the specific language governing // permissions and limitations under the License. // Package port implements session port plugin. package port import ( "errors" "testing" "time" "github.com/aws/amazon-ssm-agent/agent/appconfig" "github.com/aws/amazon-ssm-agent/agent/context" "github.com/aws/amazon-ssm-agent/agent/contracts" iohandlermocks "github.com/aws/amazon-ssm-agent/agent/framework/processor/executer/iohandler/mock" "github.com/aws/amazon-ssm-agent/agent/log" contextmocks "github.com/aws/amazon-ssm-agent/agent/mocks/context" logmocks "github.com/aws/amazon-ssm-agent/agent/mocks/log" taskmocks "github.com/aws/amazon-ssm-agent/agent/mocks/task" mgsContracts "github.com/aws/amazon-ssm-agent/agent/session/contracts" dataChannelMock "github.com/aws/amazon-ssm-agent/agent/session/datachannel/mocks" portSessionMock "github.com/aws/amazon-ssm-agent/agent/session/plugins/port/mocks" "github.com/aws/amazon-ssm-agent/agent/task" "github.com/aws/amazon-ssm-agent/common/identity" identityMock "github.com/aws/amazon-ssm-agent/common/identity/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/twinj/uuid" ) var ( mockLog = logmocks.NewMockLog() configuration = contracts.Configuration{Properties: map[string]interface{}{"portNumber": port}, SessionId: sessionId} configurationPF = contracts.Configuration{Properties: map[string]interface{}{"portNumber": port, "type": "LocalPortForwarding"}, SessionId: sessionId} configurationWithRemoteHost = contracts.Configuration{Properties: map[string]interface{}{"portNumber": port, "host": remoteHost, "type": "LocalPortForwarding"}, SessionId: sessionId} payload = []byte("testPayload") messageId = "dd01e56b-ff48-483e-a508-b5f073f31b16" schemaVersion = uint32(1) createdDate = uint64(1503434274948) clientVersion = "1.2.0" sessionId = "sessionId" port = "8080" localhost = "localhost" remoteHost = "https://remote.server.com" ) type PortTestSuite struct { suite.Suite mockContext *contextmocks.Mock mockLog log.T mockCancelFlag *taskmocks.MockCancelFlag mockDataChannel *dataChannelMock.IDataChannel mockIohandler *iohandlermocks.MockIOHandler mockPortSession *portSessionMock.IPortSession plugin *PortPlugin } // Testing initializeParameters func TestInitializeParametersWhenPortTypeIsNil(t *testing.T) { mockDataChannel := &dataChannelMock.IDataChannel{} mockDataChannel.On("GetClientVersion").Return(clientVersion) portPlugin := &PortPlugin{ context: contextmocks.NewMockDefault(), dataChannel: mockDataChannel, cancelled: make(chan struct{}), } portPlugin.initializeParameters(configuration) assert.IsType(t, &BasicPortSession{}, portPlugin.session) basicPortSession := portPlugin.session.(*BasicPortSession) assert.Equal(t, "", basicPortSession.destinationAddress) mockDataChannel.AssertExpectations(t) } func TestInitializeParametersWhenPortTypeIsLocalPortForwarding(t *testing.T) { mockDataChannel := &dataChannelMock.IDataChannel{} mockDataChannel.On("GetClientVersion").Return(clientVersion) portPlugin := &PortPlugin{ context: contextmocks.NewMockDefault(), dataChannel: mockDataChannel, cancelled: make(chan struct{}), } portPlugin.initializeParameters(configurationPF) assert.IsType(t, &MuxPortSession{}, portPlugin.session) muxPortSession := portPlugin.session.(*MuxPortSession) assert.Equal(t, "", muxPortSession.destinationAddress) mockDataChannel.AssertExpectations(t) } func TestInitializeParametersWhenPortTypeIsLocalPortForwardingAndOldClient(t *testing.T) { mockDataChannel := &dataChannelMock.IDataChannel{} mockDataChannel.On("GetClientVersion").Return("1.0.0") portPlugin := &PortPlugin{ context: contextmocks.NewMockDefault(), dataChannel: mockDataChannel, cancelled: make(chan struct{}), } portPlugin.initializeParameters(configurationPF) assert.IsType(t, &BasicPortSession{}, portPlugin.session) basicPortSession := portPlugin.session.(*BasicPortSession) assert.Equal(t, "", basicPortSession.destinationAddress) mockDataChannel.AssertExpectations(t) } func TestInitializeParametersWhenHostIsProvided(t *testing.T) { mockDataChannel := &dataChannelMock.IDataChannel{} mockDataChannel.On("GetClientVersion").Return(clientVersion) portPlugin := &PortPlugin{ context: contextmocks.NewMockDefault(), dataChannel: mockDataChannel, cancelled: make(chan struct{}), } mockIdentity := &identityMock.IAgentIdentityInner{} newEC2Identity = func(log log.T, _ *appconfig.SsmagentConfig) identity.IAgentIdentityInner { return mockIdentity } newECSIdentity = newEC2Identity mockIdentity.On("IsIdentityEnvironment").Return(true) mockMetadata := &identityMock.IMetadataIdentity{} getMetadataIdentity = func(agentIdentity identity.IAgentIdentityInner) (identity.IMetadataIdentity, bool) { return mockMetadata, true } mockMetadata.On("VpcPrimaryCIDRBlock").Return(map[string][]string{"ipv4": {"172.31.0.0/16"}, "ipv6": {"2600:1f18:64ad::/56"}}, nil) address := "127.0.0.1" lookupHost = func(host string) ([]string, error) { if host == remoteHost { return []string{address}, nil } return []string{host}, nil } portPlugin.initializeParameters(configurationWithRemoteHost) assert.IsType(t, &MuxPortSession{}, portPlugin.session) muxPortSession := portPlugin.session.(*MuxPortSession) assert.Equal(t, []string{address}, muxPortSession.addressList) assert.Equal(t, remoteHost, muxPortSession.host) mockDataChannel.AssertExpectations(t) } func (suite *PortTestSuite) SetupTest() { mockContext := contextmocks.NewMockDefault() mockCancelFlag := &taskmocks.MockCancelFlag{} mockDataChannel := &dataChannelMock.IDataChannel{} mockIohandler := new(iohandlermocks.MockIOHandler) mockPortSession := &portSessionMock.IPortSession{} suite.mockContext = mockContext suite.mockCancelFlag = mockCancelFlag suite.mockLog = mockLog suite.mockDataChannel = mockDataChannel suite.mockIohandler = mockIohandler suite.mockPortSession = mockPortSession suite.plugin = &PortPlugin{ context: mockContext, dataChannel: mockDataChannel, cancelled: make(chan struct{}), } mockIdentity := &identityMock.IAgentIdentityInner{} newEC2Identity = func(log log.T, _ *appconfig.SsmagentConfig) identity.IAgentIdentityInner { return mockIdentity } newECSIdentity = newEC2Identity mockIdentity.On("IsIdentityEnvironment").Return(true) mockMetadata := &identityMock.IMetadataIdentity{} getMetadataIdentity = func(agentIdentity identity.IAgentIdentityInner) (identity.IMetadataIdentity, bool) { return mockMetadata, true } mockMetadata.On("VpcPrimaryCIDRBlock").Return(map[string][]string{"ipv4": {"172.31.0.0/16"}, "ipv6": {"2600:1f18:64ad::/56"}}, nil) } // Testing Name func (suite *PortTestSuite) TestName() { rst := suite.plugin.name() assert.Equal(suite.T(), rst, appconfig.PluginNamePort) } // Testing GetPluginParameters func (suite *PortTestSuite) TestGetPluginParameters() { config := map[string]interface{}{"portNumber": "22", "type": "LocalPortForwarding"} assert.Equal(suite.T(), suite.plugin.GetPluginParameters(config), config) } // Testing Execute func (suite *PortTestSuite) TestExecuteWhenCancelFlagIsShutDown() { suite.mockCancelFlag.On("ShutDown").Return(true) suite.mockIohandler.On("MarkAsShutdown").Return(nil) suite.plugin.Execute( configuration, suite.mockCancelFlag, suite.mockIohandler, suite.mockDataChannel) suite.mockCancelFlag.AssertExpectations(suite.T()) suite.mockIohandler.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestExecuteWhenCancelFlagIsCancelled() { suite.mockCancelFlag.On("Canceled").Return(true) suite.mockCancelFlag.On("ShutDown").Return(false) suite.mockIohandler.On("MarkAsCancelled").Return(nil) suite.plugin.Execute( configuration, suite.mockCancelFlag, suite.mockIohandler, suite.mockDataChannel) suite.mockCancelFlag.AssertExpectations(suite.T()) suite.mockIohandler.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestExecuteWithInvalidPortNumber() { suite.mockCancelFlag.On("Canceled").Return(false) suite.mockCancelFlag.On("ShutDown").Return(false) suite.mockIohandler.On("SetStatus", contracts.ResultStatusFailed).Return(nil) suite.mockIohandler.On("SetExitCode", 1).Return(nil) suite.mockIohandler.On("SetOutput", mock.Anything).Return() suite.plugin.Execute( contracts.Configuration{Properties: map[string]interface{}{"portNumber": ""}, SessionId: "sessionId"}, suite.mockCancelFlag, suite.mockIohandler, suite.mockDataChannel) suite.mockCancelFlag.AssertExpectations(suite.T()) suite.mockIohandler.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestExecuteWhenInitializeSessionReturnsError() { suite.mockCancelFlag.On("Canceled").Return(false) suite.mockCancelFlag.On("ShutDown").Return(false) suite.mockIohandler.On("SetStatus", contracts.ResultStatusFailed).Return(nil) suite.mockIohandler.On("SetExitCode", 1).Return(nil) suite.mockIohandler.On("SetOutput", mock.Anything).Return() suite.mockDataChannel.On("GetClientVersion").Return(clientVersion) GetSession = func(context context.T, parameters PortParameters, addresses []string, cancelled chan struct{}, clientVersion string, sessionId string) (IPortSession, error) { return nil, errors.New("failed to initialize session") } suite.plugin.Execute( configuration, suite.mockCancelFlag, suite.mockIohandler, suite.mockDataChannel) suite.mockCancelFlag.AssertExpectations(suite.T()) suite.mockIohandler.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestExecute() { suite.mockCancelFlag.On("Canceled").Return(false) suite.mockCancelFlag.On("ShutDown").Return(false) suite.mockCancelFlag.On("Wait").Return(task.Completed) suite.mockIohandler.On("SetExitCode", 0).Return(nil) suite.mockIohandler.On("SetStatus", contracts.ResultStatusSuccess).Return() suite.mockDataChannel.On("GetClientVersion").Return(clientVersion) suite.mockPortSession.On("InitializeSession", mock.Anything).Return(nil) suite.mockPortSession.On("WritePump", suite.mockDataChannel).WaitUntil(time.After(time.Second)).Return(0) suite.mockPortSession.On("Stop").Return() GetSession = func(context context.T, parameters PortParameters, addresses []string, cancelled chan struct{}, clientVersion string, sessionId string) (IPortSession, error) { return suite.mockPortSession, nil } suite.plugin.Execute( configuration, suite.mockCancelFlag, suite.mockIohandler, suite.mockDataChannel) suite.mockCancelFlag.AssertExpectations(suite.T()) suite.mockIohandler.AssertExpectations(suite.T()) suite.mockDataChannel.AssertExpectations(suite.T()) suite.mockPortSession.AssertExpectations(suite.T()) } // Testing InputStreamHandler func (suite *PortTestSuite) TestInputStreamHandler() { suite.plugin.session = suite.mockPortSession suite.mockPortSession.On("HandleStreamMessage", getAgentMessage(uint32(mgsContracts.Output), payload)).Return(nil) suite.mockPortSession.On("IsConnectionAvailable").Return(true) suite.plugin.InputStreamMessageHandler(suite.mockLog, getAgentMessage(uint32(mgsContracts.Output), payload)) suite.mockPortSession.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestInputStreamHandlerSessionNotReady() { suite.plugin.InputStreamMessageHandler(suite.mockLog, getAgentMessage(uint32(mgsContracts.Output), payload)) suite.mockPortSession.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestInputStreamHandlerConnectionNotReady() { suite.plugin.session = suite.mockPortSession suite.mockPortSession.On("IsConnectionAvailable").Return(false) suite.plugin.InputStreamMessageHandler(suite.mockLog, getAgentMessage(uint32(mgsContracts.Output), payload)) suite.mockPortSession.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestValidateParametersWhenInvalidPort() { err := suite.plugin.validateParameters(PortParameters{PortNumber: ""}, configuration) assert.Contains(suite.T(), err.Error(), "Port number is empty in session properties.") } func (suite *PortTestSuite) TestValidateParametersWhenVPCHostNotAllowed() { mockContext := &contextmocks.Mock{} suite.plugin.context = mockContext mockContext.On("AppConfig").Return(appconfig.SsmagentConfig{Mgs: appconfig.MgsConfig{DeniedPortForwardingRemoteIPs: []string{"169.254.169.254", "fd00:ec2::254", "169.254.169.253", "fd00:ec2::253", "169.254.169.123", "169.254.169.250"}}}) mockContext.On("Log").Return(mockLog) err := suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "172.31.0.2"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address 172.31.0.2 is forbidden.") err = suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "2600:1f18:64ad::2"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address 2600:1f18:64ad::2 is forbidden.") mockContext.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestValidateParametersWhenDefaultDenylistHostNotAllowed() { mockContext := &contextmocks.Mock{} suite.plugin.context = mockContext mockContext.On("AppConfig").Return(appconfig.SsmagentConfig{Mgs: appconfig.MgsConfig{DeniedPortForwardingRemoteIPs: []string{"169.254.169.254", "fd00:ec2::254", "169.254.169.253", "fd00:ec2::253", "169.254.169.123", "169.254.169.250"}}}) mockContext.On("Log").Return(mockLog) err := suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "169.254.169.253"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address 169.254.169.253 is forbidden.") err = suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "fd00:ec2::253"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address fd00:ec2::253 is forbidden.") err = suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "169.254.169.254"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address 169.254.169.254 is forbidden.") err = suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "fd00:ec2::253"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address fd00:ec2::253 is forbidden.") err = suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "169.254.169.250"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address 169.254.169.250 is forbidden.") err = suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "169.254.169.123"}, configuration) assert.Contains(suite.T(), err.Error(), "Forwarding to IP address 169.254.169.123 is forbidden.") mockContext.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestValidateParametersWhenValidHostAndPort() { mockContext := &contextmocks.Mock{} suite.plugin.context = mockContext mockContext.On("AppConfig").Return(appconfig.SsmagentConfig{Mgs: appconfig.MgsConfig{DeniedPortForwardingRemoteIPs: []string{"169.254.169.254", "fd00:ec2::254", "169.254.169.253", "fd00:ec2::253"}}}) mockContext.On("Log").Return(mockLog) err := suite.plugin.validateParameters(PortParameters{PortNumber: "80", Host: "127.0.0.1"}, configuration) assert.Nil(suite.T(), err) mockContext.AssertExpectations(suite.T()) } func (suite *PortTestSuite) TestCalculateAddressMethod() { expected := []string{"172.31.0.0", "172.31.0.1", "172.31.0.2", "172.31.0.3", "172.31.255.255", "2600:1f18:64ad::", "2600:1f18:64ad::1", "2600:1f18:64ad::2", "2600:1f18:64ad::3", "2600:1f18:64ad:ff:ffff:ffff:ffff:ffff"} ipaddresses := map[string][]string{"ipv4": {"172.31.0.0/16"}, "ipv6": {"2600:1f18:64ad::/56"}} actual := calculateAddress(ipaddresses) assert.Equal(suite.T(), expected, actual) } // Execute the test suite func TestPortTestSuite(t *testing.T) { suite.Run(t, new(PortTestSuite)) } // getAgentMessage constructs and returns AgentMessage with given sequenceNumber, messageType & payload func getAgentMessage(payloadType uint32, payload []byte) mgsContracts.AgentMessage { messageUUID, _ := uuid.Parse(messageId) agentMessage := mgsContracts.AgentMessage{ MessageType: mgsContracts.InputStreamDataMessage, SchemaVersion: schemaVersion, CreatedDate: createdDate, SequenceNumber: 1, Flags: 2, MessageId: messageUUID, PayloadType: payloadType, Payload: payload, } return agentMessage }