/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file expr_functor.cc */ #include #include "functor_common.h" namespace tvm { namespace tir { void ExprVisitor::VisitExpr_(const VarNode* op) {} void ExprVisitor::VisitExpr_(const SizeVarNode* op) { this->VisitExpr_(static_cast(op)); } void ExprVisitor::VisitExpr_(const AnyNode* op) {} void ExprVisitor::VisitExpr_(const LoadNode* op) { this->VisitExpr(op->index); this->VisitExpr(op->predicate); } void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } void ExprVisitor::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->body); } void ExprVisitor::VisitExpr_(const CallNode* op) { VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); } #define DEFINE_BINOP_VISIT_(OP) \ void ExprVisitor::VisitExpr_(const OP* op) { \ this->VisitExpr(op->a); \ this->VisitExpr(op->b); \ } DEFINE_BINOP_VISIT_(AddNode); DEFINE_BINOP_VISIT_(SubNode); DEFINE_BINOP_VISIT_(MulNode); DEFINE_BINOP_VISIT_(DivNode); DEFINE_BINOP_VISIT_(ModNode); DEFINE_BINOP_VISIT_(FloorDivNode); DEFINE_BINOP_VISIT_(FloorModNode); DEFINE_BINOP_VISIT_(MinNode); DEFINE_BINOP_VISIT_(MaxNode); DEFINE_BINOP_VISIT_(EQNode); DEFINE_BINOP_VISIT_(NENode); DEFINE_BINOP_VISIT_(LTNode); DEFINE_BINOP_VISIT_(LENode); DEFINE_BINOP_VISIT_(GTNode); DEFINE_BINOP_VISIT_(GENode); DEFINE_BINOP_VISIT_(AndNode); DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} void ExprVisitor::VisitExpr_(const ReduceNode* op) { VisitArray(op->axis, [this](const IterVar& r) { this->VisitExpr(r->dom->min); this->VisitExpr(r->dom->extent); }); VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); if (!op->init.empty()) { VisitArray(op->init, [this](const PrimExpr& e) { this->VisitExpr(e); }); } this->VisitExpr(op->condition); } void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } void ExprVisitor::VisitExpr_(const SelectNode* op) { this->VisitExpr(op->condition); this->VisitExpr(op->true_value); this->VisitExpr(op->false_value); } void ExprVisitor::VisitExpr_(const RampNode* op) { this->VisitExpr(op->base); this->VisitExpr(op->stride); } void ExprVisitor::VisitExpr_(const ShuffleNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); } void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); } PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { PrimExpr index = this->VisitExpr(op->index); PrimExpr predicate = this->VisitExpr(op->predicate); if (index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { return Load(op->dtype, op->buffer_var, index, predicate); } } PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array indices = MutateArray(op->indices, fmutate); if (indices.same_as(op->indices)) { return GetRef(op); } else { return BufferLoad(op->buffer, indices); } } PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array indices = MutateArray(op->indices, fmutate); if (indices.same_as(op->indices)) { return GetRef(op); } else { return ProducerLoad(op->producer, indices); } } PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(op->var, value, body); } } PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array args = MutateArray(op->args, fmutate); if (args.same_as(op->args)) { return GetRef(op); } else { return Call(op->dtype, op->op, args); } } #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef(op); } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \ PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (a.same_as(op->a) && b.same_as(op->b)) { \ return GetRef(op); \ } else { \ return OP(a, b); \ } \ } DEFINE_BIOP_EXPR_MUTATE_(Add); DEFINE_BIOP_EXPR_MUTATE_(Sub); DEFINE_BIOP_EXPR_MUTATE_(Mul); DEFINE_BIOP_EXPR_MUTATE_(Div); DEFINE_BIOP_EXPR_MUTATE_(Mod); DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); DEFINE_BIOP_EXPR_MUTATE_(FloorMod); DEFINE_BIOP_EXPR_MUTATE_(Min); DEFINE_BIOP_EXPR_MUTATE_(Max); DEFINE_BIOP_EXPR_MUTATE_(EQ); DEFINE_BIOP_EXPR_MUTATE_(NE); DEFINE_BIOP_EXPR_MUTATE_(LT); DEFINE_BIOP_EXPR_MUTATE_(LE); DEFINE_BIOP_EXPR_MUTATE_(GT); DEFINE_BIOP_EXPR_MUTATE_(GE); DEFINE_BIOP_EXPR_MUTATE_(And); DEFINE_BIOP_EXPR_MUTATE_(Or); PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { auto fitervar = [this](const IterVar& v) { Range r = v->dom; PrimExpr min = this->VisitExpr(r->min); PrimExpr extent = this->VisitExpr(r->extent); if (min.same_as(r->min) && extent.same_as(r->extent)) { return v; } else { return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag); } }; Array axis = MutateArray(op->axis, fitervar); auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array source = MutateArray(op->source, fexpr); Array init = MutateArray(op->init, fexpr); PrimExpr condition = this->VisitExpr(op->condition); if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) && init.same_as(op->init)) { return GetRef(op); } else { return Reduce(op->combiner, source, axis, condition, op->value_index, init); } } PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { return Cast(op->dtype, value); } } PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { return GetRef(op); } else { return Not(a); } } PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr true_value = this->VisitExpr(op->true_value); PrimExpr false_value = this->VisitExpr(op->false_value); if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { return Select(condition, true_value, false_value); } } PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { return Ramp(base, stride, op->lanes); } } PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { return Broadcast(value, op->lanes); } } PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; auto vectors = MutateArray(op->vectors, fexpr); if (vectors.same_as(op->vectors)) { return GetRef(op); } else { return Shuffle(vectors, op->indices); } } } // namespace tir } // namespace tvm