# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Utilities for testing external code generation""" import os import sys import pytest import tvm from tvm import relay, runtime from tvm.contrib import utils from tests.python.relay.aot.aot_test_utils import AOTTestModel, compile_and_run skip_windows = pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") skip_micro = pytest.mark.skipif( tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON", reason="MicroTVM support not enabled. Set USE_MICRO=ON in config.cmake to enable.", ) def parametrize_external_codegen_checks(test): """Parametrize over the various check_result functions which are available""" return pytest.mark.parametrize( "check_result", [ pytest.param(check_aot_executor_result, marks=[skip_windows, skip_micro]), pytest.param(check_graph_executor_result, marks=[skip_windows]), pytest.param(check_vm_result, marks=[skip_windows]), ], )(test) def parametrize_external_json_codegen_checks(test): """Parametrize over the various check_result functions which are available for JSON""" return pytest.mark.parametrize( "check_result", [ pytest.param(check_graph_executor_result, marks=[skip_windows]), pytest.param(check_vm_result, marks=[skip_windows]), ], )(test) def update_lib(lib): test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = utils.tempdir() lib_name = "lib.so" lib_path = tmp_path.relpath(lib_name) lib.export_library(lib_path, fcompile=False, **kwargs) lib = tvm.runtime.load_module(lib_path) return lib def check_vm_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu()): with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): exe = relay.vm.compile(mod, target=target) code, lib = exe.save() lib = update_lib(lib) exe = runtime.vm.Executable.load_exec(code, lib) vm = runtime.vm.VirtualMachine(exe, device) out = vm.run(**map_inputs) tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) def check_graph_executor_result( mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu() ): with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): executor_factory = relay.build(mod, target=target) lib = update_lib(executor_factory.lib) rt_mod = tvm.contrib.graph_executor.create(executor_factory.graph_json, lib, device) for name, data in map_inputs.items(): rt_mod.set_input(name, data) rt_mod.run() out = tvm.nd.empty(out_shape, device=device) out = rt_mod.get_output(0, out) tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol) def check_aot_executor_result( mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu() ): # Late import to avoid breaking test with USE_MICRO=OFF. from aot.aot_test_utils import AOTTestModel, AOT_DEFAULT_RUNNER, compile_and_run interface_api = "packed" use_unpacked_api = False test_runner = AOT_DEFAULT_RUNNER compile_and_run( AOTTestModel(module=mod, inputs=map_inputs, outputs=[result]), test_runner, interface_api, use_unpacked_api, ) def set_external_func_attr(func, compiler, ext_symbol): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", compiler) func = func.with_attr("global_symbol", ext_symbol) return func