/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "./te_compiler_cache.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "utils.h" namespace tvm { namespace relay { namespace tec { TVM_REGISTER_NODE_TYPE(LoweredOutputNode); TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode); LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation impl) { auto n = make_object(); n->outputs = std::move(outputs); n->implementation = std::move(impl); data_ = std::move(n); } CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); data_ = std::move(n); } CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, tvm::Array outputs, te::Schedule schedule, tir::PrimFunc prim_func, tvm::Array shape_func_param_states, IRModule funcs) { auto n = make_object(); n->target = target; n->prim_fn_var = prim_fn_var; n->inputs = inputs; n->outputs = outputs; n->schedule = schedule; n->shape_func_param_states = shape_func_param_states; n->funcs = funcs; data_ = std::move(n); } Array GetShape(const Array& shape) { // for now, we always use int32 shape when possible // even if the result of shape inference becomes int64. Array res; for (IndexExpr val : shape) { const int64_t* pval = tir::as_const_int(val); if (pval != nullptr) { #ifndef TVM_INDEX_DEFAULT_I64 ICHECK_LE(pval[0], std::numeric_limits::max()) << "dimension must be less then int32_t's max value"; ICHECK_GE(pval[0], std::numeric_limits::min()) << "dimension must be less then int32_t's max value"; res.push_back(IntImm(DataType::Int(32), *pval)); #else res.push_back(val); #endif // TVM_INDEX_DEFAULT_I64 } else if (val->IsInstance()) { // currently all 'any' we meet in shape function are non-negative. res.push_back(val.as()->ToSizeVar()); } else { res.push_back(val); } } return res; } // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : public backend::MemoizedExprTranslator> { public: explicit ScheduleBuilder(Target target, bool create_schedule = true) : target_(target), device_copy_op_(Op::Get("device_copy")), create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } CachedFunc Create(const Function& relay_func, std::function renamer) { Array fn_inputs; for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); fn_inputs.push_back(tensor); inputs.push_back(tensor); } memo_[param] = inputs; } readable_name_stream_ << "fused"; auto outputs = this->VisitExpr(relay_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. auto prim_fn_var = GlobalVar(renamer(candidate_name)); prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. // Hence schedule only non PlaceholderOp outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { if (!tensor->op.as()) { tensor_outs.push_back(tensor); } } te::Schedule schedule{nullptr}; tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); ICHECK(fauto_schedule != nullptr) << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); if (obj.defined()) { schedule = Downcast(obj); } } if (use_meta_schedule_) { const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs"); const auto* f_meta_schedule = runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInsideWithScope"); ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered"; ICHECK(f_meta_schedule) << "meta_schedule.MetaScheduleContextQueryInsideWithScope is not registered"; prim_func = (*f_create_func)(tensor_outs); Optional opt_mod_or_base_func = (*f_meta_schedule)(prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), Array{IRModule({{prim_fn_var, prim_func}})}); if (const auto* result = opt_mod_or_base_func.as()) { prim_func = GetRef(result); } else { prim_func = tir::PrimFunc(nullptr); } } // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); } if (schedule.defined()) { for (const auto& scalar : scalars_) { if (schedule->Contain(scalar)) { schedule[scalar].compute_inline(); } } } } return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}); } Array VisitExpr_(const VarNode* op) final { LOG(FATAL) << "Unexpected free variable " << PrettyPrint(GetRef(op)); return {}; } Array VisitExpr_(const ConstantNode* op) final { using tir::make_const; ICHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); auto value = te::compute( {}, [&](const Array&) { if (dtype == DataType::Int(32)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Int(64)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Float(32)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Float(64)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Bool()) { return make_const(dtype, static_cast(data)[0]); } else { LOG(FATAL) << "not handled"; return tvm::PrimExpr(); } }, "compile_engine_const", topi::kBroadcast); scalars_.push_back(value->op); return {value}; } Array VisitExpr_(const CallNode* call_node) final { static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; Array inputs; int count_tuple = 0; for (Expr arg : call_node->args) { if (arg->checked_type().as()) { ++count_tuple; } for (te::Tensor tensor : VisitExpr(arg)) { inputs.push_back(tensor); } } if (count_tuple) { ICHECK_EQ(call_node->args.size(), 1U) << "Only functions with a single tuple input are allowed, but " << count_tuple << " were provided."; } ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); Array outputs; OpImplementation impl; // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; impl = lowered_out->implementation; if (create_schedule_) { int op_pattern = fpattern[op]; if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) << "Cannot apply TOPI schedule to a primitive function with two complicated ops" << " anchor=" << anchor_op_ << " current=" << op; } if (op_pattern >= anchor_op_pattern_) { anchor_op_ = op; anchor_attrs_ = call_node->attrs; anchor_op_pattern_ = op_pattern; anchor_implementation_ = impl; } } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); ICHECK(tuple_type) << "Expected output to be a tuple type " << PrettyPrint(call_node->checked_type()); ICHECK_EQ(tuple_type->fields.size(), outputs.size()); } // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; readable_name_stream_ << '_' << op->name; return outputs; } Array VisitExpr_(const FunctionNode* op) final { LOG(FATAL) << "Primitive Functions can not contain nested functions."; return Array(); } Array VisitExpr_(const LetNode* op) final { Array val = VisitExpr(op->value); ICHECK(!memo_.count(op->var)); memo_[op->var] = val; return VisitExpr(op->body); } Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { // TODO(mbs): Generalize to be equivalent to FlattenTupleType. ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); ICHECK_EQ(res.size(), 1); fields.push_back(res[0]); } return fields; } Array VisitExpr_(const TupleGetItemNode* op) final { const auto* tuple_type = op->tuple->type_as(); Array tuple = VisitExpr(op->tuple); ICHECK_EQ(tuple_type->fields.size(), tuple.size()); ICHECK_GE(op->index, 0); ICHECK_LT(static_cast(op->index), tuple.size()); return {tuple[op->index]}; } private: tvm::Target target_; Op anchor_op_; Attrs anchor_attrs_; int anchor_op_pattern_{0}; OpImplementation anchor_implementation_; std::ostringstream readable_name_stream_; Array scalars_; bool use_auto_scheduler_; bool use_meta_schedule_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; bool create_schedule_; }; /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. * \param target The target we want to create schedule for. * \return Pair of schedule and cache. * The funcs field in cache is not yet populated. */ CachedFunc PrimFuncFor(const Function& source_func, const Target& target, std::function renamer) { return ScheduleBuilder(target).Create(source_func, renamer); } // Creates shape function from functor. class MakeShapeFunc : public backend::MemoizedExprTranslator> { public: MakeShapeFunc() {} CachedFunc Create(const Function& prim_func, const Target& target, std::function renamer) { VLOG_CONTEXT << "MakeShapeFunc"; TShapeDataDependent shape_func_param_states; for (auto param : prim_func->params) { param_states_[param] = kNoNeed; Array data_inputs; Array shape_inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { // Add data placeholder (in case we discover we need it below) Shape shape = GetShape(ttype->shape); tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); data_inputs.push_back(data_tensor); // Add shape placeholder (in case we discover we need it below) int64_t ndim = shape.size(); Shape sshape; if (ndim > 0) { sshape.push_back(tvm::Integer(ndim)); } tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); shape_inputs.push_back(shape_tensor); } param_data_[param] = data_inputs; param_shapes_[param] = shape_inputs; } // Setup the name; readable_name_stream_ << "shape_func"; // Create the tensor expressions representing the output shapes. Array outputs = VisitExpr(prim_func->body); // Generate a name. auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } // Set all the inputs correctly, and accumulate their types from the p.o.v. of the // shape function rather than the primitive it is derived for. Array inputs; Array shape_function_arg_types; for (auto param : prim_func->params) { int state = param_states_[param]; shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); if (state & kNeedInputData) { // Pass the primitive arguments directly (though in flattened form and on the host) for (auto t : param_data_[param]) { inputs.push_back(t); shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType())); } } if (state & kNeedInputShape) { // Pass the shapes of the primitive arguments (also on the host) for (auto t : param_shapes_[param]) { inputs.push_back(t); shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType())); } } } // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. auto func_name = renamer(candidate_name); auto prim_fn_gvar = GlobalVar(func_name); // Gather the result types, again from the p.o.v. of the shape function rather than // the primitive it is derived for. Array shape_function_res_types; for (const auto& t : outputs) { shape_function_res_types.push_back(TensorType(t->GetShape(), t->GetDataType())); } // Assign the shape function its true type. FuncType type(shape_function_arg_types, TupleType(shape_function_res_types), /*type_params=*/{}, /*type_constraints=*/{}); VLOG(1) << "shape function '" << prim_fn_gvar->name_hint << "' has type:" << std::endl << PrettyPrint(type) << std::endl << "corresponding to primitive of type:" << std::endl << PrettyPrint(prim_func->checked_type()); prim_fn_gvar->checked_type_ = std::move(type); // generate schedule for shape func Array out_ops; for (auto t : outputs) { out_ops.push_back(t->op); } te::Schedule schedule = te::create_schedule(out_ops); tvm::te::AutoInlineInjective(schedule); for (const auto& scalar : scalars_) { auto scalar_op = scalar->op; if (schedule->Contain(scalar_op)) { schedule[scalar_op].compute_inline(); } } Array all_args = Array(inputs); for (te::Tensor arg : outputs) { all_args.push_back(arg); } using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); std::unordered_map binds; IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); // Unfortunately the above machinery creates its own GlobalVars instead of using *the* // GlobalVar we established above. Fix this before the confusion spreads any further. // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name. IRModule fixed_lowered_module; for (const auto& kv : lowered_module->functions) { GlobalVar global_var = kv.first->name_hint == prim_fn_gvar->name_hint ? prim_fn_gvar : kv.first; fixed_lowered_module->Add(global_var, kv.second); } return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr}, shape_func_param_states, fixed_lowered_module); } Array VisitExpr(const Expr& expr) final { if (expr.as()) { // Do not memoize vars because shape functions could use either the data // or the shape of a var each time. return ExprFunctor::VisitExpr(expr); } // For other case, do memoized visit return backend::MemoizedExprTranslator>::VisitExpr(expr); } Array VisitExpr_(const VarNode* var_node) final { auto var = GetRef(var_node); auto it = param_arg_map_.find(var); if (it != param_arg_map_.end()) { // This var is a parameter of a nested function. Visit the corresponding argument in the // function call site. return VisitExpr(it->second); } if (param_states_.find(var) == param_states_.end()) { LOG(FATAL) << "Unexpected free variable " << PrettyPrint(var); return {}; } else { ICHECK(data_dependents_per_input_.size()); auto data_dependent = data_dependents_per_input_.back(); if (data_dependent) { param_states_[var] |= kNeedInputData; return param_data_[var]; } else { param_states_[var] |= kNeedInputShape; return param_shapes_[var]; } } } Array VisitExpr_(const ConstantNode* op) final { using tir::make_const; ICHECK(data_dependents_per_input_.size()); bool data_dependent = data_dependents_per_input_.back(); if (!op->is_scalar()) { // This is a constant weight, extract the shape of the weight tensor. // This can not be data dependent. CHECK(!data_dependent); auto ttype = op->checked_type().as(); int ndim = static_cast(ttype->shape.size()); Array out_shape{ndim}; te::Tensor value = tvm::te::compute( out_shape, [&](const Array& indices) { auto idx = indices[0]; PrimExpr ret = make_const(DataType::Int(64), 0); for (int i = 0; i < ndim; i++) { ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); } return ret; }, "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } if (data_dependent) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); auto value = tvm::te::compute( {}, [&](const Array&) { if (dtype == DataType::Int(32)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Int(64)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Float(32)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Float(64)) { return make_const(dtype, static_cast(data)[0]); } else if (dtype == DataType::Bool()) { return make_const(dtype, static_cast(data)[0]); } else { LOG(FATAL) << "not handled"; return tvm::PrimExpr(); } }, "data_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } else { auto value = tvm::te::compute( {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } } Array VisitExpr_(const CallNode* call_node) final { VLOG(1) << "considering call:" << std::endl << PrettyPrint(GetRef(call_node)); if (auto* func = call_node->op.as()) { VLOG(1) << "user function"; for (size_t i = 0; i < func->params.size(); ++i) { param_arg_map_[func->params[i]] = call_node->args[i]; } return VisitExpr(func->body); } static auto fshape_func = Op::GetAttrMap("FShapeFunc"); static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) << "Error in op fusion: output of the shape func is fed to a " << "data-dependent shape func"; ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; ICHECK_GT(tshape_data_dependent.count(op), 0) << "Internal error, cannot find TShapeDataDependent for " << op->name; Array dep_spec = tshape_data_dependent[op]; if (dep_spec.size() == 1) { // This is for cases when data dependence is specified per op // Replicate 0 or 1 flag to all arguments for (size_t i = 1; i < call_node->args.size(); ++i) { dep_spec.push_back(dep_spec[0]); } } // Visit all inputs Array inputs; int count_tuple = 0; for (size_t i = 0; i < call_node->args.size(); ++i) { Expr arg = call_node->args[i]; if (arg->checked_type().as()) { ++count_tuple; } data_dependents_per_input_.push_back(dep_spec[i]->value != 0); for (te::Tensor tensor : VisitExpr(arg)) { inputs.push_back(tensor); } data_dependents_per_input_.pop_back(); } if (count_tuple) { ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } // Get output ndims auto ret_type = call_node->checked_type(); Array out_ndims; for (const auto& ttype : FlattenTupleType(ret_type)) { out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } // Call shape function Array outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); VLOG(1) << "shape function for '" << op->name << "' with inputs:" << std::endl << inputs << std::endl << "yielded outputs:" << std::endl << outputs; readable_name_stream_ << "_" << op->name; return outputs; } Array VisitExpr_(const FunctionNode* op) final { LOG(FATAL) << "Nested functions are not allowed to be visited."; return Array(); } Array VisitExpr_(const LetNode* op) final { Array val = VisitExpr(op->value); ICHECK(!memo_.count(op->var)); memo_[op->var] = val; return VisitExpr(op->body); } Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { ICHECK(field->checked_type().as()) << "Expected a Tuple of Tensor, but got " << PrettyPrint(field->checked_type()); Array res = VisitExpr(field); ICHECK_EQ(res.size(), 1); fields.push_back(res[0]); } return fields; } Array VisitExpr_(const TupleGetItemNode* op) final { Array input_shapes = VisitExpr(op->tuple); Array out; out.push_back(input_shapes[op->index]); return out; } private: /*! \brief String stream for function name */ std::ostringstream readable_name_stream_; /*! \brief Map from parameter to its shape function usage state */ std::unordered_map param_states_; /*! \brief Map from parameter to list of data placeholder */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; /*! \brief Stack of data dependencies for shape function, specified per each op input */ std::vector data_dependents_per_input_; /*! \brief Scalars used in the shape function */ Array scalars_; /*! \brief Map from parameters of a nested function to corresponding arguments in a function * call site. */ std::unordered_map param_arg_map_; }; CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::function renamer) { return MakeShapeFunc().Create(prim_func, target, renamer); } /*! * \brief Get unique name from name. * \param name The orginal name. * \return Updated name which is unique. */ std::string GetUniqueName(std::string name, std::unordered_map* name_map_) { for (size_t i = 0; i < name.length(); ++i) { if (name[i] == '.') name[i] = '_'; } while (true) { auto it = name_map_->find(name); if (it == name_map_->end()) { (*name_map_)[name] = 1; return name; } else { std::ostringstream os; os << name << "_" << it->second; ++(it->second); name = os.str(); } } return name; } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { return name; }); }); } // namespace tec } // namespace relay } // namespace tvm