/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file int_constraints.cc * \brief The integer constraints data structures. */ #include #include #include #include #include #include #include #include #include #include #include "../tir/transforms/ir_utils.h" namespace tvm { namespace arith { Array AsConditions(const Array& variables, const Map& bounds, const Array& relations) { Array res; // use variables to keep the order of iteration // so as to get rid of any non-determinism. ICHECK_EQ(variables.size(), bounds.size()); for (const auto v : variables) { ICHECK(bounds.count(v)); const auto& bnds = bounds[v]; PrimExpr lhs = bnds->coef * v; for (const PrimExpr& rhs : bnds->equal) { res.push_back(tir::EQ(lhs, rhs)); } for (const PrimExpr& rhs : bnds->lower) { res.push_back(tir::GE(lhs, rhs)); } for (const PrimExpr& rhs : bnds->upper) { res.push_back(tir::LE(lhs, rhs)); } } for (const PrimExpr& e : relations) { res.push_back(e); } return res; } IntGroupBounds::IntGroupBounds(PrimExpr coef, Array lower, Array equal, Array upper) { ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) << "Coefficient in IntGroupBounds must be integers"; ObjectPtr node = make_object(); node->coef = std::move(coef); node->lower = std::move(lower); node->equal = std::move(equal); node->upper = std::move(upper); data_ = std::move(node); } IntGroupBounds IntGroupBounds::FromRange(const Range& r) { Analyzer analyzer; PrimExpr coef = tir::make_const(r->min.dtype(), 1); Array equal; Array lower; Array upper; if (tir::is_one(r->extent)) { equal.push_back(r->min); } else { lower.push_back(r->min); upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); } return IntGroupBounds(coef, lower, equal, upper); } IntGroupBounds IntGroupBounds::operator+(const Range& r) { Analyzer analyzer; Array equal; Array lower; Array upper; const PrimExpr& coef = operator->()->coef; if (tir::is_one(r->extent)) { equal.push_back(analyzer.Simplify(r->min * coef)); } else { lower.push_back(analyzer.Simplify(r->min * coef)); upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef)); } for (const auto& eq : operator->()->equal) equal.push_back(eq); for (const auto& lb : operator->()->lower) lower.push_back(lb); for (const auto& ub : operator->()->upper) upper.push_back(ub); return IntGroupBounds(coef, lower, equal, upper); } IntGroupBounds IntGroupBounds::Substitute(const Map& subst) const { auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; return IntGroupBounds(tir::Substitute(operator->()->coef, subst), tir::UpdateArray(operator->()->lower, apply_fun), tir::UpdateArray(operator->()->equal, apply_fun), tir::UpdateArray(operator->()->upper, apply_fun)); } Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { Analyzer analyzer; analyzer.Bind(vranges_addl); std::unordered_map var_intsets; for (auto kv : vranges_addl) { var_intsets[kv.first.get()] = IntSet::FromRange(kv.second); } const Array& equal = operator->()->equal; const PrimExpr& coef = operator->()->coef; std::vector lowers(equal.begin(), equal.end()); std::vector uppers(equal.begin(), equal.end()); for (const auto& expr : operator->()->lower) { lowers.push_back(expr); } for (const auto& expr : operator->()->upper) { uppers.push_back(expr); } if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) { return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); } // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the // pair with the minimal difference between the upper and the lower. // Note that the bounds are for v, not for v*coef // The lower bound of the best pair so far PrimExpr best_lower; // The difference between the upper and the lower of the best pair, maybe overapproximation PrimExpr best_diff_over; for (const PrimExpr& low : lowers) { for (const PrimExpr& upp : uppers) { PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); // Since diff may depend on some other variables, we compute its overapproximation PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3); // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3); // Compute another difference which may be more precise (or not). PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); PrimExpr diff_over = analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { best_lower = low_divided; best_diff_over = diff_over; } } } if (!best_lower.defined()) { ICHECK(!best_diff_over.defined()); return Range(); } return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); } TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); TVM_REGISTER_GLOBAL("arith.IntGroupBounds") .set_body_typed([](PrimExpr coef, Array lower, Array equal, Array upper) { return IntGroupBounds(coef, lower, equal, upper); }); TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") .set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK(args.size() == 1 || args.size() == 2); IntGroupBounds bounds = args[0]; if (args.size() == 1) { *ret = bounds.FindBestRange(); } else if (args.size() == 2) { *ret = bounds.FindBestRange(args[1]); } }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower << ", equal=" << op->equal << ", upper=" << op->upper << ")"; }); IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { variables = Array(); } if (!ranges.defined()) { ranges = Map(); } ICHECK(relations.defined()); for (const auto& var : variables) { ICHECK(var.dtype().is_int() || var.dtype().is_uint()) << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); node->ranges = std::move(ranges); node->relations = std::move(relations); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_REGISTER_GLOBAL("arith.IntConstraints") .set_body_typed([](Array variables, Map ranges, Array relations) { return IntConstraints(variables, ranges, relations); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations << ")"; }); IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { ObjectPtr node = make_object(); node->src = std::move(src); node->dst = std::move(dst); node->src_to_dst = std::move(src_to_dst); node->dst_to_src = std::move(dst_to_src); data_ = std::move(node); } IntConstraintsTransform IntConstraintsTransform::operator+( const IntConstraintsTransform& other) const { ICHECK(other->src.same_as(operator->()->dst)); Map dst_to_src; Map src_to_dst; Analyzer ana_first; ana_first.Bind(operator->()->src->ranges); for (auto p : other->dst_to_src) { dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src))); } Analyzer ana_second; ana_second.Bind(other->dst->ranges); for (auto p : operator->()->src_to_dst) { src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst))); } return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src); } TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") .set_body_typed([](IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "IntConstraintsTransform(" << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t" << op->dst_to_src << "\n)"; }); } // namespace arith } // namespace tvm