# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm from tvm import te from tvm import relay import numpy as np import pytest from tvm.testing import enabled_targets var_list = [] def verify_structure(stmt, expected_struct): node_dict = {} struct = {} def _extract_vars(op): global var_list if isinstance(op, tvm.tir.Var): var_list.append(op.name) def _visit(op): key = op if isinstance(op, tvm.tir.IfThenElse): global var_list tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))] var_list.clear() elif isinstance(op, tvm.tir.For): val = [(op.body,), ("tir.For", op.loop_var.name)] elif isinstance(op, tvm.tir.AttrStmt): val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))] else: return node_dict[key] = val tvm.tir.stmt_functor.post_order_visit(stmt, _visit) for key, val in node_dict.items(): struct[val[1]] = tuple( node_dict[child][1] if child in node_dict else None for child in val[0] ) assert struct == expected_struct, "Structure mismatch: expect %s but got %s" % ( expected_struct, struct, ) var_list.clear() def test_hoist_top_for(): ib = tvm.tir.ir_builder.create() l = te.var("l") m = te.var("m") n = te.var("n") data = ib.pointer("float32", name="data") with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.For", "j"): (("tir.For", "k"),), ("tir.IfThenElse", ("i",)): (("tir.For", "j"), ("tir.For", "j")), ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) def test_hoist_multi_var_if(): ib = tvm.tir.ir_builder.create() l = te.var("l") m = te.var("m") n = te.var("n") data = ib.pointer("float32", name="data") with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i + j < 2)): ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), ("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),), ("tir.For", "i"): (("tir.For", "j"),), } verify_structure(new_stmt, expected_struct) def test_hoist_no_match_for(): ib = tvm.tir.ir_builder.create() l = te.var("l") m = te.var("m") n = te.var("n") data = ib.pointer("float32", name="data") with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: data[i * 3 + j] = data[i * 3 + j] + 0.5 with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.IfThenElse", ("i",)): (("tir.For", "k"), ("tir.For", "k")), ("tir.For", "j"): (None,), ("tir.For", "i"): (("tir.For", "j"),), } verify_structure(new_stmt, expected_struct) def test_no_else(): ib = tvm.tir.ir_builder.create() l = te.var("l") m = te.var("m") n = te.var("n") with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): ib.emit(tvm.tir.Evaluate(m)) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.For", "j"): (("tir.For", "k"),), ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) def test_attr_stmt(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), ("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),), ("tir.For", "i"): (("tir.For", "j"),), ("tir.AttrStmt", "thread_extent", 64): (("tir.For", "i"),), ("tir.AttrStmt", "thread_extent", 32): (("tir.AttrStmt", "thread_extent", 64),), } verify_structure(new_stmt, expected_struct) def test_nested_for(): ib = tvm.tir.ir_builder.create() data = ib.pointer("float32", name="data") with ib.for_range(0, 5, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.if_scope(i >= 3): data[i * 3 + j] = data[i * 3 + j] + 0.5 with ib.for_range(0, 15, "k") as k: with ib.for_range(0, 20, "l") as l: with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 with ib.else_scope(): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "l"): (None,), ("tir.For", "k"): (("tir.For", "l"),), ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), ("tir.For", "j"): (None,), ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) def test_if_block(): ib = tvm.tir.ir_builder.create() data = ib.pointer("float32", name="data") n = te.var("n") with ib.for_range(0, 5, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.if_scope(i >= 3): data[i * 3 + j] = data[i * 3 + j] + 0.5 with ib.for_range(0, 15, "k") as k: with ib.for_range(0, 20, "l") as l: with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 with ib.else_scope(): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 with ib.if_scope(j < 5): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 with ib.for_range(0, 5, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 15, "k") as k: with ib.if_scope(n >= 3): data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.IfThenElse", ("i", "j")): (None, None), ("tir.IfThenElse", ("j",)): (None, None), ("tir.For", "l"): (None,), ("tir.For", "k"): (None,), ("tir.For", "j"): (("tir.For", "j"),), ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), ("tir.IfThenElse", ("n",)): (("tir.For", "j"), None), } verify_structure(new_stmt, expected_struct) def test_multi_if(): ib = tvm.tir.ir_builder.create() data = ib.pointer("float32", name="data") with ib.for_range(0, 10, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 10, "k") as k: with ib.if_scope(i >= 3): with ib.if_scope(j >= 3): data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { ("tir.For", "k"): (None,), ("tir.IfThenElse", ("j",)): (("tir.For", "k"), None), ("tir.For", "j"): (("tir.IfThenElse", ("j",)),), ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) def test_no_hoisting_1(): ib = tvm.tir.ir_builder.create() data = ib.pointer("float32", name="data") n = te.var("n") with ib.for_range(0, 10, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 10, "k") as k: with ib.if_scope(k >= 3): data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_no_hoisting_2(): ib = tvm.tir.ir_builder.create() data = ib.pointer("float32", name="data") n = te.var("n") x = te.var("x") with ib.for_range(0, 10, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.for_range(0, 10, "k") as k: with ib.if_scope(i >= 3): data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3 data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_no_hoisting_3(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) dshape_inner = (33, 63) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: ib.scope_attr(tx, "thread_extent", dshape_inner[0]) ib.scope_attr(bx, "thread_extent", dshape_inner[1]) with ib.if_scope(tx < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_no_hoisting_4(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) dshape_inner = (33, 63) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: ib.scope_attr(tx, "thread_extent", dshape_inner[0]) with ib.if_scope(tx < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_no_hoisting_5(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) dshape_inner = (33, 63) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: ib.scope_attr(bx, "thread_extent", dshape_inner[1]) with ib.for_range(0, n, "k") as k: ib.scope_attr(tx, "thread_extent", dshape_inner[0]) with ib.if_scope(tx < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_no_hoisting_6(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope((tx + k) < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_no_hoisting_7(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.if_scope((tx + j) < 9): with ib.for_range(0, n, "k") as k: with ib.if_scope((tx + k) < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_hoisting_block_scope_1(): n = te.size_var("n") m = te.size_var("m") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), "k") B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") s = te.create_schedule(B.op) ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) BF = s.rfactor(B, ki) xo, xi = s[B].split(s[B].op.axis[0], factor=32) s[B.op].bind(xo, te.thread_axis("blockIdx.x")) s[B.op].bind(xi, te.thread_axis("threadIdx.y")) s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) func = tvm.driver.build_module.schedule_to_module(s, [A, B], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) def test_hoisting_block_scope_2(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) dshape_inner = (33, 63) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) # ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.if_scope(tx < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body # tvm.ir.assert_structural_equal(new_stmt, stmt) assert not tvm.ir.structural_equal(new_stmt, stmt) def test_hoisting_block_scope_3(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) dshape_inner = (33, 63) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: ib.scope_attr(tx, "thread_extent", dshape_inner[0]) ib.scope_attr(bx, "thread_extent", dshape_inner[1]) with ib.for_range(0, n, "k") as k: with ib.if_scope(tx < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body # tvm.ir.assert_structural_equal(new_stmt, stmt) assert not tvm.ir.structural_equal(new_stmt, stmt) def test_hoisting_block_scope_4(): nn = 1024 n = tvm.runtime.convert(nn) A = te.placeholder((n,), name="A") B = te.placeholder((n,), name="B") AA = te.compute((n,), lambda *i: A(*i), name="A") BB = te.compute((n,), lambda *i: B(*i), name="B") T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name="T") C = te.compute(A.shape, lambda *i: T(*i), name="C") s = te.create_schedule(C.op) xo, xi = s[C].split(C.op.axis[0], factor=4) xo1, xo2 = s[C].split(xo, factor=13) s[C].parallel(xo2) s[C].pragma(xo1, "parallel_launch_point") s[C].pragma(xo2, "parallel_stride_pattern") s[C].pragma(xo2, "parallel_barrier_when_finish") s[C].vectorize(xi) func = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)["main"] stmt = func.body new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) def test_hoisting_block_scope_5(): ib = tvm.tir.ir_builder.create() data = ib.pointer("float32", name="data", scope="global") l = te.var("l") m = te.var("m") n = te.var("n") g = te.var("g") ib.scope_attr(data, "storage_scope", "global") with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(data[g] < 3): data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3 with ib.else_scope(): data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) stmt = new_stmt mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) def test_hoisting_block_scope_6(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope((tx + n) < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) def test_hoisting_block_scope_7(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) data = ib.pointer("float32", name="data") l = te.var("l") m = te.var("m") n = te.var("n") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", dshape[0]) ib.scope_attr(bx, "thread_extent", dshape[1]) with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope((tx + i) < 3): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 with ib.else_scope(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) with tvm.transform.PassContext( config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) @pytest.mark.skip() def test_hoisting_op_conv(): dtype = "float32" dshape = (1, 80, 73, 73) kshape = (192, 80, 3, 3) padding = (1, 1) groups = 1 dilation = (1, 1) kernel_size = (3, 3) channels = 192 scale = 1 x = relay.var("x", shape=dshape, dtype=dtype) w = relay.var("w", shape=kshape, dtype=dtype) y = relay.nn.conv2d( x, w, padding=padding, dilation=dilation, groups=groups, channels=channels, kernel_size=kernel_size, ) func = relay.Function([x, w], y) mod = tvm.IRModule() mod["main"] = func mod = relay.transform.InferType()(mod) data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) params = {"w": tvm.nd.array(kernel)} for target, dev in enabled_targets(): with tvm.transform.PassContext(opt_level=3): lib = relay.build_module.build(mod, target=target, params=params) m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) x = np.random.uniform(size=dshape) data_tvm = tvm.nd.array(data) m.set_input("x", data_tvm) m.run() e = m.module.time_evaluator("run", dev, number=300, repeat=3) t1 = e(data_tvm).results t1 = np.array(t1) * 1000 print("{} ms".format(t1.mean())) with tvm.transform.PassContext( opt_level=3, config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} ): lib = relay.build_module.build(mod, target=target, params=params) m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) x = np.random.uniform(size=dshape) data_tvm = tvm.nd.array(data) m.set_input("x", data_tvm) m.set_input(**params) m.run() e = m.module.time_evaluator("run", dev, number=300, repeat=3) t2 = e(data_tvm).results t2 = np.array(t2) * 1000 print("{} ms".format(t2.mean())) tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) if __name__ == "__main__": pytest.main([__file__])