/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Loop unrolling as in Halide pipeline.
 * \file unroll_loop.cc
 */
// Unrolls the loop as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "ir_utils.h"

namespace tvm {
namespace tir {

struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
  int auto_max_step;
  int auto_max_depth;
  int auto_max_extent;
  int explicit_unroll;

  TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
    TVM_ATTR_FIELD(auto_max_step)
        .describe("Threshold of number of steps in the loop to be automatically unrolled")
        .set_default(0);
    TVM_ATTR_FIELD(auto_max_depth)
        .describe("The maximum nested level of loops that can be automatically unrolled.")
        .set_default(8);
    TVM_ATTR_FIELD(auto_max_extent)
        .describe("The maximum extent of loop that will be unrolled.")
        .set_default(0);
    TVM_ATTR_FIELD(explicit_unroll)
        .describe("Whether to explicitly unroll the loop instead of setting a pragma")
        .set_default(true);
  }
};

class UnrollLoopConfig : public Attrs {
 public:
  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode);
};

TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

class LoopUnroller : public StmtExprMutator {
 public:
  explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
                        bool explicit_unroll)
      : auto_max_step_(auto_max_step),
        auto_max_depth_(auto_max_depth),
        auto_max_extent_(auto_max_extent),
        explicit_unroll_(explicit_unroll) {}

  Stmt VisitStmt_(const AttrStmtNode* op) final {
    if (op->attr_key == "pragma_auto_unroll_max_step") {
      int value = static_cast<int>(Downcast<Integer>(op->value)->value);
      std::swap(value, auto_max_step_);
      Stmt ret = this->VisitStmt(op->body);
      std::swap(value, auto_max_step_);
      return ret;
    } else if (op->attr_key == "pragma_unroll_explicit") {
      bool explicit_unroll = Downcast<Integer>(op->value)->value;
      std::swap(explicit_unroll, explicit_unroll_);
      Stmt ret = this->VisitStmt(op->body);
      std::swap(explicit_unroll, explicit_unroll_);
      return ret;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }

  Stmt VisitStmt_(const ForNode* op) {
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
    op = stmt.as<ForNode>();
    int value = GetExtent(op);
    // condition for auto unroll
    bool auto_unroll = (op->kind == ForKind::kSerial && value >= 0 && normal_loop_depth_ == 0 &&
                        unroll_depth_ <= auto_max_depth_);

    auto_unroll =
        auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_);

    if (op->kind == ForKind::kUnrolled) {
      ICHECK_GE(value, 0) << "Cannot unroll non-constant loop";
      auto_unroll = true;
    }

    if (auto_unroll) {
      step_count_ *= value;
      unroll_depth_ += 1;
    } else {
      normal_loop_depth_ += 1;
    }

    if ((auto_unroll && explicit_unroll_) ||
        // unroll loops with extent = 1, no matter how many steps in body
        (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
      return Unroll(op);
    } else {
      if (auto_unroll) {
        if (op->kind != ForKind::kUnrolled) {
          return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body,
                     op->thread_binding, op->annotations);
        }
      }
      return stmt;
    }
  }

  Stmt VisitStmt_(const StoreNode* op) final {
    ++step_count_;
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const EvaluateNode* op) final {
    ++step_count_;
    return StmtExprMutator::VisitStmt_(op);
  }

  Stmt VisitStmt_(const SeqStmtNode* op) final {
    auto fmutate = [this](const Stmt& s) {
      int step_count = step_count_;
      int unroll_depth = unroll_depth_;
      int normal_loop_depth = normal_loop_depth_;
      step_count_ = 0;
      unroll_depth_ = 0;
      normal_loop_depth_ = 0;
      Stmt ret = this->VisitStmt(s);
      step_count_ += step_count;
      normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
      unroll_depth_ = std::max(unroll_depth_, unroll_depth);
      return ret;
    };
    return StmtMutator::VisitSeqStmt_(op, false, fmutate);
  }

  Stmt Unroll(const ForNode* op) {
    int value = GetExtent(op);
    // For loop must have a constant integer extent
    ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
    if (value == 0) return Evaluate(0);
    Stmt body = op->body;
    Map<Var, PrimExpr> vmap;
    Array<Stmt> unrolled;
    for (int i = 0; i < value; ++i) {
      vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
      Stmt step = Substitute(body, vmap);
      unrolled.push_back(step);
    }
    return SeqStmt::Flatten(unrolled);
  }

 private:
  // returns the extent of the loop if it's a constant integer, otherwise return -1
  int GetExtent(const ForNode* op) {
    // constant folding.
    PrimExpr extent = analyzer_.Simplify(op->extent);
    const IntImmNode* v1 = extent.as<IntImmNode>();
    int value = -1;
    // integers that do not fit in int32_t are treated as symbolic,
    // as it's impossible to unroll such large loops
    if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
      value = static_cast<int>(v1->value);
    }
    return value;
  }

  // maximum number of step to perform auto unroll.
  int auto_max_step_;
  int auto_max_depth_;
  // max extent of loop to auto unroll
  // this not not count the total steps, only count the number of loops
  int auto_max_extent_;
  bool explicit_unroll_;
  // Number of normal loops in scope
  int normal_loop_depth_{0};
  // number of unrolled cases in current scope.
  int unroll_depth_{0};
  // Number of total steps unrolled
  int step_count_{0};
  // analyzer
  arith::Analyzer analyzer_;
};

Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
  Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
                          cfg->explicit_unroll)(stmt);
  if (!ret.same_as(stmt)) {
    return ConvertSSA(ret);
  } else {
    return ret;
  }
}

namespace transform {

Pass UnrollLoop() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    auto* n = f.CopyOnWrite();
    auto cfg = ctx->GetConfig<UnrollLoopConfig>("tir.UnrollLoop");
    if (!cfg.defined()) {
      cfg = AttrsWithDefaultValues<UnrollLoopConfig>();
    }
    n->body = UnrollLoop(std::move(f->body), cfg.value());
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
}

TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop);

}  // namespace transform

}  // namespace tir
}  // namespace tvm