/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file auto_scheduler/auto_schedule.cc * \brief The user interface and tuning options of the TVM auto-scheduler. */ #include #include #include "utils.h" namespace tvm { namespace auto_scheduler { TVM_REGISTER_NODE_TYPE(TuningOptionsNode); TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks) { auto node = make_object(); node->num_measure_trials = num_measure_trials; node->early_stopping = early_stopping; node->num_measures_per_round = num_measures_per_round; node->verbose = verbose; node->builder = std::move(builder); node->runner = std::move(runner); node->measure_callbacks = std::move(measure_callbacks); data_ = std::move(node); } std::pair> AutoSchedule(SearchPolicy search_policy, TuningOptions tuning_options) { // Create a ProgramMeasurer to handle the schedule build and performance measure ProgramMeasurer measurer = ProgramMeasurer(tuning_options->builder, tuning_options->runner, tuning_options->measure_callbacks, tuning_options->verbose); // Search for the best schedule State state = search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping, tuning_options->num_measures_per_round, measurer); if (state.defined()) { return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps); } else { StdCout(tuning_options->verbose) << "No valid state found in this search round. Check if it has traversed all of the " << "search space." << std::endl; // Return the default schedule return {te::Schedule(search_policy->search_task->compute_dag->ops), search_policy->search_task->compute_dag->tensors}; } } TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions") .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramBuilder builder, ProgramRunner runner, Optional> measure_callbacks) { return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, builder, runner, measure_callbacks); }); TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule") .set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) { te::Schedule sch; Array return_tensors; std::tie(sch, return_tensors) = AutoSchedule(search_policy, tuning_options); return Array{sch, return_tensors}; }); } // namespace auto_scheduler } // namespace tvm