# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import os

from os import path

from tvm import auto_scheduler
from tvm.driver import tvmc


def _get_tasks(model):
    tvmc_model = tvmc.frontends.load_model(model)
    tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks(
        tvmc_model.mod, tvmc_model.params, "llvm"
    )
    return (tasks, weights)


def _autoscheduler_test_helper(model, tmpdir_name, early_stopping=1, prior_records=None):
    tvmc_model = tvmc.frontends.load_model(model)
    log_file = os.path.join(tmpdir_name, "autoscheduler.json")

    hardware_params = auto_scheduler.HardwareParams(num_cores=4, target="llvm")

    tvmc.tune(
        tvmc_model,
        target="llvm",
        tuning_records=log_file,
        prior_records=prior_records,
        early_stopping=early_stopping,
        enable_autoscheduler=True,
        trials=2,
        hardware_params=hardware_params,
    )

    # testing whether the log file was produced
    assert path.exists(log_file), "autoscheduler log file should exist"

    with auto_scheduler.ApplyHistoryBest(log_file) as best:
        assert isinstance(
            best, auto_scheduler.dispatcher.ApplyHistoryBest
        ), "unable to load the best results of tuning"

    return log_file


def test_get_tuning_tasks(keras_simple):
    pytest.importorskip("tensorflow")

    tasks, weights = _get_tasks(keras_simple)
    expected_task_type = auto_scheduler.SearchTask

    assert type(tasks) is list
    assert len(tasks) > 0
    assert all([type(x) is expected_task_type for x in tasks]) is True


def test_tune_tasks(keras_simple, tmpdir_factory):
    pytest.importorskip("tensorflow")

    tmpdir_name = tmpdir_factory.mktemp("data")
    _autoscheduler_test_helper(keras_simple, tmpdir_name)


def test_tune_tasks__tuning_records(keras_simple, tmpdir_factory):
    pytest.importorskip("tensorflow")

    tmpdir_name = tmpdir_factory.mktemp("data")
    output_log_phase_1 = _autoscheduler_test_helper(keras_simple, tmpdir_name)

    # Exercises transfer learning by making sure a previous log exists
    _autoscheduler_test_helper(keras_simple, tmpdir_name, prior_records=output_log_phase_1)


def test_tune_tasks__no_early_stopping(keras_simple, tmpdir_factory):
    pytest.importorskip("tensorflow")

    tmpdir_name = tmpdir_factory.mktemp("data")
    _autoscheduler_test_helper(keras_simple, tmpdir_name, early_stopping=None)