# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # 'License'); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import sys import numpy as np import pytest import tvm from tvm import autotvm from tvm import te from tvm.topi import testing from tvm.topi.utils import get_const_tuple, simplify from tvm.topi import nn def compute_plus_one_rank3(shape): X = te.placeholder(shape, name="X", dtype="float32") Y = te.compute(shape, lambda i, j, k: X[i, j, k] + 1, name="Compute_Y") return X, Y def schedule_plus_one_rank3(X, Y): s = te.create_schedule(Y.op) # Xt = s.cache_read(X, "texture", [Y]) # Xt = s.cache_read(X, "global", [Y]) Xt = s.cache_read(X, "global.texture", [Y]) # copy to texture stage x, y, c = s[Xt].op.axis s[Xt].bind(x, te.thread_axis("blockIdx.x")) s[Xt].bind(y, te.thread_axis("threadIdx.x")) s[Xt].vectorize(c) # the compute stage x, y, c = s[Y].op.axis xo, yo, xi, yi = s[Y].tile(x, y, 4, 4) s[Y].bind(xo, te.thread_axis("blockIdx.x")) s[Y].bind(yo, te.thread_axis("threadIdx.x")) s[Y].vectorize(c) return s def compute_plus_one_rank5(shape): X = te.placeholder(shape, name="X", dtype="float32") Y = te.compute(shape, lambda i, j, k, l, m: X[i, j, k, l, m] + 1, name="Compute_Y") return X, Y def schedule_plus_one_rank5(X, Y): s = te.create_schedule(Y.op) Xt = s.cache_read(X, "global.texture", [Y]) # copy to texture stage a, b, c, d, e = s[Xt].op.axis abc = s[Xt].fuse(a, b, c) s[Xt].bind(abc, te.thread_axis("blockIdx.x")) s[Xt].bind(d, te.thread_axis("threadIdx.x")) s[Xt].vectorize(e) # the compute stage a, b, c, d, e = s[Y].op.axis abc = s[Y].fuse(a, b, c) xo, yo, xi, yi = s[Y].tile(abc, d, 4, 4) s[Y].bind(xo, te.thread_axis("blockIdx.x")) s[Y].bind(yo, te.thread_axis("threadIdx.x")) s[Y].vectorize(e) return s def compute_matmul(shape): A = te.placeholder(shape, name="A", dtype="float32") B = te.placeholder(shape, name="B", dtype="float32") k = te.reduce_axis((0, shape[1]), name="k") C = te.compute( (shape[0] * shape[2], shape[0] * shape[2]), lambda i, j: te.sum( A[i // shape[2], k, i % shape[2]].astype("float32") * B[j // shape[2], k, j % shape[2]].astype("float32"), axis=[k], ), name="Compute_MatMul", ) return A, B, C def schedule_matmul(A, B, C, local=False): s = te.create_schedule(C.op) At = s.cache_read(A, "global.texture", [C]) Bt = s.cache_read(B, "global.texture", [C]) if local: Al = s.cache_read(At, "local", [C]) Bl = s.cache_read(Bt, "local", [C]) Cl = s.cache_write(C, "local") bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") def copy_to_texture(stage): _io, _k, _ii = s[stage].op.axis s[stage].vectorize(_ii) s[stage].bind(_io, bx) s[stage].bind(_k, tx) copy_to_texture(At) copy_to_texture(Bt) # copy to global stage _i, _j = s[C].op.axis xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) s[C].unroll(xi) s[C].vectorize(yi) s[C].bind(xo, te.thread_axis("blockIdx.x")) s[C].bind(yo, te.thread_axis("threadIdx.x")) # the compute stage s[Cl].compute_at(s[C], yo) (_k,) = Cl.op.reduce_axis _x, _y = s[Cl].op.axis s[Cl].reorder(_k, _x, _y) s[Cl].unroll(_x) s[Cl].vectorize(_y) if local: s[Al].compute_at(s[Cl], _k) s[Al].vectorize(s[Al].op.axis[-1]) s[Bl].compute_at(s[Cl], _k) s[Bl].vectorize(s[Bl].op.axis[-1]) return s def compute_matmul_inner(shape): A = te.placeholder(shape, name="A", dtype="float32") B = te.placeholder(shape, name="B", dtype="float32") k = te.reduce_axis((0, shape[1] * shape[2]), name="k") # (M, K) x (N, K) # (32, 256) x (32, 256) # (32, 64, 4) x (32, 64, 4) C = te.compute( (shape[0], shape[0]), lambda i, j: te.sum( A[i, k // shape[2], k % shape[2]].astype("float32") * B[j, k // shape[2], k % shape[2]].astype("float32"), axis=[k], ), name="Compute_MatMul", ) return A, B, C def schedule_matmul_inner(A, B, C, local=False): s = te.create_schedule(C.op) At = s.cache_read(A, "global.texture", [C]) Bt = s.cache_read(B, "global.texture", [C]) if local: Al = s.cache_read(At, "local", [C]) Bl = s.cache_read(Bt, "local", [C]) Cl = s.cache_write(C, "local") bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") def copy_to_texture(stage): _i, _ko, _ki = s[stage].op.axis s[stage].vectorize(_ki) s[stage].bind(_i, bx) s[stage].bind(_ko, tx) copy_to_texture(At) copy_to_texture(Bt) # copy to global stage _i, _j = s[C].op.axis xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) s[C].unroll(xi) s[C].vectorize(yi) s[C].bind(xo, te.thread_axis("blockIdx.x")) s[C].bind(yo, te.thread_axis("threadIdx.x")) # the compute stage s[Cl].compute_at(s[C], yo) (_k,) = Cl.op.reduce_axis _x, _y = s[Cl].op.axis s[Cl].reorder(_x, _y, _k) s[Cl].unroll(_x) # TODO(csullivan): consider whether the below error is worth resolving # s[Cl].vectorize(_y) # error if local: s[Al].compute_at(s[Cl], _x) s[Al].vectorize(s[Al].op.axis[-1]) s[Bl].compute_at(s[Cl], _x) s[Bl].vectorize(s[Bl].op.axis[-1]) return s def compute_matmul_vector_accumulator(shapeA, shapeB): # A x B # (K/4, M, K%4) x (K, N/4, N%4) = (M, N) # (32, 64, 4) x (128, 16, 4) = (64, 64) A = te.placeholder(shapeA, name="A", dtype="float32") B = te.placeholder(shapeB, name="B", dtype="float32") k = te.reduce_axis((0, shapeB[0]), name="k") C = te.compute( (shapeA[1], shapeB[1] * shapeB[2]), lambda i, j: te.sum( A[k // shapeA[-1], i, k % shapeA[-1]].astype("float32") * B[k, j // shapeB[-1], j % shapeB[-1]].astype("float32"), axis=[k], ), name="Compute_MatMul", ) return A, B, C def schedule_matmul_vector_accumulator(A, B, C, local=False): s = te.create_schedule(C.op) At = s.cache_read(A, "global.texture", [C]) Bt = s.cache_read(B, "global.texture", [C]) if local: Al = s.cache_read(At, "local", [C]) Bl = s.cache_read(Bt, "local", [C]) Cl = s.cache_write(C, "local") def copy_to_texture(stage): _y, _x, _v = s[stage].op.axis # TODO(csullivan): removing this vectorize results in numerical errors, autovectorize s[stage].vectorize(_v) s[stage].bind(_y, te.thread_axis("blockIdx.x")) s[stage].bind(_x, te.thread_axis("threadIdx.x")) copy_to_texture(At) copy_to_texture(Bt) # copy to global stage _i, _j = s[C].op.axis xo, yo, xi, yi = s[C].tile(_i, _j, 4, 4) s[C].unroll(xi) s[C].vectorize(yi) s[C].bind(xo, te.thread_axis("blockIdx.x")) s[C].bind(yo, te.thread_axis("threadIdx.x")) # the compute stage s[Cl].compute_at(s[C], yo) (_k,) = Cl.op.reduce_axis _a, _b = s[Cl].op.axis _ko, _ki = s[Cl].split(_k, factor=4) s[Cl].reorder(_ko, _a, _ki, _b) s[Cl].unroll(_ki) s[Cl].unroll(_a) s[Cl].vectorize(_b) if local: s[Al].compute_at(s[Cl], _a) _aa, _ka, _ba = s[Al].op.axis # TODO(csullivan)[BEFORE PR]: removing this vectorize command causes a crash. This needs to be autovectorized. s[Al].vectorize(_ba) s[Bl].compute_at(s[Cl], _ko) _ab, _kb, _bb = s[Bl].op.axis s[Bl].vectorize(_bb) s[Bl].unroll(_ab) return s def compute_conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape): # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] data = te.placeholder(input_shape, name="data", dtype="float32") filt = te.placeholder(filter_shape, name="filter", dtype="float32") c = te.reduce_axis((0, input_shape[1]), name="C") c4 = te.reduce_axis((0, input_shape[-1]), name="c4") kh = te.reduce_axis((0, filter_shape[0]), name="kh") kw = te.reduce_axis((0, filter_shape[1]), name="kw") conv = te.compute( (input_shape[0], filter_shape[-2], input_shape[2], input_shape[3], filter_shape[-1]), lambda n, ko, i, j, ki: te.sum( data[n, c, i, j, c4].astype("float32") * filt[kh, kw, c * input_shape[-1] + c4, ko, ki].astype("float32"), axis=[kh, kw, c, c4], ), # name="Compute_conv2d_1x1_NCHWc_RSCKk", name="conv2d_1x1", ) return data, filt, conv def schedule_conv2d_1x1_NCHWc_RSCKk(data, filt, conv): # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) # outputs: s = te.create_schedule(conv.op) A, B, C = data, filt, conv At = s.cache_read(A, "global.texture", [C]) Bt = s.cache_read(B, "global.texture", [C]) Al = s.cache_read(At, "local", [C]) Bl = s.cache_read(Bt, "local", [C]) Cl = s.cache_write(C, "local") def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(At) copy_to_texture(Bt) _n, _ko, _h, _w, _ki = s[C].op.axis s[C].vectorize(_ki) s[C].bind(_n, te.thread_axis("blockIdx.x")) s[C].bind(_ko, te.thread_axis("threadIdx.x")) s[Cl].compute_at(s[C], _w) _nl, _kol, _hl, _wl, _kil = s[Cl].op.axis _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis _clo, _cli = s[Cl].split(_cl, factor=4) s[Cl].reorder(_clo, _cli, _cl4, _kil) s[Cl].unroll(_cli) s[Cl].unroll(_cl4) s[Cl].vectorize(_kil) s[Al].compute_at(s[Cl], _cli) s[Al].vectorize(s[Al].op.axis[-1]) s[Bl].compute_at(s[Cl], _kwl) s[Bl].vectorize(s[Bl].op.axis[-1]) return s def compute_conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape): # input_shape = [W, C, H, N, c] -> [W, C, H*N, c] # filter_shape = [C, R, S, K, k] -> [C, R*S*K, k] # output_shape: [WK, HN, k] -> [W, K, H, N, k] data = te.placeholder(input_shape, name="data", dtype="float32") filt = te.placeholder(filter_shape, name="filter", dtype="float32") packed_data = te.compute( (input_shape[0], input_shape[1], input_shape[2] * input_shape[3], input_shape[4]), lambda i, j, k, l: data[i, j, k // input_shape[3], k % input_shape[3], l], name="packed_data", ) # Logical transformation of Nd -> 3d tensor # CRSKk -> C|RSK|k # r = rsk // SK # sk = rsk % SK # s = sk // K == (rsk % SK) // K == (rsk // K) % S # k = sk % K == (rsk % SK) % K == rsk % K packed_filter = te.compute( (filter_shape[0], filter_shape[1] * filter_shape[2] * filter_shape[3], filter_shape[4]), lambda i, j, k: filt[ i, j // (filter_shape[3] * filter_shape[2]), (j // filter_shape[3]) % filter_shape[2], j % filter_shape[3], k, ], name="packed_filter", ) c = te.reduce_axis((0, input_shape[1]), name="C") c4 = te.reduce_axis((0, input_shape[-1]), name="c4") r = te.reduce_axis((0, filter_shape[1]), name="r") s = te.reduce_axis((0, filter_shape[2]), name="s") conv = te.compute( (input_shape[0], filter_shape[3], input_shape[2], input_shape[3], filter_shape[4]), lambda w, ko, h, n, ki: te.sum( packed_data[w, c, h * input_shape[3] + n, c4].astype("float32") * packed_filter[ c * input_shape[-1] + c4, ((r * filter_shape[2]) + s) * filter_shape[3] + ko, ki ].astype("float32"), axis=[r, s, c, c4], ), name="conv2d_1x1", ) return data, filt, packed_data, packed_filter, conv def schedule_conv2d_1x1_WCHNc_CRSKk(data, filt, packed_data, packed_filter, conv): # data: [W, C, H*N, c] # filter: [C, R*S*K, k] # output: [W, K, H, N, k] # conv2d( [N, C, H, W, c] , [1, 1, C, K, k] # inputs: (1, 128//4, 56, 56, 4), (1, 1, 128, 128//4, 4) # data: (56, 128//4, 56*1, 4) = (56, 32, 56, 4) # filt: (128, 1*1*128//4, 4) = (128, 32, 4) # conv: (56, 32, 56, 1, 4) s = te.create_schedule(conv.op) cfg = autotvm.get_config() s[packed_data].compute_inline() s[packed_filter].compute_inline() A, B, C = packed_data, packed_filter, conv At = s.cache_read(A, "global.texture", [C]) Bt = s.cache_read(B, "global.texture", [C]) Al = s.cache_read(At, "local", [C]) Bl = s.cache_read(Bt, "local", [C]) Cl = s.cache_write(C, "local") def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(At) copy_to_texture(Bt) _w, _ko, _h, _n, _ki = s[C].op.axis kernel_scope, _n = s[C].split(_n, nparts=1) cfg.define_split("tile_f", _ko, num_outputs=4) cfg.define_split("tile_w", _w, num_outputs=4) cfg.define_split("tile_h", _h, num_outputs=4) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) bk, vk, tk, ki = cfg["tile_f"].apply(s, C, _ko) bw, vw, tw, wi = cfg["tile_w"].apply(s, C, _w) bh, vh, th, hi = cfg["tile_h"].apply(s, C, _h) s[C].reorder(bh, _n, vh, th, hi) bhn = s[C].fuse(bh, _n) s[C].bind(bk, te.thread_axis("blockIdx.z")) s[C].bind(bhn, te.thread_axis("blockIdx.y")) s[C].bind(bw, te.thread_axis("blockIdx.x")) s[C].bind(vk, te.thread_axis("vthread")) s[C].bind(vh, te.thread_axis("vthread")) s[C].bind(vw, te.thread_axis("vthread")) s[C].bind(tk, te.thread_axis("threadIdx.z")) s[C].bind(th, te.thread_axis("threadIdx.y")) s[C].bind(tw, te.thread_axis("threadIdx.x")) s[C].reorder(bw, bk, bhn, vw, vk, vh, tw, tk, th, ki, hi, wi, _ki) s[C].vectorize(_ki) # TODO(csullivan): Try uneven workgroup split # _wo, _wi = s[C].split(_w, factor=4) # #_hno, _hni = s[C].split(_hn, factor=8) # #s[C].reorder(_wo, _wi, _ko, _hno, _hni, _ki) # s[C].reorder(_wo, _ko, _hn, _ki, _wi) # s[C].unroll(_wi) # # mace: # # const int out_ch_blk = get_global_id(0); # # const int out_w_blk = get_global_id(1); # # const int out_hb = get_global_id(2); # bx = te.thread_axis("blockIdx.x") # by = te.thread_axis("blockIdx.y") # bz = te.thread_axis("blockIdx.z") # s[C].bind(_ko, bx) # s[C].bind(_wo, by) # s[C].bind(_hn, bz) # s[Cl].compute_at(s[C], _hn) s[Cl].compute_at(s[C], th) _wl, _kol, _hl, _nl, _kil = s[Cl].op.axis _khl, _kwl, _cl, _cl4 = s[Cl].op.reduce_axis cfg.define_split("tile_c", _cl, num_outputs=2) cfg.define_split("tile_kh", _khl, num_outputs=2) cfg.define_split("tile_kw", _kwl, num_outputs=2) _clo, _cli = cfg["tile_c"].apply(s, Cl, _cl) _khlo, _khli = cfg["tile_kh"].apply(s, Cl, _khl) _kwlo, _kwli = cfg["tile_kw"].apply(s, Cl, _kwl) # s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli, _kol, _hl, _nl, _kil, _wl) # s[Cl].reorder(_clo, _khlo, _kwlo, _cli, _cl4, _khli, _kwli) # s[Cl].reorder(_cl, _cl4, _kil, _wl) s[Cl].unroll(_cl4) s[Cl].unroll(_wl) s[Cl].vectorize(_kil) _wla, _cla, _hnla, _cl4a = s[Al].op.axis s[Al].compute_at(s[Cl], _cli) s[Al].vectorize(_cl4a) s[Al].unroll(_wla) _clb, _rskolb, _kilb = s[Bl].op.axis s[Bl].compute_at(s[Cl], _cli) s[Bl].vectorize(_kilb) s[Bl].unroll(_clb) s[C].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) WO, K, HO, N, K4 = get_const_tuple(C.shape) RSC, _, _ = get_const_tuple(B.shape) cfg.add_flop(2 * N * K * K4 * HO * WO * RSC) return s def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None): """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") rcc = te.reduce_axis((0, in_channel_chunk), name="rc") rcb = te.reduce_axis((0, in_channel_block), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # NCHWc x KCRSk # texture: NCH|W|c # texture: K|CRS|k # c = crs//RS # rs = crs % RS # r = rs // W == (crs // S) % R # s = rs % W == crs % S Filter = te.compute( (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), lambda ffc, crs, ffb: Filter[ ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb ], name="packed_filter", ) return te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( temp[ nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb ].astype(out_dtype) * Filter[ ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb ].astype(out_dtype), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc_kcrsk_texture", ) def schedule_conv2d_NCHWc_KCRSk(cfg, s, conv): """schedule optimized for batch size = 1""" ##### space definition begin ##### n, fc, y, x, fb = s[conv].op.axis rcc, rcb, ry, rx = s[conv].op.reduce_axis cfg.define_split("tile_fc", fc, num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_rcc", rcc, num_outputs=2) cfg.define_split("tile_ry", ry, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) pad_data, flattened_kernel = s[conv].op.input_tensors kernel = s[flattened_kernel].op.input_tensors[0] s[flattened_kernel].compute_inline() s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() kernel = flattened_kernel if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") else: output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv # create cache stage AT = s.cache_read(pad_data, "global.texture", [OL]) WT = s.cache_read(kernel, "global.texture", [OL]) def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(AT) copy_to_texture(WT) AA = s.cache_read(AT, "shared", [OL]) WW = s.cache_read(WT, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bf = s[output].fuse(n, bf) s[output].bind(bf, te.thread_axis("blockIdx.z")) s[output].bind(by, te.thread_axis("blockIdx.y")) s[output].bind(bx, te.thread_axis("blockIdx.x")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) s[output].bind(tf, te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) s[output].vectorize(fb) s[OL].compute_at(s[output], tx) # tile reduction axes n, fc, y, x, fb = s[OL].op.axis rcc, rcb, ry, rx = s[OL].op.reduce_axis rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) # TODO(csullivan): check position of rcb s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) s[OL].vectorize(fb) s[OL].unroll(rcb) s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: if load == WW: n, fyx, v = s[load].op.axis fused = s[load].fuse(n, fyx) else: n, f, y, x, v = s[load].op.axis fused = s[load].fuse(n, f, y, x) tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) s[load].vectorize(v) # unroll s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) _, ICKHKW, _ = get_const_tuple(kernel.shape) if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") rcc = te.reduce_axis((0, in_channel_chunk), name="rc") rcb = te.reduce_axis((0, in_channel_block), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # NCHWc x KCRSk # texture: NCH|W|c # texture: K|CRS|k # c = crs//RS # rs = crs % RS # r = rs // W == (crs // S) % R # s = rs % W == crs % S Filter = te.compute( (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), lambda ffc, crs, ffb: Filter[ ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb ], name="packed_filter", ) conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( ( temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] * Filter[ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb] ).astype(out_dtype), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc_kcrsk_texture", ) output = te.compute(conv.shape, lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float32")) return output def schedule_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): """schedule optimized for batch size = 1""" conv = output.op.input_tensors[0] ##### space definition begin ##### n, fc, y, x, fb = s[conv].op.axis rcc, rcb, ry, rx = s[conv].op.reduce_axis cfg.define_split("tile_fc", fc, num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_rcc", rcc, num_outputs=2) cfg.define_split("tile_ry", ry, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) pad_data, flattened_kernel = s[conv].op.input_tensors kernel = s[flattened_kernel].op.input_tensors[0] s[flattened_kernel].compute_inline() s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() kernel = flattened_kernel if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") else: output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv # create cache stage AT = s.cache_read(pad_data, "global.texture", [OL]) WT = s.cache_read(kernel, "global.texture", [OL]) def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(AT) copy_to_texture(WT) AA = s.cache_read(AT, "shared", [OL]) WW = s.cache_read(WT, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bf = s[output].fuse(n, bf) s[output].bind(bf, te.thread_axis("blockIdx.z")) s[output].bind(by, te.thread_axis("blockIdx.y")) s[output].bind(bx, te.thread_axis("blockIdx.x")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) s[output].bind(tf, te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) s[output].vectorize(fb) s[OL].compute_at(s[output], tx) # tile reduction axes n, fc, y, x, fb = s[OL].op.axis rcc, rcb, ry, rx = s[OL].op.reduce_axis rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) # TODO(csullivan): check position of rcb s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) s[OL].vectorize(fb) s[OL].unroll(rcb) s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: if load == WW: n, fyx, v = s[load].op.axis fused = s[load].fuse(n, fyx) else: n, f, y, x, v = s[load].op.axis fused = s[load].fuse(n, f, y, x) tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) s[load].vectorize(v) # unroll s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) _, ICKHKW, _ = get_const_tuple(kernel.shape) if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) def compute_depthwise_conv2d_NCHWc_KCRSk_acc32( Input, Filter, stride, padding, dilation, out_dtype=None ): """Depthwise convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation batch, channel_chunk, in_height, in_width, channel_block = Input.shape _, channel_multiplier, kernel_h, kernel_w, _ = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) out_channel_chunk = simplify(channel_chunk * channel_multiplier) out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph pad_before = [0, 0, pad_top, pad_left, 0] pad_after = [0, 0, pad_down, pad_right, 0] temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] # NCHWc x CMRS # texture: NCH|W|c # texture: C|MRS|c # output: N # m = mrs//RS # rs = mrs % RS # r = rs // W == (mrs // S) % R # s = rs % W == mrs % S Filter = te.compute( (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), lambda ffc, mrs, ffb: Filter[ ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb ], name="packed_filter", ) conv = te.compute( (batch, out_channel_chunk, out_height, out_width, channel_block), lambda nn, ffc, yy, xx, ffb: te.sum( ( temp[ nn, ffc // channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb, ] * Filter[ ffc // channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb, ] ).astype(out_dtype), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk_texture", ) return te.compute( conv.shape, lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype("float32") ) def schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output): """schedule optimized for batch size = 1""" conv = output.op.input_tensors[0] ##### space definition begin ##### n, fc, y, x, fb = s[conv].op.axis ry, rx = s[conv].op.reduce_axis cfg.define_split("tile_fc", fc, num_outputs=4) cfg.define_split("tile_y", y, num_outputs=4) cfg.define_split("tile_x", x, num_outputs=4) cfg.define_split("tile_ry", ry, num_outputs=2) cfg.define_split("tile_rx", rx, num_outputs=2) cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) pad_data, flattened_kernel = s[conv].op.input_tensors kernel = s[flattened_kernel].op.input_tensors[0] s[flattened_kernel].compute_inline() s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() kernel = flattened_kernel if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") else: output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv # create cache stage AT = s.cache_read(pad_data, "global.texture", [OL]) WT = s.cache_read(kernel, "global.texture", [OL]) def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) block, thread = s[stage].split(fused, factor=32) s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) copy_to_texture(AT) copy_to_texture(WT) AA = s.cache_read(AT, "shared", [OL]) WW = s.cache_read(WT, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis kernel_scope, n = s[output].split(n, nparts=1) bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bf = s[output].fuse(n, bf) s[output].bind(bf, te.thread_axis("blockIdx.z")) s[output].bind(by, te.thread_axis("blockIdx.y")) s[output].bind(bx, te.thread_axis("blockIdx.x")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) s[output].bind(tf, te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) s[output].vectorize(fb) s[OL].compute_at(s[output], tx) # tile reduction axes n, fc, y, x, fb = s[OL].op.axis ry, rx = s[OL].op.reduce_axis ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) s[OL].vectorize(fb) # s[OL].unroll() s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: if load == WW: n, fyx, v = s[load].op.axis fused = s[load].fuse(n, fyx) else: n, f, y, x, v = s[load].op.axis fused = s[load].fuse(n, f, y, x) tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) s[load].vectorize(v) # unroll s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) ICC, MKHKW, ICB = get_const_tuple(kernel.shape) M = (OCC * OCB) // (ICC * ICB) KHKW = MKHKW // M if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) def scheduler(compute, schedule, *args, **kwargs): placeholders = compute(*args) s = schedule(*placeholders, **kwargs) return s, placeholders def conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape): placeholders = compute_conv2d_1x1_NCHWc_RSCKk(input_shape, filter_shape) s = schedule_conv2d_1x1_NCHWc_RSCKk(*placeholders) return s, placeholders def conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape): placeholders = compute_conv2d_1x1_WCHNc_CRSKk(input_shape, filter_shape) s = schedule_conv2d_1x1_WCHNc_CRSKk(*placeholders) return s, (placeholders[0], placeholders[1], placeholders[-1]) def conv2d_NCHWc_KCRSk(input_shape, filter_shape): data = te.placeholder(input_shape, name="data", dtype="float32") filt = te.placeholder(filter_shape, name="filter", dtype="float32") conv = compute_conv2d_NCHWc_KCRSk(data, filt, [1, 1], [0, 0], [1, 1], "float32") cfg = autotvm.get_config() s = te.create_schedule([x.op for x in [conv]]) schedule_conv2d_NCHWc_KCRSk(cfg, s, conv) return s, (data, filt, conv) def conv2d_NCHWc_KCRSk_fp32_acc(input_shape, filter_shape): data = te.placeholder(input_shape, name="data", dtype="float32") filt = te.placeholder(filter_shape, name="filter", dtype="float32") output = compute_conv2d_NCHWc_KCRSk_acc32(data, filt, [1, 1], [0, 0], [1, 1], "float32") cfg = autotvm.get_config() s = te.create_schedule([x.op for x in [output]]) schedule_conv2d_NCHWc_KCRSk_acc32(cfg, s, output) return s, (data, filt, output) def depthwise_conv2d_NCHWc_KCRSk_acc32(input_shape, filter_shape): data = te.placeholder(input_shape, name="data", dtype="float32") filt = te.placeholder(filter_shape, name="filter", dtype="float32") output = compute_depthwise_conv2d_NCHWc_KCRSk_acc32( data, filt, [1, 1], [0, 0], [1, 1], "float32" ) cfg = autotvm.get_config() s = te.create_schedule([x.op for x in [output]]) schedule_depthwise_conv2d_NCHWc_KCRSk_acc32(cfg, s, output) return s, (data, filt, output) def ref_convolution(data, kernel, stride, pad): import mxnet as mx groups = 1 kernel_size = (kernel.shape[2], kernel.shape[3]) num_filter = kernel.shape[0] ref_res = mx.nd.Convolution( data=mx.nd.array(data), weight=mx.nd.array(kernel), bias=None, no_bias=True, kernel=kernel_size, stride=stride, pad=pad, num_filter=num_filter, num_group=groups, ) return ref_res.asnumpy() def ref_depthwise_convolution(data, kernel, stride, pad): import mxnet as mx groups = kernel.shape[0] kernel_size = (kernel.shape[2], kernel.shape[3]) num_filter = kernel.shape[0] multiplier = kernel.shape[1] ref_res = mx.nd.Convolution( data=mx.nd.array(data), weight=mx.nd.array(kernel), bias=None, no_bias=True, kernel=kernel_size, stride=stride, pad=pad, num_filter=num_filter, num_group=groups, ) return ref_res.asnumpy() def validate(workload, target, dev, input_shapes, *args, **kwargs): s, placeholders = workload(*input_shapes, *args, **kwargs) func = tvm.driver.build(s, [*placeholders], target=target, name="TestFunction") args_tvm = [] args_np = [] for var in placeholders[:-1]: var_np = np.random.uniform(size=[i.value for i in var.shape]).astype(var.dtype) args_np.append(var_np) args_tvm.append(tvm.nd.array(var_np, dev)) args_tvm.append( tvm.nd.array( np.zeros([i.value for i in placeholders[-1].shape], dtype=placeholders[-1].dtype), dev ) ) func(*args_tvm) if "plus_one" in workload.__name__: np_result = args_np[0] + 1.0 elif "matmul" in workload.__name__: if "inner" in workload.__name__: np_result = np.matmul( args_np[0].reshape(32, 256), args_np[1].reshape(32, 256).transpose(1, 0) ) elif "accum" in workload.__name__: np_result = np.matmul( args_np[0].transpose((1, 0, 2)).reshape(64, 128), args_np[1].reshape(128, 64) ) else: np_result = np.matmul( args_np[0].transpose((0, 2, 1)).reshape(128, 64), args_np[1].transpose(1, 0, 2).reshape(64, 128), ) elif "conv2d_1x1_NCHWc_RSCKk" in workload.__name__: vec_length = args_np[1].shape[-1] # nchwc -> nchw args_np[0] = ( args_np[0] .transpose((0, 1, 4, 2, 3)) .reshape( args_np[0].shape[0], args_np[0].shape[1] * args_np[0].shape[-1], args_np[0].shape[2], args_np[0].shape[3], ) ) # rsckk -> rsck -> kcrs args_np[1] = ( args_np[1] .reshape( args_np[1].shape[0], args_np[1].shape[1], args_np[1].shape[2], args_np[1].shape[3] * args_np[1].shape[4], ) .transpose((3, 2, 0, 1)) ) np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) # nkhw -> nkhwk np_result = np_result.reshape( np_result.shape[0], np_result.shape[1] // vec_length, vec_length, np_result.shape[2], np_result.shape[3], ).transpose(0, 1, 3, 4, 2) elif "conv2d_1x1_WCHNc_CRSKk" in workload.__name__: vec_length = args_np[1].shape[-1] # wchnc -> nchw args_np[0] = ( args_np[0] .transpose((3, 1, 4, 2, 0)) .reshape( args_np[0].shape[3], args_np[0].shape[1] * args_np[0].shape[-1], args_np[0].shape[2], args_np[0].shape[0], ) ) # crskk -> crsk -> kcrs args_np[1] = ( args_np[1] .reshape( args_np[1].shape[0], args_np[1].shape[1], args_np[1].shape[2], args_np[1].shape[3] * args_np[1].shape[4], ) .transpose((3, 0, 1, 2)) ) np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) # nkhw -> nkkhw -> wkhnk np_result = np_result.reshape( np_result.shape[0], np_result.shape[1] // vec_length, vec_length, np_result.shape[2], np_result.shape[3], ).transpose(4, 1, 3, 0, 2) elif "NCHW_KCRS" in workload.__name__: np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) elif "NCHWc_KCRSk" in workload.__name__: vec_length = args_np[1].shape[-1] # nchwc -> nchw args_np[0] = ( args_np[0] .transpose((0, 1, 4, 2, 3)) .reshape( args_np[0].shape[0], args_np[0].shape[1] * args_np[0].shape[-1], args_np[0].shape[2], args_np[0].shape[3], ) ) # kcrsk/cmrsc -> kcrs/cmrs args_np[1] = ( args_np[1] .transpose((0, 4, 1, 2, 3)) .reshape( args_np[1].shape[0] * args_np[1].shape[4], args_np[1].shape[1], args_np[1].shape[2], args_np[1].shape[3], ) ) if "depthwise" in workload.__name__: # np_result = testing.depthwise_conv2d_python_nchw(args_np[0], args_np[1], 1, "VALID") np_result = ref_depthwise_convolution(args_np[0], args_np[1], [], []) else: # np_result = testing.conv2d_nchw_python(args_np[0], args_np[1], 1, 0) np_result = ref_convolution(args_np[0], args_np[1], [], []) # nkhw -> nkhwk np_result = np_result.reshape( np_result.shape[0], np_result.shape[1] // vec_length, vec_length, np_result.shape[2], np_result.shape[3], ).transpose(0, 1, 3, 4, 2) np.testing.assert_allclose(args_tvm[-1].asnumpy(), np_result, rtol=1e-2, atol=1e-2) class BaseSingleShapeValidator: @tvm.testing.parametrize_targets("opencl") def test_unary(self, test_func, input_shape, target, dev): validate(test_func, target, dev, [input_shape]) class TestPlusOneRank3(BaseSingleShapeValidator): input_shape = tvm.testing.parameter((32, 32, 4)) def plus_one(input_shape): return scheduler(compute_plus_one_rank3, schedule_plus_one_rank3, input_shape) test_func = tvm.testing.parameter(plus_one) class TestPlusOneRank5(BaseSingleShapeValidator): input_shape = tvm.testing.parameter((32, 2, 4, 4, 4)) def plus_one(input_shape): return scheduler(compute_plus_one_rank5, schedule_plus_one_rank5, input_shape) test_func = tvm.testing.parameter(plus_one) class TestMatmul: input_shape = tvm.testing.parameter((32, 64, 4)) local = tvm.testing.parameter(False, True) def matmul(input_shape, local): return scheduler(compute_matmul, schedule_matmul, input_shape, local=local) def matmul_inner(input_shape, local): return scheduler(compute_matmul_inner, schedule_matmul_inner, input_shape, local=local) test_func = tvm.testing.parameter(matmul, matmul_inner) @tvm.testing.parametrize_targets("opencl") def test_matmul(self, test_func, input_shape, local, target, dev): validate(test_func, target, dev, [input_shape], local=local) class TestMatmulVectorAccumulator: shapeA = tvm.testing.parameter((32, 64, 4)) shapeB = tvm.testing.parameter((128, 16, 4)) local = tvm.testing.parameter(False, True) def matmul_vector_accumulator(shapeA, shapeB, local): return scheduler( compute_matmul_vector_accumulator, schedule_matmul_vector_accumulator, shapeA, shapeB, local=local, ) test_func = tvm.testing.parameter(matmul_vector_accumulator) @tvm.testing.parametrize_targets("opencl") def test_matmul_vec_acc(self, test_func, shapeA, shapeB, local, target, dev): validate(test_func, target, dev, [shapeA, shapeB], local=local) class BaseConv2DValidator: @tvm.testing.parametrize_targets("opencl") def test_conv2d(self, test_func, input_shapes, target, dev): validate(test_func, target, dev, input_shapes) class TestConv2dNCHWcRSCKk(BaseConv2DValidator): input_shapes = tvm.testing.parameter([(1, 32, 56, 56, 4), (1, 1, 128, 32, 4)]) test_func = tvm.testing.parameter(conv2d_1x1_NCHWc_RSCKk) class TestConv2dWCHNcCRSKk(BaseConv2DValidator): input_shapes = tvm.testing.parameter([(56, 32, 56, 1, 4), (128, 1, 1, 32, 4)]) test_func = tvm.testing.parameter(conv2d_1x1_WCHNc_CRSKk) class TestConv2dNCHWcKCRSk(BaseConv2DValidator): input_shapes = tvm.testing.parameter( [(1, 32, 56, 56, 4), (32, 128, 1, 1, 4)], [(1, 32, 112, 112, 4), (32, 128, 3, 3, 4)] ) test_func = tvm.testing.parameter(conv2d_NCHWc_KCRSk, conv2d_NCHWc_KCRSk_fp32_acc) class TestDepthwiseConv2dNCHWcKCRSk(BaseConv2DValidator): input_shapes = tvm.testing.parameter([(1, 24, 257, 257, 4), (24, 1, 3, 3, 4)]) test_func = tvm.testing.parameter(depthwise_conv2d_NCHWc_KCRSk_acc32) if __name__ == "__main__": sys.exit(pytest.main(sys.argv))