# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import os from os import path from tvm import auto_scheduler from tvm.driver import tvmc def _get_tasks(model): tvmc_model = tvmc.frontends.load_model(model) tasks, weights = tvmc.autotuner.autoscheduler_get_tuning_tasks( tvmc_model.mod, tvmc_model.params, "llvm" ) return (tasks, weights) def _autoscheduler_test_helper(model, tmpdir_name, early_stopping=1, prior_records=None): tvmc_model = tvmc.frontends.load_model(model) log_file = os.path.join(tmpdir_name, "autoscheduler.json") hardware_params = auto_scheduler.HardwareParams(num_cores=4, target="llvm") tvmc.tune( tvmc_model, target="llvm", tuning_records=log_file, prior_records=prior_records, early_stopping=early_stopping, enable_autoscheduler=True, trials=2, hardware_params=hardware_params, ) # testing whether the log file was produced assert path.exists(log_file), "autoscheduler log file should exist" with auto_scheduler.ApplyHistoryBest(log_file) as best: assert isinstance( best, auto_scheduler.dispatcher.ApplyHistoryBest ), "unable to load the best results of tuning" return log_file def test_get_tuning_tasks(keras_simple): pytest.importorskip("tensorflow") tasks, weights = _get_tasks(keras_simple) expected_task_type = auto_scheduler.SearchTask assert type(tasks) is list assert len(tasks) > 0 assert all([type(x) is expected_task_type for x in tasks]) is True def test_tune_tasks(keras_simple, tmpdir_factory): pytest.importorskip("tensorflow") tmpdir_name = tmpdir_factory.mktemp("data") _autoscheduler_test_helper(keras_simple, tmpdir_name) def test_tune_tasks__tuning_records(keras_simple, tmpdir_factory): pytest.importorskip("tensorflow") tmpdir_name = tmpdir_factory.mktemp("data") output_log_phase_1 = _autoscheduler_test_helper(keras_simple, tmpdir_name) # Exercises transfer learning by making sure a previous log exists _autoscheduler_test_helper(keras_simple, tmpdir_name, prior_records=output_log_phase_1) def test_tune_tasks__no_early_stopping(keras_simple, tmpdir_factory): pytest.importorskip("tensorflow") tmpdir_name = tmpdir_factory.mktemp("data") _autoscheduler_test_helper(keras_simple, tmpdir_name, early_stopping=None)