/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tir/analysis/usmp/algo/greedy.cc
 * \brief This source contains greedy algorithms for planning
 * memory for USMP. There are two algorithms present here :
 * 1) greedy_by_size and 2) greedy_by_conflicts.
 *
 * greedy_by_size : this algorithm prioritizes placing the
 * largest size buffer to the given pools. The BufferInfo objects
 * are sorted based on the size and placed on each pool adhering
 * to size_hint constraint.
 *
 * greedy_by_conflicts : this algorithm prioritizes placing the
 * the most liveness conflicted buffer to the given pools. The
 * BufferInfo objects are sorted based on the number of conflicts
 * and placed on each pool adhering to size_hint constraint.
 */

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/usmp/utils.h>

namespace tvm {
namespace tir {
namespace usmp {
namespace algo {

/*!
 * \brief This is the base class for Greedy Algorithms where the sorting
 * is specialized in the extended classes based on the greedy criteria.
 */
class GreedyBase {
 public:
  GreedyBase() {}
  /*!
   * \brief This function should be implemented by the extended classes to sort the BufferInfo
   * objects based on a criteria and then calling PostSortAllocation.
   */
  virtual Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) = 0;

 protected:
  /*!
   * \brief Rounds up the offset to satisfy the alignement requirement
   */
  size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
                                    const int& byte_alignment) {
    return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
  }

  /*!
   * \brief A helper function check whether a offset is valid given the constraints
   */
  bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
                        const size_t& size_bytes) {
    if (candidate_pool->size_hint_bytes == -1) {
      // this means pool is not bounded
      return true;
    }
    auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value);
    auto max_address = next_offset + size_bytes;
    if (max_address <= pool_size) {
      return true;
    }
    return false;
  }

  /*!
   * \brief Selects a pool for placement in the given set of ordered pool candidates
   */
  PoolInfo SelectPlacementPool(
      const BufferInfo& buf_info,
      const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
    // Here the pool candidates are ordered when it is consumed by the algorithm.
    // This could be from order the user has specified. However, schedulers are
    // welcome to change the order for performance reasons.
    for (const auto& pool_info : buf_info->pool_candidates) {
      if (pool_offsets.count(pool_info)) {
        return pool_info;
      }
    }
    CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when "
                    "trying to allocate the buffer : "
                 << buf_info << "\n. Please increase the size_hints for memory pools.";
    return PoolInfo();
  }

  /*!
   * \brief This is the base allocation function that works on sorted BufferInfo objects based
   * on the greedy heuristic. The sorting algorithm has to be called before calling this.
   */
  Map<BufferInfo, PoolAllocation> PostSortAllocation(
      const std::vector<BufferInfo>& buffer_info_vec) {
    Map<BufferInfo, PoolAllocation> pool_allocations;
    for (const auto& buf_info : buffer_info_vec) {
      std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
      for (const auto& pool_info : buf_info->pool_candidates) {
        // Mark pool candidates that satisfy the size constraints.
        if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
          pool_offset_candidates[pool_info] = 0;
        }
      }

      for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
        auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
        size_t next_offset = 0;
        // We only look at already allocated BufferInfo in-terms of conflicts.
        if (pool_allocations.count(conflict_buf_info)) {
          auto pool_allocation = pool_allocations[conflict_buf_info];
          next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes;
          next_offset =
              round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
          // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid.
          if (IsValidPlacement(pool_allocation->pool_info, next_offset,
                               buf_info->size_bytes->value)) {
            // There could be multiple conflicting BufferInfo in the same pool.
            // Thus, we need to make sure we pick the largest offset of them all.
            if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
              pool_offset_candidates[pool_allocation->pool_info] = next_offset;
            }
          } else {
            pool_offset_candidates.erase(pool_allocation->pool_info);
          }
        }
      }
      auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates);
      pool_allocations.Set(
          buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
    }
    return pool_allocations;
  }
};

/*!
 * \brief This class implements Greedy by the size of BufferInfo
 * greedy algorithm. Please refer to main documentation of the file
 * for more details.
 */
class GreedySize : public GreedyBase {
 public:
  GreedySize() {}
  Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
    std::vector<BufferInfo> buffer_info_vec;
    Map<BufferInfo, PoolAllocation> pool_allocations;
    for (const auto& buffer_info : buffer_info_arr) {
      buffer_info_vec.push_back(std::move(buffer_info));
    }
    std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
              [](const BufferInfo& a, const BufferInfo& b) {
                if (a->size_bytes->value == b->size_bytes->value) {
                  if (a->conflicts.size() == b->conflicts.size()) {
                    return std::string(a->name_hint->data) > std::string(b->name_hint->data);
                  } else {
                    return a->conflicts.size() > b->conflicts.size();
                  }
                }
                return a->size_bytes > b->size_bytes;
              });
    return PostSortAllocation(buffer_info_vec);
  }
};

/*!
 * \brief This class implements Greedy by the number of conflicts of
 * BufferInfo greedy algorithm. Please refer to main documentation
 * of the file for more details.
 */
class GreedyConflicts : public GreedyBase {
 public:
  GreedyConflicts() {}
  Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
    std::vector<BufferInfo> buffer_info_vec;
    Map<BufferInfo, PoolAllocation> pool_allocations;
    for (const auto& buffer_info : buffer_info_arr) {
      buffer_info_vec.push_back(std::move(buffer_info));
    }
    std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
              [](const BufferInfo& a, const BufferInfo& b) {
                if (a->conflicts.size() == b->conflicts.size()) {
                  if (a->size_bytes->value == b->size_bytes->value) {
                    return std::string(a->name_hint->data) > std::string(b->name_hint->data);
                  } else {
                    return a->size_bytes->value > b->size_bytes->value;
                  }
                }
                return a->conflicts.size() > b->conflicts.size();
              });
    return PostSortAllocation(buffer_info_vec);
  }
};

Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr,
                                             const Integer& memory_pressure) {
  return GreedySize().PlanMemory(buffer_info_arr);
}

Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
                                                  const Integer& memory_pressure) {
  return GreedyConflicts().PlanMemory(buffer_info_arr);
}

TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
    .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
      return GreedyBySize(buffer_info_arr, memory_pressure);
    });

TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
    .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
      return GreedyByConflicts(buffer_info_arr, memory_pressure);
    });

}  // namespace algo
}  // namespace usmp
}  // namespace tir
}  // namespace tvm