/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "../src/arith/pattern_match.h" #include #include TEST(Pattern, Basic) { using namespace tvm; using namespace tvm::tir; using namespace tvm::arith; tvm::tir::Var x("x"), y("y"), z("z"); arith::PVar px, py, pz; arith::PVar pt; arith::PVar planes; // arithmetics auto r = 1 + (y + 1); ICHECK(!(px + (px + px)).Match(r)); ICHECK(!(px + (py + py)).Match(r)); ICHECK((px + (py + pz)).Match(r)); auto pattern = px + (py + pz); ICHECK(pattern.Match(r)); { ICHECK((px + (py + px)).Match(r)); auto rr = (px + py).Eval(); ICHECK(tir::ExprDeepEqual()(rr, 1 + y)); ICHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); } { ICHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } ICHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); ICHECK((px + min(py, px)).Match(z + min(y, z))); ICHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); ICHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); ICHECK((px - floormod(py, px * PConst(2))).Match(x - floormod(2, x * 2))); // logicals ICHECK((px == pz).Match(x == 1)); ICHECK((px != pz).Match(x != 1)); ICHECK((px > py).Match(x > y)); ICHECK((px < py).Match(x < y)); ICHECK((px <= py).Match(x <= y)); ICHECK((px >= py).Match(x >= y)); ICHECK((px >= py && px < pz).Match(x >= y && x < z)); ICHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { ICHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1))); ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics { ICHECK((px >> pz).Match(x >> 1)); ICHECK(is_const_int(pz.Eval(), 1)); } ICHECK(!(px >> pz).Match(x << 1)); ICHECK((px << pz).Match(x << 1)); ICHECK((px & pz).Match(x & 1)); ICHECK((px | pz).Match(x | 1)); ICHECK((px ^ pz).Match(x ^ 1)); ICHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { ICHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1))); ICHECK(is_const_int(pz.Eval(), 1)); } ICHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1))); ICHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1))); { ICHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1))); ICHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { ICHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); ICHECK(is_const_int(pz.Eval(), 1)); } // cast pattern { ICHECK(!cast(PConst(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x))); ICHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x))); ICHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); ICHECK((cast(pt, px) - cast(pt, py)) .Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x))); auto expr = tir::Cast(DataType::Int(32), tir::Cast(DataType::Float(64), x)); ICHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); ICHECK(planes.Eval() == 10); ICHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); } // broadcast pattern { ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); ICHECK(planes.Eval() == 10); ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); } } TEST(Pattern, IntImm) { using namespace tvm; tir::Var tx, ty; arith::PVar c; arith::PVar v; { // We can match integer and Var, both of which are // special case container of Expr ICHECK((v * c).Match(tx * 3)); ICHECK_EQ(c.Eval()->value, 3); ICHECK((v * 3).Match(tx * 3)); } // cannot match c to ty ICHECK(!(v * c).Match(tx * ty)); // cannot match tx + 1 to v ICHECK(!(v * c).Match((tx + 1) * 3)); } TEST(Pattern, MatchWithType) { using namespace tvm; // match expr with specified dtype arith::PVarWithDataType> pat(DataType::Float(32)); tir::Var x("x", DataType::Float(32)); tir::Var y("y", DataType::Float(32)); tir::Var x_int("x", DataType::Int(32)); tir::Var y_int("y", DataType::Int(32)); ICHECK(pat.Match(x + y * 2.0f)); ICHECK(!pat.Match(x_int + y_int * 2)); // match vectorized expr with specified element dtype arith::PVecDataType vec_ty(DataType::Float(32)); arith::PVarWithDataType vpat(vec_ty); tir::Var vx = tir::Var("x", DataType::Float(32, 8)); tir::Var vy("y", DataType::Float(32, 8)); tir::Var vx_int("x", DataType::Int(32, 8)); tir::Var vy_int("y", DataType::Int(32, 8)); ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8))); ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8))); }