/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/ir/op_strategy.cc * \brief The Relay operator Strategy and related data structure. */ #include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(OpImplementationNode); TVM_REGISTER_NODE_TYPE(OpSpecializationNode); TVM_REGISTER_NODE_TYPE(OpStrategyNode); Array OpImplementation::Compute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return (*this)->fcompute(attrs, inputs, out_type); } te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array& outs, const Target& target) { return (*this)->fschedule(attrs, outs, target); } void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, tvm::relay::FTVMSchedule fschedule, String name, int plevel) { auto n = make_object(); n->fcompute = fcompute; n->fschedule = fschedule; n->name = std::move(name); n->plevel = plevel; (*this)->implementations.push_back(OpImplementation(n)); } void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, int plevel) { auto curr_cond = te::SpecializedCondition::Current(); auto self = this->operator->(); Array specializations = self->specializations; OpSpecialization op_spec; for (OpSpecialization op_spec : specializations) { if (op_spec->condition == curr_cond) { op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel); return; } } ObjectPtr n = make_object(); n->condition = curr_cond; op_spec = OpSpecialization(n); op_spec.AddImplementation(fcompute, fschedule, std::move(name), plevel); self->specializations.push_back(op_spec); } TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute") .set_body([](TVMArgs args, TVMRetValue* rv) { OpImplementation imp = args[0]; Attrs attrs = args[1]; Array inputs = args[2]; Type out_type = args[3]; *rv = imp.Compute(attrs, inputs, out_type); }); TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule") .set_body([](TVMArgs args, TVMRetValue* rv) { OpImplementation imp = args[0]; Attrs attrs = args[1]; Array outs = args[2]; Target target = args[3]; *rv = imp.Schedule(attrs, outs, target); }); TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectPtr n = make_object(); *rv = OpStrategy(n); }); TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation") .set_body([](TVMArgs args, TVMRetValue* rv) { OpStrategy strategy = args[0]; FTVMCompute compute = args[1]; FTVMSchedule schedule = args[2]; std::string name = args[3]; int plevel = args[4]; strategy.AddImplementation(compute, schedule, name, plevel); }); } // namespace relay } // namespace tvm