/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.ml.rest; import static org.opensearch.ml.common.MLTask.FUNCTION_NAME_FIELD; import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import java.io.IOException; import java.util.Map; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.utils.TestHelper; public class RestMLCustomModelActionIT extends MLCommonsRestTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); private MLRegisterModelInput registerModelInput; @Before public void setup() { registerModelInput = createRegisterModelInput(); } @Ignore public void testCustomModelWorkflow() throws IOException, InterruptedException { // register model String taskId = registerModel(TestHelper.toJsonString(registerModelInput)); waitForTask(taskId, MLTaskState.COMPLETED); getTask(client(), taskId, response -> { String algorithm = (String) response.get(FUNCTION_NAME_FIELD); assertEquals(registerModelInput.getFunctionName().name(), algorithm); assertNotNull(response.get(MODEL_ID_FIELD)); assertEquals(MLTaskState.COMPLETED.name(), response.get(STATE_FIELD)); String modelId = (String) response.get(MODEL_ID_FIELD); try { // deploy model String deployTaskId = deployModel(modelId); waitForTask(deployTaskId, MLTaskState.COMPLETED); getTask(client(), deployTaskId, deployTaskResponse -> { assertEquals(modelId, deployTaskResponse.get(MODEL_ID_FIELD)); assertEquals(MLTaskState.COMPLETED.name(), response.get(STATE_FIELD)); }); Thread.sleep(300); // profile getModelProfile(modelId, verifyTextEmbeddingModelDeployed()); // predict predictTextEmbedding(modelId); // undeploy model Map<String, Object> result = undeployModel(modelId); for (Map.Entry<String, Object> entry : result.entrySet()) { Map stats = (Map) ((Map) entry.getValue()).get("stats"); assertEquals("undeployed", stats.get(modelId)); } } catch (IOException | InterruptedException e) { throw new RuntimeException(e); } }); } }