// Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may not // use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, // either express or implied. See the License for the specific language governing // permissions and limitations under the License. // Package websocketutil contains methods for interacting with websocket connections. package websocketutil import ( "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/aws/amazon-ssm-agent/agent/appconfig" "github.com/aws/amazon-ssm-agent/agent/mocks/log" "github.com/aws/amazon-ssm-agent/agent/network" mgsConfig "github.com/aws/amazon-ssm-agent/agent/session/config" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } var dialerInput = &websocket.Dialer{ TLSClientConfig: network.GetDefaultTLSConfig(log.NewMockLog(), appconfig.DefaultConfig()), Proxy: http.ProxyFromEnvironment, WriteBufferSize: mgsConfig.ControlChannelWriteBufferSizeLimit, } func handlerToBeTested(w http.ResponseWriter, req *http.Request) { conn, err := upgrader.Upgrade(w, req, nil) if err != nil { http.Error(w, fmt.Sprintf("cannot upgrade: %v", err), http.StatusInternalServerError) } mt, p, err := conn.ReadMessage() if err != nil { return } conn.WriteMessage(mt, []byte("hello "+string(p))) } func TestWebsocketUtilOpenCloseConnection(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested)) u, _ := url.Parse(srv.URL) u.Scheme = "ws" var log = log.NewMockLog() appConfig := appconfig.SsmagentConfig{} var ws = NewWebsocketUtil(log, appConfig, nil) conn, _ := ws.OpenConnection(u.String(), http.Header{}) assert.NotNil(t, conn, "Open connection failed.") err := ws.CloseConnection(conn) assert.Nil(t, err, "Error closing the websocket connection.") } func TestWebsocketUtilOpenCloseConnectionWithWriteBuffer(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested)) u, _ := url.Parse(srv.URL) u.Scheme = "ws" var log = log.NewMockLog() appConfig := appconfig.SsmagentConfig{} var ws = NewWebsocketUtil(log, appConfig, dialerInput) conn, _ := ws.OpenConnection(u.String(), http.Header{}) assert.NotNil(t, conn, "Open connection failed.") err := ws.CloseConnection(conn) assert.Nil(t, err, "Error closing the websocket connection.") } func TestWebsocketUtilOpenConnectionInvalidUrl(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested)) u, _ := url.Parse(srv.URL) u.Scheme = "ws" var log = log.NewMockLog() appConfig := appconfig.SsmagentConfig{} var ws = NewWebsocketUtil(log, appConfig, nil) conn, _ := ws.OpenConnection("InvalidUrl", http.Header{}) assert.Nil(t, conn, "Open connection failed.") } func TestSendMessage(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested)) u, _ := url.Parse(srv.URL) u.Scheme = "ws" var log = log.NewMockLog() appConfig := appconfig.SsmagentConfig{} var ws = NewWebsocketUtil(log, appConfig, nil) conn, _ := ws.OpenConnection(u.String(), http.Header{}) assert.NotNil(t, conn, "Open connection failed.") err := conn.WriteMessage(websocket.TextMessage, []byte("testing testing")) assert.Nil(t, err) } func TestSendMessageWithWriteBuffer(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(handlerToBeTested)) u, _ := url.Parse(srv.URL) u.Scheme = "ws" var log = log.NewMockLog() appConfig := appconfig.SsmagentConfig{} var ws = NewWebsocketUtil(log, appConfig, dialerInput) conn, _ := ws.OpenConnection(u.String(), http.Header{}) assert.NotNil(t, conn, "Open connection failed.") err := conn.WriteMessage(websocket.TextMessage, []byte("testing testing")) assert.Nil(t, err) }