# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import sys import tvm from tvm.script import tir as T from tvm.tir import stmt_functor from tvm.tir.usmp import utils as usmp_utils from tvm.target import Target # fmt: off @tvm.script.ir_module class LinearStructure: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dTpe="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body for ax0_ax1_fused_1 in T.serial(0, 224): for ax2_1, ax3_inner_1 in T.grid(224, 3): T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) @T.prim_func def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_7 = T.allocate([157323], "int16", "global") for i0_i1_fused_7 in T.serial(0, 229): for i2_7, i3_7 in T.grid(229, 3): T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7 = T.allocate([64], "int32", "global") for ff_3 in T.serial(0, 64): T.store(Conv2dOutput_7, ff_3, 0, True) for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) for ax3_inner_7 in T.serial(0, 64): T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) @T.prim_func def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body tensor_2 = T.allocate([200704], "uint8", "global") for ax0_ax1_fused_4 in T.serial(0, 56): for ax2_4 in T.serial(0, 56): for ax3_init in T.serial(0, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) for ax0_ax1_fused_5 in T.serial(0, 56): for ax2_5, ax3_3 in T.grid(56, 64): T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) @T.prim_func def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) __tvm_meta__ = None # fmt: on def test_create_pool_info(): target = Target("c") pool_info = usmp_utils.PoolInfo( pool_name="foo_workspace", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) assert pool_info.pool_name == "foo_workspace" assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} # default pool size constraint assert pool_info.size_hint_bytes == -1 pool_info = usmp_utils.PoolInfo( pool_name="bar_workspace", target_access={target: usmp_utils.PoolInfo.READ_ONLY_ACCESS}, size_hint_bytes=1425, ) assert pool_info.pool_name == "bar_workspace" assert dict(pool_info.target_access) == {target: usmp_utils.PoolInfo.READ_ONLY_ACCESS} assert pool_info.size_hint_bytes == 1425 def test_create_buffer_info(): global_ws_pool = usmp_utils.PoolInfo( pool_name="global_workspace", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) buffer_info_obj = tvm.tir.usmp.BufferInfo( name_hint="buf1", size_bytes=256, pool_candidates=[global_ws_pool] ) assert buffer_info_obj.name_hint == "buf1" assert buffer_info_obj.size_bytes == 256 assert list(buffer_info_obj.pool_candidates) == [global_ws_pool] # default workspace alignment assert buffer_info_obj.alignment == 1 buffer_info_obj = tvm.tir.usmp.BufferInfo("buf2", 512, [global_ws_pool], 8) assert buffer_info_obj.name_hint == "buf2" assert buffer_info_obj.size_bytes == 512 assert list(buffer_info_obj.pool_candidates) == [global_ws_pool] assert buffer_info_obj.alignment == 8 def test_create_pool_allocation(): pool_info = usmp_utils.PoolInfo( pool_name="foo_workspace", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) pool_allocation = usmp_utils.PoolAllocation(pool_info=pool_info, byte_offset=64) assert pool_allocation.pool_info == pool_info assert pool_allocation.byte_offset == 64 def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" def set_poolinfos(stmt): if isinstance(stmt, tvm.tir.Allocate): return tvm.tir.Allocate( buffer_var=stmt.buffer_var, dtype=stmt.dtype, extents=stmt.extents, condition=stmt.condition, body=stmt.body, annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, ) return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): """helper to assing poolinfos to allocate nodes in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) return ret def _assign_targets_to_primfuncs_irmodule(mod, target): """helper to assign target for PrimFunc in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): ret[global_var] = basefunc.with_attr("target", target) return ret def test_create_array_buffer_info(): target = Target("c") global_ws_pool = usmp_utils.PoolInfo( pool_name="global_workspace", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") tir_mod = LinearStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) buffer_info_array = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts) for buffer_info in buffer_info_array: assert buffer_info in buffer_info_analysis.buffer_info_stmts.keys() if __name__ == "__main__": pytest.main([__file__] + sys.argv[1:])