/* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ /* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ package org.opensearch.cluster.service; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.cluster.metadata.ProcessClusterEventTimeoutException; import org.opensearch.common.Priority; import org.opensearch.common.collect.Tuple; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor; import org.junit.Before; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Semaphore; import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasToString; public class TaskBatcherTests extends TaskExecutorTests { protected TestTaskBatcher taskBatcher; @Before public void setUpBatchingTaskExecutor() throws Exception { taskBatcher = new TestTaskBatcher(logger, threadExecutor); } static class TestTaskBatcher extends TaskBatcher { TestTaskBatcher(Logger logger, PrioritizedOpenSearchThreadPoolExecutor threadExecutor) { super(logger, threadExecutor, getMockListener()); } @Override protected void run(Object batchingKey, List<? extends BatchedTask> tasks, String tasksSummary) { List<UpdateTask> updateTasks = (List) tasks; ((TestExecutor) batchingKey).execute(updateTasks.stream().map(t -> t.task).collect(Collectors.toList())); updateTasks.forEach(updateTask -> updateTask.listener.processed(updateTask.source)); } @Override protected void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout) { threadPool.generic() .execute( () -> tasks.forEach( task -> ((UpdateTask) task).listener.onFailure( task.source, new ProcessClusterEventTimeoutException(timeout, task.source) ) ) ); } class UpdateTask extends BatchedTask { final TestListener listener; UpdateTask(Priority priority, String source, Object task, TestListener listener, TestExecutor<?> executor) { super(priority, source, executor, task); this.listener = listener; } @Override public String describeTasks(List<? extends BatchedTask> tasks) { return ((TestExecutor<Object>) batchingKey).describeTasks( tasks.stream().map(BatchedTask::getTask).collect(Collectors.toList()) ); } } } @Override protected void submitTask(String source, TestTask testTask) { submitTask(source, testTask, testTask, testTask, testTask); } private <T> void submitTask(String source, T task, ClusterStateTaskConfig config, TestExecutor<T> executor, TestListener listener) { submitTasks(source, Collections.singletonMap(task, listener), config, executor); } private <T> void submitTasks( final String source, final Map<T, TestListener> tasks, final ClusterStateTaskConfig config, final TestExecutor<T> executor ) { List<TestTaskBatcher.UpdateTask> safeTasks = tasks.entrySet() .stream() .map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), e.getValue(), executor)) .collect(Collectors.toList()); taskBatcher.submitTasks(safeTasks, config.timeout()); } @Override public void testTimedOutTaskCleanedUp() throws Exception { super.testTimedOutTaskCleanedUp(); synchronized (taskBatcher.tasksPerBatchingKey) { assertTrue("expected empty map but was " + taskBatcher.tasksPerBatchingKey, taskBatcher.tasksPerBatchingKey.isEmpty()); } } public void testOneExecutorDoesntStarveAnother() throws InterruptedException { final List<String> executionOrder = Collections.synchronizedList(new ArrayList<>()); final Semaphore allowProcessing = new Semaphore(0); final Semaphore startedProcessing = new Semaphore(0); class TaskExecutor implements TestExecutor<String> { @Override public void execute(List<String> tasks) { executionOrder.addAll(tasks); // do this first, so startedProcessing can be used as a notification that this is done. startedProcessing.release(tasks.size()); try { allowProcessing.acquire(tasks.size()); } catch (InterruptedException e) { throw new RuntimeException(e); } } } TaskExecutor executorA = new TaskExecutor(); TaskExecutor executorB = new TaskExecutor(); final ClusterStateTaskConfig config = ClusterStateTaskConfig.build(Priority.NORMAL); final TestListener noopListener = (source, e) -> { throw new AssertionError(e); }; // this blocks the cluster state queue, so we can set it up right submitTask("0", "A0", config, executorA, noopListener); // wait to be processed startedProcessing.acquire(1); assertThat(executionOrder, equalTo(Arrays.asList("A0"))); // these will be the first batch submitTask("1", "A1", config, executorA, noopListener); submitTask("2", "A2", config, executorA, noopListener); // release the first 0 task, but not the second allowProcessing.release(1); startedProcessing.acquire(2); assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2"))); // setup the queue with pending tasks for another executor same priority submitTask("3", "B3", config, executorB, noopListener); submitTask("4", "B4", config, executorB, noopListener); submitTask("5", "A5", config, executorA, noopListener); submitTask("6", "A6", config, executorA, noopListener); // now release the processing allowProcessing.release(6); // wait for last task to be processed startedProcessing.acquire(4); assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2", "B3", "B4", "A5", "A6"))); } static class TaskExecutor implements TestExecutor<Integer> { List<Integer> tasks = new ArrayList<>(); @Override public void execute(List<Integer> tasks) { this.tasks.addAll(tasks); } } // test that for a single thread, tasks are executed in the order // that they are submitted public void testTasksAreExecutedInOrder() throws BrokenBarrierException, InterruptedException { int numberOfThreads = randomIntBetween(2, 8); TaskExecutor[] executors = new TaskExecutor[numberOfThreads]; for (int i = 0; i < numberOfThreads; i++) { executors[i] = new TaskExecutor(); } int tasksSubmittedPerThread = randomIntBetween(2, 1024); CopyOnWriteArrayList<Tuple<String, Throwable>> failures = new CopyOnWriteArrayList<>(); CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread); final TestListener listener = new TestListener() { @Override public void onFailure(String source, Exception e) { logger.error(() -> new ParameterizedMessage("unexpected failure: [{}]", source), e); failures.add(new Tuple<>(source, e)); updateLatch.countDown(); } @Override public void processed(String source) { updateLatch.countDown(); } }; CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); for (int i = 0; i < numberOfThreads; i++) { final int index = i; Thread thread = new Thread(() -> { try { barrier.await(); for (int j = 0; j < tasksSubmittedPerThread; j++) { submitTask( "[" + index + "][" + j + "]", j, ClusterStateTaskConfig.build(randomFrom(Priority.values())), executors[index], listener ); } barrier.await(); } catch (InterruptedException | BrokenBarrierException e) { throw new AssertionError(e); } }); thread.start(); } // wait for all threads to be ready barrier.await(); // wait for all threads to finish barrier.await(); updateLatch.await(); assertThat(failures, empty()); for (int i = 0; i < numberOfThreads; i++) { assertEquals(tasksSubmittedPerThread, executors[i].tasks.size()); for (int j = 0; j < tasksSubmittedPerThread; j++) { assertNotNull(executors[i].tasks.get(j)); assertEquals("cluster state update task executed out of order", j, (int) executors[i].tasks.get(j)); } } } public void testNoTasksAreDroppedInParallelSubmission() throws BrokenBarrierException, InterruptedException { int numberOfThreads = randomIntBetween(2, 8); TaskExecutor[] executors = new TaskExecutor[numberOfThreads]; for (int i = 0; i < numberOfThreads; i++) { executors[i] = new TaskExecutor(); } int tasksSubmittedPerThread = randomIntBetween(2, 1024); CopyOnWriteArrayList<Tuple<String, Throwable>> failures = new CopyOnWriteArrayList<>(); CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread); final TestListener listener = new TestListener() { @Override public void onFailure(String source, Exception e) { logger.error(() -> new ParameterizedMessage("unexpected failure: [{}]", source), e); failures.add(new Tuple<>(source, e)); updateLatch.countDown(); } @Override public void processed(String source) { updateLatch.countDown(); } }; CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); for (int i = 0; i < numberOfThreads; i++) { final int index = i; Thread thread = new Thread(() -> { try { barrier.await(); CyclicBarrier tasksBarrier = new CyclicBarrier(1 + tasksSubmittedPerThread); for (int j = 0; j < tasksSubmittedPerThread; j++) { int taskNumber = j; Thread taskThread = new Thread(() -> { try { tasksBarrier.await(); submitTask( "[" + index + "][" + taskNumber + "]", taskNumber, ClusterStateTaskConfig.build(randomFrom(Priority.values())), executors[index], listener ); tasksBarrier.await(); } catch (InterruptedException | BrokenBarrierException e) { throw new AssertionError(e); } }); // submit tasks per batchingKey in parallel taskThread.start(); } // wait for all task threads to be ready tasksBarrier.await(); // wait for all task threads to finish tasksBarrier.await(); barrier.await(); } catch (InterruptedException | BrokenBarrierException e) { throw new AssertionError(e); } }); thread.start(); } // wait for all executor threads to be ready barrier.await(); // wait for all executor threads to finish barrier.await(); updateLatch.await(); assertThat(failures, empty()); for (int i = 0; i < numberOfThreads; i++) { // assert that total executed tasks is same for every executor as we initiated assertEquals(tasksSubmittedPerThread, executors[i].tasks.size()); } } public void testSingleBatchSubmission() throws InterruptedException { Map<Integer, TestListener> tasks = new HashMap<>(); final int numOfTasks = randomInt(10); final CountDownLatch latch = new CountDownLatch(numOfTasks); Set<Integer> usedKeys = new HashSet<>(numOfTasks); for (int i = 0; i < numOfTasks; i++) { int key = randomValueOtherThanMany(k -> usedKeys.contains(k), () -> randomInt(1024)); tasks.put(key, new TestListener() { @Override public void processed(String source) { latch.countDown(); } @Override public void onFailure(String source, Exception e) { throw new AssertionError(e); } }); usedKeys.add(key); } assert usedKeys.size() == numOfTasks; TestExecutor<Integer> executor = taskList -> { assertThat(taskList.size(), equalTo(tasks.size())); assertThat(taskList.stream().collect(Collectors.toSet()), equalTo(tasks.keySet())); }; submitTasks("test", tasks, ClusterStateTaskConfig.build(Priority.LANGUID), executor); latch.await(); } public void testDuplicateSubmission() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(2); try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) { submitTask("blocking", blockingTask); TestExecutor<SimpleTask> executor = tasks -> {}; SimpleTask task1 = new SimpleTask(1); TestListener listener = new TestListener() { @Override public void processed(String source) { latch.countDown(); } @Override public void onFailure(String source, Exception e) { throw new AssertionError(e); } }; submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); // submitting same task1 again, it should throw exception. final IllegalStateException e = expectThrows( IllegalStateException.class, () -> submitTask("second time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener) ); assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued"))); // inserting new task with same data, this should pass as it is new object and reference is different. SimpleTask task2 = new SimpleTask(1); // equals method returns true for both task assertTrue(task1.equals(task2)); // references of both tasks are different. assertFalse(task1 == task2); // submitting this task should be allowed, as it is new object. submitTask("third time a charm", task2, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); // submitting same task2 again, it should throw exception, since it was submitted last time final IllegalStateException e2 = expectThrows( IllegalStateException.class, () -> submitTask("second time", task2, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener) ); assertThat(e2, hasToString(containsString("task [1] with source [second time] is already queued"))); assertThat(latch.getCount(), equalTo(2L)); } latch.await(); } public void testDuplicateSubmissionAfterTimeout() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(2); final CountDownLatch timeOutLatch = new CountDownLatch(1); try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) { submitTask("blocking", blockingTask); TestExecutor<SimpleTask> executor = tasks -> {}; SimpleTask task1 = new SimpleTask(1); TestListener listener = new TestListener() { @Override public void processed(String source) { latch.countDown(); } @Override public void onFailure(String source, Exception e) { if (e instanceof ProcessClusterEventTimeoutException) { timeOutLatch.countDown(); } else { throw new AssertionError(e); } } }; submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); ArrayList<TaskBatcher.BatchedTask> tasks = new ArrayList(); tasks.add( taskBatcher.new UpdateTask( ClusterStateTaskConfig.build(Priority.NORMAL).priority(), "first time", task1, listener, executor ) ); // task1 got timed out, it will be removed from map. taskBatcher.onTimeoutInternal(tasks, TimeValue.ZERO); timeOutLatch.await(); // wait for task to get timeout // submitting same task1 again, it should get submitted, since last task was timeout. submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); assertThat(latch.getCount(), equalTo(2L)); } latch.await(); } public void testDuplicateSubmissionAfterExecution() throws InterruptedException { final CountDownLatch firstTaskLatch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(2); TestExecutor<SimpleTask> executor = tasks -> {}; SimpleTask task1 = new SimpleTask(1); TestListener listener = new TestListener() { @Override public void processed(String source) { firstTaskLatch.countDown(); latch.countDown(); } @Override public void onFailure(String source, Exception e) { if (e instanceof ProcessClusterEventTimeoutException) { latch.countDown(); } else { throw new AssertionError(e); } } }; submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); firstTaskLatch.await(); // wait till task is not executed // submitting same task1 again, it should get submitted, since last task was executed. submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); latch.await(); // wait till all tasks are not completed. } protected static TaskBatcherListener getMockListener() { return new TaskBatcherListener() { @Override public void onBeginSubmit(List<? extends TaskBatcher.BatchedTask> tasks) { // No Op } @Override public void onSubmitFailure(List<? extends TaskBatcher.BatchedTask> tasks) { // No Op } @Override public void onBeginProcessing(List<? extends TaskBatcher.BatchedTask> tasks) { // No Op } @Override public void onTimeout(List<? extends TaskBatcher.BatchedTask> tasks) { // No Op } }; } private static class SimpleTask { private final int id; private SimpleTask(int id) { this.id = id; } @Override public int hashCode() { return this.id; } @Override public boolean equals(Object obj) { return ((SimpleTask) obj).getId() == this.id; } public int getId() { return id; } @Override public String toString() { return Integer.toString(id); } } }