# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import botocore.session from botocore.exceptions import ClientError from awscli.testutils import mock, unittest from awscli.compat import StringIO from awscli.customizations.configservice.subscribe import SubscribeCommand, \ S3BucketHelper, SNSTopicHelper class TestS3BucketHelper(unittest.TestCase): def setUp(self): self.session = botocore.session.get_session() self.s3_client = mock.Mock(self.session.create_client('s3')) self.helper = S3BucketHelper(self.s3_client) self.error_response = { 'Error': { 'Code': '404', 'Message': 'Not Found' } } self.bucket_no_exists_error = ClientError( self.error_response, 'HeadBucket' ) def test_correct_prefix_returned(self): name = 'MyBucket/MyPrefix' bucket, prefix = self.helper.prepare_bucket(name) # Ensure the returned bucket and key are as expected self.assertEqual(bucket, 'MyBucket') self.assertEqual(prefix, 'MyPrefix') def test_bucket_exists(self): name = 'MyBucket' bucket, prefix = self.helper.prepare_bucket(name) # A new bucket should not have been created because no error was thrown self.assertFalse(self.s3_client.create_bucket.called) # Ensure the returned bucket and key are as expected self.assertEqual(bucket, name) self.assertEqual(prefix, '') def test_bucket_no_exist(self): name = 'MyBucket/MyPrefix' self.s3_client.head_bucket.side_effect = self.bucket_no_exists_error self.s3_client.meta.region_name = 'us-east-1' bucket, prefix = self.helper.prepare_bucket(name) # Ensure that the create bucket was called with the proper args. self.s3_client.create_bucket.assert_called_with( Bucket='MyBucket' ) # Ensure the returned bucket and key are as expected self.assertEqual(bucket, 'MyBucket') self.assertEqual(prefix, 'MyPrefix') def test_bucket_no_exist_with_location_constraint(self): name = 'MyBucket/MyPrefix' self.s3_client.head_bucket.side_effect = self.bucket_no_exists_error self.s3_client.meta.region_name = 'us-west-2' bucket, prefix = self.helper.prepare_bucket(name) # Ensure that the create bucket was called with the proper args. self.s3_client.create_bucket.assert_called_with( Bucket='MyBucket', CreateBucketConfiguration={'LocationConstraint': 'us-west-2'} ) # Ensure the returned bucket and key are as expected self.assertEqual(bucket, 'MyBucket') self.assertEqual(prefix, 'MyPrefix') def test_bucket_client_exception_non_404(self): name = 'MyBucket/MyPrefix' self.error_response['Error']['Code'] = '403' self.error_response['Error']['Message'] = 'Forbidden' forbidden_error = ClientError(self.error_response, 'HeadBucket') self.s3_client.head_bucket.side_effect = forbidden_error self.s3_client._endpoint.region_name = 'us-east-1' bucket, prefix = self.helper.prepare_bucket(name) # A new bucket should not have been created because a 404 error # was not thrown self.assertFalse(self.s3_client.create_bucket.called) # Ensure the returned bucket and key are as expected self.assertEqual(bucket, 'MyBucket') self.assertEqual(prefix, 'MyPrefix') def test_output_use_existing_bucket(self): name = 'MyBucket/MyPrefix' with mock.patch('sys.stdout', StringIO()) as mock_stdout: self.helper.prepare_bucket(name) self.assertIn( 'Using existing S3 bucket: MyBucket', mock_stdout.getvalue()) def test_output_create_bucket(self): name = 'MyBucket/MyPrefix' self.s3_client.head_bucket.side_effect = self.bucket_no_exists_error self.s3_client._endpoint.region_name = 'us-east-1' with mock.patch('sys.stdout', StringIO()) as mock_stdout: self.helper.prepare_bucket(name) self.assertIn( 'Using new S3 bucket: MyBucket', mock_stdout.getvalue()) class TestSNSTopicHelper(unittest.TestCase): def setUp(self): self.session = botocore.session.get_session() self.sns_client = mock.Mock(self.session.create_client( 'sns', 'us-east-1')) self.helper = SNSTopicHelper(self.sns_client) def test_sns_topic_by_name(self): name = 'mysnstopic' self.sns_client.create_topic.return_value = {'TopicArn': 'myARN'} sns_arn = self.helper.prepare_topic(name) # Ensure that the topic was create and returned the expected arn self.assertTrue(self.sns_client.create_topic.called) self.assertEqual(sns_arn, 'myARN') def test_sns_topic_by_arn(self): name = 'arn:aws:sns:us-east-1:934212987125:config' sns_arn = self.helper.prepare_topic(name) # Ensure that the topic was not created and returned the expected arn self.assertFalse(self.sns_client.create_topic.called) self.assertEqual(sns_arn, name) def test_output_existing_topic(self): name = 'mysnstopic' self.sns_client.create_topic.return_value = {'TopicArn': 'myARN'} with mock.patch('sys.stdout', StringIO()) as mock_stdout: self.helper.prepare_topic(name) self.assertIn( 'Using new SNS topic: myARN', mock_stdout.getvalue()) def test_output_new_topic(self): name = 'arn:aws:sns:us-east-1:934212987125:config' with mock.patch('sys.stdout', StringIO()) as mock_stdout: self.helper.prepare_topic(name) self.assertIn( 'Using existing SNS topic: %s' % name, mock_stdout.getvalue()) class TestSubscribeCommand(unittest.TestCase): def setUp(self): self.session = botocore.session.get_session() # Set up the client mocks. self.s3_client = mock.Mock(self.session.create_client('s3')) self.sns_client = mock.Mock(self.session.create_client( 'sns', 'us-east-1')) self.config_client = mock.Mock(self.session.create_client( 'config', 'us-east-1')) self.config_client.describe_configuration_recorders.return_value = \ {'ConfigurationRecorders': []} self.config_client.describe_delivery_channels.return_value = \ {'DeliveryChannels': []} self.session = mock.Mock(self.session) self.session.create_client.side_effect = [ self.s3_client, self.sns_client, self.config_client ] self.parsed_args = mock.Mock() self.parsed_args.s3_bucket = 'MyBucket/MyPrefix' self.parsed_args.sns_topic = \ 'arn:aws:sns:us-east-1:934212987125:config' self.parsed_args.iam_role = 'arn:aws:iam::1234556789:role/config' self.parsed_globals = mock.Mock() self.cmd = SubscribeCommand(self.session) def test_setup_clients(self): self.parsed_globals.verify_ssl = True self.parsed_globals.region = 'us-east-1' self.parsed_globals.endpoint_url = 'http://myendpoint.com' self.cmd._run_main(self.parsed_args, self.parsed_globals) # Check to see that the clients were created correctly self.session.create_client.assert_any_call( 's3', verify=self.parsed_globals.verify_ssl, region_name=self.parsed_globals.region, ) self.session.create_client.assert_any_call( 'sns', verify=self.parsed_globals.verify_ssl, region_name=self.parsed_globals.region, ) self.session.create_client.assert_any_call( 'config', verify=self.parsed_globals.verify_ssl, region_name=self.parsed_globals.region, endpoint_url=self.parsed_globals.endpoint_url ) def test_subscribe(self): self.cmd._run_main(self.parsed_args, self.parsed_globals) # Check the call made to put configuration recorder. self.config_client.put_configuration_recorder.assert_called_with( ConfigurationRecorder={ 'name': 'default', 'roleARN': self.parsed_args.iam_role } ) # Check the call made to put delivery channel. self.config_client.put_delivery_channel.assert_called_with( DeliveryChannel={ 'name': 'default', 's3BucketName': 'MyBucket', 'snsTopicARN': self.parsed_args.sns_topic, 's3KeyPrefix': 'MyPrefix' } ) # Check the call made to start configuration recorder. self.config_client.start_configuration_recorder.assert_called_with( ConfigurationRecorderName='default' ) # Check that the describe delivery channel and configuration recorder # methods were called. self.assertTrue( self.config_client.describe_configuration_recorders.called ) self.assertTrue(self.config_client.describe_delivery_channels.called)