""" Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 """ import os import unittest from graph_notebook.configuration.generate_config import AuthModeEnum, Configuration from graph_notebook.configuration.get_config import get_config class TestGenerateConfigurationMain(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.generic_host = 'blah' cls.neptune_host_reg = 'instance.cluster.us-west-2.neptune.amazonaws.com' cls.neptune_host_cn = 'instance.cluster.neptune.cn-north-1.amazonaws.com.cn' cls.neptune_host_custom = 'localhost' cls.port = 8182 cls.test_file_path = f'{os.path.abspath(os.path.curdir)}/test_generate_from_main.json' cls.custom_hosts_list = ['localhost'] cls.python_cmd = os.environ.get('PYTHON_CMD', 'python3') # environment variable to let ToD hosts specify # where the python command is that is being used for testing. def tearDown(self) -> None: if os.path.exists(self.test_file_path): os.remove(self.test_file_path) def test_generate_configuration_main_defaults_neptune_reg(self): expected_config = Configuration(self.neptune_host_reg, self.port, auth_mode=AuthModeEnum.DEFAULT, load_from_s3_arn='', ssl=True) self.generate_config_from_main_and_test(expected_config, host_type='neptune') def test_generate_configuration_main_defaults_neptune_cn(self): expected_config = Configuration(self.neptune_host_cn, self.port, auth_mode=AuthModeEnum.DEFAULT, load_from_s3_arn='', ssl=True) self.generate_config_from_main_and_test(expected_config, host_type='neptune') def test_generate_configuration_main_defaults_generic(self): expected_config = Configuration(self.generic_host, self.port, ssl=True) self.generate_config_from_main_and_test(expected_config) def test_generate_configuration_main_override_defaults_neptune_reg(self): expected_config = Configuration(self.neptune_host_reg, self.port, auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn', ssl=False) self.generate_config_from_main_and_test(expected_config, host_type='neptune') def test_generate_configuration_main_override_defaults_neptune_no_verify(self): expected_config = Configuration(self.neptune_host_reg, self.port, auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn', ssl=True, ssl_verify=False) self.generate_config_from_main_and_test(expected_config, host_type='neptune') def test_generate_configuration_main_override_defaults_neptune_cn(self): expected_config = Configuration(self.neptune_host_cn, self.port, auth_mode=AuthModeEnum.IAM, load_from_s3_arn='loader_arn', ssl=False) self.generate_config_from_main_and_test(expected_config, host_type='neptune') def test_generate_configuration_main_override_defaults_generic(self): expected_config = Configuration(self.generic_host, self.port, ssl=False) self.generate_config_from_main_and_test(expected_config) def test_generate_configuration_main_empty_args_neptune(self): expected_config = Configuration(self.neptune_host_reg, self.port) result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' f'--host "{expected_config.host}" --port "{expected_config.port}" --auth_mode "" --ssl "" ' f'--load_from_s3_arn "" --config_destination="{self.test_file_path}" ') self.assertEqual(0, result) config = get_config(self.test_file_path) self.assertEqual(expected_config.to_dict(), config.to_dict()) def test_generate_configuration_main_empty_args_custom(self): expected_config = Configuration(self.neptune_host_custom, self.port, neptune_hosts=self.custom_hosts_list) result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' f'--host "{expected_config.host}" --port "{expected_config.port}" --auth_mode "" --ssl "" ' f'--load_from_s3_arn "" --config_destination="{self.test_file_path}" ' f'--neptune_hosts {self.custom_hosts_list[0]}') self.assertEqual(0, result) config = get_config(self.test_file_path, neptune_hosts=self.custom_hosts_list) self.assertEqual(expected_config.to_dict(), config.to_dict()) def test_generate_configuration_main_empty_args_generic(self): expected_config = Configuration(self.generic_host, self.port) result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' f'--host "{expected_config.host}" --port "{expected_config.port}" --ssl "" ' f'--config_destination="{self.test_file_path}" ') self.assertEqual(0, result) config = get_config(self.test_file_path) self.assertEqual(expected_config.to_dict(), config.to_dict()) def generate_config_from_main_and_test(self, source_config: Configuration, host_type=None): # This will run the main method that our install script runs on a Sagemaker notebook. # The return code should be 0, but more importantly, we need to assert that the # Configuration object we get from the resulting file is what we expect. if host_type == 'neptune': result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' f'--host "{source_config.host}" ' f'--port "{source_config.port}" ' f'--auth_mode "{source_config.auth_mode.value}" ' f'--ssl "{source_config.ssl}" ' f'--ssl-verify "{source_config.ssl_verify}" ' f'--load_from_s3_arn "{source_config.load_from_s3_arn}" ' f'--proxy_host "{source_config.proxy_host}" ' f'--proxy_port "{source_config.proxy_port}" ' f'--config_destination="{self.test_file_path}" ') else: result = os.system(f'{self.python_cmd} -m graph_notebook.configuration.generate_config ' f'--host "{source_config.host}" --port "{source_config.port}" ' f'--proxy_host "{source_config.proxy_host}" ' f'--proxy_port "{source_config.proxy_port}" ' f'--ssl "{source_config.ssl}" ' f'--ssl-verify "{source_config.ssl_verify}" ' f'--config_destination="{self.test_file_path}" ') self.assertEqual(result, 0) config = get_config(self.test_file_path) self.assertEqual(source_config.to_dict(), config.to_dict())