/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file auto_scheduler/search_policy/sketch_policy_rules.h
 * \brief Rules for generating the sketches, sampling the initial population, and mutating the
 * population in SketchPolicy.
 */

#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_
#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_

#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/search_task.h>

#include <string>
#include <utility>
#include <vector>

#include "utils.h"

namespace tvm {
namespace auto_scheduler {

class SketchPolicyNode;

/********** Sketch Generation Rule **********/

/*! \brief The base class for derivation rules used in the sketch generation. */
class SketchGenerationRule {
 public:
  /*! \brief Result enumeration of the condition function. */
  enum class ConditionKind : int {
    /*! \brief Skip this rule and continue to try the next rules. */
    kSkip = 0,
    /*! \brief Apply this rule and continue to try the next rules. */
    kApply = 1,
    /*! \brief Apply this rule and skip the rest rules. */
    kApplyAndSkipRest = 2
  };

  /*!
   * \brief Condition check function of this rule.
   * \param policy The SketchPolicyNode of this rule, some information may be used during
   * the condition checking.
   * \param state The original state to be checked.
   * \param stage_id The index of the stage to process this condition check.
   * \return The condition check result of this rule.
   */
  virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
                                      int stage_id) const = 0;

  /*!
   * \brief Apply function of this rule.
   * \param policy The SketchPolicyNode of this rule, some information may be used during
   * the rule applying.
   * \param state The original state to apply this rule.
   * \param stage_id The index of the next stage to apply this rule.
   * \return The state after applying this rule, and index of the next stage.
   */
  virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy,
                                                   const State& state, int stage_id) const = 0;

  /*!
   * \brief Get the name of this rule.
   * \return A string of the rule name.
   */
  virtual std::string GetRuleName() const = 0;
};

#define DEFINE_SKETCH_GENERATION_RULE(rule_name)                                                 \
  class rule_name : public SketchGenerationRule {                                                \
   public:                                                                                       \
    ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,              \
                                int stage_id) const final;                                       \
    std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
                                             int stage_id) const final;                          \
    std::string GetRuleName() const final { return #rule_name; }                                 \
  };

/*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
 * the next stage. */
DEFINE_SKETCH_GENERATION_RULE(RuleSkipStage);

/*! \brief The rule that inlines simple elementwise ops.
 * \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly
 * inlineable will have a chance to try different compute at location in InitPopulation later.
 */
DEFINE_SKETCH_GENERATION_RULE(RuleAlwaysInline);

/*! \brief The rule that performs multi-level tiling. */
DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTiling);

/*! \brief The rule that performs multi-level tiling and fuses later consumers. */
DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTilingWithFusion);

/*! \brief The rule that adds a cache read stage. Mainly used for GPU cooperative fetching,
 * Currently only support 1 to 1 match cache read. */
DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheRead);

/*! \brief The rule that adds a cache write stage. */
DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheWrite);

/*! \brief The rule that adds rfactor stage. */
DEFINE_SKETCH_GENERATION_RULE(RuleAddRfactor);

/*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors.
 * This kind of op comes from winograd transformation. */
DEFINE_SKETCH_GENERATION_RULE(RuleSimplifyComputeWithConstTensor);

/*! \brief The rule that use cross thread reduction for GPU. */
DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);

/*! \brief Handle special cases in Winograd transformation for GPU. We need to change the compute
 * location of the producers of compute ops that perform "fake reduction" with const tensors. */
DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);

/*! \brief The rule that allows users to generate custom sketches. */
class RuleCustomSketch : public SketchGenerationRule {
 public:
  RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func,
                   String rule_name = "CustomSketchRule")
      : meet_condition_func_(std::move(meet_condition_func)),
        apply_func_(std::move(apply_func)),
        rule_name_(std::move(rule_name)) {}

  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
                              int stage_id) const final;

  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
                                           int stage_id) const final;

  std::string GetRuleName() const final { return rule_name_; }

 private:
  PackedFunc meet_condition_func_;
  PackedFunc apply_func_;
  String rule_name_;
};

/********** Init Population **********/

/*! \brief The base class for rules used to annotate the sketches to get the initial population. */
class PopulationGenerationRule {
 public:
  /*! \brief Result enumeration of the apply function. */
  enum class ResultKind : int { kValid = 0, kInvalid = 1 };

  /*!
   * \brief Apply function of this rule.
   * \param policy The SketchPolicyNode of this rule, some member may get changed during the
   * rule applying. (e.g. random number generator)
   * \param state The state to apply this rule, update inplace.
   * \return The result of this rule, indicate if there's any valid state generated.
   */
  virtual ResultKind Apply(SketchPolicyNode* policy, State* state,
                           std::mt19937* rand_gen) const = 0;

  /*! \brief The deconstructor */
  virtual ~PopulationGenerationRule() = default;
};

// A helper to define population initialization rules
#define DEFINE_INIT_POPULATION_RULE(rule_name)                                                    \
  class rule_name : public PopulationGenerationRule {                                             \
   public:                                                                                        \
    ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
  };

/*! \brief The rule that fills the incomplete SplitSteps. */
DEFINE_INIT_POPULATION_RULE(InitFillTileSize);

/*! \brief The rule that randomly changes the computation location for some stages that do not
 * need tiling and are not strictly inlineable(e.g. data padding). */
DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation);

/*! \brief The rule that annotates parallel for CPU. */
DEFINE_INIT_POPULATION_RULE(InitParallel);

/*! \brief The rule that annotates unroll. */
DEFINE_INIT_POPULATION_RULE(InitUnroll);

/*! \brief The rule that annotates vectorization. */
DEFINE_INIT_POPULATION_RULE(InitVectorization);

/*! \brief The rule that annotates thread binding for GPU. */
DEFINE_INIT_POPULATION_RULE(InitThreadBind);

/********** Mutation **********/

/*! \brief The base class for mutation rules used in the evolutionary search. */
class PopulationMutationRule : public PopulationGenerationRule {
 public:
  /* \brief The constructor
   * \param selection_weight the probabiliy of applying this rule is
   *        proportional to this weight
   */
  explicit PopulationMutationRule(double selection_weight) : weight(selection_weight) {}

  /* \brief The weight of this rule */
  double weight;
};

// A helper to define mutation rules used in the evolutionary search
#define DEFINE_MUTATE_POPULATION_RULE(rule_name)                                                  \
  class rule_name : public PopulationMutationRule {                                               \
   public:                                                                                        \
    explicit rule_name(double weight) : PopulationMutationRule(weight) {}                         \
    ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \
  };

/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor
    and multipling it to another tile size. */
DEFINE_MUTATE_POPULATION_RULE(MutateTileSize);

/*! \brief The rule that mutates the number of fused outer iterators annotated by parallel. */
DEFINE_MUTATE_POPULATION_RULE(MutateParallel);

/*! \brief The rule that randomly changes the computation location for some stages that do not
 * need tiling and are not strictly inlineable(e.g. data padding). */
DEFINE_MUTATE_POPULATION_RULE(MutateComputeLocation);

/*! \brief The rule that mutates the value of a randomly selected auto unroll pragma step. */
DEFINE_MUTATE_POPULATION_RULE(MutateAutoUnroll);

}  // namespace auto_scheduler
}  // namespace tvm

#endif  // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_