/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.knn; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpHost; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; import org.apache.hc.core5.http.nio.ssl.TlsStrategy; import org.apache.hc.core5.reactor.ssl.TlsDetails; import org.apache.hc.core5.ssl.SSLContextBuilder; import org.apache.hc.core5.util.Timeout; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; import org.opensearch.common.Strings; import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.MediaType; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.core.rest.RestStatus; import org.opensearch.search.SearchHit; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.junit.After; import org.opensearch.commons.rest.SecureRestClientBuilder; import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.knn.TestUtils.KNN_BWC_PREFIX; import static org.opensearch.knn.TestUtils.OPENDISTRO_SECURITY; import static org.opensearch.knn.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; import static org.opensearch.knn.TestUtils.SECURITY_AUDITLOG_PREFIX; import static org.opensearch.knn.TestUtils.SKIP_DELETE_MODEL_INDEX; import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; /** * ODFE integration test base class to support both security disabled and enabled ODFE cluster. */ public abstract class ODFERestTestCase extends OpenSearchRestTestCase { private final Set IMMUTABLE_INDEX_PREFIXES = Set.of(KNN_BWC_PREFIX, SECURITY_AUDITLOG_PREFIX, OPENSEARCH_SYSTEM_INDEX_PREFIX); protected boolean isHttps() { boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false); if (isHttps) { // currently only external cluster is supported for security enabled testing if (!Optional.ofNullable(System.getProperty("tests.rest.cluster")).isPresent()) { throw new RuntimeException("cluster url should be provided for security enabled testing"); } } return isHttps; } @Override protected String getProtocol() { return isHttps() ? "https" : "http"; } @Override protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { RestClientBuilder builder = RestClient.builder(hosts); if (isHttps()) { String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH); if (Objects.nonNull(keystore)) { URI uri; try { uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI(); } catch (URISyntaxException e) { throw new RuntimeException(e); } Path configPath = PathUtils.get(uri).getParent().toAbsolutePath(); return new SecureRestClientBuilder(settings, configPath).build(); } else { configureHttpsClient(builder, settings); boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); builder.setStrictDeprecationMode(strictDeprecationMode); return builder.build(); } } else { configureClient(builder, settings); } return builder.build(); } protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { Map headers = ThreadContext.buildDefaultHeaders(settings); Header[] defaultHeaders = new Header[headers.size()]; int i = 0; for (Map.Entry entry : headers.entrySet()) { defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); } builder.setDefaultHeaders(defaultHeaders); builder.setHttpClientConfigCallback(httpClientBuilder -> { String userName = Optional.ofNullable(System.getProperty("user")) .orElseThrow(() -> new RuntimeException("user name is missing")); String password = Optional.ofNullable(System.getProperty("password")) .orElseThrow(() -> new RuntimeException("password is missing")); BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); final AuthScope anyScope = new AuthScope(null, -1); credentialsProvider.setCredentials(anyScope, new UsernamePasswordCredentials(userName, password.toCharArray())); try { final TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) // See https://issues.apache.org/jira/browse/HTTPCLIENT-2219 .setTlsDetailsFactory(sslEngine -> new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol())) .build(); final PoolingAsyncClientConnectionManager connectionManager = PoolingAsyncClientConnectionManagerBuilder.create() .setMaxConnPerRoute(DEFAULT_MAX_CONN_PER_ROUTE) .setMaxConnTotal(DEFAULT_MAX_CONN_TOTAL) .setTlsStrategy(tlsStrategy) .build(); return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(connectionManager); } catch (Exception e) { throw new RuntimeException(e); } }); final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); final TimeValue socketTimeout = TimeValue.parseTimeValue( socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT ); builder.setRequestConfigCallback(conf -> { Timeout timeout = Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())); conf.setConnectTimeout(timeout); conf.setResponseTimeout(timeout); return conf; }); if (settings.hasValue(CLIENT_PATH_PREFIX)) { builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); } } /** * wipeAllIndices won't work since it cannot delete security index. Use wipeAllODFEIndices instead. */ @Override protected boolean preserveIndicesUponCompletion() { return true; } @SuppressWarnings("unchecked") @After protected void wipeAllODFEIndices() throws Exception { Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); MediaType mediaType = MediaType.fromMediaType(response.getEntity().getContentType()); try ( XContentParser parser = mediaType.xContent() .createParser( NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, response.getEntity().getContent() ) ) { XContentParser.Token token = parser.nextToken(); List> parserList = null; if (token == XContentParser.Token.START_ARRAY) { parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); } else { parserList = Collections.singletonList(parser.mapOrdered()); } for (Map index : parserList) { final String indexName = (String) index.get("index"); if (isIndexCleanupRequired(indexName)) { wipeIndexContent(indexName); continue; } if (!skipDeleteIndex(indexName)) { adminClient().performRequest(new Request("DELETE", "/" + indexName)); } } } } private boolean isIndexCleanupRequired(final String index) { return MODEL_INDEX_NAME.equals(index) && !getSkipDeleteModelIndexFlag(); } private void wipeIndexContent(String indexName) throws IOException, ParseException { deleteModels(getModelIds()); deleteAllDocs(indexName); } private List getModelIds() throws IOException, ParseException { final String restURIGetModels = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); final Response response = adminClient().performRequest(new Request("GET", restURIGetModels)); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); final String responseBody = EntityUtils.toString(response.getEntity()); assertNotNull(responseBody); final XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody); final SearchResponse searchResponse = SearchResponse.fromXContent(parser); return Arrays.stream(searchResponse.getHits().getHits()).map(SearchHit::getId).collect(Collectors.toList()); } private void deleteModels(final List modelIds) throws IOException { for (final String testModelID : modelIds) { final String restURIGetModel = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); final Response getModelResponse = adminClient().performRequest(new Request("GET", restURIGetModel)); if (RestStatus.OK != RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())) { continue; } final String restURIDeleteModel = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); adminClient().performRequest(new Request("DELETE", restURIDeleteModel)); } } private void deleteAllDocs(final String indexName) throws IOException { final String restURIDeleteByQuery = String.join("/", indexName, "_delete_by_query"); final Request request = new Request("POST", restURIDeleteByQuery); final XContentBuilder matchAllDocsQuery = XContentFactory.jsonBuilder() .startObject() .startObject("query") .startObject("match_all") .endObject() .endObject() .endObject(); request.setJsonEntity(Strings.toString(matchAllDocsQuery)); adminClient().performRequest(request); } private boolean getSkipDeleteModelIndexFlag() { return Boolean.parseBoolean(System.getProperty(SKIP_DELETE_MODEL_INDEX, "false")); } private boolean skipDeleteModelIndex(String indexName) { return (MODEL_INDEX_NAME.equals(indexName) && getSkipDeleteModelIndexFlag()); } private boolean skipDeleteIndex(String indexName) { if (indexName != null && !OPENDISTRO_SECURITY.equals(indexName) && IMMUTABLE_INDEX_PREFIXES.stream().noneMatch(indexName::startsWith) && !skipDeleteModelIndex(indexName)) { return false; } return true; } @Override protected Settings restAdminSettings() { return Settings.builder() // disable the warning exception for admin client since it's only used for cleanup. .put("strictDeprecationMode", false) .put("http.port", 9200) .put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps()) .put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem") .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks") .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit") .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit") .build(); } }