################################################################################# # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # # # Licensed under the Apache License, Version 2.0 (the "License"). # # You may not use this file except in compliance with the License. # # You may obtain a copy of the License at # # # # http://www.apache.org/licenses/LICENSE-2.0 # # # # Unless required by applicable law or agreed to in writing, software # # distributed under the License is distributed on an "AS IS" BASIS, # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # See the License for the specific language governing permissions and # # limitations under the License. # ################################################################################# from unittest import mock, TestCase from unittest.mock import patch, MagicMock from ude_gym_bridge.gym_environment_adapter import GymEnvironmentAdapter from gym.spaces.space import Space @mock.patch("gym.make") class GymEnvironmentAdapterTest(TestCase): def setUp(self) -> None: pass def test_initialization(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value with patch("ude_gym_bridge.gym_environment_adapter.SingleSideChannel") as side_channel_mock: gym_env_adapter = GymEnvironmentAdapter(env_name, agent_name="agent") side_channel_mock.assert_called_once() gym_env_mock_obj.reset.assert_called_once() assert gym_env_adapter._agent_name == "agent" assert gym_env_adapter.env == gym_env_mock_obj def test_initialization_without_agent_name(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value gym_env_adapter = GymEnvironmentAdapter(env_name) gym_env_mock_obj.reset.assert_called_once() assert gym_env_adapter._agent_name == "agent0" assert gym_env_adapter.env == gym_env_mock_obj def test_setters(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value gym_env_adapter = GymEnvironmentAdapter(env_name) gym_env_mock_obj.reset.assert_called_once() assert gym_env_adapter._agent_name == "agent0" assert gym_env_adapter.env == gym_env_mock_obj new_gym_name = "new_test_env" new_env_mock = MagicMock() gym_make_mock.return_value = new_env_mock gym_env_adapter.env_name = new_gym_name assert gym_env_adapter.env_name != new_gym_name assert gym_env_adapter.env != new_env_mock gym_env_adapter.reset() gym_env_mock_obj.close.assert_called_once() new_env_mock.reset.assert_called_once() assert gym_env_adapter.env_name == new_gym_name assert gym_env_adapter.env == new_env_mock def test_step(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value agent_name = "agent0" agent_action = 1 action_dict = {agent_name: agent_action} next_state = "next_state" done = False reward = 42 info = {} gym_env_step_return = (next_state, reward, done, info) gym_env_mock_obj.step.return_value = gym_env_step_return gym_env_adapter = GymEnvironmentAdapter(env_name, agent_name=agent_name) ret_step_val = gym_env_adapter.step(action_dict=action_dict) expected_return = ( {agent_name: next_state}, {agent_name: reward}, {agent_name: done}, {agent_name: agent_action}, info ) assert ret_step_val == expected_return gym_env_mock_obj.step.assert_called_once_with(agent_action) def test_reset(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value agent_name = "agent0" next_state = "next_state" gym_env_mock_obj.reset.return_value = next_state gym_env_adapter = GymEnvironmentAdapter(env_name, agent_name=agent_name) ret_reset_val = gym_env_adapter.reset() expected_return = ( {agent_name: next_state}, {} ) assert ret_reset_val == expected_return assert gym_env_mock_obj.reset.call_count == 2 def test_close(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value gym_env_adapter = GymEnvironmentAdapter(env_name) gym_env_adapter.close() gym_env_mock_obj.close.assert_called_once() def test_observation_space(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value agent_name = "agent0" observation_space = Space([3, 4]) gym_env_mock_obj.observation_space = observation_space gym_env_adapter = GymEnvironmentAdapter(env_name, agent_name=agent_name) ret_observation_space_val = gym_env_adapter.observation_space expected_return = ( {agent_name: observation_space} ) assert ret_observation_space_val == expected_return def test_action_space(self, gym_make_mock): env_name = "test_env" gym_env_mock_obj = gym_make_mock.return_value agent_name = "agent0" action_space = Space([3, 4]) gym_env_mock_obj.action_space = action_space gym_env_adapter = GymEnvironmentAdapter(env_name, agent_name=agent_name) ret_action_space_val = gym_env_adapter.action_space expected_return = ( {agent_name: action_space} ) assert ret_action_space_val == expected_return def test_side_channel(self, gym_make_mock): env_name = "test_env" with patch("ude_gym_bridge.gym_environment_adapter.SingleSideChannel") as side_channel_mock: gym_env_adapter = GymEnvironmentAdapter(env_name) side_channel_mock.assert_called_once() assert gym_env_adapter.side_channel == side_channel_mock.return_value def test_on_received(self, gym_make_mock): env_name = "CartPole-v0" gym_env_mock_obj = gym_make_mock.return_value gym_env_adapter = GymEnvironmentAdapter(env_name) assert gym_env_adapter.env == gym_env_mock_obj new_env_mock = MagicMock() gym_make_mock.return_value = new_env_mock assert gym_env_adapter.env != new_env_mock new_env_name = "CartPole-v1" gym_env_adapter.on_received(side_channel=gym_env_adapter.side_channel, key="env", value=new_env_name) assert gym_env_adapter.env == gym_env_mock_obj assert gym_env_adapter.env != new_env_mock assert gym_env_adapter.env_name == env_name gym_env_adapter.reset() gym_env_mock_obj.close.assert_called_once() assert gym_env_adapter.env == new_env_mock assert gym_env_adapter.env_name == new_env_name new_env_mock.reset_assert_called_once() def test_on_received_bad_env_name(self, gym_make_mock): env_name = "CartPole-v0" gym_env_mock_obj = gym_make_mock.return_value gym_env_adapter = GymEnvironmentAdapter(env_name) assert gym_env_adapter.env == gym_env_mock_obj new_env_mock = MagicMock() gym_make_mock.return_value = new_env_mock assert gym_env_adapter.env != new_env_mock new_env_name = "bad-test-env" gym_env_adapter.on_received(side_channel=gym_env_adapter.side_channel, key="env", value=new_env_name) assert gym_env_adapter.env == gym_env_mock_obj assert gym_env_adapter.env != new_env_mock assert gym_env_adapter.env_name == env_name gym_env_adapter.reset() gym_env_mock_obj.close.assert_not_called() assert gym_env_adapter.env == gym_env_mock_obj assert gym_env_adapter.env_name == env_name