/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #ifndef TVM_RUNTIME_THREAD_MAP_H_ #define TVM_RUNTIME_THREAD_MAP_H_ #include #include #include #include #include #include #include namespace tvm { namespace runtime { /*! \brief Container to hold one value per thread * * Similar to thread_local, but intended for use as a non-static or * non-block variable, such as class member variables. All member * functions are thread-safe to call. If only the current thread's * value is accessed, no additional synchronization is required. If * another thread's stored values are accessed, external * synchronization may be required. * * Calls that only require access to already-existing values will not * block each other. Calls that require constructing a new value will * block any other calls. * * \tparam T The object type to be held. For instantiation of * ThreadMap and for calls to ThreadMap::Get, only a forward * declaration is required. For calls to ThreadMap::GetOrMake, a * full class definition is required. */ template class ThreadMap { public: ThreadMap() {} /*! \brief Return the current thread's stored object, if it exists. * * \return If it exists, a pointer to the stored object. Otherwise, * returns nullptr. */ const T* Get() const { return this->Get(std::this_thread::get_id()); } /*! \brief Return the stored object for a given thread, if it exists. * * \param id The thread whose object should be returned. * * \return If it exists, a pointer to the stored object. Otherwise, * returns nullptr. */ const T* Get(std::thread::id id) const { std::shared_lock lock(mutex_); auto res = values_.find(id); if (res == values_.end()) { return nullptr; } else { return res->second.get(); } } /*! \brief Return the current thread's stored object, if it exists. * * \return If it exists, a pointer to the stored object. Otherwise, * returns nullptr. */ T* Get() { return const_cast(const_cast*>(this)->Get()); } /*! \brief Return the stored object for a given thread, if it exists. * * \param id The thread whose object should be returned. * * \return If it exists, a pointer to the stored object. Otherwise, * returns nullptr. */ T* Get(std::thread::id id) { return const_cast(const_cast*>(this)->Get(id)); } /*! \brief Return the current thread's stored object, making it if * necessary. * * Since this method can modify the stored map, there is no * non-const version available. * * \tparam Params Types of the stored object's constructor arguments * * \return A reference to the stored object */ template T& GetOrMake(Params&&... params) { return GetOrMake(std::this_thread::get_id(), std::forward(params)...); } /*! \brief Return the stored object for a given thread, making it if * necessary * * Since this method can modify the stored map, there is no * non-const version available. * * \tparam Params Types of the stored object's constructor arguments * * \param id The thread whose object should be returned. * * \param params Arguments to the stored object's constructor. Only * used if the specified thread does not currently exist in the map. * * \return A reference to the stored object */ template T& GetOrMake(std::thread::id id, Params&&... params) { // Try to get stored value first, which would only require shared // access. if (T* output = Get(id)) { return *output; } // Not in map, need exclusive lock to write std::unique_lock lock(mutex_); // Check again, in case another thread got the unique lock first // and already constructed the object. auto res = values_.find(id); if (res != values_.end()) { return *res->second; } // No value exists, make one and return it. std::unique_ptr& new_val = values_[id] = std::make_unique(std::forward(params)...); return *new_val; } /*! \brief Clears all values held by the ThreadMap * * Calling Clear() invalidates any pointers/references previously * returned by Get/GetOrMake. * */ void Clear() { std::unique_lock lock(mutex_); values_.clear(); } private: //! \brief Mutex to protect values_ mutable std::shared_timed_mutex mutex_; //! \brief Map containing stored values std::unordered_map> values_; }; } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_THREAD_MAP_H_