# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm from tvm.ir import Range from tvm.script import tir as T @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i, j in T.grid(32, 32): with T.block("init"): vi, vj = T.axis.remap("SS", [i, j]) for ii, jj in T.grid(4, 4): C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) for k in range(0, 32): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) for ii, jj, kk in T.grid(4, 4, 4): C[vi * 4 + ii, vj * 4 + jj] = ( C[vi * 4 + ii, vj * 4 + jj] + A[vi * 4 + ii, vk * 4 + kk] * B[vj * 4 + jj, vk * 4 + kk] ) @T.prim_func def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block() as []: for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block() as []: with T.block() as []: B[0, 0] = A[0, 0] + T.float32(1) for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1) @T.prim_func def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block() as []: for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) B[vi, vj] = A[vi, vj] + T.float32(1) for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1) def test_complete_matmul(): func = matmul A, B, C = [func.buffer_map[x] for x in func.params] block = func.body.block.body.body.body.body.block assert isinstance(block, tvm.tir.Block) vi, vj, vk = [x.var for x in block.iter_vars] access_A = tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)]) access_B = tvm.tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)]) access_C = tvm.tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)]) tvm.ir.assert_structural_equal(block.reads, [access_C, access_A, access_B]) tvm.ir.assert_structural_equal(block.writes, [access_C]) def test_complete_matmul_original(): func = matmul_original A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body.body.body[0].block assert isinstance(block1, tvm.tir.Block) vi, vj = [x.var for x in block1.iter_vars] access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] ) tvm.ir.assert_structural_equal(block1.reads, []) tvm.ir.assert_structural_equal(block1.writes, [access_C]) block2 = func.body.block.body.body.body[1].body.block assert isinstance(block2, tvm.tir.Block) vi, vj, vk = [x.var for x in block2.iter_vars] access_A = tvm.tir.BufferRegion( A, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vk * 4, 4)] ) access_B = tvm.tir.BufferRegion( B, [Range.from_min_extent(vj * 4, 4), Range.from_min_extent(vk * 4, 4)] ) access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] ) tvm.ir.assert_structural_equal(block2.reads, [access_C, access_A, access_B]) tvm.ir.assert_structural_equal(block2.writes, [access_C]) def _check_elementwise(func): A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body[0].body.body.block assert isinstance(block1, tvm.tir.Block) vi, vj = [x.var for x in block1.iter_vars] tvm.ir.assert_structural_equal( block1.reads, [tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block1.writes, [tvm.tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) block2 = func.body.block.body[1].body.body.block assert isinstance(block2, tvm.tir.Block) vi, vj = [x.var for x in block2.iter_vars] tvm.ir.assert_structural_equal( block2.reads, [tvm.tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block2.writes, [tvm.tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) def test_complete_with_root(): _check_elementwise(elementwise_with_root) def test_complete_part_region(): _check_elementwise(func_with_part_access_region) @T.prim_func def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: data_buf = T.match_buffer(data, (16, 16), "float32") index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] @T.prim_func def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[vi, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] @T.prim_func def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: data_buf = T.match_buffer(data, (16, 16), "float32") index_buf = T.match_buffer(index, (1,), "int32") out_buf = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] @T.prim_func def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[0:16, 0:16], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] def test_complete_buffer_indices(): new_func = tvm.script.from_source(func_with_bufferslice_indices.script()) tvm.ir.assert_structural_equal(new_func, expected_bufferslice_indices) new_func = tvm.script.from_source(func_with_recursive_bufferslice_indices.script()) tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) @T.prim_func def match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): with T.block(): A0 = T.match_buffer(A[i, 0:16], (16)) with T.block(): for j in range(0, 16): with T.block() as []: A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 @T.prim_func def expected_match_buffer_func(a: T.handle) -> None: A = T.match_buffer(a, (16, 16)) for i in range(0, 16): with T.block(): T.reads([]) T.writes(A[i, 0:16]) A0 = T.match_buffer(A[i, 0:16], (16)) with T.block(): T.reads([]) T.writes(A0[0:16]) for j in range(0, 16): with T.block() as []: T.reads([]) T.writes(A0[j]) A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 def test_complete_match_buffer(): tvm.ir.assert_structural_equal(match_buffer_func, expected_match_buffer_func) @T.prim_func def alloc_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [2, 2], dtype="float32") B = T.match_buffer(b, [2, 2], dtype="float32") C = T.alloc_buffer([2, 2], dtype="float32") A[(0, 0)] = T.float32(2) C[(0, 0)] = A[(0, 0)] + B[(0, 0)] B[(0, 0)] = C[(0, 0)] @T.prim_func def expect_alloc_buffer_func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) with T.block("root"): T.reads([]) T.writes([]) C = T.alloc_buffer([2, 2], dtype="float32", elem_offset=0, align=128, offset_factor=1) A[(0, 0)] = T.float32(2) C[(0, 0)] = A[(0, 0)] + B[(0, 0)] B[(0, 0)] = C[(0, 0)] def test_complete_alloc_buffer(): rt_func = tvm.script.from_source(alloc_buffer_func.script(show_meta=True)) tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func) @T.prim_func def load_var() -> None: d = T.var("float32") d[1] = d[1] if __name__ == "__main__": test_complete_matmul() test_complete_matmul_original() test_complete_with_root() test_complete_part_region() test_complete_buffer_indices() test_complete_match_buffer() test_complete_alloc_buffer()