/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*
 * Modifications Copyright OpenSearch Contributors. See
 * GitHub history for details.
 */

package org.opensearch.common.util.concurrent;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.logging.HeaderWarning;
import org.opensearch.common.settings.Settings;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.sameInstance;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

public class ThreadContextTests extends OpenSearchTestCase {

    public void testStashContext() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("foo", "bar");
        threadContext.putTransient("ctx.foo", 1);
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
            assertNull(threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertEquals("1", threadContext.getHeader("default"));
        }

        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
    }

    public void testStashContextWithPersistentHeaders() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("foo", "bar");
        threadContext.putTransient("ctx.foo", 1);
        threadContext.putPersistent("persistent_foo", "baz");
        threadContext.putPersistent("ctx.persistent_foo", 10);
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
            assertNull(threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertEquals("1", threadContext.getHeader("default"));

            assertEquals("baz", threadContext.getPersistent("persistent_foo"));
            assertEquals(Integer.valueOf(10), threadContext.getPersistent("ctx.persistent_foo"));
            assertNull(threadContext.getPersistent("default"));
        }

        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));

        assertEquals("baz", threadContext.getPersistent("persistent_foo"));
        assertEquals(Integer.valueOf(10), threadContext.getPersistent("ctx.persistent_foo"));
        assertNull(threadContext.getPersistent("default"));
    }

    public void testNewContextWithClearedTransients() {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        threadContext.putTransient("foo", "bar");
        threadContext.putTransient("bar", "baz");
        threadContext.putHeader("foo", "bar");
        threadContext.putHeader("baz", "bar");
        threadContext.addResponseHeader("foo", "bar");
        threadContext.addResponseHeader("bar", "qux");

        // this is missing or null
        if (randomBoolean()) {
            threadContext.putTransient("acme", null);
        }

        // foo is the only existing transient header that is cleared
        try (
            ThreadContext.StoredContext stashed = threadContext.newStoredContext(
                false,
                randomFrom(Arrays.asList("foo", "foo"), Arrays.asList("foo"), Arrays.asList("foo", "acme"))
            )
        ) {
            // only the requested transient header is cleared
            assertNull(threadContext.getTransient("foo"));
            // missing header is still missing
            assertNull(threadContext.getTransient("acme"));
            // other headers are preserved
            assertEquals("baz", threadContext.getTransient("bar"));
            assertEquals("bar", threadContext.getHeader("foo"));
            assertEquals("bar", threadContext.getHeader("baz"));
            assertEquals("bar", threadContext.getResponseHeaders().get("foo").get(0));
            assertEquals("qux", threadContext.getResponseHeaders().get("bar").get(0));

            // try override stashed header
            threadContext.putTransient("foo", "acme");
            assertEquals("acme", threadContext.getTransient("foo"));
            // add new headers
            threadContext.putTransient("baz", "bar");
            threadContext.putHeader("bar", "baz");
            threadContext.addResponseHeader("baz", "bar");
            threadContext.addResponseHeader("foo", "baz");
        }

        // original is restored (it is not overridden)
        assertEquals("bar", threadContext.getTransient("foo"));
        // headers added inside the stash are NOT preserved
        assertNull(threadContext.getTransient("baz"));
        assertNull(threadContext.getHeader("bar"));
        assertNull(threadContext.getResponseHeaders().get("baz"));
        // original headers are restored
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals("bar", threadContext.getHeader("baz"));
        assertEquals("bar", threadContext.getResponseHeaders().get("foo").get(0));
        assertEquals(1, threadContext.getResponseHeaders().get("foo").size());
        assertEquals("qux", threadContext.getResponseHeaders().get("bar").get(0));

        // test stashed missing header stays missing
        try (
            ThreadContext.StoredContext stashed = threadContext.newStoredContext(
                randomBoolean(),
                randomFrom(Arrays.asList("acme", "acme"), Arrays.asList("acme"))
            )
        ) {
            assertNull(threadContext.getTransient("acme"));
            threadContext.putTransient("acme", "foo");
        }
        assertNull(threadContext.getTransient("acme"));

        // test preserved response headers
        try (
            ThreadContext.StoredContext stashed = threadContext.newStoredContext(
                true,
                randomFrom(Arrays.asList("foo", "foo"), Arrays.asList("foo"), Arrays.asList("foo", "acme"))
            )
        ) {
            threadContext.addResponseHeader("baz", "bar");
            threadContext.addResponseHeader("foo", "baz");
        }
        assertEquals("bar", threadContext.getResponseHeaders().get("foo").get(0));
        assertEquals("baz", threadContext.getResponseHeaders().get("foo").get(1));
        assertEquals(2, threadContext.getResponseHeaders().get("foo").size());
        assertEquals("bar", threadContext.getResponseHeaders().get("baz").get(0));
        assertEquals(1, threadContext.getResponseHeaders().get("baz").size());
    }

    public void testStashContextWithPreservedTransients() {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        threadContext.putTransient("foo", "bar");
        threadContext.putTransient(TASK_ID, 1);
        threadContext.stashContext();
        assertNull(threadContext.getTransient("foo"));
        assertEquals(1, (int) threadContext.getTransient(TASK_ID));
    }

    public void testStashWithOrigin() {
        final String origin = randomAlphaOfLengthBetween(4, 16);
        final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);

        final boolean setOtherValues = randomBoolean();
        if (setOtherValues) {
            threadContext.putTransient("foo", "bar");
            threadContext.putHeader("foo", "bar");
        }

        assertNull(threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
        try (ThreadContext.StoredContext storedContext = threadContext.stashWithOrigin(origin)) {
            assertEquals(origin, threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
            assertNull(threadContext.getTransient("foo"));
            assertNull(threadContext.getTransient("bar"));
        }

        assertNull(threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));

        if (setOtherValues) {
            assertEquals("bar", threadContext.getTransient("foo"));
            assertEquals("bar", threadContext.getHeader("foo"));
        }
    }

    public void testStashAndMerge() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("foo", "bar");
        threadContext.putTransient("ctx.foo", 1);
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        HashMap<String, String> toMerge = new HashMap<>();
        toMerge.put("foo", "baz");
        toMerge.put("simon", "says");
        try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) {
            assertEquals("bar", threadContext.getHeader("foo"));
            assertEquals("says", threadContext.getHeader("simon"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertEquals("1", threadContext.getHeader("default"));
        }

        assertNull(threadContext.getHeader("simon"));
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
    }

    public void testStoreContext() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("foo", "bar");
        threadContext.putTransient("ctx.foo", 1);
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        ThreadContext.StoredContext storedContext = threadContext.newStoredContext(false);
        threadContext.putHeader("foo.bar", "baz");
        try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
            assertNull(threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertEquals("1", threadContext.getHeader("default"));
        }

        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        assertEquals("baz", threadContext.getHeader("foo.bar"));
        if (randomBoolean()) {
            storedContext.restore();
        } else {
            storedContext.close();
        }
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        assertNull(threadContext.getHeader("foo.bar"));
    }

    public void testRestorableContext() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("foo", "bar");
        threadContext.putTransient("ctx.foo", 1);
        threadContext.addResponseHeader("resp.header", "baaaam");
        Supplier<ThreadContext.StoredContext> contextSupplier = threadContext.newRestorableContext(true);

        try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
            assertNull(threadContext.getHeader("foo"));
            assertEquals("1", threadContext.getHeader("default"));
            threadContext.addResponseHeader("resp.header", "boom");
            try (ThreadContext.StoredContext tmp = contextSupplier.get()) {
                assertEquals("bar", threadContext.getHeader("foo"));
                assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
                assertEquals("1", threadContext.getHeader("default"));
                assertEquals(2, threadContext.getResponseHeaders().get("resp.header").size());
                assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
                assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(1));
            }
            assertNull(threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
            assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
        }
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
        assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));

        contextSupplier = threadContext.newRestorableContext(false);

        try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
            assertNull(threadContext.getHeader("foo"));
            assertEquals("1", threadContext.getHeader("default"));
            threadContext.addResponseHeader("resp.header", "boom");
            try (ThreadContext.StoredContext tmp = contextSupplier.get()) {
                assertEquals("bar", threadContext.getHeader("foo"));
                assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
                assertEquals("1", threadContext.getHeader("default"));
                assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
                assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
            }
            assertNull(threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
            assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0));
        }

        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
        assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
        assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
    }

    public void testResponseHeaders() {
        final boolean expectThird = randomBoolean();

        final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);

        threadContext.addResponseHeader("foo", "bar");
        // pretend that another thread created the same response
        if (randomBoolean()) {
            threadContext.addResponseHeader("foo", "bar");
        }

        final String value = HeaderWarning.formatWarning("qux");
        threadContext.addResponseHeader("baz", value, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false));
        // pretend that another thread created the same response at a different time
        if (randomBoolean()) {
            final String duplicateValue = HeaderWarning.formatWarning("qux");
            threadContext.addResponseHeader("baz", duplicateValue, s -> HeaderWarning.extractWarningValueFromWarningHeader(s, false));
        }

        threadContext.addResponseHeader("Warning", "One is the loneliest number");
        threadContext.addResponseHeader("Warning", "Two can be as bad as one");
        if (expectThird) {
            threadContext.addResponseHeader("Warning", "No is the saddest experience");
        }

        final Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders();
        final List<String> foo = responseHeaders.get("foo");
        final List<String> baz = responseHeaders.get("baz");
        final List<String> warnings = responseHeaders.get("Warning");
        final int expectedWarnings = expectThird ? 3 : 2;

        assertThat(foo, hasSize(1));
        assertThat(baz, hasSize(1));
        assertEquals("bar", foo.get(0));
        assertEquals(value, baz.get(0));
        assertThat(warnings, hasSize(expectedWarnings));
        assertThat(warnings, hasItem(equalTo("One is the loneliest number")));
        assertThat(warnings, hasItem(equalTo("Two can be as bad as one")));

        if (expectThird) {
            assertThat(warnings, hasItem(equalTo("No is the saddest experience")));
        }
    }

    public void testCopyHeaders() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.copyHeaders(Collections.<String, String>emptyMap().entrySet());
        threadContext.copyHeaders(Collections.<String, String>singletonMap("foo", "bar").entrySet());
        assertEquals("bar", threadContext.getHeader("foo"));
    }

    public void testSerialize() throws IOException {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("foo", "bar");
        threadContext.putTransient("ctx.foo", 1);
        threadContext.addResponseHeader("Warning", "123456");
        if (rarely()) {
            threadContext.addResponseHeader("Warning", "123456");
        }
        threadContext.addResponseHeader("Warning", "234567");

        BytesStreamOutput out = new BytesStreamOutput();
        threadContext.writeTo(out);
        try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
            assertNull(threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));
            assertTrue(threadContext.getResponseHeaders().isEmpty());
            assertEquals("1", threadContext.getHeader("default"));

            threadContext.readHeaders(out.bytes().streamInput());
            assertEquals("bar", threadContext.getHeader("foo"));
            assertNull(threadContext.getTransient("ctx.foo"));

            final Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders();
            final List<String> warnings = responseHeaders.get("Warning");

            assertThat(responseHeaders.keySet(), hasSize(1));
            assertThat(warnings, hasSize(2));
            assertThat(warnings, hasItem(equalTo("123456")));
            assertThat(warnings, hasItem(equalTo("234567")));
        }
        assertEquals("bar", threadContext.getHeader("foo"));
        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
        assertEquals("1", threadContext.getHeader("default"));
    }

    public void testSerializeInDifferentContext() throws IOException {
        BytesStreamOutput out = new BytesStreamOutput();
        {
            Settings build = Settings.builder().put("request.headers.default", "1").build();
            ThreadContext threadContext = new ThreadContext(build);
            threadContext.putHeader("foo", "bar");
            threadContext.putTransient("ctx.foo", 1);
            threadContext.addResponseHeader("Warning", "123456");
            if (rarely()) {
                threadContext.addResponseHeader("Warning", "123456");
            }
            threadContext.addResponseHeader("Warning", "234567");

            assertEquals("bar", threadContext.getHeader("foo"));
            assertNotNull(threadContext.getTransient("ctx.foo"));
            assertEquals("1", threadContext.getHeader("default"));
            assertThat(threadContext.getResponseHeaders().keySet(), hasSize(1));
            threadContext.writeTo(out);
        }
        {
            Settings otherSettings = Settings.builder().put("request.headers.default", "5").build();
            ThreadContext otherThreadContext = new ThreadContext(otherSettings);
            otherThreadContext.readHeaders(out.bytes().streamInput());

            assertEquals("bar", otherThreadContext.getHeader("foo"));
            assertNull(otherThreadContext.getTransient("ctx.foo"));
            assertEquals("1", otherThreadContext.getHeader("default"));

            final Map<String, List<String>> responseHeaders = otherThreadContext.getResponseHeaders();
            final List<String> warnings = responseHeaders.get("Warning");

            assertThat(responseHeaders.keySet(), hasSize(1));
            assertThat(warnings, hasSize(2));
            assertThat(warnings, hasItem(equalTo("123456")));
            assertThat(warnings, hasItem(equalTo("234567")));
        }
    }

    public void testSerializeInDifferentContextNoDefaults() throws IOException {
        BytesStreamOutput out = new BytesStreamOutput();
        {
            ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
            threadContext.putHeader("foo", "bar");
            threadContext.putTransient("ctx.foo", 1);

            assertEquals("bar", threadContext.getHeader("foo"));
            assertNotNull(threadContext.getTransient("ctx.foo"));
            assertNull(threadContext.getHeader("default"));
            threadContext.writeTo(out);
        }
        {
            Settings otherSettings = Settings.builder().put("request.headers.default", "5").build();
            ThreadContext otherhreadContext = new ThreadContext(otherSettings);
            otherhreadContext.readHeaders(out.bytes().streamInput());

            assertEquals("bar", otherhreadContext.getHeader("foo"));
            assertNull(otherhreadContext.getTransient("ctx.foo"));
            assertEquals("5", otherhreadContext.getHeader("default"));
        }
    }

    public void testCanResetDefault() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader("default", "2");
        assertEquals("2", threadContext.getHeader("default"));
    }

    public void testStashAndMergeWithModifiedDefaults() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        HashMap<String, String> toMerge = new HashMap<>();
        toMerge.put("default", "2");
        try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) {
            assertEquals("2", threadContext.getHeader("default"));
        }

        build = Settings.builder().put("request.headers.default", "1").build();
        threadContext = new ThreadContext(build);
        threadContext.putHeader("default", "4");
        toMerge = new HashMap<>();
        toMerge.put("default", "2");
        try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) {
            assertEquals("4", threadContext.getHeader("default"));
        }
    }

    public void testPreserveContext() throws IOException {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        Runnable withContext;

        // Create a runnable that should run with some header
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "bar");
            withContext = threadContext.preserveContext(
                sometimesAbstractRunnable(() -> { assertEquals("bar", threadContext.getHeader("foo")); })
            );
        }

        // We don't see the header outside of the runnable
        assertNull(threadContext.getHeader("foo"));

        // But we do inside of it
        withContext.run();

        // but not after
        assertNull(threadContext.getHeader("foo"));
    }

    public void testPreserveContextKeepsOriginalContextWhenCalledTwice() throws IOException {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        Runnable originalWithContext;
        Runnable withContext;

        // Create a runnable that should run with some header
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "bar");
            withContext = threadContext.preserveContext(
                sometimesAbstractRunnable(() -> { assertEquals("bar", threadContext.getHeader("foo")); })
            );
        }

        // Now attempt to rewrap it
        originalWithContext = withContext;
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "zot");
            withContext = threadContext.preserveContext(withContext);
        }

        // We get the original context inside the runnable
        withContext.run();

        // In fact the second wrapping didn't even change it
        assertThat(withContext, sameInstance(originalWithContext));
    }

    public void testPreservesThreadsOriginalContextOnRunException() throws IOException {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        Runnable withContext;

        // create a abstract runnable, add headers and transient objects and verify in the methods
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "bar");
            boolean systemContext = randomBoolean();
            if (systemContext) {
                threadContext.markAsSystemContext();
            }
            threadContext.putTransient("foo", "bar_transient");
            withContext = threadContext.preserveContext(new AbstractRunnable() {

                @Override
                public void onAfter() {
                    assertEquals(systemContext, threadContext.isSystemContext());
                    assertEquals("bar", threadContext.getHeader("foo"));
                    assertEquals("bar_transient", threadContext.getTransient("foo"));
                    assertNotNull(threadContext.getTransient("failure"));
                    assertEquals("exception from doRun", ((RuntimeException) threadContext.getTransient("failure")).getMessage());
                    assertFalse(threadContext.isDefaultContext());
                    threadContext.putTransient("after", "after");
                }

                @Override
                public void onFailure(Exception e) {
                    assertEquals(systemContext, threadContext.isSystemContext());
                    assertEquals("exception from doRun", e.getMessage());
                    assertEquals("bar", threadContext.getHeader("foo"));
                    assertEquals("bar_transient", threadContext.getTransient("foo"));
                    assertFalse(threadContext.isDefaultContext());
                    threadContext.putTransient("failure", e);
                }

                @Override
                protected void doRun() throws Exception {
                    assertEquals(systemContext, threadContext.isSystemContext());
                    assertEquals("bar", threadContext.getHeader("foo"));
                    assertEquals("bar_transient", threadContext.getTransient("foo"));
                    assertFalse(threadContext.isDefaultContext());
                    throw new RuntimeException("exception from doRun");
                }
            });
        }

        // We don't see the header outside of the runnable
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertNull(threadContext.getTransient("failure"));
        assertNull(threadContext.getTransient("after"));
        assertTrue(threadContext.isDefaultContext());

        // But we do inside of it
        withContext.run();

        // verify not seen after
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertNull(threadContext.getTransient("failure"));
        assertNull(threadContext.getTransient("after"));
        assertTrue(threadContext.isDefaultContext());

        // repeat with regular runnable
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "bar");
            threadContext.putTransient("foo", "bar_transient");
            withContext = threadContext.preserveContext(() -> {
                assertEquals("bar", threadContext.getHeader("foo"));
                assertEquals("bar_transient", threadContext.getTransient("foo"));
                assertFalse(threadContext.isDefaultContext());
                threadContext.putTransient("run", true);
                throw new RuntimeException("exception from run");
            });
        }

        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertNull(threadContext.getTransient("run"));
        assertTrue(threadContext.isDefaultContext());

        final Runnable runnable = withContext;
        RuntimeException e = expectThrows(RuntimeException.class, runnable::run);
        assertEquals("exception from run", e.getMessage());
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertNull(threadContext.getTransient("run"));
        assertTrue(threadContext.isDefaultContext());
    }

    public void testPreservesThreadsOriginalContextOnFailureException() throws IOException {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        Runnable withContext;

        // a runnable that throws from onFailure
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "bar");
            threadContext.putTransient("foo", "bar_transient");
            withContext = threadContext.preserveContext(new AbstractRunnable() {
                @Override
                public void onFailure(Exception e) {
                    throw new RuntimeException("from onFailure", e);
                }

                @Override
                protected void doRun() throws Exception {
                    assertEquals("bar", threadContext.getHeader("foo"));
                    assertEquals("bar_transient", threadContext.getTransient("foo"));
                    assertFalse(threadContext.isDefaultContext());
                    throw new RuntimeException("from doRun");
                }
            });
        }

        // We don't see the header outside of the runnable
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertTrue(threadContext.isDefaultContext());

        // But we do inside of it
        RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
        assertEquals("from onFailure", e.getMessage());
        assertEquals("from doRun", e.getCause().getMessage());

        // but not after
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertTrue(threadContext.isDefaultContext());
    }

    public void testPreservesThreadsOriginalContextOnAfterException() throws IOException {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        Runnable withContext;

        // a runnable that throws from onAfter
        try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
            threadContext.putHeader("foo", "bar");
            threadContext.putTransient("foo", "bar_transient");
            withContext = threadContext.preserveContext(new AbstractRunnable() {

                @Override
                public void onAfter() {
                    throw new RuntimeException("from onAfter");
                }

                @Override
                public void onFailure(Exception e) {
                    throw new RuntimeException("from onFailure", e);
                }

                @Override
                protected void doRun() throws Exception {
                    assertEquals("bar", threadContext.getHeader("foo"));
                    assertEquals("bar_transient", threadContext.getTransient("foo"));
                    assertFalse(threadContext.isDefaultContext());
                }
            });
        }

        // We don't see the header outside of the runnable
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertTrue(threadContext.isDefaultContext());

        // But we do inside of it
        RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
        assertEquals("from onAfter", e.getMessage());
        assertNull(e.getCause());

        // but not after
        assertNull(threadContext.getHeader("foo"));
        assertNull(threadContext.getTransient("foo"));
        assertTrue(threadContext.isDefaultContext());
    }

    public void testMarkAsSystemContext() throws IOException {
        ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
        assertFalse(threadContext.isSystemContext());
        try (ThreadContext.StoredContext context = threadContext.stashContext()) {
            assertFalse(threadContext.isSystemContext());
            threadContext.markAsSystemContext();
            assertTrue(threadContext.isSystemContext());
        }
        assertFalse(threadContext.isSystemContext());
    }

    public void testPutHeaders() {
        Settings build = Settings.builder().put("request.headers.default", "1").build();
        ThreadContext threadContext = new ThreadContext(build);
        threadContext.putHeader(Collections.<String, String>emptyMap());
        threadContext.putHeader(Collections.<String, String>singletonMap("foo", "bar"));
        assertEquals("bar", threadContext.getHeader("foo"));
        IllegalArgumentException e = expectThrows(
            IllegalArgumentException.class,
            () -> threadContext.putHeader(Collections.<String, String>singletonMap("foo", "boom"))
        );
        assertEquals("value for key [foo] already present", e.getMessage());
    }

    /**
     * Sometimes wraps a Runnable in an AbstractRunnable.
     */
    private Runnable sometimesAbstractRunnable(Runnable r) {
        if (random().nextBoolean()) {
            return r;
        }
        return new AbstractRunnable() {
            @Override
            public void onFailure(Exception e) {
                throw new RuntimeException(e);
            }

            @Override
            protected void doRun() throws Exception {
                r.run();
            }
        };
    }
}