/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/arithmetic/pattern_match.h * * \brief Internal tool for expression-template based pattern matching. * * It helps to simplify pattern matching and rewrites. * All the patterns are generated via expression template during compile time, * so the result code should be as efficient as manually written pattern match code. * * The code below shows how to use the pattern matcher. * * \code * * // max(x + z, y + z) => max(x, y) + z * arith::PVar x, y, z; * * // The following code tries to match the declared pattern. * // Match will fill the result of match into PVar if successful. * // Note that z occurs twice in the pattern, * // an equality check is performed to ensure each occurance of z * // is equivalent to each other. * if (max(x + z, y + z).Match(expr)) { * // Eval evaluates a pattern with the current matched value. * // The filled value is valid until the next call to Match. * return (max(x, y) + z).Eval(); * } * * tvm::tir::Var tx, ty; * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are * // special case container of Expr * ICHECK((v * c).Match(tx * 3)); * ICHECK_EQ(c.Eval()->value, 3); * // cannot match c to ty * ICHECK(!(v * c).Match(tx * ty)); * * \endcode * * \note The pattern matcher is not threadsafe, * do not use the same PVar in multiple threads. * * Please be aware that the filled value in a PVar * can be overriden in the next call to Match. */ #ifndef TVM_ARITH_PATTERN_MATCH_H_ #define TVM_ARITH_PATTERN_MATCH_H_ #include #include #include #include #include #include "const_fold.h" namespace tvm { namespace arith { /*! * \brief Base class of all the patterns. * * There are two major member functions supported by each pattern. * - Match: checks if value matches the pattern. * - Eval: construct a new value based on matched values in PVar. * * We use curiously recurring template pattern to construct * expression templates. * * \tparam Derived The type of the derived class. */ template class Pattern { public: /*! * \brief Nested storage type in the expression. * * Depending on the Derived class, * Nested can be Derived (nest by value) or * const Derived& (nest by reference). * * The trick of Nested typedef originates from Eigen. * * \note We use nest by value for intermediate expressions, * and nest by reference for PVars. */ using Nested = Derived; /*! * \brief Check if value matches the current pattern. * * This call also populates the PVars with matched value. * The values in PVars are valid until the next call to Match. * * \return whether value matches the pattern. */ template bool Match(const NodeType& value) const { derived().InitMatch_(); return derived().Match_(value); } /*! \return Derived instance of current class. */ const Derived& derived() const { return *static_cast(this); } }; /*! * \brief Default deep equality checker * \tparam T the comparison point. */ template class PEqualChecker { public: bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } }; template <> class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { if (lhs.same_as(rhs)) return true; return tir::ExprDeepEqual()(lhs, rhs); } }; template <> class PEqualChecker { public: bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; template <> class PEqualChecker { public: bool operator()(const FloatImm& lhs, const FloatImm& rhs) const { return std::fabs(lhs->value - rhs->value) < 1e-20; } }; template <> class PEqualChecker { public: bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } }; /*! * \brief Pattern variable container. * * PVar is used as a "hole" in the pattern that can be matched. * * \tparam T the type of the hole. * * \note PVar is not thread safe. * Do not use the same PVar in multiple threads. */ template class PVar : public Pattern> { public: // Store PVars by reference in the expression. using Nested = const PVar&; void InitMatch_() const { filled_ = false; } bool Match_(const T& value) const { if (!filled_) { value_ = value; filled_ = true; return true; } else { return PEqualChecker()(value_, value); } } template ::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { return Match_(GetRef(ptr)); } else { return false; } } T Eval() const { ICHECK(filled_); return value_; } protected: /*! \brief The matched value */ mutable T value_; /*! \brief whether the variable has been filled */ mutable bool filled_{false}; }; /*! * \brief Wrapper for pattern variable container with extra match logic. * * \tparam Derived the type of derived class. * \tparam T the type of the hole. */ template class PVarWithCheck : public arith::Pattern> { public: // Store by reference in the expression. using Nested = const PVarWithCheck&; void InitMatch_() const { pvar_.InitMatch_(); } bool Match_(const T& value) const { if (!static_cast(this)->Match_(value)) return false; return pvar_.Match_(value); } template ::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { return Match_(GetRef(ptr)); } else { return false; } } T Eval() const { return pvar_.Eval(); } protected: arith::PVar pvar_; }; /*! * \brief Pattern variable container with expr type check. * * \tparam T the type of the hole. * \tparam DType the Pattern type of dtype. */ template ::value>> class PVarWithDataType : public PVarWithCheck, T> { public: explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {} bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } protected: typename DType::Nested dtype_; }; /*! * \brief Pattern variable container for data type with lanes. */ class PVecDataType : public PVarWithCheck { public: /*! \brief construct vector dtype placeholder with element type check */ explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } protected: DataType elem_dtype_; }; /*! * \brief Constant Pattern variable container. * * \tparam T the type of the hole. */ template class PConst : public Pattern> { public: PConst(T value) // NOLINT(*) : value_(value) {} void InitMatch_() const {} bool Match_(const T& value) const { return PEqualChecker()(value_, value); } T Eval() const { return value_; } private: const T value_; }; /*! * \brief Pattern binary expression. * \tparam OpType The AST noderef type. * \tparam TA The pattern type of the first operand. * \tparam TB The pattern type of the second operand. */ template class PBinaryExpr : public Pattern> { public: PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} void InitMatch_() const { a_.InitMatch_(); b_.InitMatch_(); } bool Match_(const ObjectRef& node) const { using NodeType = typename OpType::ContainerType; if (const NodeType* ptr = node.as()) { if (!a_.Match_(ptr->a)) return false; if (!b_.Match_(ptr->b)) return false; return true; } else { return false; } } PrimExpr Eval() const { PrimExpr lhs = a_.Eval(); PrimExpr rhs = b_.Eval(); PrimExpr ret = TryConstFold(lhs, rhs); if (ret.defined()) return ret; return OpType(lhs, rhs); } private: typename TA::Nested a_; typename TB::Nested b_; }; template class PConstWithTypeLike : public Pattern> { public: PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {} void InitMatch_() const {} bool Match_(const ObjectRef& node) const { if (const tir::IntImmNode* ptr = node.as()) { return ptr->value == value_; } else { return false; } } PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); } private: typename TA::Nested ref_; int64_t value_; }; #define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ template \ inline PBinaryExpr FuncName(const Pattern& a, const Pattern& b) { \ CheckStep; \ return PBinaryExpr(a.derived(), b.derived()); \ } \ template \ inline PBinaryExpr> FuncName(const Pattern& a, \ int64_t b) { \ CheckStep; \ return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ } \ template \ inline PBinaryExpr, TA> FuncName(int64_t b, \ const Pattern& a) { \ CheckStep; \ return FuncName(PConstWithTypeLike(a.derived(), b), a); \ } #define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) // raise ambiguity error for operator overload of / and % TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a)); TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a)); // arithmetic expressions TVM_PATTERN_BINARY_OP(operator+, tir::Add); TVM_PATTERN_BINARY_OP(operator-, tir::Sub); TVM_PATTERN_BINARY_OP(operator*, tir::Mul); TVM_PATTERN_BINARY_OP(min, tir::Min); TVM_PATTERN_BINARY_OP(max, tir::Max); TVM_PATTERN_BINARY_OP(div, tir::Div); TVM_PATTERN_BINARY_OP(truncdiv, tir::Div); TVM_PATTERN_BINARY_OP(truncmod, tir::Mod); TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv); TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod); // logical expressions TVM_PATTERN_BINARY_OP(operator>, tir::GT); TVM_PATTERN_BINARY_OP(operator>=, tir::GE); TVM_PATTERN_BINARY_OP(operator<, tir::LT); TVM_PATTERN_BINARY_OP(operator<=, tir::LE); TVM_PATTERN_BINARY_OP(operator==, tir::EQ); TVM_PATTERN_BINARY_OP(operator!=, tir::NE); TVM_PATTERN_BINARY_OP(operator&&, tir::And); TVM_PATTERN_BINARY_OP(operator||, tir::Or); /*! * \brief Pattern not expression. * \tparam TA The pattern type of the true operand. */ template class PNotExpr : public Pattern> { public: explicit PNotExpr(const TA& value) : value_(value) {} void InitMatch_() const { value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::NotNode* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; } else { return false; } } PrimExpr Eval() const { return tir::Not(value_.Eval()); } private: typename TA::Nested value_; }; template inline PNotExpr operator!(const Pattern& value) { return PNotExpr(value.derived()); } // select /*! * \brief Pattern select expression. * \tparam TCond The pattern type of the condition. * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ template class PSelectExpr : public Pattern> { public: PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value) : condition_(condition), true_value_(true_value), false_value_(false_value) {} void InitMatch_() const { condition_.InitMatch_(); true_value_.InitMatch_(); false_value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::SelectNode* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; if (!false_value_.Match_(ptr->false_value)) return false; return true; } else { return false; } } PrimExpr Eval() const { return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } private: typename TCond::Nested condition_; typename TA::Nested true_value_; typename TB::Nested false_value_; }; /*! * \brief Construct a select pattern. * * \param condition The condition expression. * \param true_value The value when condition is true. * \param true_value The value when condition is false. * * \return The result pattern. * * \tparam TCond The pattern type of the condition. * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ template inline PSelectExpr select(const Pattern& condition, const Pattern& true_value, const Pattern& false_value) { return PSelectExpr(condition.derived(), true_value.derived(), false_value.derived()); } /*! * \brief Pattern cast expression. * \tparam DType The Pattern type of dtype. * \tparam TA The pattern type of the first operand. */ template class PCastExpr : public Pattern> { public: PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {} void InitMatch_() const { dtype_.InitMatch_(); value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::CastNode* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; } else { return false; } } PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); } private: typename DType::Nested dtype_; typename TA::Nested value_; }; /*! * \brief Construct a cast pattern. * * \param dtype The target data type, can be PVar or PConst. * \param value The input type. * * \return The result pattern. * * \tparam DType The pattern type of type. * \tparam TA The pattern type of value. */ template inline PCastExpr cast(const Pattern& dtype, const Pattern& value) { return PCastExpr(dtype.derived(), value.derived()); } /*! * \brief Pattern ramp expression. * \tparam TBase The pattern type of the base. * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ template class PRampExpr : public Pattern> { public: PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes) : base_(base), stride_(stride), lanes_(lanes) {} void InitMatch_() const { base_.InitMatch_(); stride_.InitMatch_(); lanes_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::RampNode* ptr = node.as()) { if (!base_.Match_(ptr->base)) return false; if (!stride_.Match_(ptr->stride)) return false; if (!lanes_.Match_(ptr->lanes)) return false; return true; } else { return false; } } PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: typename TBase::Nested base_; typename TStride::Nested stride_; typename TLanes::Nested lanes_; }; /*! * \brief Construct a ramp pattern. * * \param base The base pattern. * \param stride The stride pattern. * \param lanes The lanes pattern. * * \return The result pattern. * * \tparam TBase The pattern type of the base. * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ template inline PRampExpr ramp(const Pattern& base, const Pattern& stride, const Pattern& lanes) { return PRampExpr(base.derived(), stride.derived(), lanes.derived()); } template inline PRampExpr, PConst> ramp(const Pattern& base, int stride, int lanes) { return PRampExpr, PConst>( base.derived(), PConstWithTypeLike(base.derived(), stride), PConst(lanes)); } /*! * \brief Pattern broadcast expression. * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ template class PBroadcastExpr : public Pattern> { public: PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {} void InitMatch_() const { value_.InitMatch_(); lanes_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::BroadcastNode* ptr = node.as()) { if (!value_.Match_(ptr->value)) return false; if (!lanes_.Match_(ptr->lanes)) return false; return true; } else { return false; } } PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); } private: typename TA::Nested value_; typename TLanes::Nested lanes_; }; /*! * \brief Construct a broadcast pattern. * * \param value The value pattern. * \param lanes The lanes pattern. * * \return The result pattern. * * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ template inline PBroadcastExpr broadcast(const Pattern& value, const Pattern& lanes) { return PBroadcastExpr(value.derived(), lanes.derived()); } // internal namespace namespace detail { // implementation details for CallExpr template struct tuple_for_each_dispatcher { template static void run(F& f, const TTuple& tuple) { // NOLINT(*) f(I, std::get(tuple)); tuple_for_each_dispatcher<(I + 1) == std::tuple_size::value, (I + 1), F>::run(f, tuple); } }; template struct tuple_for_each_dispatcher { template static void run(F& f, const TTuple& tuple) {} // NOLINT(*) }; template inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*) tuple_for_each_dispatcher::value == 0, 0, F>::run(f, tuple); } struct PCallExprInitMatchFunctor { template void operator()(size_t i, const T& pattern) const { pattern.InitMatch_(); } }; struct PCallExprMatchFunctor { const tir::CallNode* call_; bool matched_{true}; explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} template void operator()(size_t i, const T& pattern) { matched_ = matched_ && pattern.Match_(call_->args[i]); } }; struct PCallExprEvalArgsFunctor { Array args_; template void operator()(size_t i, const T& pattern) { args_.push_back(pattern.Eval()); } }; } // namespace detail /*! * \brief Pattern CallExpr expression. * \tparam Op The operator functor class. * \tparam TArgs The arguments. * \note Op functor contains the name of the function and * the implementation of Eval. */ template class PCallExpr : public Pattern> { public: explicit PCallExpr(const TArgs&... args) : args_(args...) {} void InitMatch_() const { detail::PCallExprInitMatchFunctor finit; detail::tuple_for_each(finit, args_); } bool Match_(const ObjectRef& node) const { if (const tir::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (!ptr->op.same_as(Op::GetOp())) return false; detail::PCallExprMatchFunctor fmatch(ptr); detail::tuple_for_each(fmatch, args_); return fmatch.matched_; } else { return false; } } PrimExpr Eval() const { detail::PCallExprEvalArgsFunctor feval_args; detail::tuple_for_each(feval_args, args_); return Op::Eval(feval_args.args_); } private: std::tuple args_; }; // arithemetic intrinsics #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(Array args) { \ return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ }; \ template \ inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ return PCallExpr(a.derived(), b.derived()); \ } TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left); TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right); TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and); TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(Array args) { \ return tir::Call(args[0].dtype(), GetOp(), args); \ } \ static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ }; \ template \ inline PCallExpr FuncName(const Pattern& a) { \ return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { return tir::Call(args[1].dtype(), GetOp(), args); } static const Op& GetOp() { return tir::builtin::if_then_else(); } }; /*! * \brief Construct a if_then_else pattern. * * \param cond The condition expression. * \param true_value The value when condition is true. * \param true_value The value when condition is false. * * \return The result pattern. * * \tparam TCond The pattern type of the condition. * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ template inline PCallExpr if_then_else(const Pattern& cond, const Pattern& true_value, const Pattern& false_value) { return PCallExpr(cond.derived(), true_value.derived(), false_value.derived()); } } // namespace arith } // namespace tvm #endif // TVM_ARITH_PATTERN_MATCH_H_