# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import pathlib import tempfile from dataclasses import dataclass from typing import Any, Callable, Generator, List, Type, Union from unittest.mock import patch import click import pytest from click.testing import CliRunner from smspark.cli import submit, submit_main from smspark.errors import InputError from smspark.job import ProcessingJobManager @pytest.fixture def tempdir() -> Generator[tempfile.TemporaryDirectory, None, None]: dir = tempfile.TemporaryDirectory() yield dir dir.cleanup() @pytest.fixture def empty_tempdir_path() -> Generator[str, None, None]: dir = tempfile.TemporaryDirectory() yield pathlib.Path(dir.name).resolve() dir.cleanup() @pytest.fixture def jars_dir(tempdir: tempfile.TemporaryDirectory) -> str: return tempdir.name @pytest.fixture def jar_file(jars_dir: str) -> Generator[str, None, None]: tmp = tempfile.NamedTemporaryFile(dir=jars_dir, prefix="1", suffix=".jar") yield pathlib.Path(tmp.name).resolve() tmp.close() @pytest.fixture def other_jar_file(jars_dir: str) -> Generator[str, None, None]: tmp = tempfile.NamedTemporaryFile(dir=jars_dir, prefix="2", suffix=".jar") yield pathlib.Path(tmp.name).resolve() tmp.close() @dataclass class SubmitTest: """Represents data for one submit test.""" name: str args: str expected_cmd: Union[str, Type[BaseException]] def get_test_cases() -> List[SubmitTest]: test_cases = [] for arg in ["--jars", "--files", "--py-files"]: files_test_cases = [ SubmitTest( name="single local jar should pass", args=arg + " {jar_file} app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client " + arg + " {jar_file} app.jar", ), SubmitTest( name="list of local jars should pass", args=arg + " {jar_file},{other_jar_file} app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client " + arg + " {jar_file},{other_jar_file} app.jar", ), SubmitTest( name="s3 url to jar should pass", args=arg + " s3://bucket/to/jar1.jar app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client " + arg + " s3://bucket/to/jar1.jar app.jar", ), SubmitTest( name="s3a url to jar should pass", args=arg + " s3a://bucket/to/jar1.jar app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client " + arg + " s3a://bucket/to/jar1.jar app.jar", ), SubmitTest( name="multiple s3 urls to jar should pass", args=arg + " s3://bucket/to/jar1.jar,s3://bucket/to/jar2.jar app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client " + arg + " s3://bucket/to/jar1.jar,s3://bucket/to/jar2.jar app.jar", ), SubmitTest( name="mixed s3 urls to jars and local paths should pass", args=arg + " s3://bucket/to/jar1.jar,{jar_file} app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client " + arg + " s3://bucket/to/jar1.jar,{jar_file} app.jar", ), SubmitTest( name="relative paths should fail", args=arg + " relative/path/to/jar.jar app.jar", expected_cmd=InputError, ), SubmitTest( name="nonexistent paths should fail", args=arg + " /path/to/nonexistent/file app.jar", expected_cmd=InputError, ), SubmitTest( name="directory with no files should fail", args=arg + " {empty_tempdir_path} app.jar", expected_cmd=InputError, ), ] test_cases = test_cases + files_test_cases test_cases = [ SubmitTest( name="missing APP arg should fail", args="", expected_cmd=click.exceptions.MissingParameter, ), SubmitTest( name="invalid spark options should fail", args="--invalid-spark-option opt arg.py", expected_cmd=click.exceptions.NoSuchOption, ), SubmitTest( name="happy path should pass", args="app.py", expected_cmd="spark-submit --master yarn --deploy-mode client app.py", ), SubmitTest( name="valid spark option should pass", args="--class com.app.Main app.jar", expected_cmd="spark-submit --master yarn --deploy-mode client --class com.app.Main app.jar", ), ] + test_cases # Quote tests: test_cases.append( SubmitTest( name="quotes are handled correctly", args="--jars {jar_file} myscript.py --query-bbox 'BBOX(geometry,0.1,-0.2,3.3,4.5)' --query-start-date '2020-05-01 00:00:00' --query-end-date '2020-05-31 23:59:59' --data-location-uri s3://123456789012-us-west-2/path --bucket-region us-west-2 --output-folder-path /opt/ml/processing/output/out --output-stage gamma --storage-bucket-uri s3://123456789012-us-west-2/", expected_cmd="spark-submit --master yarn --deploy-mode client --jars {jar_file} myscript.py --query-bbox 'BBOX(geometry,0.1,-0.2,3.3,4.5)' --query-start-date '2020-05-01 00:00:00' --query-end-date '2020-05-31 23:59:59' --data-location-uri s3://123456789012-us-west-2/path --bucket-region us-west-2 --output-folder-path /opt/ml/processing/output/out --output-stage gamma --storage-bucket-uri s3://123456789012-us-west-2/", ) ) return test_cases test_cases = get_test_cases() @patch("smspark.cli.ProcessingJobManager") @pytest.mark.parametrize("test_case", test_cases, ids=[submit_test.name for submit_test in test_cases]) def test_submit( patched_processing_job_manager: ProcessingJobManager, test_case: SubmitTest, jar_file: str, other_jar_file: str, empty_tempdir_path: str, ) -> None: runner = CliRunner() args = test_case.args.format( jar_file=jar_file, other_jar_file=other_jar_file, empty_tempdir_path=empty_tempdir_path ) result = runner.invoke(submit, args, standalone_mode=False) # happy if isinstance(test_case.expected_cmd, str): expected_cmd = test_case.expected_cmd.format(jar_file=jar_file, other_jar_file=other_jar_file) assert result.exception is None, result.output assert result.exit_code == 0 patched_processing_job_manager.assert_called_once() patched_processing_job_manager.return_value.run.assert_called_once_with(expected_cmd, None, None) # sad else: assert result.exit_code != 0, result.output assert isinstance(result.exception, test_case.expected_cmd)