# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the # License. A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and # limitations under the License. import time from datetime import datetime, timezone from assertpy import assert_that, soft_assertions from slurm_plugin.common import TaskController from slurm_plugin.task_executor import TaskExecutor def test_task_executor(): def get_task(value): def task(): return value + 1 return task task_executor = TaskExecutor(worker_pool_size=3, max_backlog=10) futures = {value: task_executor.queue_task(get_task(value)) for value in range(10, 20)} with soft_assertions(): for value, future in futures.items(): assert_that(future.result()).is_equal_to(value + 1) task_executor.shutdown() def test_exceeding_max_backlog(): def get_task(value): def task(): time.sleep(value) return value + 1 return task task_executor = TaskExecutor(worker_pool_size=1, max_backlog=1) future = task_executor.queue_task(get_task(10)) assert_that(task_executor.queue_task).raises(TaskExecutor.MaximumBacklogExceededError).when_called_with( get_task(20) ) assert_that(future.result()).is_equal_to(11) task_executor.shutdown() def test_that_shutdown_does_not_block(): def get_task(value): def task(): task_executor.wait_unless_shutdown(value) return value + 1 return task def callback(*args): nonlocal callback_called callback_called = True task_executor = TaskExecutor(worker_pool_size=1, max_backlog=1) callback_called = False start_wait = datetime.now(tz=timezone.utc) future = task_executor.queue_task(get_task(600)) future.add_done_callback(callback) task_executor.shutdown(wait=True) delta = (datetime.now(tz=timezone.utc) - start_wait).total_seconds() assert_that(delta).is_less_than(300) assert_that(future.exception).raises(TaskController.TaskShutdownError) assert_that(callback_called).is_true()