/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.tasks;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.common.lifecycle.AbstractLifecycleComponent;
import org.opensearch.common.metrics.CounterMetric;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

/**
 * This monitoring service is responsible to track long-running(defined by a threshold) cancelled tasks as part of
 * node stats.
 */
public class TaskCancellationMonitoringService extends AbstractLifecycleComponent implements TaskManager.TaskEventListeners {

    private static final Logger logger = LogManager.getLogger(TaskCancellationMonitoringService.class);
    private final static List<Class<? extends CancellableTask>> TASKS_TO_TRACK = Arrays.asList(SearchShardTask.class);

    private volatile Scheduler.Cancellable scheduledFuture;
    private final ThreadPool threadPool;
    private final TaskManager taskManager;
    /**
     * This is to keep track of currently running cancelled tasks. This is needed to accurately calculate cumulative
     * sum(from genesis) of cancelled tasks which have been running beyond a threshold and avoid double count
     * problem.
     * For example:
     * A task M was cancelled at some point of time and continues to run for long. This Monitoring service sees this
     * M for the first time and adds it as part of stats. In next iteration of monitoring service, it might see
     * this M(if still running) again, but using below map we will not double count this task as part of our cumulative
     * metric.
     */
    private final Map<Long, Boolean> cancelledTaskTracker;
    /**
     *  This map holds statistics for each cancellable task type.
     */
    private final Map<Class<? extends CancellableTask>, TaskCancellationStatsHolder> cancellationStatsHolder;
    private final TaskCancellationMonitoringSettings taskCancellationMonitoringSettings;

    public TaskCancellationMonitoringService(
        ThreadPool threadPool,
        TaskManager taskManager,
        TaskCancellationMonitoringSettings taskCancellationMonitoringSettings
    ) {
        this.threadPool = threadPool;
        this.taskManager = taskManager;
        this.taskCancellationMonitoringSettings = taskCancellationMonitoringSettings;
        this.cancelledTaskTracker = new ConcurrentHashMap<>();
        cancellationStatsHolder = TASKS_TO_TRACK.stream()
            .collect(Collectors.toConcurrentMap(task -> task, task -> new TaskCancellationStatsHolder()));
        taskManager.addTaskEventListeners(this);
    }

    void doRun() {
        if (!taskCancellationMonitoringSettings.isEnabled() || this.cancelledTaskTracker.isEmpty()) {
            return;
        }
        Map<Class<? extends CancellableTask>, List<CancellableTask>> taskCancellationListByType = getCurrentRunningTasksPostCancellation();
        taskCancellationListByType.forEach((key, value) -> {
            long uniqueTasksRunningCount = value.stream().filter(task -> {
                if (this.cancelledTaskTracker.containsKey(task.getId()) && !this.cancelledTaskTracker.get(task.getId())) {
                    // Mark it as seen by the stats logic.
                    this.cancelledTaskTracker.put(task.getId(), true);
                    return true;
                } else {
                    return false;
                }
            }).count();
            cancellationStatsHolder.get(key).totalLongRunningCancelledTaskCount.inc(uniqueTasksRunningCount);
        });
    }

    @Override
    protected void doStart() {
        scheduledFuture = threadPool.scheduleWithFixedDelay(() -> {
            try {
                doRun();
            } catch (Exception e) {
                logger.debug("Exception occurred in Task monitoring service", e);
            }
        }, taskCancellationMonitoringSettings.getInterval(), ThreadPool.Names.GENERIC);
    }

    @Override
    protected void doStop() {
        if (scheduledFuture != null) {
            scheduledFuture.cancel();
        }
    }

    @Override
    protected void doClose() throws IOException {

    }

    // For testing
    protected Map<Long, Boolean> getCancelledTaskTracker() {
        return this.cancelledTaskTracker;
    }

    /**
     * Invoked when a task is completed. This helps us to disable monitoring service when there are no cancelled tasks
     * running to avoid wasteful work.
     * @param task task which got completed.
     */
    @Override
    public void onTaskCompleted(Task task) {
        if (!TASKS_TO_TRACK.contains(task.getClass())) {
            return;
        }
        this.cancelledTaskTracker.entrySet().removeIf(entry -> entry.getKey() == task.getId());
    }

    /**
     * Invoked when a task is cancelled. This is to keep track of tasks being cancelled. More importantly also helps
     * us to enable this monitoring service only when needed.
     * @param task task which got cancelled.
     */
    @Override
    public void onTaskCancelled(CancellableTask task) {
        if (!TASKS_TO_TRACK.contains(task.getClass())) {
            return;
        }
        // Add task to tracker and mark it as not seen(false) yet by the stats logic.
        this.cancelledTaskTracker.putIfAbsent(task.getId(), false);
    }

    public TaskCancellationStats stats() {
        Map<Class<? extends CancellableTask>, List<CancellableTask>> currentRunningCancelledTasks =
            getCurrentRunningTasksPostCancellation();
        return new TaskCancellationStats(
            new SearchShardTaskCancellationStats(
                Optional.of(currentRunningCancelledTasks).map(mapper -> mapper.get(SearchShardTask.class)).map(List::size).orElse(0),
                cancellationStatsHolder.get(SearchShardTask.class).totalLongRunningCancelledTaskCount.count()
            )
        );
    }

    private Map<Class<? extends CancellableTask>, List<CancellableTask>> getCurrentRunningTasksPostCancellation() {
        long currentTimeInNanos = System.nanoTime();

        return taskManager.getCancellableTasks()
            .values()
            .stream()
            .filter(task -> TASKS_TO_TRACK.contains(task.getClass()))
            .filter(CancellableTask::isCancelled)
            .filter(task -> {
                long runningTimeSinceCancellationSeconds = TimeUnit.NANOSECONDS.toSeconds(
                    currentTimeInNanos - task.getCancellationStartTimeNanos()
                );
                return runningTimeSinceCancellationSeconds >= taskCancellationMonitoringSettings.getDuration().getSeconds();
            })
            .collect(Collectors.groupingBy(CancellableTask::getClass, Collectors.toList()));
    }

    /**
     * Holds stats related to monitoring service
     */
    public static class TaskCancellationStatsHolder {
        CounterMetric totalLongRunningCancelledTaskCount = new CounterMetric();
    }
}