/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #ifndef TVM_SUPPORT_ARRAY_H_ #define TVM_SUPPORT_ARRAY_H_ #include #include #include namespace tvm { namespace support { /*! * \brief Checks if two arrays contain the same objects * \tparam T The type of objects in the array * \param a The first array * \param b The second array * \return A boolean indicating if they are the same */ template inline bool ArrayWithSameContent(const Array& a, const Array& b) { if (a.size() != b.size()) { return false; } int n = a.size(); for (int i = 0; i < n; ++i) { if (!a[i].same_as(b[i])) { return false; } } return true; } /*! * \brief Checks if two arrays contain the same objects * \tparam T The type of objects in the array * \param a The first array * \param b The second array * \return A boolean indicating if they are the same */ template inline bool ArrayWithSameContent(const std::vector& a, const std::vector& b) { if (a.size() != b.size()) { return false; } int n = a.size(); for (int i = 0; i < n; ++i) { if (a[i] != b[i]) { return false; } } return true; } /*! * \brief Convert a tvm::runtime::Array to std::vector * \tparam TSrc The type of elements in the source Array * \tparam TDst The type of elements in the result vector * \return The result vector */ template inline std::vector AsVector(const Array& vec); /*! * \brief Convert a std::vector to tvm::runtime::Array * \tparam TSrc The type of elements in the source vector * \tparam TDst The type of elements in the result Array * \return The result vector */ template inline Array AsArray(const std::vector& vec); /*! * \brief Get the shape tuple as array * \param shape The shape tuple * \return An array of the shape tuple */ inline Array AsArray(const ShapeTuple& shape) { Array result; result.reserve(shape->size); for (ShapeTuple::index_type i : shape) { result.push_back(Integer(i)); } return result; } /********** Implementation details of AsVector **********/ namespace details { template struct AsVectorImpl {}; template struct AsVectorImpl { inline std::vector operator()(const Array& vec) const { return std::vector(vec.begin(), vec.end()); } }; template struct AsVectorImpl { inline std::vector operator()(const Array& vec) const { std::vector results; for (const TSrcObjectRef& x : vec) { const auto* n = x.template as(); ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); results.push_back(n->value); } return results; } }; template struct AsVectorImpl { inline std::vector operator()(const Array& vec) const { std::vector results; for (const TSrcObjectRef& x : vec) { const auto* n = x.template as(); ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); results.push_back(n->value); } return results; } }; template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { std::vector results; for (const TSrcObjectRef& x : array) { const auto* n = x.template as(); ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); results.push_back(n->value); } return results; } }; } // namespace details /********** Implementation details of AsArray **********/ namespace details { template struct AsArrayImpl {}; template struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { return Array(vec.begin(), vec.end()); } }; template struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); for (int x : vec) { result.push_back(Integer(x)); } return result; } }; template struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); for (int64_t x : vec) { result.push_back(Integer(x)); } return result; } }; template struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); for (double x : vec) { result.push_back(FloatImm(tvm::DataType::Float(64), x)); } return result; } }; } // namespace details template inline std::vector AsVector(const Array& vec) { return details::AsVectorImpl()(vec); } template inline Array AsArray(const std::vector& vec) { return details::AsArrayImpl()(vec); } } // namespace support } // namespace tvm #endif // TVM_SUPPORT_ARRAY_H_