/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include <random> #include "../utils.h" namespace tvm { namespace tir { struct PrimeTable { /*! \brief The table contains prime numbers in [2, kMaxPrime) */ static constexpr const int32_t kMaxPrime = 65536; /*! \brief The exact number of prime numbers in the table */ static constexpr const int32_t kNumPrimes = 6542; /*! * \brief For each number in [2, kMaxPrime), the index of its min factor. * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. */ int32_t min_factor_idx[kMaxPrime]; /*! \brief The prime numbers in [2, kMaxPrime) */ std::vector<int32_t> primes; /*! * \brief The power of each prime number. * pow_table[i, j] stores the result of pow(prime[i], j + 1) */ std::vector<std::vector<int32_t>> pow_tab; /*! \brief Get a global instance of the prime table */ static const PrimeTable* Global() { static const PrimeTable table; return &table; } /*! \brief Constructor, pre-computes all info in the prime table */ PrimeTable() { constexpr const int64_t int_max = std::numeric_limits<int32_t>::max(); // Euler's sieve: prime number in linear time for (int32_t i = 0; i < kMaxPrime; ++i) { min_factor_idx[i] = -1; } primes.reserve(kNumPrimes); for (int32_t x = 2; x < kMaxPrime; ++x) { if (min_factor_idx[x] == -1) { min_factor_idx[x] = primes.size(); primes.push_back(x); } for (size_t i = 0; i < primes.size(); ++i) { int64_t factor = primes[i]; int64_t y = x * factor; if (y >= kMaxPrime) { break; } min_factor_idx[y] = i; if (x % factor == 0) { break; } } } ICHECK_EQ(static_cast<int32_t>(primes.size()), static_cast<int32_t>(kNumPrimes)); // Calculate the power table for each prime number pow_tab.reserve(primes.size()); for (int32_t prime : primes) { std::vector<int32_t> tab; tab.reserve(32); for (int64_t pow = prime; pow <= int_max; pow *= prime) { tab.push_back(pow); } tab.shrink_to_fit(); pow_tab.emplace_back(std::move(tab)); } } /*! * \brief Factorize a number n, and return in a cryptic format * \param n The number to be factorized * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] * For each pair (i, j), we define * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) * (primes[i], j) if i != -1 * Then the factorization is * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) */ std::vector<std::pair<int32_t, int32_t>> Factorize(int32_t n) const { std::vector<std::pair<int32_t, int32_t>> result; result.reserve(16); int32_t i = 0, n_primes = primes.size(); // Phase 1: n >= kMaxPrime for (int32_t j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { } if (j != 0) { result.emplace_back(i, j); } } // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number if (n >= kMaxPrime) { result.emplace_back(-1, n); return result; } // Phase 2: n < kMaxPrime for (int32_t j; n > 1;) { int32_t i = min_factor_idx[n]; for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { } result.emplace_back(i, j); } return result; } }; int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive) { CHECK(min_inclusive < max_exclusive) << "ValueError: max_exclusive must be greater than min_inclusive."; if (min_inclusive + 1 == max_exclusive) { return min_inclusive; } support::LinearCongruentialEngine rand_(rand_state); std::uniform_int_distribution<int32_t> dist(min_inclusive, max_exclusive - 1); return dist(rand_); } std::vector<int32_t> SampleWithoutReplacement( support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k) { if (k == 1) { return {SampleInt(rand_state, 0, n)}; } if (k == 2) { int32_t result0 = SampleInt(rand_state, 0, n); int32_t result1 = SampleInt(rand_state, 0, n - 1); if (result1 >= result0) { result1 += 1; } return {result0, result1}; } std::vector<int32_t> order(n); for (int32_t i = 0; i < n; ++i) { order[i] = i; } for (int32_t i = 0; i < k; ++i) { int32_t j = SampleInt(rand_state, i, n); if (i != j) { std::swap(order[i], order[j]); } } return {order.begin(), order.begin() + k}; } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array<Integer>& candidates, const Array<FloatImm>& probs, Optional<Integer>* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; int32_t n = candidates.size(); if (decision->defined()) { const auto* int_imm = decision->as<IntImmNode>(); i = int_imm->value; CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { std::vector<double> weights = support::AsVector<FloatImm, double>(probs); std::discrete_distribution<int32_t> dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n << ", but decision is: " << i; } *decision = Integer(i); // decision is guaranteed not to be nullptr. return candidates[i]; } std::function<int32_t()> MakeMultinomialSampler( support::LinearCongruentialEngine::TRandState* rand_state, const std::vector<double>& weights) { ICHECK(!weights.empty()); std::vector<double> sums; sums.reserve(weights.size()); double sum = 0.0; for (double w : weights) { sums.push_back(sum += w); } return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), dist = std::uniform_real_distribution<double>(0.0, sum), sums = std::move(sums)]() mutable -> int32_t { support::LinearCongruentialEngine rand_(&rng); double p = dist(rand_); int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); int32_t n = sums.size(); CHECK_LE(0, idx); CHECK_LE(idx, n); return (idx == n) ? (n - 1) : idx; }; } std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; // Handle special case that we can potentially accelerate if (n_splits == 1) { return {extent}; } if (extent == 1) { return std::vector<int64_t>(n_splits, 1); } // Enumerate each pair (i, j), we define // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) // (primes[i], j) if i != -1 // Then the factorization is // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) const PrimeTable* prime_tab = PrimeTable::Global(); std::vector<std::pair<int32_t, int32_t>> factorized = prime_tab->Factorize(extent); if (n_splits == 2) { // n_splits = 2, this can be taken special care of, // because general reservoir sampling can be avoided to accelerate the sampling int32_t result0 = 1; int32_t result1 = 1; for (const std::pair<int32_t, int32_t>& ij : factorized) { // Case 1: (a, p) = (j, 1), where j is a prime number if (ij.first == -1) { (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; continue; } // Case 2: (a = primes[i], p = 1) int32_t p = ij.second; const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; int32_t x1 = SampleInt(rand_state, 0, p + 1); int32_t x2 = p - x1; if (x1 != 0) { result0 *= pow[x1]; } if (x2 != 0) { result1 *= pow[x2]; } } return {result0, result1}; } // Data range: // 2 <= extent <= 2^31 - 1 // 3 <= n_splits <= max tiling splits // 1 <= p <= 31 std::vector<int64_t> result(n_splits, 1); for (const std::pair<int32_t, int32_t>& ij : factorized) { // Handle special cases to accelerate sampling // Case 1: (a, p) = (j, 1), where j is a prime number if (ij.first == -1) { result[SampleInt(rand_state, 0, n_splits)] *= ij.second; continue; } // Case 2: (a = primes[i], p = 1) int32_t p = ij.second; if (p == 1) { result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; continue; } // The general case. We have to sample uniformly from the solution of: // x_1 + x_2 + ... + x_{n_splits} = p // where x_i >= 0 // Data range: // 2 <= p <= 31 // 3 <= n_splits <= max tiling splits std::vector<int32_t> sampled = SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); std::sort(sampled.begin(), sampled.end()); sampled.push_back(p + n_splits - 1); const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; for (int32_t i = 0, last = -1; i < n_splits; ++i) { int32_t x = sampled[i] - last - 1; last = sampled[i]; if (x != 0) { result[i] *= pow[x]; } } } return result; } std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits, int32_t max_innermost_factor) { if (max_innermost_factor == -1) { return SamplePerfectTile(rand_state, extent, n_splits); } CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; std::vector<int32_t> innermost_candidates; innermost_candidates.reserve(max_innermost_factor); for (int32_t i = 1; i <= max_innermost_factor; ++i) { if (extent % i == 0) { innermost_candidates.push_back(i); } } // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. // We should do multiple factorization to weight the choices. However, it would lead to slower // sampling speed. On the other hand, considering potential tricks we might do on the innermost // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add // more heuristics in the future int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; std::vector<int64_t> result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); result.push_back(innermost); return result; } std::vector<int64_t> SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional<Array<Integer>>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); int64_t extent = GetLoopIntExtent(loop); std::vector<int64_t> result; if (extent == -1) { // Case 1. Handle loops with non-constant length result = std::vector<int64_t>(n_splits, 1); result[0] = -1; } else if (decision->defined()) { // Case 2. Use previous decision result = support::AsVector<Integer, int64_t>(decision->value()); int n = result.size(); ICHECK_GE(n, 2); int64_t len = extent; for (int i = n - 1; i > 0; --i) { int64_t& l = result[i]; // A previous decision could become invalid because of the change of outer tiles // To handle this case properly, we check if the tiling strategy is still perfect. // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles if (len % l != 0) { l = len; } len /= l; } result[0] = len; } else { // Case 3. Use fresh new sampling result result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); ICHECK_LE(result.back(), max_innermost_factor); } *decision = support::AsArray<int64_t, Integer>(result); return result; } /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits<SampleCategoricalTraits> { static constexpr const char* kName = "SampleCategorical"; static constexpr bool kIsPure = true; private: static constexpr size_t kNumInputs = 0; static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 1; static ExprRV UnpackedApplyToSchedule(Schedule sch, // Array<Integer> candidates, // Array<FloatImm> probs, // Optional<Integer> decision) { return sch->SampleCategorical(candidates, probs, decision); } static String UnpackedAsPython(Array<String> outputs, // Array<Integer> candidates, // Array<FloatImm> probs, // Optional<Integer> decision) { PythonAPICall py("sample_categorical"); py.Input("candidates", candidates); py.Input("probs", probs); py.Decision(decision); py.SingleOutput(outputs); return py.Str(); } template <typename> friend struct ::tvm::tir::UnpackedInstTraits; }; struct SamplePerfectTileTraits : public UnpackedInstTraits<SamplePerfectTileTraits> { static constexpr const char* kName = "SamplePerfectTile"; static constexpr bool kIsPure = true; private: static constexpr size_t kNumInputs = 1; static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 1; static Array<ExprRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, Integer max_innermost_factor, Optional<Array<Integer>> decision) { return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); } static String UnpackedAsPython(Array<String> outputs, String loop_rv, Integer n, Integer max_innermost_factor, Optional<Array<Integer>> decision) { PythonAPICall py("sample_perfect_tile"); py.Input("loop", loop_rv); py.Input("n", n->value); py.Input("max_innermost_factor", max_innermost_factor->value); py.Decision(decision); py.OutputList(outputs); return py.Str(); } template <typename> friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); } // namespace tir } // namespace tvm