/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file quantize.cc * * \brief transform a graph to a low-bit graph * for compression and acceleration. */ #include "./quantize.h" #include <dmlc/thread_local.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> #include <stack> namespace tvm { namespace relay { namespace quantize { TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 5); const auto param = attrs.as<SimulatedQuantizeAttrs>(); ICHECK(param != nullptr); const auto* data = types[0].as<TensorTypeNode>(); if (data == nullptr) { return false; } ICHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max reporter->Assign(types[4], types[0]); // output return true; } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) .set_num_inputs(4) .add_argument("data", "Tensor", "The input data.") .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") .set_attrs_type<SimulatedQuantizeAttrs>() .set_support_level(11) .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, String rounding) { auto attrs = make_object<SimulatedQuantizeAttrs>(); attrs->kind = kind; attrs->sign = sign; attrs->rounding = rounding; static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); }); /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ QConfig default_config; /*! \brief The current build config context */ std::stack<QConfig> context_stack; TVMQConfigThreadLocalEntry() : default_config(make_object<QConfigNode>()) {} }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore; void QConfig::EnterQConfigScope(const QConfig& build_config) { TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.push(build_config); } void QConfig::ExitQConfigScope() { TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.pop(); } QConfig& QConfig::Current() { TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } return entry->default_config; } TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<QConfigNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* op = static_cast<const QConfigNode*>(ref.get()); p->stream << "qconfig("; p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "weight_scale=" << op->weight_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "skip_dense_layer==" << op->skip_dense_layer << ", "; p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; p->stream << "rounding==" << op->rounding << ", "; p->stream << "partition_conversions==" << op->partition_conversions; p->stream << ")"; }); TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig").set_body_typed([]() -> QConfig { return QConfig::Current(); }); TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") .set_body_typed(QConfig::EnterQConfigScope); TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope").set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize } // namespace relay } // namespace tvm