/* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ /* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ package org.opensearch.common.util.concurrent; import org.opensearch.common.lease.Releasable; import org.opensearch.test.OpenSearchTestCase; import org.hamcrest.Matchers; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; public class KeyedLockTests extends OpenSearchTestCase { public void testIfMapEmptyAfterLotsOfAcquireAndReleases() throws InterruptedException { ConcurrentHashMap<String, Integer> counter = new ConcurrentHashMap<>(); ConcurrentHashMap<String, AtomicInteger> safeCounter = new ConcurrentHashMap<>(); KeyedLock<String> connectionLock = new KeyedLock<>(randomBoolean()); String[] names = new String[randomIntBetween(1, 40)]; for (int i = 0; i < names.length; i++) { names[i] = randomRealisticUnicodeOfLengthBetween(10, 20); } int numThreads = randomIntBetween(3, 10); final CountDownLatch startLatch = new CountDownLatch(1 + numThreads); AcquireAndReleaseThread[] threads = new AcquireAndReleaseThread[numThreads]; for (int i = 0; i < numThreads; i++) { threads[i] = new AcquireAndReleaseThread(startLatch, connectionLock, names, counter, safeCounter); } for (int i = 0; i < numThreads; i++) { threads[i].start(); } startLatch.countDown(); for (int i = 0; i < numThreads; i++) { threads[i].join(); } assertThat(connectionLock.hasLockedKeys(), equalTo(false)); Set<Entry<String, Integer>> entrySet = counter.entrySet(); assertThat(counter.size(), equalTo(safeCounter.size())); for (Entry<String, Integer> entry : entrySet) { AtomicInteger atomicInteger = safeCounter.get(entry.getKey()); assertThat(atomicInteger, not(Matchers.nullValue())); assertThat(atomicInteger.get(), equalTo(entry.getValue())); } } public void testHasLockedKeys() { KeyedLock<String> lock = new KeyedLock<>(); assertFalse(lock.hasLockedKeys()); Releasable foo = lock.acquire("foo"); assertTrue(lock.hasLockedKeys()); foo.close(); assertFalse(lock.hasLockedKeys()); } public void testTryAcquire() throws InterruptedException { KeyedLock<String> lock = new KeyedLock<>(); Releasable foo = lock.tryAcquire("foo"); Releasable second = lock.tryAcquire("foo"); assertTrue(lock.hasLockedKeys()); foo.close(); assertTrue(lock.hasLockedKeys()); second.close(); assertFalse(lock.hasLockedKeys()); // lock again Releasable acquire = lock.tryAcquire("foo"); assertNotNull(acquire); final AtomicBoolean check = new AtomicBoolean(false); CountDownLatch latch = new CountDownLatch(1); Thread thread = new Thread(() -> { latch.countDown(); try (Releasable ignore = lock.acquire("foo")) { assertTrue(check.get()); } }); thread.start(); latch.await(); check.set(true); acquire.close(); foo.close(); thread.join(); } public void testLockIsReentrant() throws InterruptedException { KeyedLock<String> lock = new KeyedLock<>(); Releasable foo = lock.acquire("foo"); assertTrue(lock.isHeldByCurrentThread("foo")); assertFalse(lock.isHeldByCurrentThread("bar")); Releasable foo2 = lock.acquire("foo"); AtomicInteger test = new AtomicInteger(0); CountDownLatch latch = new CountDownLatch(1); Thread t = new Thread(() -> { latch.countDown(); try (Releasable r = lock.acquire("foo")) { test.incrementAndGet(); } }); t.start(); latch.await(); Thread.yield(); assertEquals(0, test.get()); List<Releasable> list = Arrays.asList(foo, foo2); Collections.shuffle(list, random()); list.get(0).close(); Thread.yield(); assertEquals(0, test.get()); list.get(1).close(); t.join(); assertEquals(1, test.get()); assertFalse(lock.hasLockedKeys()); } public static class AcquireAndReleaseThread extends Thread { private CountDownLatch startLatch; KeyedLock<String> connectionLock; String[] names; ConcurrentHashMap<String, Integer> counter; ConcurrentHashMap<String, AtomicInteger> safeCounter; final int numRuns = scaledRandomIntBetween(5000, 50000); public AcquireAndReleaseThread( CountDownLatch startLatch, KeyedLock<String> connectionLock, String[] names, ConcurrentHashMap<String, Integer> counter, ConcurrentHashMap<String, AtomicInteger> safeCounter ) { this.startLatch = startLatch; this.connectionLock = connectionLock; this.names = names; this.counter = counter; this.safeCounter = safeCounter; } @Override public void run() { startLatch.countDown(); try { startLatch.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } for (int i = 0; i < numRuns; i++) { String curName = names[randomInt(names.length - 1)]; assert connectionLock.isHeldByCurrentThread(curName) == false; Releasable lock; if (randomIntBetween(0, 10) < 4) { int tries = 0; boolean stepOut = false; while ((lock = connectionLock.tryAcquire(curName)) == null) { assertFalse(connectionLock.isHeldByCurrentThread(curName)); if (tries++ == 10) { stepOut = true; break; } } if (stepOut) { break; } } else { lock = connectionLock.acquire(curName); } try (Releasable ignore = lock) { assert connectionLock.isHeldByCurrentThread(curName); assert connectionLock.isHeldByCurrentThread(curName + "bla") == false; if (randomBoolean()) { try (Releasable reentrantIgnored = connectionLock.acquire(curName)) { // just acquire this and make sure we can :) Thread.yield(); } } Integer integer = counter.get(curName); if (integer == null) { counter.put(curName, 1); } else { counter.put(curName, integer.intValue() + 1); } } AtomicInteger atomicInteger = new AtomicInteger(0); AtomicInteger value = safeCounter.putIfAbsent(curName, atomicInteger); if (value == null) { atomicInteger.incrementAndGet(); } else { value.incrementAndGet(); } } } } }