/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file ad_simplify.cc * \brief Simplify tensor compute generated by tensor-level autodiff. * * The major simplification we do in this file is to eliminate * the Jacobian tensor created by autodiff. * * Jacobian tensor is sparse because one output element usually relates * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping * between input tensor and output tensor, thus the Jacobian is diagonal. * * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix, * \alpha and \beta are vectors represent the indices of In and Out respectively. * i.e., the non-zero Jacobian indices is a linear combination of the input indices. * Thereby we solve linear equations of \beta = A \alpha, * as well as linear inequalities of their domain ranges. * * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J]. * arXiv preprint arXiv:1711.01348, 2017. for more details. * * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition, * replace the compute expression with solved new axes, and create a selection node * (non-zero-condition ? new_compute_expression : 0). * * Due to TVM's restriction, we also lift the reduction to the top of the compute stage. * */ #include #include #include #include #include #include #include #include #include #include #include "ad_utils.h" namespace tvm { namespace te { using arith::DivMode; using arith::kFloorDiv; using arith::kSimplifyRewriteCanonicalRewrite; using arith::kTruncDiv; // Combine all expressions from the container using &&. template PrimExpr All(const container& c) { PrimExpr res; for (const auto& e : c) { if (res.get()) { res = res && e; } else { res = e; } } if (res.get()) { return res; } else { return const_true(); } } Map IterVarsToMap(const Array& itervars) { Map res; for (const IterVar& v : itervars) { res.Set(v->var, v->dom); } return res; } // Given a map from vars to ranges create an array of itervars Array IterVarsFromMap(const Array& vars, const Map& vranges, IterVarType iter_type = kDataPar, std::string thread_tag = "") { Array res; for (const Var& v : vars) { ICHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map " << vranges; res.push_back(IterVar(vranges[v], v, iter_type, thread_tag)); } return res; } Array IterVarsToVars(const Array& itervars) { Array res; for (const IterVar& v : itervars) { res.push_back(v->var); } return res; } template bool is_const_value(const PrimExpr& e, ValueType value) { static_assert(std::is_integral::value, "Comparison to non-integer values is forbidden."); if (const tir::IntImmNode* i = e.as()) { return i->value == value; } else if (const tir::FloatImmNode* i = e.as()) { return i->value == value; } else if (const tir::CastNode* c = e.as()) { return is_const_value(c->value, value); } else if (const tir::BroadcastNode* b = e.as()) { return is_const_value(b->value, value); } else { return false; } } // Return true if this combiner is just a sum. bool IsSumCombiner(const CommReducer& combiner, const Map& vranges) { arith::Analyzer analyzer; analyzer.Bind(vranges); if (combiner->result.size() != 1) { return false; } if (!is_const_value( analyzer.Simplify(combiner->identity_element[0], kSimplifyRewriteCanonicalRewrite), 0)) { return false; } PrimExpr combiner_result = analyzer.Simplify(combiner->result[0], kSimplifyRewriteCanonicalRewrite); return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) || tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]); } bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, const Map& vranges) { arith::Analyzer analyzer; analyzer.Bind(vranges); if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], kSimplifyRewriteCanonicalRewrite), 0)) { return false; } PrimExpr zero = make_zero(combiner->result[value_index].dtype()); PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero}, {combiner->rhs[value_index], zero}}); in = analyzer.Simplify(in, kSimplifyRewriteCanonicalRewrite); return is_const_value(in, 0); } struct NonzeroConditionResult { PrimExpr cond; PrimExpr value; PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); } friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) { return os << r.to_expr(); } }; // The implementation of NonzeroCondition // transform expression to cond ? value : 0 class NonzeroConditionFunctor : public ExprFunctor { public: NonzeroConditionResult NonzeroCondition(const PrimExpr& e) { if (e.dtype().is_bool()) { // Boolean expressions are non-zero whenever they are true themselves return {e, const_true()}; } else { return VisitExpr(e); } } // Most of the cases are implemented using helpers below result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef(op)); } result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef(op)); } result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef(op)); } result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef(op)); } result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef(op)); } result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef(op)); } result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef(op)); } result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef
(op)); } result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef(op)); } result_type VisitExpr_(const FloorDivNode* op) final { return BinOpDivLike_(GetRef(op)); } result_type VisitExpr_(const FloorModNode* op) final { return BinOpDivLike_(GetRef(op)); } result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef(op)); } result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef(op)); } result_type VisitExpr_(const CastNode* op) final { auto nz_a = NonzeroCondition(op->value); return {nz_a.cond, Cast(op->dtype, nz_a.value)}; } result_type VisitExpr_(const SelectNode* op) final { PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value; auto nz_a = NonzeroCondition(true_val); auto nz_b = NonzeroCondition(false_val); // If the false part is zero, we can get rid of the select if (is_const_value(nz_b.value, 0)) { PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && cond, kSimplifyRewriteCanonicalRewrite); return {new_cond, nz_a.value}; } // If the true part is zero, we can also get rid of the select if (is_const_value(nz_a.value, 0)) { PrimExpr new_cond = analyzer_.Simplify(nz_b.cond && !cond, kSimplifyRewriteCanonicalRewrite); return {new_cond, nz_b.value}; } // Otherwise we retain the select and combine the conditions into this PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), kSimplifyRewriteCanonicalRewrite); if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { return {new_cond, GetRef(op)}; } else { return {new_cond, Select(cond, nz_a.value, nz_b.value)}; } } result_type VisitExpr_(const CallNode* op) final { if (op->op.same_as(op_if_then_else_)) { PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; auto nz_a = NonzeroCondition(true_val); auto nz_b = NonzeroCondition(false_val); // We don't have as much freedom here as in the select case // since the `if` must be preserved in any case PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), kSimplifyRewriteCanonicalRewrite); if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { return {new_cond, GetRef(op)}; } else { return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; } } else { return Default_(GetRef(op)); } } result_type VisitExpr_(const ProducerLoadNode* op) final { return Default_(GetRef(op)); } NonzeroConditionResult Default_(const PrimExpr& e) { // This is always correct, so it's the default return {const_true(), e}; } template NonzeroConditionResult Const_(const T& op) { if (op->value == 0) { return {const_false(), op}; } else { return {const_true(), op}; } } template NonzeroConditionResult BinOpAddLike_(const T& op) { auto nz_a = NonzeroCondition(op->a); auto nz_b = NonzeroCondition(op->b); // For addition and similar ops the result may be nonzero if either of the arguments is // nonzero, so we combine the conditions with Or. if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) { // If the conditions are the same, we don't need Or if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { return {nz_a.cond, op}; } else { return {nz_a.cond, T(nz_a.value, nz_b.value)}; } } else { // Otherwise use Or PrimExpr new_cond = analyzer_.Simplify(nz_a.cond || nz_b.cond, kSimplifyRewriteCanonicalRewrite); // A little optimization: if the combined condition is the same as one of the inner // conditions, we don't need to guard the inner value with a select, otherwise // we create a select in the `to_expr` call. PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); PrimExpr new_expr = T(new_a, new_b); return {new_cond, new_expr}; } } template NonzeroConditionResult BinOpMulLike_(const T& op) { auto nz_a = NonzeroCondition(op->a); auto nz_b = NonzeroCondition(op->b); // For multiplication and similar ops the result may be nonzero if // both the arguments are nonzero, so we combine with And. PrimExpr new_cond = analyzer_.Simplify(nz_a.cond && nz_b.cond, kSimplifyRewriteCanonicalRewrite); if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { return {new_cond, op}; } else { return {new_cond, T(nz_a.value, nz_b.value)}; } } template NonzeroConditionResult BinOpDivLike_(const T& op) { auto nz_a = NonzeroCondition(op->a); // For Div we simply use the condition of the numerator. if (nz_a.value.same_as(op->a)) { return {nz_a.cond, op}; } else { return {nz_a.cond, T(nz_a.value, op->b)}; } } private: arith::Analyzer analyzer_; const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); }; inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) { return NonzeroConditionFunctor().NonzeroCondition(expr); } struct FactorOutAtomicFormulasResult { std::vector atomic_formulas; PrimExpr rest; PrimExpr to_expr() const { PrimExpr res = rest; for (const PrimExpr& e : atomic_formulas) { res = And(e, res); } return res; } Array to_array() const { Array res = atomic_formulas; res.push_back(rest); return res; } }; // The implementation of FactorOutAtomicFormulas class FactorOutAtomicFormulasFunctor : public ExprFunctor { public: result_type Atomic_(const PrimExpr& e) { // For atomic expressions the result is the expr itself with True as the residual return {{e}, make_const(e.dtype(), 1)}; } // This is basically the list of expression kinds that are considered atomic result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef(op)); } result_type VisitExpr_(const SelectNode* op) final { // Select can be rewritten through other logical ops PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value); return VisitExpr(expr); } result_type VisitExpr_(const NotNode* op) final { // Not should be moved down if (const OrNode* or_expr = op->a.as()) { PrimExpr expr = !or_expr->a && !or_expr->b; return VisitExpr(expr); } else if (const AndNode* and_expr = op->a.as()) { PrimExpr expr = !and_expr->a || !and_expr->b; return VisitExpr(expr); } else if (const SelectNode* sel_expr = op->a.as()) { PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) && (sel_expr->condition || !sel_expr->false_value)); return VisitExpr(expr); } return Atomic_(GetRef(op)); } result_type VisitExpr_(const AndNode* op) final { auto res_a = VisitExpr(op->a); auto res_b = VisitExpr(op->b); // For the And case we return the union of the sets of atomic formulas std::unordered_set res_a_set; res_a_set.reserve(res_a.atomic_formulas.size()); std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), std::inserter(res_a_set, res_a_set.end())); std::vector res = res_a.atomic_formulas; for (const auto& e : res_b.atomic_formulas) { if (res_a_set.find(e) == res_a_set.end()) { res.emplace_back(e); } } // And the residuals are combined with && return {res, res_a.rest && res_b.rest}; } result_type VisitExpr_(const MulNode* op) final { // Since we work with bools, for multiplication we do the same thing as for And PrimExpr e_and = op->a && op->b; return VisitExpr(e_and); } result_type VisitExpr_(const OrNode* op) final { auto res_a = VisitExpr(op->a); auto res_b = VisitExpr(op->b); std::unordered_set res_a_set{ res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()}; std::unordered_set res_b_set{ res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()}; // For the Or case we intersect the sets of atomic formulas std::unordered_set res_set; std::vector res; res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); for (const auto& res_b_formula : res_b.atomic_formulas) { if (res_a_set.count(res_b_formula)) { res_set.insert(res_b_formula); res.push_back(res_b_formula); } } // Computing the residual is more complex: we have to compute the sets of atomic formulas // which are left behind, and then combine them with the residuals into the new residual. std::vector new_cond_a; new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size()); for (const auto& formula : res_a.atomic_formulas) { if (!res_set.count(formula)) new_cond_a.emplace_back(formula); } std::vector new_cond_b; new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size()); for (const auto& formula : res_b.atomic_formulas) { if (!res_set.count(formula)) new_cond_b.emplace_back(formula); } res_a.atomic_formulas = std::move(new_cond_a); res_b.atomic_formulas = std::move(new_cond_b); PrimExpr new_rest = res_a.to_expr() || res_b.to_expr(); return {res, new_rest}; } }; // Transform the given formula into a conjunction of atomic formulas (represented as an array) // and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b, // etc), i.e. formulas which are not logical operators (||, &&, !) on the top level. FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) { ICHECK(e.dtype().is_bool()); return FactorOutAtomicFormulasFunctor().VisitExpr(e); } struct EliminateDivModResult { PrimExpr expr; Map substitution; Array new_variables; Array conditions; Map ranges; }; inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncmod(a, b); } else { ICHECK_EQ(mode, kFloorDiv); return floormod(a, b); } } inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncdiv(a, b); } else { ICHECK_EQ(mode, kFloorDiv); return floordiv(a, b); } } class EliminateDivModMutator : public ExprMutator { public: Map substitution; Array new_variables; Array conditions; Map ranges; explicit EliminateDivModMutator(Map ranges) : ranges(std::move(ranges)) {} virtual PrimExpr VisitExpr_(const DivNode* op) { const IntImmNode* imm = op->b.as(); if (imm && imm->value != 0) { if (imm->value < 0) { // x / -c == -(x/c) for truncated division return make_zero(op->dtype) - VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value))); } // Try to find the already existing variables for this expression auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value)); if (it != expr_to_vars_.end()) { return it->second.first; } // Otherwise recursively mutate the left hand side, and create new variables PrimExpr mutated_a = VisitExpr(op->a); if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) { return var_pair_opt.value().first; } else { return truncdiv(mutated_a, op->b); } } return div(VisitExpr(op->a), VisitExpr(op->b)); } virtual PrimExpr VisitExpr_(const ModNode* op) { const IntImmNode* imm = op->b.as(); if (imm && imm->value != 0) { if (imm->value < 0) { // x % -c == x % c for truncated division return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value))); } // Try to find the already existing variables for this expression auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value)); if (it != expr_to_vars_.end()) { return it->second.second; } // Otherwise recursively mutate the left hand side, and create new variables PrimExpr mutated_a = VisitExpr(op->a); if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) { return var_pair_opt.value().second; } else { return truncmod(mutated_a, op->b); } } return truncmod(VisitExpr(op->a), VisitExpr(op->b)); } virtual PrimExpr VisitExpr_(const FloorDivNode* op) { const IntImmNode* imm = op->b.as(); if (imm && imm->value != 0) { if (imm->value < 0) { // x / -c == (-x) / c for flooring division return VisitExpr( floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value))); } // Try to find the already existing variables for this expression auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value)); if (it != expr_to_vars_.end()) { return it->second.first; } // Otherwise recursively mutate the left hand side, and create new variables PrimExpr mutated_a = VisitExpr(op->a); if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) { return var_pair_opt.value().first; } else { return floordiv(mutated_a, op->b); } } return floordiv(VisitExpr(op->a), VisitExpr(op->b)); } virtual PrimExpr VisitExpr_(const FloorModNode* op) { const IntImmNode* imm = op->b.as(); if (imm && imm->value != 0) { if (imm->value < 0) { // x % -c == -(-x % c) for flooring division return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value))); } // Try to find the already existing variables for this expression auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value)); if (it != expr_to_vars_.end()) { return it->second.second; } // Otherwise recursively mutate the left hand side, and create new variables PrimExpr mutated_a = VisitExpr(op->a); if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) { return var_pair_opt.value().second; } else { return floormod(mutated_a, op->b); } } return floormod(VisitExpr(op->a), VisitExpr(op->b)); } private: dmlc::optional> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut, int64_t val, DivMode mode) { using tresult = dmlc::optional>; // Try to find the variables using the mutated expressions if (!e.same_as(mut)) { auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val)); if (it != expr_to_vars_.end()) { return tresult(it->second); } } PrimExpr val_e = make_const(e.dtype(), val); idx_ += 1; // Convert `ranges` to IntSets std::unordered_map var_intsets; for (const auto& p : ranges) { var_intsets[p.first.get()] = IntSet::FromRange(p.second); } // Infer ranges for the expressions we want to replace with variables Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range()); Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range()); // We don't want to add unbounded variables if (!div_range.get() || !mod_range.get()) { LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode) << " because its bounds cannot be inferred"; return tresult(); } if (!mod_range.get()) { LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode) << " because its bounds cannot be inferred"; return tresult(); } // Create new variables for the expressions auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype()); auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype()); new_variables.push_back(div); new_variables.push_back(mod); // Note that we have to perform substitution to mut because mut may contain new variables substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode)); substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode)); ranges.Set(div, div_range); ranges.Set(mod, mod_range); // This additional condition works as a definition for the new variables conditions.push_back(mut == div * val_e + mod); if (!analyzer_.CanProve(mod_range->extent <= val_e)) { // If we use the C/C++ definition of mod, there may be multiple values of `mod` // satisfying the added condition if the expr `e` may change its sign, so we // have to add another condition. LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because " << ModImpl(e, val_e, mode) << " probably may change its sign"; conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0)); } auto p = std::make_pair(div, mod); expr_to_vars_[std::make_tuple(mode, e, val)] = p; if (!e.same_as(mut)) { expr_to_vars_[std::make_tuple(mode, mut, val)] = p; } return tresult(p); } class TupleEqual_ { public: bool operator()(const std::tuple& lhs, const std::tuple& rhs) const { return std::get<0>(lhs) == std::get<0>(rhs) && tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) && std::get<2>(lhs) == std::get<2>(rhs); } }; class TupleHasher_ { public: size_t operator()(const std::tuple& key) const { return ((std::hash()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >> 1) ^ (std::hash()(std::get<2>(key)) << 1); } }; // A counter for naming new variables int idx_{0}; // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod) // such that `div = e / n` and `mod = e % n` std::unordered_map, std::pair, TupleHasher_, TupleEqual_> expr_to_vars_; arith::Analyzer analyzer_; }; // Replace every subexpr of the form e/const and e % const with a new variable. // Syntactically equal expressions will be mapped to the same variable. EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map ranges) { EliminateDivModResult res; EliminateDivModMutator mutator(ranges); res.expr = mutator(expr); res.conditions = std::move(mutator.conditions); res.new_variables = std::move(mutator.new_variables); res.substitution = std::move(mutator.substitution); res.ranges = std::move(mutator.ranges); return res; } arith::IntConstraintsTransform EliminateDivModFromDomainConditions( const arith::IntConstraints& domain) { auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges); Map new_vranges = elim_res.ranges; Array new_axis = Concat(domain->variables, elim_res.new_variables); PrimExpr new_cond = elim_res.expr && All(elim_res.conditions); arith::IntConstraints new_domain(new_axis, new_vranges, FactorOutAtomicFormulas(new_cond).to_array()); Map src_to_dst; Map dst_to_src = elim_res.substitution; for (const Var& v : domain->variables) { src_to_dst.Set(v, v); dst_to_src.Set(v, v); } return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src); } inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) { Map identity_map; for (const Var& v : domain->variables) { identity_map.Set(v, v); } return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map); } // Simplify an iteration domain. arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains, bool eliminate_div_mod) { arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains); if (eliminate_div_mod) { transf = transf + EliminateDivModFromDomainConditions(transf->dst); } // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably // should find a better terminating criterion (like stop when the domain volume stops decreasing) // Also 2 steps seems to be slightly better than 3 for (size_t i = 0; i < 2; ++i) { transf = transf + arith::SolveLinearEquations(transf->dst); transf = transf + arith::SolveInequalitiesDeskewRange(transf->dst); } return transf; } // Use the condition of a reduction op to simplify its domain (axis) PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map& outer_vranges) { if (const ReduceNode* red = expr.as()) { Array vars = IterVarsToVars(red->axis); Map vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); Array relations = FactorOutAtomicFormulas(red->condition).to_array(); arith::IntConstraints domain(vars, vranges, relations); auto res = SimplifyDomain(domain); Array new_source; for (const PrimExpr& src : red->source) { new_source.push_back(Substitute(src, res->src_to_dst)); } Array new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce); // Perform simplification mainly to remove a possibly empty reduction. arith::Analyzer analyzer; return analyzer.Simplify(Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index, red->init), kSimplifyRewriteCanonicalRewrite); } else { return expr; } } // Extract from cond an implication of cond not containing vars std::pair ImplicationNotContainingVars( const PrimExpr& cond, const std::unordered_set& vars) { ICHECK(cond.dtype().is_bool()) << "The type of cond must be bool"; // TODO(sgrechanik-h): NOTs could be pushed down using De Morgan laws // before running this function but this case didn't seem to be important enough. if (const AndNode* op = cond.as()) { auto pair_a = ImplicationNotContainingVars(op->a, vars); auto pair_b = ImplicationNotContainingVars(op->b, vars); return {pair_a.first && pair_b.first, pair_a.second && pair_b.second}; } else if (const OrNode* op = cond.as()) { auto pair_a = ImplicationNotContainingVars(op->a, vars); auto pair_b = ImplicationNotContainingVars(op->b, vars); return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) && (pair_b.first || pair_a.second) && (pair_a.second || pair_b.second)}; } else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { return {cond, const_true()}; } else { return {const_true(), cond}; } } // Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out // (in)equalities which do not depend on the reduction variables. std::pair LiftConditionsThroughReduction(const PrimExpr& cond, const Array& red_axis, const Array& outer_axis) { // Factor out atomics so that we can consider this as a system of inequalities auto factor_atomic_res = FactorOutAtomicFormulas(cond); Array atomics = factor_atomic_res.atomic_formulas; const PrimExpr& rest = factor_atomic_res.rest; Array allvars; for (const IterVar& v : red_axis) { allvars.push_back(v->var); } for (const IterVar& v : outer_axis) { allvars.push_back(v->var); } auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis)); // start from reduction vars, so that input vars don't depend on them arith::IntConstraints ineq_to_solve(allvars, vranges, atomics); auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve); atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second); // Append the rest part PrimExpr rewritten_cond = All(atomics) && rest; std::unordered_set vset; for (const IterVar& v : red_axis) { vset.insert(v->var.get()); } // The outer (first) condition does not contain reduction vars, // the inner (second) condition is everything else auto res = ImplicationNotContainingVars(rewritten_cond, vset); return res; } // Convert an array of itervars to an array of inequalities Array IterVarsToInequalities(const Array& itervars) { Array res; for (const IterVar& v : itervars) { res.push_back(GE(v->var, v->dom->min)); res.push_back(LT(v->var, v->dom->min + v->dom->extent)); } return res; } class RemoveRedundantInequalitiesMutator : public ExprMutator { public: explicit RemoveRedundantInequalitiesMutator(Array known) { for (const PrimExpr& cond : known) { known_.push_back(analyzer_.Simplify(cond, kSimplifyRewriteCanonicalRewrite)); } } virtual PrimExpr VisitExpr_(const SelectNode* op) { bool has_side_effect = (SideEffect(GetRef(op)) > CallEffectKind::kReadState); PrimExpr new_cond = analyzer_.Simplify(VisitExpr(op->condition), kSimplifyRewriteCanonicalRewrite); if (is_one(new_cond) && !has_side_effect) { return VisitExpr(op->true_value); } else if (is_zero(new_cond) && !has_side_effect) { return VisitExpr(op->false_value); } else { Array new_known = known_; for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { new_known.push_back(atomic); } RemoveRedundantInequalitiesMutator new_mutator(new_known); // Note that we mutate only the true value with the new mutator // TODO(sgrechanik-h): Update known conditions for the false value as well return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value)); } } virtual PrimExpr VisitExpr_(const CallNode* op) { if (op->op.same_as(op_if_then_else_)) { PrimExpr new_cond = analyzer_.Simplify(VisitExpr(op->args[0]), kSimplifyRewriteCanonicalRewrite); if (is_one(new_cond)) { return VisitExpr(op->args[1]); } else if (is_zero(new_cond)) { return VisitExpr(op->args[2]); } else { Array new_known = known_; for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { new_known.push_back(atomic); } RemoveRedundantInequalitiesMutator new_mutator(new_known); // Note that we mutate only the true value with the new mutator // TODO(sgrechanik-h): Update known conditions for the false value as well return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2])); } } else { return ExprMutator::VisitExpr_(op); } } virtual PrimExpr VisitExpr_(const ReduceNode* op) { Array known_with_axes = known_; ICHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented"; for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) { known_with_axes.push_back(axis_cond); } RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes); PrimExpr new_cond = mutator_with_axes(op->condition); Array new_known = known_with_axes; for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) { new_known.push_back(atomic); } RemoveRedundantInequalitiesMutator new_mutator(new_known); Array new_source; for (const PrimExpr& src : op->source) { new_source.push_back(new_mutator(src)); } return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index, op->init); } virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef(op)); } virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef(op)); } virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef(op)); } virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef(op)); } virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef(op)); } virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef(op)); } virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); } private: PrimExpr MutateAtomic_(const PrimExpr& e) { PrimExpr simplified = analyzer_.Simplify(e, kSimplifyRewriteCanonicalRewrite); for (const PrimExpr& other : known_) { if (ExprDeepEqual()(simplified, other)) { return const_true(); } } return simplified; } Array known_; arith::Analyzer analyzer_; const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); }; // Propagate information from conditions and remove redundant inequalities inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array& known) { return RemoveRedundantInequalitiesMutator(known)(expr); } // Extract the given expr under the given condition as a separate tensor if the volume of the // extracted tensor will be less than the volume of the outer_axis PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond, const Array& outer_axis, const Map& vranges) { // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j) arith::IntConstraints domain_to_solve(outer_axis, vranges, FactorOutAtomicFormulas(cond).to_array()); auto res = SimplifyDomain(domain_to_solve); arith::Analyzer analyzer; analyzer.Bind(res->dst->ranges); PrimExpr new_expr = analyzer.Simplify(Substitute(expr, res->src_to_dst), kSimplifyRewriteCanonicalRewrite); // TODO(yzhliu): This is mostly done to simplify if_then_else // which is not realized by the canonical simplifier new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations); // Keep only those variables of the new vars which are used in the new_expr Array used_res_variables; for (const Var& var : res->dst->variables) { if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) { ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred."; used_res_variables.push_back(var); } } // If the expression does not use vars then it is probably better to keep it inlined if (used_res_variables.empty()) { // We can return the new_expr here instead of the old expr because it doesn't use variables // otherwise we would need to replace the new vars or create a let-expression return new_expr; } // If it's already tensor[...] then it will probably be useless to further simplify it. if (new_expr.as()) { return expr; } // Compute volumes before and after PrimExpr old_volume = make_const(DataType::Int(64), 1); for (const Var& var : outer_axis) { ICHECK(vranges.count(var)) << "Range of " << var << " was not provided."; old_volume = old_volume * vranges[var]->extent; } PrimExpr new_volume = make_const(DataType::Int(64), 1); for (const Var& var : used_res_variables) { new_volume = new_volume * res->dst->ranges[var]->extent; } // if we can prove that the old volume is not greater than the new volume then // prefer the old expression. arith::Analyzer ana_vranges; ana_vranges.Bind(vranges); if (ana_vranges.CanProve(old_volume <= new_volume)) { return expr; } Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges), "extracted_tensor"); Array args; for (const Var& var : used_res_variables) { args.push_back(res->dst_to_src[var]); } return ProducerLoad(tensor, args); } class ReductionAsTensorAccessMutator : public ExprMutator { public: explicit ReductionAsTensorAccessMutator(const Array& outer_axis, Map vranges, std::string name = "extracted_reduction") : outer_axis_(outer_axis), vranges_(std::move(vranges)), name_(std::move(name)) {} PrimExpr VisitExpr_(const ReduceNode* op) final { ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_), Merge(vranges_, IterVarsToMap(op->axis)), name_); ICHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented"; Array new_source; for (const PrimExpr& src : op->source) { new_source.push_back(new_mutator(src)); } PrimExpr new_reduce = Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index, op->init); Array undefined_vars = UndefinedVars(new_reduce); std::unordered_set undefined_var_set; for (const Var& var : undefined_vars) { undefined_var_set.insert(var.get()); } // Vars of the tensor we are going to create for this reduction Array vars; for (const Var& v : outer_axis_) { // We take variables from the outer_axis_ which are also present in the new reduction if (undefined_var_set.count(v.get())) { vars.push_back(v); } } auto new_axis_vmap_pair = CloneIterVars(IterVarsFromMap(vars, vranges_)); Array new_axis = new_axis_vmap_pair.first; arith::Analyzer analyzer; analyzer.Bind(IterVarsToMap(new_axis)); new_reduce = analyzer.Simplify(Substitute(new_reduce, new_axis_vmap_pair.second), kSimplifyRewriteCanonicalRewrite); Tensor tensor = TensorFromExpr(new_reduce, new_axis, name_, tag_, attrs_); Array args; for (const Var& v : vars) { args.push_back(v); } return ProducerLoad(tensor, args); } private: Array outer_axis_; Map vranges_; std::string name_; std::string tag_; Map attrs_; }; // Extract reductions as separate tensors. inline PrimExpr ReductionAsTensorAccess(const PrimExpr& expr, const Array& outer_axis, const Map& vranges) { return ReductionAsTensorAccessMutator(outer_axis, vranges)(expr); } PrimExpr LiftReductions(const PrimExpr& expr, const Array& outer_axis, const Map& vranges) { if (const ReduceNode* red = expr.as()) { Array new_outer_axis = Concat(IterVarsToVars(red->axis), outer_axis); Map new_vranges = Merge(vranges, IterVarsToMap(red->axis)); Array new_source; for (const PrimExpr& src : red->source) { new_source.push_back(ReductionAsTensorAccess(src, new_outer_axis, new_vranges)); } PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges); return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index, red->init); } else { return ReductionAsTensorAccess(expr, outer_axis, vranges); } } PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const Array& axis, const Map& vranges) { PrimExpr result; Map combined_vranges = Merge(vranges, IterVarsToMap(axis)); arith::Analyzer analyzer; analyzer.Bind(combined_vranges); // Simplify the original expression first, mostly to simplify combiners PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite); if (const ReduceNode* red = expr.as()) { ICHECK(red->init.empty()) << "Derivative of Reduction with initialization is not implemented"; // TODO(sgrechanik-h): There are some other operations which behave like sum bool is_sum = IsSumCombiner(red->combiner, vranges); if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) { PrimExpr new_red = expr; // Here we simplify the reduction PrimExpr cond = red->condition; Array source = red->source; // If it is a summation then we can lift nonzeroness conditions from the source // and add them to the reduction conditions if (is_sum) { auto nz = NonzeronessCondition(red->source[red->value_index]); cond = nz.cond && cond; source.Set(0, nz.value); } new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init); new_red = SimplifyReductionDomain(new_red, combined_vranges); // Update original red pointer for later use. red = new_red.as(); // If the reduction disappears completely then transform the result as a non-reduction if (!red) { return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges); } PrimExpr new_outer_cond, new_reduce_cond; Array new_source = red->source; // Partially lift conditions from the reduce condition std::tie(new_outer_cond, new_reduce_cond) = LiftConditionsThroughReduction(red->condition, red->axis, axis); // If it's not sum then we haven't yet lifted nonzeroness cond from the source if (!is_sum) { PrimExpr outer_nz_cond, nz_cond, nz_source; auto nz = NonzeronessCondition(red->source[red->value_index]); // Append conditions from the reduction nz_cond = new_reduce_cond && nz.cond; nz_source = nz.value; std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis); new_outer_cond = new_outer_cond && outer_nz_cond; new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype()))); } PrimExpr new_reduce = Reduce(red->combiner, new_source, red->axis, new_reduce_cond, red->value_index, red->init); new_reduce = TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges); result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype())); } else { return SimplifyReductionDomain(expr, combined_vranges); } } else { auto nz = NonzeronessCondition(expr); PrimExpr new_expr = TrySimplifyCompute(nz.value, nz.cond, IterVarsToVars(axis), combined_vranges); result = Select(nz.cond, new_expr, make_zero(new_expr.dtype())); } // Note that RemoveRedundantInequalities can sometimes propagate equalities which // other simplifiers cannot, like (i % 3) == 0. Array axis_conds = IterVarsToInequalities(axis); result = RemoveRedundantInequalities(result, axis_conds); // Currently in TVM reductions are only allowed at the top level of compute, // we need to extract intermediate inlined reduction as a separate stage (tensor). // Sometimes TrySimplifyCompute doesn't perform lift / extraction, // so there may be some non-top reductions left, take care of them. result = LiftReductions(result, IterVarsToVars(axis), combined_vranges); return analyzer.Simplify(result, kSimplifyRewriteCanonicalRewrite); } Tensor RemoveJacobianAndLiftNonzeroCond(const Tensor& tensor, const Map& vranges) { auto transform_func = [&vranges](const PrimExpr& expr, const Array& axis) { return RemoveJacobianAndLiftNonzeroCondImpl(expr, axis, vranges); }; return TransformTensorBody(tensor, transform_func); } } // namespace te } // namespace tvm