/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include "../utils.h"

namespace tvm {
namespace tir {

Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const String& func_name) {
  struct Finder : public StmtVisitor {
    explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {}

    void VisitStmt_(const BlockNode* block) override {
      if (block->name_hint == name_) {
        auto it = self_->stmt2ref.find(block);
        ICHECK(it != self_->stmt2ref.end());
        results_.push_back(it->second);
      }
      StmtVisitor::VisitStmt_(block);
    }

    const ScheduleState& self_;
    const String& name_;
    Array<StmtSRef> results_;
  };

  BaseFunc func = self->mod->Lookup(func_name);
  const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode);
  Finder finder(self, name);
  finder(prim_func->body);
  return std::move(finder.results_);
}

Array<StmtSRef> GetLoops(const StmtSRef& block_sref) {
  std::vector<StmtSRef> result;
  for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance<ForNode>();
       parent = parent->parent) {
    result.push_back(GetRef<StmtSRef>(parent));
  }
  return {result.rbegin(), result.rend()};
}

Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) {
  struct Collector : public StmtVisitor {
   private:
    void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); }

   public:
    explicit Collector(const ScheduleState& self) : self(self) {}

    const ScheduleState& self;
    Array<StmtSRef> result;
  };
  Collector collector(self);
  if (parent_sref->stmt->IsInstance<ForNode>()) {
    const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
    collector(loop->body);
  } else if (parent_sref->stmt->IsInstance<BlockNode>()) {
    const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
    collector(block->body);
  }
  return std::move(collector.result);
}

Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
  StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
                                     /*require_stage_pipeline=*/false);
  Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
  Array<StmtSRef> results;
  results.reserve(edges.size());
  for (const Dependency& edge : edges) {
    if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
      results.push_back(edge->src);
    }
  }
  return results;
}

Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
  StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
                                     /*require_stage_pipeline=*/false);
  Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
  Array<StmtSRef> results;
  results.reserve(edges.size());
  for (const Dependency& edge : edges) {
    if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
      results.push_back(edge->dst);
    }
  }
  return results;
}

/******** InstructionKind Registration ********/

struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
  static constexpr const char* kName = "GetBlock";
  static constexpr bool kIsPure = true;

 private:
  static constexpr size_t kNumInputs = 0;
  static constexpr size_t kNumAttrs = 2;
  static constexpr size_t kNumDecisions = 0;

  static BlockRV UnpackedApplyToSchedule(Schedule sch, String name, String func_name) {
    return sch->GetBlock(name, func_name);
  }

  static String UnpackedAsPython(Array<String> outputs, String name, String func_name) {
    PythonAPICall py("get_block");
    py.Input("name", name);
    py.Input("func_name", func_name);
    py.SingleOutput(outputs);
    return py.Str();
  }

  template <typename>
  friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetLoopsTraits : public UnpackedInstTraits<GetLoopsTraits> {
  static constexpr const char* kName = "GetLoops";
  static constexpr bool kIsPure = true;

 private:
  static constexpr size_t kNumInputs = 1;
  static constexpr size_t kNumAttrs = 0;
  static constexpr size_t kNumDecisions = 0;

  static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
    return sch->GetLoops(block_rv);
  }

  static String UnpackedAsPython(Array<String> outputs, String block_rv) {
    PythonAPICall py("get_loops");
    py.Input("block", block_rv);
    py.OutputList(outputs);
    return py.Str();
  }

  template <typename>
  friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> {
  static constexpr const char* kName = "GetChildBlocks";
  static constexpr bool kIsPure = true;

 private:
  static constexpr size_t kNumInputs = 1;
  static constexpr size_t kNumAttrs = 0;
  static constexpr size_t kNumDecisions = 0;

  static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) {
    if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
      return sch->GetChildBlocks(GetRef<BlockRV>(block));
    }
    if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
      return sch->GetChildBlocks(GetRef<LoopRV>(loop));
    }
    LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey();
    throw;
  }

  static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv) {
    PythonAPICall py("get_child_blocks");
    py.Input("", block_or_loop_rv);
    py.OutputList(outputs);
    return py.Str();
  }

  template <typename>
  friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetProducersTraits : public UnpackedInstTraits<GetProducersTraits> {
  static constexpr const char* kName = "GetProducers";
  static constexpr bool kIsPure = true;

 private:
  static constexpr size_t kNumInputs = 1;
  static constexpr size_t kNumAttrs = 0;
  static constexpr size_t kNumDecisions = 0;

  static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
    return sch->GetProducers(block_rv);
  }

  static String UnpackedAsPython(Array<String> outputs, String block_rv) {
    PythonAPICall py("get_producers");
    py.Input("block", block_rv);
    py.OutputList(outputs);
    return py.Str();
  }

  template <typename>
  friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> {
  static constexpr const char* kName = "GetConsumers";
  static constexpr bool kIsPure = true;

 private:
  static constexpr size_t kNumInputs = 1;
  static constexpr size_t kNumAttrs = 0;
  static constexpr size_t kNumDecisions = 0;

  static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
    return sch->GetConsumers(block_rv);
  }

  static String UnpackedAsPython(Array<String> outputs, String block_rv) {
    PythonAPICall py("get_consumers");
    py.Input("block", block_rv);
    py.OutputList(outputs);
    return py.Str();
  }

  template <typename>
  friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits);

}  // namespace tir
}  // namespace tvm