/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file auto_scheduler/utils.h * \brief Common utilities. */ #ifndef TVM_AUTO_SCHEDULER_UTILS_H_ #define TVM_AUTO_SCHEDULER_UTILS_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace std { /*! \brief Hash function for std::pair */ template struct hash> { std::size_t operator()(const std::pair& k) const { return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); } }; /*! \brief Hash function for std::tuple */ template struct hash> { std::size_t operator()(const std::tuple& k) const { return ::dmlc::HashCombine( ::dmlc::HashCombine(std::hash()(std::get<0>(k)), std::hash()(std::get<1>(k))), std::hash()(std::get<2>(k))); } }; } // namespace std namespace tvm { namespace auto_scheduler { /********** Utilities for Array, std::vector, std::string **********/ /*! \brief Get the first appearance index of elements in an Array */ template inline void GetIndices(const Array& array, const Array& to_locate, Array* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); if (it != array.end()) { indices->push_back(it - array.begin()); } else { LOG(FATAL) << "Cannot find the item"; } } } /*! \brief Get the first appearance index of an element in an Array */ template inline int GetIndex(const Array& array, const T& to_locate) { for (size_t i = 0; i < array.size(); ++i) { if (array[i] == to_locate) { return i; } } LOG(FATAL) << "Cannot find the item"; return -1; } /*! \brief Delete the item in a std::vector if it exists. */ template inline void FindAndDeleteItem(std::vector* array, const T& to_delete) { auto iter = std::find(array->begin(), array->end(), to_delete); if (iter != array->end()) { array->erase(iter); } } /*! \brief Compute the product of all elements in a vector */ inline int64_t ElementProduct(const std::vector& array) { int64_t ret = 1; for (auto x : array) { ret *= x; } return ret; } /*! \brief Move elements from multiple vectors to one vector */ template std::vector& ConcatenateMove(std::vector* out, std::vector* in) { out->insert(out->end(), std::make_move_iterator(in->begin()), std::make_move_iterator(in->end())); return *out; } /*! \brief Move elements from multiple vectors to one vector */ template std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args... args) { ConcatenateMove(out, first); ConcatenateMove(out, args...); return *out; } /*! \brief Get a random permutation of integers [0, n-1] */ template void RandomPermutation(int n, std::vector* out, G* gen) { out->assign(n, 0); std::iota(out->begin(), out->end(), 0); std::shuffle(out->begin(), out->end(), *gen); } /*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); while (pos != std::string::npos) { base->replace(pos, from.size(), to); pos = base->find(from, pos + to.size()); } } /*! \brief Return whether two int arrays are elementwise-equal */ inline bool IntArrayEqual(const Array& arr1, const Array& arr2) { if (arr1.size() != arr2.size()) { return false; } for (size_t i = 0; i < arr1.size(); ++i) { auto int1 = arr1[i].as(); auto int2 = arr2[i].as(); ICHECK(int1 != nullptr); ICHECK(int2 != nullptr); if (int1->value != int2->value) { return false; } } return true; } /********** Utilities for TVM Containers / ByteArray **********/ /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { double sum = 0; if (float_array.empty()) { return 0.0; } for (const auto& x : float_array) { auto floatimm = x.as(); ICHECK(floatimm != nullptr); sum += floatimm->value; } return sum / float_array.size(); } /*! \brief Return whether a string starts with another substring */ inline bool StrStartsWith(const String& a, const String& b) { if (b.size() > a.size()) return false; return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str()); } /*! \brief Return whether a string ends with another substring */ inline bool StrEndsWith(const String& a, const String& b) { if (b.size() > a.size()) return false; return std::equal(a.c_str() + a.size() - b.size(), a.c_str() + a.size(), b.c_str()); } /********** Other Utilities **********/ /*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { auto pint = expr.as(); ICHECK(pint != nullptr) << "Expect an IntImm but get " << expr; return pint->value; } /*! \brief Compute the product of the lengths of axes */ inline int64_t AxisLengthProd(const Array& axes) { int64_t ret = 1.0; for (const auto& x : axes) { if (const IntImmNode* imm = x->dom->extent.as()) { ret *= imm->value; } else { return -1.0; } } return ret; } /*! * \brief Clean the name of an iterator or an op to make it valid in python code. * \param str The original name. * \param prefix The name prefix to differentiate the same name (e.g., the same iterator names). * \return The cleaned name. */ inline std::string CleanName(const std::string& str, const std::string& prefix = "") { std::string ret = str; StrReplace(&ret, ".", "_"); StrReplace(&ret, "@", "_"); StrReplace(&ret, "outer", "o"); StrReplace(&ret, "inner", "i"); if (prefix != "") { return prefix + "_" + ret; } return ret; } /*! \brief An empty output stream */ class NullStream : public std::ostream { public: NullStream() : std::ostream(nullptr) {} NullStream(const NullStream&) : std::ostream(nullptr) {} static NullStream& Global(); }; template NullStream& operator<<(NullStream& os, const T& value) { return os; } /*! \brief Get std cout with verbose control */ inline std::ostream& StdCout(int verbose, int setting = 1) { return verbose >= setting ? std::cout : NullStream::Global(); } /*! \brief Print multiple chars */ inline std::string Chars(const char& str, int times) { std::stringstream ret; for (int i = 0; i < times; ++i) { ret << str; } return ret.str(); } /*! \brief Print the time elapsed */ inline void PrintTimeElapsed(std::chrono::time_point t_begin, const std::string& info, int verbose) { double duration = std::chrono::duration_cast>( std::chrono::high_resolution_clock::now() - t_begin) .count(); StdCout(verbose) << "Time elapsed for " << info << ": " << std::fixed << std::setprecision(2) << duration << " s" << std::endl; } /*! * \brief Parse shape and axis names from layout string */ inline void ParseKernelLayout(const String& layout, Array* shape, std::vector* axes) { int32_t factor = 0; std::string axis = ""; for (char c : std::string(layout)) { if (c >= 'A' && c <= 'z') { axis += c; if (factor != 0) { shape->push_back(factor); factor = 0; } } else if (c >= '0' && c <= '9') { factor = factor * 10 + c - '0'; if (!axis.empty()) { axes->push_back(axis); axis = ""; } } else { LOG(FATAL) << "Invalid layout " << layout; } } if (!axis.empty()) { axes->push_back(axis); } } /*! \brief Get the base name before '_' of an axis */ inline std::string AxisBaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } } // namespace auto_scheduler } // namespace tvm #endif // TVM_AUTO_SCHEDULER_UTILS_H_