# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import tvm from tvm import te import numpy @pytest.fixture def mod_without_attrs(): ib = tvm.tir.ir_builder.create() A = tvm.tir.decl_buffer(name="A", shape=[1]) stmt = ib.get() return tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], stmt)) @pytest.fixture def mod(mod_without_attrs): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( mod_without_attrs ) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) return mod def test_fails_if_not_global_symbol(mod_without_attrs): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))( mod_without_attrs ) with pytest.raises(tvm.TVMError, match="Expect PrimFunc to have the global_symbol attribute"): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] def test_fails_if_no_target(mod_without_attrs): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod_without_attrs) with pytest.raises(tvm.TVMError, match="Require the target attribute"): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] @tvm.testing.parametrize_targets("c", "llvm", "cuda") def test_device_setup(mod, target, dev): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 assert f.params[0].name == "A" assert f.body.node == "default" assert f.body.attr_key == "device_id" assert f.body.value == 0 assert f.body.body.node == "default" assert f.body.body.attr_key == "device_type" assert f.body.body.value == dev.device_type def test_no_buffers_no_device_setup(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], stmt)) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 assert f.params[0].name == "A" def test_argument_mapping(mod): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 assert f.params[0].name == "A" def test_argument_mapping_multiple(): ib = tvm.tir.ir_builder.create() A = tvm.tir.decl_buffer(name="A", shape=[1]) B = tvm.tir.decl_buffer(name="B", shape=[1]) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, B], stmt)) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 2 assert f.params[0].name == "A" assert f.params[1].name == "B" def test_argument_mapping_multiple_matching(): ib = tvm.tir.ir_builder.create() A = tvm.tir.decl_buffer(name="A", shape=[1]) B = tvm.tir.decl_buffer(name="B", shape=[1]) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, A], stmt)) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 2 assert f.params[0].name == "A" assert f.params[1].name == "A" def test_body(): ib = tvm.tir.ir_builder.create() A = tvm.tir.decl_buffer(name="A", shape=[1]) B = tvm.tir.decl_buffer(name="B", shape=[1]) C = ib.buffer_ptr(A) stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, B, C], stmt)) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 3 assert f.params[0].name == "A" assert f.params[1].name == "B" assert f.params[2].name == "A" if __name__ == "__main__": pytest.main([__file__])