// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may not // use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, // either express or implied. See the License for the specific language governing // permissions and limitations under the License. // Package messagebus logic to send message and get reply over IPC package messagebus import ( "encoding/json" "errors" "testing" "github.com/aws/amazon-ssm-agent/agent/log" logmocks "github.com/aws/amazon-ssm-agent/agent/mocks/log" "github.com/aws/amazon-ssm-agent/common/channel" channelmocks "github.com/aws/amazon-ssm-agent/common/channel/mocks" "github.com/aws/amazon-ssm-agent/common/message" contextmocks "github.com/aws/amazon-ssm-agent/core/app/context/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) var ( pid = 1000 workerType = message.LongRunning workerName = "worker-name" ) type MessageBusTestSuite struct { suite.Suite mockLog log.T mockHealthChannel *channelmocks.IChannel mockTerminateChannel *channelmocks.IChannel mockContext *contextmocks.ICoreAgentContext messageBus *MessageBus } func (suite *MessageBusTestSuite) SetupTest() { mockLog := logmocks.NewMockLog() suite.mockLog = mockLog suite.mockContext = &contextmocks.ICoreAgentContext{} suite.mockContext.On("With", mock.Anything).Return(suite.mockContext) suite.mockContext.On("Log").Return(mockLog) suite.mockHealthChannel = &channelmocks.IChannel{} suite.mockTerminateChannel = &channelmocks.IChannel{} channels := make(map[message.TopicType]channel.IChannel) channels[message.GetWorkerHealthRequest] = suite.mockHealthChannel channels[message.TerminateWorkerRequest] = suite.mockTerminateChannel suite.messageBus = &MessageBus{ context: suite.mockContext, surveyChannels: channels, } } // Execute the test suite func TestMessageBusTestSuite(t *testing.T) { suite.Run(t, new(MessageBusTestSuite)) } func (suite *MessageBusTestSuite) TestStart_Successful() { suite.mockHealthChannel.On("Initialize", mock.Anything).Return(nil) suite.mockHealthChannel.On("Listen", mock.Anything).Return(nil) suite.mockHealthChannel.On("SetOption", mock.Anything, mock.Anything).Return(nil) suite.mockTerminateChannel.On("Initialize", mock.Anything).Return(nil) suite.mockTerminateChannel.On("Listen", mock.Anything).Return(nil) suite.mockTerminateChannel.On("SetOption", mock.Anything, mock.Anything).Return(nil) err := suite.messageBus.Start() assert.Nil(suite.T(), err) suite.mockHealthChannel.AssertExpectations(suite.T()) suite.mockTerminateChannel.AssertExpectations(suite.T()) } func (suite *MessageBusTestSuite) TestStart_Fail() { suite.mockHealthChannel.On("Initialize", mock.Anything).Return(errors.New("failed")) err := suite.messageBus.Start() assert.NotNil(suite.T(), err) suite.mockHealthChannel.AssertExpectations(suite.T()) } func (suite *MessageBusTestSuite) TestStop_Successful() { suite.mockHealthChannel.On("Close").Return(nil) suite.mockTerminateChannel.On("Close").Return(nil) suite.messageBus.Stop() suite.mockHealthChannel.AssertExpectations(suite.T()) suite.mockTerminateChannel.AssertExpectations(suite.T()) } func (suite *MessageBusTestSuite) TestSendSurveyMessage_Successful() { healthResult, _ := message.CreateHealthResult( workerName, workerType, pid) resultString, _ := json.Marshal(healthResult) suite.mockHealthChannel.On("IsConnect").Return(true) suite.mockHealthChannel.On("Send", mock.Anything).Return(nil) suite.mockHealthChannel.On("Recv").Return(resultString, nil).Once() suite.mockHealthChannel.On("Recv").Return(nil, errors.New("stop")).Once() surveyMsg := &message.Message{ SchemaVersion: 1, Topic: message.GetWorkerHealthRequest, } results, err := suite.messageBus.SendSurveyMessage(surveyMsg) assert.Nil(suite.T(), err) assert.True(suite.T(), len(results) == 1) suite.mockHealthChannel.AssertExpectations(suite.T()) for _, result := range results { var payload message.HealthResultPayload json.Unmarshal(result.Payload, &payload) assert.Equal(suite.T(), payload.SchemaVersion, 1) assert.Equal(suite.T(), payload.Name, workerName) assert.Equal(suite.T(), payload.WorkerType, workerType) assert.Equal(suite.T(), payload.Pid, pid) } } func (suite *MessageBusTestSuite) TestSendSurveyMessage_Fail() { resultString := "can not deserialize" suite.mockHealthChannel.On("IsConnect").Return(true) suite.mockHealthChannel.On("Send", mock.Anything).Return(nil) suite.mockHealthChannel.On("Recv").Return([]byte(resultString), nil).Once() suite.mockHealthChannel.On("Recv").Return(nil, errors.New("stop")).Once() surveyMsg := &message.Message{ SchemaVersion: 1, Topic: message.GetWorkerHealthRequest, } results, err := suite.messageBus.SendSurveyMessage(surveyMsg) assert.Nil(suite.T(), err) assert.True(suite.T(), len(results) == 0) suite.mockHealthChannel.AssertExpectations(suite.T()) }