/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/target/virtual_device.cc * \brief A compile time representation for where data is to be stored at runtime, and how to * compile code to compute it. */ #include #include #include namespace tvm { TVM_REGISTER_NODE_TYPE(VirtualDeviceNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = ref.as(); p->stream << "VirtualDevice("; if (node->IsFullyUnconstrained()) { p->stream << "?"; } else { bool need_sep = false; if (node->device_type() != kInvalidDeviceType) { p->stream << "device_type=" << node->device_type(); need_sep = true; } if (node->virtual_device_id >= 0) { if (need_sep) { p->stream << ", "; } p->stream << "virtual_device_id=" << node->virtual_device_id; need_sep = true; } if (node->target.defined()) { if (need_sep) { p->stream << ", "; } p->stream << "target=" << node->target->ToDebugString(); need_sep = true; } if (!node->memory_scope.empty()) { if (need_sep) { p->stream << ", "; } p->stream << "memory_scope='" << node->memory_scope << "'"; } } p->stream << ")"; }); VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, Target target, MemoryScope memory_scope) { ICHECK(!target.defined() || device_type == target->kind->device_type) << "target " << target->ToDebugString() << " has device type " << target->kind->device_type << " but virtual device has device type " << device_type; auto node = make_object(); node->device_type_int = device_type; node->virtual_device_id = virtual_device_id; node->target = std::move(target); node->memory_scope = std::move(memory_scope); data_ = std::move(node); } /* static */ VirtualDevice VirtualDevice::FullyUnconstrained() { static const VirtualDevice unconstrained{}; return unconstrained; } /* static */ Optional VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) { if (lhs == rhs) { return lhs; } DLDeviceType joined_device_type; if (lhs->device_type() != kInvalidDeviceType) { joined_device_type = lhs->device_type(); if (rhs->device_type() != kInvalidDeviceType && lhs->device_type() != rhs->device_type()) { return {}; } } else { joined_device_type = rhs->device_type(); } int joined_virtual_device_id; if (lhs->virtual_device_id >= 0) { joined_virtual_device_id = lhs->virtual_device_id; if (rhs->virtual_device_id >= 0 && lhs->virtual_device_id != rhs->virtual_device_id) { return {}; } } else { joined_virtual_device_id = rhs->virtual_device_id; } Target joined_target; if (lhs->target.defined()) { joined_target = lhs->target; if (rhs->target.defined() && lhs->target != rhs->target) { return {}; } } else { joined_target = rhs->target; } MemoryScope joined_memory_scope; if (!lhs->memory_scope.empty()) { joined_memory_scope = lhs->memory_scope; if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { return {}; } } else { joined_memory_scope = rhs->memory_scope; } return VirtualDevice(joined_device_type, joined_virtual_device_id, joined_target, joined_memory_scope); } /* static */ VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevice& rhs) { if (lhs == rhs) { return lhs; } DLDeviceType defaulted_device_type; if (lhs->device_type() != kInvalidDeviceType) { defaulted_device_type = lhs->device_type(); } else { defaulted_device_type = rhs->device_type(); } int defaulted_virtual_device_id; if (lhs->virtual_device_id >= 0) { defaulted_virtual_device_id = lhs->virtual_device_id; } else { defaulted_virtual_device_id = rhs->virtual_device_id; } Target defaulted_target; if (lhs->target.defined()) { defaulted_target = lhs->target; } else { // We can only default to the rhs's target if it is consistent with the device type if (rhs->target.defined() && rhs->target->kind->device_type == defaulted_device_type) { defaulted_target = rhs->target; } // else: leave as null } MemoryScope defaulted_memory_scope; if (!lhs->memory_scope.empty()) { defaulted_memory_scope = lhs->memory_scope; } else { defaulted_memory_scope = rhs->memory_scope; } return VirtualDevice(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, defaulted_memory_scope); } VirtualDevice VirtualDeviceCache::Make(DLDeviceType device_type, int virtual_device_id, Target target, MemoryScope memory_scope) { VirtualDevice prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); auto itr = cache_.find(prototype); if (itr == cache_.end()) { cache_.emplace(prototype); return prototype; } else { ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); if (prototype->target.defined()) { ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); } return *itr; } } VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { return Make(virtual_device->device_type(), virtual_device->virtual_device_id, virtual_device->target, virtual_device->memory_scope); } TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope); } // namespace tvm