# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest pytest.importorskip("ethosu.vela") import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute import tvm.relay.backend.contrib.ethosu.op as ethosu_ops def test_ethosu_conv2d(): ifm = relay.var("ifm", shape=(1, 10, 20, 30), dtype="uint8") weight = relay.var("weight", shape=(40, 3, 3, 30), dtype="uint8") scale_bias = relay.var("scale_bias", shape=(40, 10), dtype="uint8") lut = relay.var("lut", shape=(), dtype="uint8") conv = ethosu_ops.ethosu_conv2d( ifm, weight, scale_bias, lut, ifm_scale=0.5, ifm_zero_point=10, weight_zero_point=12, ofm_scale=0.25, ofm_zero_point=14, ofm_channels=40, padding=(1, 1, 1, 1), kernel_shape=(3, 3), strides=(1, 1), dilation=(1, 1), ) expr = relay.Function(relay.analysis.free_vars(conv), conv) mod = tvm.IRModule.from_expr(expr) mod = relay.transform.InferType()(mod) lowered = lower_to_te(mod["main"]) assert len(lowered.outputs) == 1 assert len(lowered.inputs) == 4 conv2d_compute = Convolution2DCompute.from_output(lowered.outputs[0]) assert conv2d_compute.conv2d.name == "ethosu_conv2d" input_shapes = set() for inp in lowered.inputs: input_shapes.add(tuple([x.value for x in inp.shape])) assert input_shapes == {(40, 10), (1, 10, 20, 30), (40, 3, 3, 30), ()} if __name__ == "__main__": pytest.main([__file__])