# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring # mypy: ignore-errors import sys import pytest import tvm from tvm import tir from tvm.script import tir as T from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV, Trace # pylint: disable=no-member,invalid-name,unused-variable @T.prim_func def elementwise(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 @T.prim_func def elementwise_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 # pylint: enable=no-member,invalid-name,unused-variable def _make_get_block(name, output): return Instruction( kind=InstructionKind.get("GetBlock"), inputs=[], attrs=[name, "main"], outputs=[output], ) def _make_get_loops(input, outputs): # pylint: disable=redefined-builtin return Instruction( kind=InstructionKind.get("GetLoops"), inputs=[input], attrs=[], outputs=outputs, ) def _make_compute_inline(input): # pylint: disable=redefined-builtin return Instruction( kind=InstructionKind.get("ComputeInline"), inputs=[input], attrs=[], outputs=[], ) def _make_enter_postproc(): return Instruction( kind=InstructionKind.get("EnterPostproc"), inputs=[], attrs=[], outputs=[], ) def _make_trace_1(b0, l1, l2): # pylint: disable=invalid-name return Trace( insts=[ _make_get_block(name="block", output=b0), _make_get_loops(input=b0, outputs=[l1, l2]), ], decisions={}, ) def _make_trace_2(b0): # pylint: disable=invalid-name return Trace( insts=[ _make_get_block(name="B", output=b0), _make_compute_inline(input=b0), ], decisions={}, ) def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name if add_postproc: insts = [ _make_get_block(name="B", output=b0), _make_compute_inline(input=b0), _make_get_block(name="C", output=b1), _make_enter_postproc(), _make_compute_inline(input=b1), ] else: insts = [ _make_get_block(name="B", output=b0), _make_compute_inline(input=b0), _make_get_block(name="C", output=b1), ] return Trace(insts=insts, decisions={}) def test_trace_construct_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) assert str(trace) == "\n".join( ( 'b0 = sch.get_block(name="block", func_name="main")', "l1, l2 = sch.get_loops(block=b0)", ) ) assert len(trace.insts) == 2 assert len(trace.decisions) == 0 def test_trace_construct_get_decision_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) assert trace.get_decision(trace.insts[0]) is None assert trace.get_decision(trace.insts[1]) is None def test_trace_construct_append_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) trace.append(inst=_make_get_block("block2", BlockRV())) assert str(trace) == "\n".join( ( 'b0 = sch.get_block(name="block", func_name="main")', "l1, l2 = sch.get_loops(block=b0)", 'b3 = sch.get_block(name="block2", func_name="main")', ) ) def test_trace_construct_pop_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) last_inst = trace.insts[-1] assert trace.pop().same_as(last_inst) assert str(trace) == 'b0 = sch.get_block(name="block", func_name="main")' def test_trace_construct_pop_2(): trace = Trace([], {}) assert str(trace) == "" assert trace.pop() is None assert str(trace) == "" def test_trace_apply_to_schedule(): trace = _make_trace_2(BlockRV()) sch = tir.Schedule(elementwise, debug_mask="all") trace.apply_to_schedule(sch, remove_postproc=False, decision_provider=None) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) def test_trace_as_json_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) obj = trace.as_json() assert obj == [ [ ["GetBlock", [], ["block", "main"], ["b0"]], ["GetLoops", ["b0"], [], ["l1", "l2"]], ], [], ] def test_trace_simplified_1(): trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) assert str(trace) == "\n".join( ( 'b0 = sch.get_block(name="B", func_name="main")', "sch.compute_inline(block=b0)", 'b1 = sch.get_block(name="C", func_name="main")', "sch.enter_postproc()", "sch.compute_inline(block=b1)", ) ) trace = trace.simplified(remove_postproc=True) assert str(trace) == "\n".join( ( 'b0 = sch.get_block(name="B", func_name="main")', "sch.compute_inline(block=b0)", ) ) def test_trace_simplified_2(): trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True) assert str(trace) == "\n".join( ( 'b0 = sch.get_block(name="B", func_name="main")', "sch.compute_inline(block=b0)", 'b1 = sch.get_block(name="C", func_name="main")', "sch.enter_postproc()", "sch.compute_inline(block=b1)", ) ) trace = trace.simplified(remove_postproc=False) assert str(trace) == "\n".join( ( 'b0 = sch.get_block(name="B", func_name="main")', "sch.compute_inline(block=b0)", 'b1 = sch.get_block(name="C", func_name="main")', "sch.enter_postproc()", "sch.compute_inline(block=b1)", ) ) def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json() sch = tir.Schedule(elementwise, debug_mask="all") Trace.apply_json_to_schedule(json_obj, sch) tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))