/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file intrin_rule_spirv.cc */ #include #include #include #include #include #include namespace tvm { namespace codegen { namespace spirv { // num_signature means number of arguments used to query signature template PrimExpr CallGLSLIntrin(PrimExpr e, const Array& args) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : args) { cargs.push_back(arg); } return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs); } template PrimExpr CallGLSLIntrin(PrimExpr e) { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); return CallGLSLIntrin(e, call->args); } template inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) { return CallGLSLIntrin(e); } namespace intrin { using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.floor") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.ceil") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.round") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.trunc") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.fabs") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.exp").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.sin").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.cos").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.log").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.log2") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.sqrt") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); // WebGPU rules. TVM_REGISTER_OP("tir.floor") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.ceil") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.round") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.trunc") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.fabs") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.sqrt") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); TVM_REGISTER_OP("tir.tanh") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); } // namespace intrin namespace legalize { using tir::FLegalize; TVM_REGISTER_OP("tir.clz").set_attr( "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; PrimExpr msb; if (arg.dtype().bits() == 64) { // SPIR-V FindUMsb intrinsic only supports 32 bit input auto int32 = DataType::Int(32); PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32); PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg); PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); } else if (arg.dtype().bits() == 32) { msb = CallGLSLIntrin(e); } else { LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; } return PrimExpr(arg.dtype().bits() - 1) - msb; }); } // namespace legalize } // namespace spirv } // namespace codegen } // namespace tvm