# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm from tvm import te import numpy as np from tvm import topi from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_vectorize_add(): num_thread = 8 def check_cuda(dtype, n, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version): print("skip because gpu does not support int8") return A = te.placeholder((n,), name="A", dtype="%sx%d" % (dtype, lanes)) B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, bx) s[B].bind(xi, tx) fun = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n, lanes))) c = tvm.nd.empty((n,), B.dtype, dev) fun(a, c) tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1) check_cuda("float32", 64, 2) check_cuda("float32", 64, 3) check_cuda("float32", 64, 4) check_cuda("int8", 64, 2) check_cuda("int8", 64, 3) check_cuda("int8", 64, 4) check_cuda("uint8", 64, 2) check_cuda("uint8", 64, 3) check_cuda("uint8", 64, 4) check_cuda("float16", 64, 2) check_cuda("float16", 64, 4) check_cuda("float16", 64, 6) check_cuda("float16", 64, 8) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_bf16_vectorize_add(): if not have_bf16(tvm.cuda(0).compute_version): print("skip because gpu does not support bf16") return num_thread = 8 def np_float2np_bf16(arr): """Convert a numpy array of float to a numpy array of bf16 in uint16""" orig = arr.view(" b, name="c") s = te.create_schedule(c.op) axes = [axis for axis in c.op.axis] fused = s[c].fuse(*axes) bx, tx = s[c].split(fused, factor=64) s[c].bind(bx, te.thread_axis("blockIdx.x")) s[c].bind(tx, te.thread_axis("threadIdx.x")) func = tvm.build(s, [a, c], "cuda") dev = tvm.cuda(0) a_np = np.random.uniform(size=shape).astype(a.dtype) c_np = np.zeros(shape=shape, dtype=c.dtype) a = tvm.nd.array(a_np, dev) c = tvm.nd.array(c_np, dev) func(a, c) np.testing.assert_equal(c.numpy(), a_np > b.value) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_reduction(): def check(device, dtype, m=32, n=32): if not tvm.testing.device_enabled(device): print("Skipping", device) return dev = tvm.device(device, 0) a = te.placeholder((m, n), name="a", dtype=dtype) b = te.placeholder((m, n), name="b", dtype=dtype) c = a + b d = a * b e = topi.elemwise_sum([c, d]) g = topi.sum(e) with tvm.target.Target(device): sg = topi.cuda.schedule_reduce(g) func = tvm.build(sg, [a, b, g], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.random.uniform(size=(m, n)).astype(b.dtype) g_np = np.sum(np.add(a_np * b_np, a_np + b_np)) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(b_np, dev) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-3) check("cuda", "float32") check("rocm", "float32") check("cuda", "float16") @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_mix_threaded_and_normal_reduction(): def check(device, dtype, m=32, n=32): if not tvm.testing.device_enabled(device): print("Skipping", device) return dev = tvm.device(device, 0) if dtype == "float16" and not have_fp16(dev.compute_version): print("Skip because gpu does not have fp16 support") return a = tvm.te.placeholder((m, n), name="a", dtype=dtype) b = topi.sum(a) with tvm.target.Target(device): sb = tvm.te.create_schedule(b.op) i, _ = b.op.reduce_axis sb[b].bind(i, tvm.te.thread_axis("threadIdx.x")) func = tvm.build(sb, [a, b], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.sum(a_np) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) check("cuda", "float32") check("rocm", "float32") check("cuda", "float16") @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_floordiv_with_vectorization(): with tvm.target.cuda(): # B[i] = A[floordiv(i, k)] n = 256 k = 37 A = te.placeholder((n,), name="A") B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name="B") s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], nparts=1) xio, xii = s[B].split(xi, factor=4) s[B].vectorize(xii) s[B].bind(xo, bx) s[B].bind(xio, tx) func = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) a_np = np.random.uniform(size=(n,)).astype(A.dtype) b_np = np.array([a_np[i // k] for i in range(0, n)]) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_floormod_with_vectorization(): with tvm.target.cuda(): # B[i] = A[floormod(i, k)] n = 256 k = 37 A = te.placeholder((n,), name="A") B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name="B") s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], nparts=1) xio, xii = s[B].split(xi, factor=4) s[B].vectorize(xii) s[B].bind(xo, bx) s[B].bind(xio, tx) func = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) a_np = np.random.uniform(size=(n,)).astype(A.dtype) b_np = np.array([a_np[i % k] for i in range(0, n)]) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), dev) func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_casts(): def check(t0, t1, factor): if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return # compute n = 128 A = te.placeholder((n,), dtype=t0, name="A") B = te.placeholder((n,), dtype=t1, name="B") C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name="C") # schedule s = tvm.te.create_schedule(C.op) ob, ib = s[C].split(s[C].op.axis[0], factor=factor) s[C].vectorize(ib) s[C].bind(ob, tx) func = tvm.build(s, [A, B, C], "cuda") # correctness dev = tvm.cuda(0) low, high = (0, 20) if t0.startswith("u") or t1.startswith("u") else (-10, 10) a_np = np.random.randint(low, high, size=n).astype(A.dtype) b_np = np.random.randint(low, high, size=n).astype(B.dtype) c_np = (a_np + b_np).astype(A.dtype) a_nd = tvm.nd.array(a_np, dev) b_nd = tvm.nd.array(b_np, dev) c_nd = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np.dtype), dev) func(a_nd, b_nd, c_nd) tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-3) def skip(t0, t1): if t0 == t1: return True # CUDA does support cast between {u}int8 and fp16. skip_set = {"float16", "uint8", "int8"} if t0 in skip_set and t1 in skip_set: return True return False types_4 = [ "float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32", "float64", "int64", "uint64", ] types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"] for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]: check(t0, t1, 4) for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]: check(t0, t1, 8) check("int8", "uint8", 16) check("uint8", "int8", 16) def sched(B): s = te.create_schedule(B.op) io, ii = s[B].split(s[B].op.axis[0], nparts=1) iio, iii = s[B].split(ii, nparts=32) _, iiii = s[B].split(iii, factor=4) s[B].vectorize(iiii) s[B].bind(io, bx) s[B].bind(iio, tx) return s @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_intrin1(): test_funcs = [ (tvm.tir.floor, lambda x: np.floor(x)), (tvm.tir.ceil, lambda x: np.ceil(x)), (tvm.tir.trunc, lambda x: np.trunc(x)), (tvm.tir.abs, lambda x: np.fabs(x)), (tvm.tir.round, lambda x: np.round(x)), (tvm.tir.exp, lambda x: np.exp(x)), (tvm.tir.exp2, lambda x: np.exp2(x)), (tvm.tir.exp10, lambda x: np.power(10, x)), (tvm.tir.log, lambda x: np.log(x)), (tvm.tir.log2, lambda x: np.log2(x)), (tvm.tir.log10, lambda x: np.log10(x)), (tvm.tir.tan, lambda x: np.tan(x)), (tvm.tir.cos, lambda x: np.cos(x)), (tvm.tir.cosh, lambda x: np.cosh(x)), (tvm.tir.sin, lambda x: np.sin(x)), (tvm.tir.sinh, lambda x: np.sinh(x)), (tvm.tir.atan, lambda x: np.arctan(x)), (tvm.tir.tanh, lambda x: np.tanh(x)), (tvm.tir.sqrt, lambda x: np.sqrt(x)), ] def run_test(tvm_intrin, np_func, dtype): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return # set of intrinsics does not support fp16 yet. skip_set = { tvm.tir.abs, tvm.tir.round, tvm.tir.tan, tvm.tir.atan, tvm.tir.tanh, tvm.tir.cosh, tvm.tir.sinh, } if dtype == "float16" and tvm_intrin in skip_set: print("Skip because '{0}' does not support fp16 yet".format(tvm_intrin.__name__)) return n = 128 A = te.placeholder((n,), dtype=dtype, name="A") B = te.compute((n,), lambda *i: tvm_intrin(A(*i)), name="B") s = sched(B) f = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) for func in test_funcs: run_test(*func, "float32") run_test(*func, "float16") @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_intrin2(dtype="float32"): c2 = tvm.tir.const(2, dtype=dtype) test_funcs = [ (tvm.tir.power, lambda x: np.power(x, 2.0)), (tvm.tir.fmod, lambda x: np.fmod(x, 2.0)), ] def run_test(tvm_intrin, np_func): n = 128 A = te.placeholder((n,), dtype=dtype, name="A") B = te.compute((n,), lambda i: tvm_intrin(A[i], c2), name="B") s = sched(B) f = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) for func in test_funcs: run_test(*func) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_popcount(): def ref_popcount(x): cnt = 0 while x: x -= x & -x cnt += 1 return cnt def run_test(dtype): n = 128 A = te.placeholder((n,), dtype=dtype, name="A") B = te.compute((n,), lambda i: tvm.tir.popcount(A[i]), name="B") s = sched(B) f = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev) f(a, b) ref = np.vectorize(ref_popcount)(a.numpy()) tvm.testing.assert_allclose(b.numpy(), ref) run_test("uint32") run_test("uint64") @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_cuda_vectorize_load_permute_pad(): def check_cuda(dtype, n, l, padding, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return dev = tvm.cuda(0) A = tvm.te.placeholder((n, l), name="A", dtype=dtype) B = tvm.te.compute( (n // lanes, l + 2 * padding, lanes), lambda i, j, k: tvm.te.if_then_else( tvm.te.any(j < padding, j >= l + padding), tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding], ), name="B", ) s = te.create_schedule(B.op) block, thread, vectorize = s[B].op.axis s[B].bind(block, bx) s[B].bind(thread, tx) s[B].vectorize(vectorize) fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad") np_a = np.random.randint(low=-128, high=127, size=(n, l)).astype(A.dtype) a = tvm.nd.empty((n, l), A.dtype, dev).copyfrom(np_a) b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, dev) fun(a, b) np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1) ref = np.pad( np_a_reshape, ((0, 0), (padding, padding), (0, 0)), mode="constant", constant_values=0 ) tvm.testing.assert_allclose(b.numpy(), ref) check_cuda("int8", 64, 16, 3, 2) check_cuda("uint8", 64, 16, 3, 2) check_cuda("int8", 64, 16, 3, 4) check_cuda("uint8", 64, 16, 3, 4) check_cuda("int32", 64, 16, 3, 4) check_cuda("float16", 64, 16, 3, 4) check_cuda("float32", 64, 16, 3, 4) def vcf_check_common(s, args): N = 512 # To check if every vectorize loop transforms to ramp expr successfully stmt = tvm.lower(s, args) # Use this as a stack flag to show whether this stmt is inside a BroadcastNode inside_broadcast = [False] # Possible patterns: # Reduce init: Store[Ramp] = Broadcast(0) # Shared memory copy: Store[Ramp] = Load[Ramp] # Compute: Store[Ramp] = Load[Ramp] ... Broadcast[Load] def pre_visit(stmt): if isinstance(stmt, tvm.tir.Broadcast): inside_broadcast[0] = True # Check Broadcast[Imm numbers] or Broadcast[Load] patterns assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load)) if isinstance(stmt, tvm.tir.Store): # Check Store[Ramp] pattern assert isinstance(stmt.index, tvm.tir.Ramp) if isinstance(stmt, tvm.tir.Load): # Check Broadcast[Load] or Load[Ramp] patterns assert inside_broadcast[0] or isinstance(stmt.index, tvm.tir.Ramp) # Skip the rest return stmt return None def post_visit(stmt): if isinstance(stmt, tvm.tir.Broadcast): inside_broadcast[0] = False return None tvm.tir.stmt_functor.ir_transform(stmt["main"].body, pre_visit, post_visit) tgt = tvm.target.cuda() mod = tvm.build(s, args, tgt) # To check if every vectorize loop transforms to correct instruction # print(mod.imported_modules[0].get_source()) dev = tvm.device("cuda", 0) a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), dev) b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), dev) c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), dev) mod(a, b, c) tvm.testing.assert_allclose(c.numpy(), np.dot(a.numpy(), b.numpy()), rtol=1e-5) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_cooperative_fetching_x(): N = 512 A = te.placeholder((N, N), name="A", dtype="float32") B = te.placeholder((N, N), name="B", dtype="float32") k = te.reduce_axis((0, N), name="k") C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)) s = te.create_schedule(C.op) i, j = s[C].op.axis k = s[C].op.reduce_axis[0] AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) i3, i4 = s[C].split(i, factor=4) i2, i3 = s[C].split(i3, factor=2) i1, i2 = s[C].split(i2, factor=8) i0, i1 = s[C].split(i1, factor=1) j3, j4 = s[C].split(j, factor=4) j2, j3 = s[C].split(j3, factor=2) j1, j2 = s[C].split(j2, factor=8) j0, j1 = s[C].split(j1, factor=2) k1, k2 = s[C].split(k, factor=8) k0, k1 = s[C].split(k1, factor=8) s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) block_it = s[C].fuse(i0, j0) s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x")) vthread_it = s[C].fuse(i1, j1) s[C].bind(vthread_it, tvm.te.thread_axis("vthread")) thread_it = s[C].fuse(i2, j2) s[C].bind(thread_it, tvm.te.thread_axis("threadIdx.x")) s[C].vectorize(j4) s[AA].compute_at(s[C], k0) iaa, jaa = s[AA].op.axis s[BB].compute_at(s[C], k0) ibb, jbb = s[BB].op.axis aa_fused = s[AA].fuse(iaa, jaa) bb_fused = s[BB].fuse(ibb, jbb) aa1, aa2 = s[AA].split(aa_fused, factor=4) aa0, aa1 = s[AA].split(aa1, factor=64) bb1, bb2 = s[BB].split(bb_fused, factor=4) bb0, bb1 = s[BB].split(bb1, factor=64) s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.x")) s[AA].vectorize(aa2) s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.x")) s[BB].vectorize(bb2) vcf_check_common(s, [A, B, C]) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_cooperative_fetching_xy(): N = 512 A = te.placeholder((N, N), name="A") B = te.placeholder((N, N), name="B") k = te.reduce_axis((0, N), name="k") C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)) s = te.create_schedule(C.op) i, j = s[C].op.axis k = s[C].op.reduce_axis[0] AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) i3, i4 = s[C].split(i, factor=4) i2, i3 = s[C].split(i3, factor=2) i1, i2 = s[C].split(i2, factor=8) i0, i1 = s[C].split(i1, factor=1) j3, j4 = s[C].split(j, factor=4) j2, j3 = s[C].split(j3, factor=2) j1, j2 = s[C].split(j2, factor=8) j0, j1 = s[C].split(j1, factor=2) k1, k2 = s[C].split(k, factor=8) k0, k1 = s[C].split(k1, factor=8) s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) block_it = s[C].fuse(i0, j0) s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x")) vthread_it = s[C].fuse(i1, j1) s[C].bind(vthread_it, tvm.te.thread_axis("vthread")) s[C].bind(i2, tvm.te.thread_axis("threadIdx.y")) s[C].bind(j2, tvm.te.thread_axis("threadIdx.x")) s[C].vectorize(j4) s[AA].compute_at(s[C], k0) iaa, jaa = s[AA].op.axis s[BB].compute_at(s[C], k0) ibb, jbb = s[BB].op.axis aa_fused = s[AA].fuse(iaa, jaa) bb_fused = s[BB].fuse(ibb, jbb) aa2, aa3 = s[AA].split(aa_fused, factor=4) aa1, aa2 = s[AA].split(aa2, factor=8) aa0, aa1 = s[AA].split(aa1, factor=8) bb2, bb3 = s[BB].split(bb_fused, factor=4) bb1, bb2 = s[BB].split(bb2, factor=8) bb0, bb1 = s[BB].split(bb1, factor=8) s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.y")) s[AA].bind(aa2, tvm.te.thread_axis("threadIdx.x")) s[AA].vectorize(aa3) s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.y")) s[BB].bind(bb2, tvm.te.thread_axis("threadIdx.x")) s[BB].vectorize(bb3) vcf_check_common(s, [A, B, C]) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_unrolled_vectorization(): dtype = "float32" target = "cuda" # Compute declaration N = 128 A = te.placeholder((N, N), name="A") B = te.placeholder((N, N), name="B") k = te.reduce_axis((0, N), name="k") C = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name="C") # Schedule s = te.create_schedule([C.op]) CC = s.cache_write(C, "local") i, j = s[C].op.axis bx, tx, ii, ji = s[C].tile(i, j, 1, 2) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(ji) s[CC].compute_at(s[C], tx) i, j = s[CC].op.axis k = s[CC].op.reduce_axis[0] ko, ki = s[CC].split(k, 2) s[CC].unroll(ki) s[CC].vectorize(j) # Check correctness dev = tvm.device(target) a_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), device=dev) b_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), device=dev) c_tvm = tvm.nd.empty((N, N), device=dev) func_tvm = tvm.build(s, [A, B, C], target=target) func_tvm(a_tvm, b_tvm, c_tvm) c_np = c_tvm.numpy() tvm.testing.assert_allclose(c_np, N * np.ones((N, N))) @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_try_unaligned_vector_load(): def get_compute(N, C_N, offset): A = te.placeholder((N,), name="A", dtype="float16") C = te.compute((C_N,), lambda i: A[i + offset], name="C") return N, C_N, A, C def get_compute_unaligned(): return get_compute(3, 2, 1) def get_compute_aligned(): return get_compute(4, 2, 2) def build(A, C, N, C_N): s = te.create_schedule(C.op) oi, ii = s[C].split(C.op.axis[0], factor=2) s[C].bind(oi, te.thread_axis("threadIdx.x")) s[C].vectorize(ii) # BUG: misalignment tgt = tvm.target.Target(target="cuda", host="llvm") dev = tvm.device(tgt.kind.name, 0) f = tvm.build(s, [A, C], tgt, name="foo") kernel_source = f.imported_modules[0].get_source() a_data = np.arange(0, N).astype(A.dtype) a = tvm.nd.array(a_data, dev) c = tvm.nd.array(np.zeros(C_N, dtype=C.dtype), dev) f(a, c) return a_data, c.numpy(), kernel_source N, C_N, A, C = get_compute_unaligned() a_data, c, kernel_source = build(A, C, N, C_N) # (uint1*)(A + (1)) is invalid assert "A + (1)" not in kernel_source expected = a_data[1 : C_N + 1] assert np.allclose(c, expected), f"expected={expected}\nactual={c}" N, C_N, A, C = get_compute_aligned() a_data, c, kernel_source = build(A, C, N, C_N) # (uint1*)(A + (2)) is a valid vector load assert "A + (2)" in kernel_source expected = a_data[2 : C_N + 2] assert np.allclose(c, expected), f"expected={expected}\nactual={c}" if __name__ == "__main__": pytest.main([__file__])