# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Test Meta Schedule Runner """ import itertools import sys import time from typing import Any, List import numpy as np import pytest import tvm from tvm._ffi import register_func from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput, LocalBuilder from tvm.meta_schedule.runner import ( EvaluatorConfig, LocalRunner, PyRunner, RPCConfig, RPCRunner, RunnerFuture, RunnerInput, ) from tvm.meta_schedule.runner.rpc_runner import ( default_alloc_argument as rpc_default_alloc_argument, T_ARG_INFO_JSON_OBJ_LIST, T_ARGUMENT_LIST, ) from tvm.meta_schedule.runner.local_runner import ( default_alloc_argument as local_default_alloc_argument, ) from tvm.meta_schedule.testing import LocalRPC from tvm.meta_schedule.utils import get_global_func_with_default_on_worker from tvm.rpc import RPCSession from tvm.runtime import Device, Module from tvm.script import tir as T from tvm.target import Target import tvm.testing from tvm.tir import FloatImm MATMUL_N = 16 MATMUL_M = 32 # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking @tvm.script.ir_module class MatmulModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i, j, k in T.grid(16, 16, 16): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @tvm.script.ir_module class MatmulReluModule: @T.prim_func def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") C = T.alloc_buffer((16, 16), "float32") for i, j, k in T.grid(16, 16, 16): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(16, 16): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0) @tvm.script.ir_module class BatchMatmulModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, [16, 32, 32]) B = T.match_buffer(b, [16, 32, 32]) C = T.match_buffer(c, [16, 32, 32]) for n, i, j, k in T.grid(16, 32, 32, 32): with T.block("update"): vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) with T.init(): C[vn, vi, vj] = 0.0 C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @tvm.script.ir_module class AddModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, [32], "float32") B = T.match_buffer(b, [32], "float32") C = T.match_buffer(c, [32], "float32") for i in range(32): with T.block("add"): vi = T.axis.S(32, i) C[vi] = A[vi] + B[vi] # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring def _clean_build(artifact_path: str) -> None: f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) if f_clean_build is not None: f_clean_build(artifact_path) else: raise RuntimeError("Unable to find remove_build_dir function.") def test_meta_schedule_rpc_single_run(): """Test meta schedule rpc runner for a single run""" # Build the module mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) with LocalRPC() as rpc: rpc_config = RPCConfig( tracker_host=rpc.tracker_host, tracker_port=rpc.tracker_port, tracker_key=rpc.tracker_key, session_priority=1, session_timeout_sec=100, ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = RPCRunner(rpc_config, evaluator_config) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 _clean_build(builder_result.artifact_path) def test_meta_schedule_local_single_run(): """Test meta schedule local runner for a single run""" # Build the module mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = LocalRunner(timeout_sec=100, evaluator_config=evaluator_config) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 _clean_build(builder_result.artifact_path) def test_meta_schedule_rpc_multiple_runs(): """Test meta schedule rpc runner for multiple runs""" # Build the module mods = [ MatmulModule, MatmulReluModule, BatchMatmulModule, ] builder = LocalBuilder() builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods] builder_results = builder.build(builder_inputs) for builder_result in builder_results: assert builder_result.artifact_path is not None assert builder_result.error_msg is None args_infos = [ [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], [ TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), ], ] runner_inputs = [ RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i]) for i in range(len(mods)) ] with LocalRPC() as rpc: rpc_config = RPCConfig( tracker_host=rpc.tracker_host, tracker_port=rpc.tracker_port, tracker_key=rpc.tracker_key, session_priority=1, session_timeout_sec=100, ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = RPCRunner(rpc_config, evaluator_config) # Run the module runner_futures = runner.run(runner_inputs) runner_results = [runner_future.result() for runner_future in runner_futures] for runner_result in runner_results: assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 for builder_result in builder_results: _clean_build(builder_result.artifact_path) def test_meta_schedule_local_multiple_runs(): """Test meta schedule local runner for multiple runs""" # Build the module mods = [ MatmulModule, MatmulReluModule, BatchMatmulModule, ] builder = LocalBuilder() builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods] builder_results = builder.build(builder_inputs) for builder_result in builder_results: assert builder_result.artifact_path is not None assert builder_result.error_msg is None args_infos = [ [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], [ TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), ], ] runner_inputs = [ RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i]) for i in range(len(mods)) ] evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = LocalRunner(timeout_sec=100, evaluator_config=evaluator_config) # Run the module runner_futures = runner.run(runner_inputs) runner_results = [runner_future.result() for runner_future in runner_futures] for runner_result in runner_results: assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 for builder_result in builder_results: _clean_build(builder_result.artifact_path) def test_meta_schedule_py_runner(): """Test meta schedule PyRunner""" class TestRunner(PyRunner): def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: raise ValueError("TestRunner") runner = TestRunner() with pytest.raises(ValueError, match="TestRunner"): runner.run([]) def test_meta_schedule_rpc_runner_time_out(): """Test meta schedule RPC Runner time out""" def initializer(): @register_func("meta_schedule.runner.test_time_out") def timeout_session_creator( # pylint: disable=unused-variable rpc_config: RPCConfig, # pylint: disable=unused-argument ) -> RPCSession: time.sleep(2) runner_input = RunnerInput( "test", "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) with LocalRPC() as rpc: rpc_config = RPCConfig( tracker_host=rpc.tracker_host, tracker_port=rpc.tracker_port, tracker_key=rpc.tracker_key, session_priority=1, session_timeout_sec=1, ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = RPCRunner( rpc_config, evaluator_config, initializer=initializer, f_create_session="meta_schedule.runner.test_time_out", ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is not None and runner_result.error_msg.startswith( "RPCRunner: Timeout, killed after" ) assert runner_result.run_secs is None def test_meta_schedule_local_runner_time_out(): """Test meta schedule Local Runner time out""" mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) def initializer(): @register_func("meta_schedule.runner.test_time_out") def timeout_session_creator( # pylint: disable=unused-variable device: Device, # pylint: disable=unused-argument args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument alloc_repeat: int, # pylint: disable=unused-argument ) -> RPCSession: time.sleep(2) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = LocalRunner( timeout_sec=1, evaluator_config=evaluator_config, initializer=initializer, f_alloc_argument="meta_schedule.runner.test_time_out", ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is not None and runner_result.error_msg.startswith( "LocalRunner: Timeout, killed after" ) assert runner_result.run_secs is None _clean_build(builder_result.artifact_path) def test_meta_schedule_rpc_runner_exception(): """Test meta schedule RPC Runner exception""" def initializer(): @register_func("meta_schedule.runner.test_exception") def exception_session_creator( # pylint: disable=unused-variable rpc_config: RPCConfig, # pylint: disable=unused-argument ) -> RPCSession: raise Exception("Test") runner_input = RunnerInput( "test", "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) with LocalRPC() as rpc: rpc_config = RPCConfig( tracker_host=rpc.tracker_host, tracker_port=rpc.tracker_port, tracker_key=rpc.tracker_key, session_priority=1, session_timeout_sec=100, ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = RPCRunner( rpc_config, evaluator_config, initializer=initializer, f_create_session="meta_schedule.runner.test_exception", ) (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is not None and runner_result.error_msg.startswith( "RPCRunner: An exception occurred\n" ) assert runner_result.run_secs is None def test_meta_schedule_local_runner_exception(): """Test meta schedule Local Runner exception""" mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) def initializer(): @register_func("meta_schedule.runner.test_exception") def timeout_session_creator( # pylint: disable=unused-variable device: Device, # pylint: disable=unused-argument args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument alloc_repeat: int, # pylint: disable=unused-argument ) -> RPCSession: raise Exception("Test") evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = LocalRunner( evaluator_config=evaluator_config, initializer=initializer, f_alloc_argument="meta_schedule.runner.test_exception", ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is not None and runner_result.error_msg.startswith( "LocalRunner: An exception occurred\n" ) assert runner_result.run_secs is None _clean_build(builder_result.artifact_path) def test_meta_schedule_runner_matmul_test(): """Test meta schedule runner with add module""" def _check_correct_matmul( args_before: List[np.ndarray], args_after: List[np.ndarray], ) -> None: a_before, b_before, c_before = args_before a_after, b_after, c_after = args_after c_before = np.matmul(a_before, b_before) assert (a_before == a_after).all() assert (b_before == b_after).all() tvm.testing.assert_allclose(c_before, c_after, rtol=1e-5) def test_alloc_argument( session: RPCSession, device: Device, args_info: Any, alloc_repeat: int, ) -> List[Any]: global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name repeated_args_before = [] # type: ignore repeated_args = rpc_default_alloc_argument(session, device, args_info, alloc_repeat) for args in repeated_args: repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore return repeated_args def test_run_evaluator( session: RPCSession, # pylint: disable=unused-argument rt_mod: Module, device: Device, evaluator_config: EvaluatorConfig, repeated_args: List[Any], ) -> List[float]: global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name repeated_args_after = [] evaluator = rt_mod.time_evaluator( func_name=rt_mod.entry_name, dev=device, number=evaluator_config.number, repeat=evaluator_config.repeat, min_repeat_ms=evaluator_config.min_repeat_ms, f_preproc="cache_flush_cpu_non_first_arg" if evaluator_config.enable_cpu_cache_flush else "", ) repeated_costs: List[List[float]] = [] for args in repeated_args: device.sync() profile_result = evaluator(*args) repeated_costs.append(profile_result.results) repeated_args_after.append([arg.numpy() for arg in args]) costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] for args_before, args_after in zip( repeated_args_before, # type: ignore repeated_args_after, ): _check_correct_matmul(args_before, args_after) del repeated_args_before # type: ignore return costs # Build the module mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), TensorInfo("float32", (MATMUL_N, MATMUL_N)), ], ) with LocalRPC() as rpc: rpc_config = RPCConfig( tracker_host=rpc.tracker_host, tracker_port=rpc.tracker_port, tracker_key=rpc.tracker_key, session_priority=1, session_timeout_sec=100, ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = RPCRunner( rpc_config, evaluator_config, f_alloc_argument=test_alloc_argument, f_run_evaluator=test_run_evaluator, ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 _clean_build(builder_result.artifact_path) def test_meta_schedule_runner_add_test(): """Test meta schedule runner with add module""" def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarray]) -> None: a_before, b_before, c_before = args_before a_after, b_after, c_after = args_after c_before = a_before + b_before assert (a_before == a_after).all() assert (b_before == b_after).all() assert (c_before == c_after).all() def test_alloc_argument( session: RPCSession, device: Device, args_info: Any, alloc_repeat: int, ) -> List[Any]: global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name repeated_args_before = [] # type: ignore repeated_args = rpc_default_alloc_argument( session, device, args_info, alloc_repeat, ) for args in repeated_args: repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore return repeated_args def test_run_evaluator( session: RPCSession, # pylint: disable=unused-argument rt_mod: Module, device: Device, evaluator_config: EvaluatorConfig, repeated_args: List[Any], ) -> List[float]: global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name repeated_args_after = [] evaluator = rt_mod.time_evaluator( func_name=rt_mod.entry_name, dev=device, number=evaluator_config.number, repeat=evaluator_config.repeat, min_repeat_ms=evaluator_config.min_repeat_ms, f_preproc="cache_flush_cpu_non_first_arg" if evaluator_config.enable_cpu_cache_flush else "", ) repeated_costs: List[List[float]] = [] for args in repeated_args: device.sync() profile_result = evaluator(*args) repeated_costs.append(profile_result.results) repeated_args_after.append([arg.numpy() for arg in args]) costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] for args_before, args_after in zip( repeated_args_before, # type: ignore repeated_args_after, ): _check_correct_add(args_before, args_after) del repeated_args_before # type: ignore return costs # Build the module mod = AddModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", [MATMUL_M]), TensorInfo("float32", [MATMUL_M]), TensorInfo("float32", [MATMUL_M]), ], ) with LocalRPC() as rpc: rpc_config = RPCConfig( tracker_host=rpc.tracker_host, tracker_port=rpc.tracker_port, tracker_key=rpc.tracker_key, session_priority=1, session_timeout_sec=100, ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = RPCRunner( rpc_config, evaluator_config, f_alloc_argument=test_alloc_argument, f_run_evaluator=test_run_evaluator, ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 _clean_build(builder_result.artifact_path) def test_meta_schedule_local_runner_add_test(): """Test meta schedule local runner with add module""" def _check_correct_add(args_before: List[np.array], args_after: List[np.array]) -> None: a_before, b_before, c_before = args_before a_after, b_after, c_after = args_after c_before = a_before + b_before assert (a_before == a_after).all() assert (b_before == b_after).all() assert (c_before == c_after).all() def test_alloc_argument( device: Device, args_info: T_ARG_INFO_JSON_OBJ_LIST, # pylint: disable=unused-argument alloc_repeat: int, ) -> List[T_ARGUMENT_LIST]: global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name repeated_args_before = [] repeated_args = local_default_alloc_argument(device, args_info, alloc_repeat) for args in repeated_args: repeated_args_before.append([arg.asnumpy() for arg in args]) return repeated_args def test_run_evaluator( rt_mod: Module, device: Device, evaluator_config: EvaluatorConfig, repeated_args: List[Any], ) -> List[float]: global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name repeated_args_after = [] evaluator = rt_mod.time_evaluator( func_name=rt_mod.entry_name, dev=device, number=evaluator_config.number, repeat=evaluator_config.repeat, min_repeat_ms=evaluator_config.min_repeat_ms, f_preproc="cache_flush_cpu_non_first_arg" if evaluator_config.enable_cpu_cache_flush else "", ) repeated_costs: List[List[float]] = [] for args in repeated_args: device.sync() profile_result = evaluator(*args) repeated_costs.append(profile_result.results) repeated_args_after.append([arg.asnumpy() for arg in args]) costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] for args_before, args_after in zip(repeated_args_before, repeated_args_after): _check_correct_add(args_before, args_after) del repeated_args_before return costs # Build the module mod = AddModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None assert builder_result.error_msg is None runner_input = RunnerInput( builder_result.artifact_path, "llvm", [ TensorInfo("float32", [MATMUL_M]), TensorInfo("float32", [MATMUL_M]), TensorInfo("float32", [MATMUL_M]), ], ) evaluator_config = EvaluatorConfig( number=1, repeat=1, min_repeat_ms=0, enable_cpu_cache_flush=False, ) runner = LocalRunner( timeout_sec=100, evaluator_config=evaluator_config, f_alloc_argument=test_alloc_argument, f_run_evaluator=test_run_evaluator, ) # Run the module (runner_future,) = runner.run([runner_input]) runner_result = runner_future.result() assert runner_result.error_msg is None for result in runner_result.run_secs: if isinstance(result, FloatImm): result = result.value assert isinstance(result, float) assert result >= 0.0 _clean_build(builder_result.artifact_path) if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))