# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import tvm import tvm.testing from tvm import te from tvm.tir import Buffer import numpy as np def test_buffer(): m = te.size_var("m") n = te.size_var("n") l = te.size_var("l") Ab = tvm.tir.decl_buffer((m, n), "float32") Bb = tvm.tir.decl_buffer((n, l), "float32") assert isinstance(Ab, tvm.tir.Buffer) assert Ab.dtype == "float32" assert tuple(Ab.shape) == (m, n) def test_buffer_access_ptr(): m = te.size_var("m") n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1]) aptr = Ab.access_ptr("rw") assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m) assert aptr.args[0].dtype == Ab.dtype assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("w") assert aptr.args[4].value == Buffer.WRITE def test_buffer_access_ptr_offset(): m = te.size_var("m") n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw", offset=100) tvm.testing.assert_prim_expr_equal(aptr.args[2], 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE v = te.size_var("int32") aptr = Ab.access_ptr("rw", offset=100 + 100 + v) tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern("int32", "test_call", 100 + 100 + v)) tvm.testing.assert_prim_expr_equal( aptr.args[2], tvm.tir.call_extern("int32", "test_call", 200 + v) ) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE def test_buffer_access_ptr_extent(): m = te.size_var("m") n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw") assert tvm.ir.structural_equal(aptr.args[3], m * n) aptr = Ab.access_ptr("rw", offset=100) assert tvm.ir.structural_equal(aptr.args[3], m * n - 100) Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1]) aptr = Ab.access_ptr("rw", offset=100) assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100) def test_buffer_vload(): m = te.size_var("m") n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103) def test_buffer_vload_nullptr(): var = tvm.tir.Var("v", dtype="int32") buf = tvm.tir.decl_buffer((1,), name="buf") buf_load = tvm.tir.expr.BufferLoad(buffer=buf, indices=tvm.runtime.convert([0])) buf_load_stmt = tvm.tir.stmt.Evaluate(buf_load) for_loop = tvm.tir.stmt.For( loop_var=var, kind=0, min_val=0, extent=buf_load, body=buf_load_stmt ) buf_func = tvm.tir.PrimFunc(params={}, body=for_loop) mod = tvm.IRModule({"main": buf_func}) # Trigger nullptr buffer bug by pass with pytest.raises(tvm.error.TVMError) as cm: mod = tvm.transform.Sequential( [ tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), tvm.tir.transform.CompactBufferAllocation(), tvm.tir.transform.FlattenBuffer(), ] )(mod) assert "(n != nullptr) is false" in str(cm.execption) def test_buffer_index_merge_mult_mod(): m = te.size_var("m") n = te.size_var("n") s = te.size_var("s") k0 = te.size_var("k0") k1 = te.size_var("k1") A = tvm.tir.decl_buffer((m, n), "float32") A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): assert tvm.ir.structural_equal( index_simplified, index_direct ), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod # Test Case1 index_simplified = A_stride.vload( (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1) ) index_direct = A_stride.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case2 index_simplified = A.vload( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)) ) index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 index_simplified = A.vload( ( idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxd(idxm(k0, idxd(k1, s)), n), idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + idxm(idxm(k0, idxd(k1, s)), n), ) ) index_direct = A.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case4 (not able to simplify) index_simplified = A.vload( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)) ) index_direct = A.vload( (0, idxd(idxm(k0, idxd(k1, s)), n) * n + (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))) ) assert_simplified_equal(index_simplified, index_direct) # Test Case5 B = tvm.tir.decl_buffer((1, 14, 14, 1024)) i = te.size_var("i") j = te.size_var("j") k = te.size_var("k") index_simplified = B.vload( ( idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd((i * 50176 + j * 28672 + k), 1024), 14), idxm((i * 50176 + j * 28672 + k), 1024), ) ) index_direct = B.vload((0, 0, 0, (i * 50176 + j * 28672 + k))) assert_simplified_equal(index_simplified, index_direct) @tvm.testing.requires_llvm def test_buffer_broadcast(): m0, m1, m2 = te.size_var("m0"), te.size_var("m1"), te.size_var("m2") n0, n1, n2 = te.size_var("n0"), te.size_var("n1"), te.size_var("n2") o0, o1, o2 = te.size_var("o0"), te.size_var("o1"), te.size_var("o2") A = te.placeholder((m0, m1, m2), name="A") B = te.placeholder((n0, n1, n2), name="B") C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name="C") Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") s = te.create_schedule(C.op) def check(): fadd = tvm.build(s, [A, B, C], target="llvm", name="bcast_add", binds={A: Ab, B: Bb}) dev = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=(2, 1, 1)).astype(B.dtype), dev) c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), dev) fadd(a, b, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) check() @tvm.testing.requires_llvm def test_buffer_broadcast_expr(): n0, m0, x = te.size_var("n0"), te.size_var("m0"), te.size_var("x") n1, m1 = te.size_var("n1"), te.size_var("m1") o0, o1 = te.size_var("o0"), te.size_var("o1") A = te.placeholder((m0, n0), name="A") B = te.placeholder((m1, n1), name="B") C = te.compute((o0, o1 // x), lambda i, j: A[i, j] + B[i, j], name="C") Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") Cc = tvm.tir.decl_buffer(C.shape, C.dtype, name="Cc", buffer_type="auto_broadcast") s = te.create_schedule(C.op) def check_stride(): fadd = tvm.build( s, [A, B, C, o1, x], target="llvm", name="bcast_add", binds={A: Ab, B: Bb, C: Cc} ) dev = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), dev) c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), dev) fadd(a, b, c, 4, 1) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) def check_no_stride(): fadd = tvm.build( s, [A, B, C, o1, x], target="llvm", name="bcast_add", binds={A: Ab, B: Bb, C: Cc} ) dev = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), dev) c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), dev) fadd(a, b, c, 4, 1) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) def check_auto_bind(): # Let build bind buffers fadd = tvm.build(s, [A, B, C, o1, x], target="llvm", name="bcast_add") dev = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), dev) c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), dev) fadd(a, b, c, 4, 1) tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) check_stride() check_no_stride() check_auto_bind() if __name__ == "__main__": pytest.main([__file__])