/* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. * * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ package org.opensearch.ad.ratelimit; import static org.opensearch.ad.settings.AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY; import java.time.Clock; import java.time.Duration; import java.util.ArrayDeque; import java.util.Locale; import java.util.Optional; import java.util.Random; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.ml.EntityColdStarter; import org.opensearch.ad.ml.EntityModel; import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.util.ExceptionUtil; /** * A queue for HCAD model training (a.k.a. cold start). As model training is a * pretty expensive operation, we pull cold start requests from the queue in a * serial fashion. Each detector has an equal chance of being pulled. The equal * probability is achieved by putting model training requests for different * detectors into different segments and pulling requests from segments in a * round-robin fashion. * */ public class EntityColdStartWorker extends SingleRequestWorker<EntityRequest> { private static final Logger LOG = LogManager.getLogger(EntityColdStartWorker.class); public static final String WORKER_NAME = "cold-start"; private final EntityColdStarter entityColdStarter; private final CacheProvider cacheProvider; public EntityColdStartWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting<Float> maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, ADCircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, EntityColdStarter entityColdStarter, Duration stateTtl, NodeStateManager nodeStateManager, CacheProvider cacheProvider ) { super( WORKER_NAME, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, ENTITY_COLD_START_QUEUE_CONCURRENCY, executionTtl, stateTtl, nodeStateManager ); this.entityColdStarter = entityColdStarter; this.cacheProvider = cacheProvider; } @Override protected void executeRequest(EntityRequest coldStartRequest, ActionListener<Void> listener) { String detectorId = coldStartRequest.getId(); Optional<String> modelId = coldStartRequest.getModelId(); if (false == modelId.isPresent()) { String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); LOG.warn(error); listener.onFailure(new RuntimeException(error)); return; } ModelState<EntityModel> modelState = new ModelState<>( new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null), modelId.get(), detectorId, ModelType.ENTITY.getName(), clock, 0 ); ActionListener<Void> coldStartListener = ActionListener.wrap(r -> { nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { try { if (!detectorOptional.isPresent()) { LOG .error( new ParameterizedMessage( "fail to load trained model [{}] to cache due to the detector not being found.", modelState.getModelId() ) ); return; } AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); EntityModel model = modelState.getModel(); // load to cache if cold start succeeds if (model != null && model.getTrcf() != null) { cacheProvider.get().hostIfPossible(detector, modelState); } } finally { listener.onResponse(null); } }, listener::onFailure)); }, e -> { try { if (ExceptionUtil.isOverloaded(e)) { LOG.error("OpenSearch is overloaded"); setCoolDownStart(); } nodeStateManager.setException(detectorId, e); } finally { listener.onFailure(e); } }); entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, coldStartListener); } }