/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * \file webgpu_runtime.cc * \brief WebGPU runtime based on the TVM JS. */ // configurations for tvm logging #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 #define DMLC_USE_LOGGING_LIBRARY #include #include #include #include #include #include #include #include "../../src/runtime/meta_data.h" #include "../../src/runtime/vulkan/vulkan_shader.h" #include "../../src/runtime/workspace_pool.h" namespace tvm { namespace runtime { /*! \brief Thread local workspace */ class WebGPUThreadEntry { public: /*! \brief thread local pool*/ WorkspacePool pool; /*! \brief constructor */ WebGPUThreadEntry(); // get the threadlocal workspace static WebGPUThreadEntry* ThreadLocal(); }; // All the implementations are redirectly to the JS side. class WebGPUDeviceAPI : public DeviceAPI { public: WebGPUDeviceAPI() { auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUDeviceAPI"); CHECK(fp != nullptr) << "Cannot find wasm.WebGPUContext in the env"; auto getter = TypedPackedFunc(*fp); alloc_space_ = getter("deviceAllocDataSpace"); free_space_ = getter("deviceFreeDataSpace"); copy_to_gpu_ = getter("deviceCopyToGPU"); copy_from_gpu_ = getter("deviceCopyFromGPU"); copy_within_gpu_ = getter("deviceCopyWithinGPU"); } void SetDevice(Device dev) final {} void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final { if (kind == kExist) { *rv = 1; } } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { double ptr_number = alloc_space_(nbytes); return reinterpret_cast(static_cast(ptr_number)); } void FreeDataSpace(Device dev, void* ptr) final { return free_space_(ptr); } protected: void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) final { if (static_cast(dev_from.device_type) == kDLWebGPU && static_cast(dev_to.device_type) == kDLWebGPU) { CHECK_EQ(dev_from.device_id, dev_to.device_id); copy_within_gpu_(const_cast(from), from_offset, to, to_offset, size); } else if (static_cast(dev_from.device_type) == kDLWebGPU && dev_to.device_type == kDLCPU) { void* to_ptr = static_cast(to) + to_offset; copy_from_gpu_(const_cast(from), from_offset, to_ptr, size); } else if (dev_from.device_type == kDLCPU && static_cast(dev_to.device_type) == kDLWebGPU) { void* from_ptr = static_cast(const_cast(from)) + from_offset; copy_to_gpu_(from_ptr, to, to_offset, size); } else { LOG(FATAL) << "expect copy from/to WebGPU or between WebGPU"; } } public: TVMStreamHandle CreateStream(Device dev) final { LOG(FATAL) << "Not implemented"; return nullptr; } void FreeStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; return; } void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { LOG(FATAL) << "Not implemented"; return; } void StreamSync(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; return; } void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final { return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size); } void FreeWorkspace(Device dev, void* data) final { WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data); } static WebGPUDeviceAPI* Global() { static WebGPUDeviceAPI* inst = new WebGPUDeviceAPI(); return inst; } private: // NOTE: js return number as double. TypedPackedFunc alloc_space_; TypedPackedFunc free_space_; TypedPackedFunc copy_to_gpu_; TypedPackedFunc copy_from_gpu_; TypedPackedFunc copy_within_gpu_; }; typedef dmlc::ThreadLocalStore WebGPUThreadStore; WebGPUThreadEntry::WebGPUThreadEntry() : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) {} WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); } class WebGPUModuleNode final : public runtime::ModuleNode { public: explicit WebGPUModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) { auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); CHECK(fp != nullptr); create_shader_ = *fp; } const char* type_key() const final { return "webgpu"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { auto it = smap_.find(name); if (it != smap_.end()) { FunctionInfo info = fmap_.at(name); info.name = name; std::ostringstream os; dmlc::JSONWriter writer(&os); info.Save(&writer); TVMByteArray arr; arr.data = reinterpret_cast(it->second.data.data()); arr.size = it->second.data.size() * sizeof(it->second.data[0]); return create_shader_(os.str(), arr); } else { return PackedFunc(nullptr); } } void SaveToFile(const std::string& file_name, const std::string& format) final { LOG(FATAL) << "Not implemented"; } void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; } std::string GetSource(const std::string& format) final { // can only return source code. return source_; } private: // function information table. std::unordered_map smap_; // function information table. std::unordered_map fmap_; // The source std::string source_; // Callback to get the GPU function. TypedPackedFunc create_shader_; }; Module WebGPUModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::unordered_map smap; std::unordered_map fmap; std::string fmt; stream->Read(&fmt); stream->Read(&fmap); stream->Read(&smap); return Module(make_object(smap, fmap, "")); } // for now webgpu is hosted via a vulkan module. TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = WebGPUDeviceAPI::Global(); *rv = static_cast(ptr); }); } // namespace runtime } // namespace tvm