/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "vulkan_device.h" #include #include #include #include #include "../../support/utils.h" #include "vulkan_common.h" #include "vulkan_device.h" #include "vulkan_device_api.h" #include "vulkan_instance.h" namespace tvm { namespace runtime { namespace vulkan { VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, const VulkanDevice& device) { /////////////////////////////////////////////////////////////// // Query properties from Vulkan API // /////////////////////////////////////////////////////////////// // Declare output locations for properties VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; VkPhysicalDeviceSubgroupProperties subgroup = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; // Need to do initial query in order to check the apiVersion. vkGetPhysicalDeviceProperties(device, &properties.properties); // Set up linked list for property query { void** pp_next = &properties.pNext; if (device.HasExtension("VK_KHR_driver_properties")) { *pp_next = &driver; pp_next = &driver.pNext; } if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { *pp_next = &subgroup; pp_next = &subgroup.pNext; } } // Declare output locations for features VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; VkPhysicalDevice8BitStorageFeatures storage_8bit = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; VkPhysicalDevice16BitStorageFeatures storage_16bit = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; // Set up linked list for feature query { void** pp_next = &features.pNext; if (device.HasExtension("VK_KHR_8bit_storage")) { *pp_next = &storage_8bit; pp_next = &storage_8bit.pNext; } if (device.HasExtension("VK_KHR_16bit_storage")) { *pp_next = &storage_16bit; pp_next = &storage_16bit.pNext; } if (device.HasExtension("VK_KHR_shader_float16_int8")) { *pp_next = &float16_int8; pp_next = &float16_int8.pNext; } } if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { // Preferred method, call to get all properties that can be queried. auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); vkGetPhysicalDeviceProperties2KHR(device, &properties); auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); vkGetPhysicalDeviceFeatures2KHR(device, &features); } else { // Fallback, get as many features as we can from the Vulkan1.0 // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. vkGetPhysicalDeviceFeatures(device, &features.features); } /////////////////////////////////////////////////////////////// // Fill member variables from Vulkan structures // /////////////////////////////////////////////////////////////// supports_float16 = float16_int8.shaderFloat16; supports_float32 = true; supports_float64 = features.features.shaderFloat64; supports_int8 = float16_int8.shaderInt8; supports_int16 = features.features.shaderInt16; supports_int32 = true; supports_int64 = features.features.shaderInt64; supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess; supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess; supports_storage_buffer_storage_class = device.HasExtension("VK_KHR_storage_buffer_storage_class"); // Support is available based on these extensions, but allow it to // be disabled based on an environment variable. supports_push_descriptor = device.HasExtension("VK_KHR_push_descriptor") && device.HasExtension("VK_KHR_descriptor_update_template") && !support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); // Support is available based on these extensions, but allow it to // be disabled based on an environment variable. supports_dedicated_allocation = device.HasExtension("VK_KHR_get_memory_requirements2") && device.HasExtension("VK_KHR_dedicated_allocation") && !support::BoolEnvironmentVar("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically // needed, since it will be set so long at least one queue has // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future // confusion.. supported_subgroup_operations = (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations; // Even if we can't query it, warp size must be at least 1. thread_warp_size = std::max(subgroup.subgroupSize, 1U); max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0]; max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1]; max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2]; max_push_constants_size = properties.properties.limits.maxPushConstantsSize; max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange; max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange; max_per_stage_descriptor_storage_buffer = properties.properties.limits.maxPerStageDescriptorStorageBuffers; max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize; device_name = properties.properties.deviceName; driver_version = properties.properties.driverVersion; if (device.HasExtension("VK_KHR_driver_properties")) { driver_name = driver.driverName; } switch (properties.properties.deviceType) { case VK_PHYSICAL_DEVICE_TYPE_OTHER: device_type = "other"; break; case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU: device_type = "integrated"; break; case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU: device_type = "discrete"; break; case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU: device_type = "virtual"; break; case VK_PHYSICAL_DEVICE_TYPE_CPU: device_type = "cpu"; break; default: LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType; break; } // By default, use the maximum API version that the driver allows, // so that any supported features can be used by TVM shaders. // However, if we can query the conformance version, then limit to // only using the api version that passes the vulkan conformance // tests. vulkan_api_version = properties.properties.apiVersion; if (device.HasExtension("VK_KHR_driver_properties")) { auto api_major = VK_VERSION_MAJOR(vulkan_api_version); auto api_minor = VK_VERSION_MINOR(vulkan_api_version); if ((api_major > driver.conformanceVersion.major) || ((api_major == driver.conformanceVersion.major) && (api_minor > driver.conformanceVersion.minor))) { vulkan_api_version = VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); } } // From "Versions and Formats" section of Vulkan spec. max_spirv_version = 0x10000; if (vulkan_api_version >= VK_API_VERSION_1_2) { max_spirv_version = 0x10500; } else if (device.HasExtension("VK_KHR_spirv_1_4")) { max_spirv_version = 0x10400; } else if (vulkan_api_version >= VK_API_VERSION_1_1) { max_spirv_version = 0x10300; } } VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); } VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions( VkDevice device) { vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); } VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device) : physical_device_(phy_device) { queue_family_index = SelectComputeQueueFamily(); if (queue_family_index == uint32_t(-1)) { // The GPU doesn't support compute, cannot use return; } enabled_extensions = SelectEnabledExtensions(); device_properties = VulkanDeviceProperties(instance, *this); CreateVkDevice(instance); // Currently, any exceptions called after this point will prevent // vkDestroyDevice from being called in the destructor. If this // becomes an issue, can split out the VulkanDevice into two // classes, one of which strictly holds the VkDevice, and one which // holds the ancillary handles that TVM needs. vkGetDeviceQueue(device_, queue_family_index, 0, &queue); // Find suitable memory type for staging and compute // Find suitable compute index. VkBuffer buffer; VkMemoryRequirements req_staging, req_compute; VkBufferCreateInfo info; info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; info.pNext = nullptr; info.flags = 0; info.size = 1024; info.queueFamilyIndexCount = 1; info.pQueueFamilyIndices = &queue_family_index; info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; // get staging requirement info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); vkGetBufferMemoryRequirements(device_, buffer, &req_staging); vkDestroyBuffer(device_, buffer, nullptr); // get compute requirement info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); vkGetBufferMemoryRequirements(device_, buffer, &req_compute); vkDestroyBuffer(device_, buffer, nullptr); // Query phyiscal device property // find a memory that is host visible, no need to be consistent int win_rank = -1; VkPhysicalDeviceMemoryProperties prop; vkGetPhysicalDeviceMemoryProperties(physical_device_, &prop); for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { VkMemoryType ty = prop.memoryTypes[k]; size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; // host visible if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; // match copy requirment if (!(req_staging.memoryTypeBits & (1 << k))) continue; if (heap_size < 1024) continue; int rank = 0; rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; if (rank > win_rank) { win_rank = rank; staging_mtype_index = k; coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; } } ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; win_rank = -1; for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { VkMemoryType ty = prop.memoryTypes[k]; size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; // host visible if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; // match copy requirment if (!(req_staging.memoryTypeBits & (1 << k))) continue; if (heap_size < 1024) continue; int rank = 0; // prefer not host visible rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); if (rank > win_rank) { win_rank = rank; compute_mtype_index = k; } } ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; if (device_properties.supports_push_descriptor) { descriptor_template_khr_functions = std::make_unique(device_); } if (device_properties.supports_dedicated_allocation) { get_buffer_memory_requirements_2_functions = std::make_unique(device_); } } VulkanDevice::~VulkanDevice() { // Need to clear anything that uses this device calling // vkDestroyDevice. Might be a sign that the VkDevice should be // held by member variable rather than beind owned directly by // VulkanDevice. stream_per_thread.Clear(); staging_buffer_per_thread.Clear(); uniform_buffer_per_thread.Clear(); if (device_) { vkDestroyDevice(device_, nullptr); } } VulkanDevice::VulkanDevice(VulkanDevice&& other) { do_swap(std::move(other)); } VulkanDevice& VulkanDevice::operator=(VulkanDevice&& other) { do_swap(std::move(other)); return *this; } void VulkanDevice::do_swap(VulkanDevice&& other) { if (this == &other) { return; } std::lock(queue_mutex, other.queue_mutex); std::lock_guard lock_self(queue_mutex, std::adopt_lock); std::lock_guard lock_other(other.queue_mutex, std::adopt_lock); std::swap(device_properties, other.device_properties); std::swap(staging_mtype_index, other.staging_mtype_index); std::swap(coherent_staging, other.coherent_staging); std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions); std::swap(get_buffer_memory_requirements_2_functions, other.get_buffer_memory_requirements_2_functions); std::swap(compute_mtype_index, other.compute_mtype_index); std::swap(queue, other.queue); std::swap(queue_family_index, other.queue_family_index); std::swap(physical_device_, other.physical_device_); std::swap(enabled_extensions, other.enabled_extensions); std::swap(device_, other.device_); } bool VulkanDevice::SupportsCompute() const { return queue_family_index != uint32_t(-1); } void VulkanDevice::QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const { // Multiple streams (on different threads) use the same VulkanDevice // instance, so we need to externally synchronize accesses. std::lock_guard lock(queue_mutex); VULKAN_CALL(vkQueueSubmit(queue, 1, &submit_info, fence)); } uint32_t VulkanDevice::SelectComputeQueueFamily() const { // Get a queue family that supports compute. We currently only use // one queue from one family. uint32_t queue_prop_count = 0; vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, nullptr); std::vector queue_props(queue_prop_count); vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, dmlc::BeginPtr(queue_props)); std::vector result; // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). for (uint32_t i = 0; i != queue_prop_count; ++i) { if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { return i; } } // Now, push the compute queues that we skipped above into the list. for (uint32_t i = 0; i != queue_prop_count; ++i) { if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { return i; } } // No queues support compute capability, this GPU cannot be used. return -1; } std::vector VulkanDevice::SelectEnabledExtensions() const { std::vector required_extensions{}; std::vector optional_extensions{ "VK_KHR_driver_properties", "VK_KHR_storage_buffer_storage_class", "VK_KHR_8bit_storage", "VK_KHR_16bit_storage", "VK_KHR_shader_float16_int8", "VK_KHR_push_descriptor", "VK_KHR_descriptor_update_template", "VK_KHR_get_memory_requirements2", "VK_KHR_dedicated_allocation", "VK_KHR_spirv_1_4", }; uint32_t device_extension_prop_count; VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, &device_extension_prop_count, nullptr)); std::vector device_extension_prop(device_extension_prop_count); VULKAN_CALL(vkEnumerateDeviceExtensionProperties( physical_device_, nullptr, &device_extension_prop_count, device_extension_prop.data())); return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); } bool VulkanDevice::HasExtension(const char* query) const { return std::any_of(enabled_extensions.begin(), enabled_extensions.end(), [&](const char* extension) { return std::strcmp(query, extension) == 0; }); } void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) { // Enable all features we may use that a device supports. VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; VkPhysicalDevice8BitStorageFeatures storage_8bit = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; VkPhysicalDevice16BitStorageFeatures storage_16bit = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; void** pp_next = &enabled_features.pNext; bool needs_float16_int8 = false; if (device_properties.supports_float16) { float16_int8.shaderFloat16 = true; needs_float16_int8 = true; } if (device_properties.supports_float64) { enabled_features.features.shaderFloat64 = true; } if (device_properties.supports_int8) { float16_int8.shaderInt8 = true; needs_float16_int8 = true; } if (device_properties.supports_int16) { enabled_features.features.shaderInt16 = true; } if (device_properties.supports_int64) { enabled_features.features.shaderInt64 = true; } if (device_properties.supports_8bit_buffer) { storage_8bit.storageBuffer8BitAccess = true; *pp_next = &storage_8bit; pp_next = &storage_8bit.pNext; } if (device_properties.supports_16bit_buffer) { storage_16bit.storageBuffer16BitAccess = true; *pp_next = &storage_16bit; pp_next = &storage_16bit.pNext; } if (needs_float16_int8) { *pp_next = &float16_int8; pp_next = &float16_int8.pNext; } float priority = 1.0f; struct VkDeviceQueueCreateInfo queue_create_info; queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; queue_create_info.pNext = nullptr; queue_create_info.flags = 0; queue_create_info.queueFamilyIndex = queue_family_index; queue_create_info.queueCount = 1; queue_create_info.pQueuePriorities = &priority; VkDeviceCreateInfo device_create_info; device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; device_create_info.pNext = nullptr; device_create_info.flags = 0; device_create_info.queueCreateInfoCount = 1; device_create_info.pQueueCreateInfos = &queue_create_info; device_create_info.enabledLayerCount = 0; device_create_info.ppEnabledLayerNames = nullptr; device_create_info.enabledExtensionCount = enabled_extensions.size(); device_create_info.ppEnabledExtensionNames = enabled_extensions.data(); if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { device_create_info.pEnabledFeatures = nullptr; device_create_info.pNext = &enabled_features; } else { device_create_info.pNext = nullptr; device_create_info.pEnabledFeatures = &enabled_features.features; } VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_)); } VulkanStream& VulkanDevice::ThreadLocalStream() { return const_cast(const_cast(this)->ThreadLocalStream()); } const VulkanStream& VulkanDevice::ThreadLocalStream() const { return stream_per_thread.GetOrMake(this); } VulkanStagingBuffer& VulkanDevice::ThreadLocalStagingBuffer(size_t min_size) { auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; VulkanStagingBuffer& result = staging_buffer_per_thread.GetOrMake(*this, min_size, usage, staging_mtype_index); if (result.size < min_size) { result = VulkanStagingBuffer(*this, min_size, usage, staging_mtype_index); } return result; } void VulkanDevice::AllocateThreadLocalUniformBuffer(size_t min_size) { auto usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; auto buffer_info = MakeBufferCreateInfo(min_size, usage); auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; auto mem_type_index = FindMemoryType(*this, buffer_info, prop); VulkanUniformBuffer& result = uniform_buffer_per_thread.GetOrMake(*this, min_size, usage, mem_type_index); if (result.size < min_size) { result = VulkanUniformBuffer(*this, min_size, usage, mem_type_index); } } VulkanStagingBuffer& VulkanDevice::ThreadLocalUniformBuffer(size_t min_size) { VulkanStagingBuffer* buffer = uniform_buffer_per_thread.Get(); ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated."; ICHECK_GE(buffer->size, min_size) << "Vulkan uniform buffer of size " << min_size << " requested, but only " << buffer->size << " was previously allocated."; return *buffer; } uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop) { VkBuffer buffer; VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); VkMemoryRequirements mem_reqs; vkGetBufferMemoryRequirements(device, buffer, &mem_reqs); uint32_t type_bits = mem_reqs.memoryTypeBits; VkPhysicalDeviceMemoryProperties phy_mem_prop; vkGetPhysicalDeviceMemoryProperties(device, &phy_mem_prop); for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { if ((type_bits & 1) == 1 && (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { return i; } type_bits >>= 1; } LOG(FATAL) << "Requested memory type not found"; return 0; } VulkanHostVisibleBuffer* GetOrAllocate( int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, std::unordered_map>* buffers_ptr, bool sync_before_realloc) { auto& device = VulkanDeviceAPI::Global()->device(device_id); auto& buffers = *buffers_ptr; bool needs_alloc = !buffers[device_id] || (buffers[device_id]->size < size); bool is_realloc = buffers[device_id] && (buffers[device_id]->size < size); if (is_realloc && sync_before_realloc) { device.ThreadLocalStream().Synchronize(); } if (needs_alloc) { auto new_buffer = std::make_unique(device, size, usage, mem_type_index); buffers[device_id] = std::move(new_buffer); } return buffers[device_id].get(); } } // namespace vulkan } // namespace runtime } // namespace tvm