/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *
 * \file simplify_fc_transpose.cc
 *
 * \brief Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to
 *        ```y = nn.dense(x, wt)```
 */
#include <tvm/ir/expr.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>

#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relay {

// Find name of weight in ```y = nn.dense(x, tranpose(w, [1, 0]))```
class FCTransposeVisitor : private ExprVisitor {
 public:
  FCTransposeVisitor() : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {}

  Array<String> Search(const Expr& expr) {
    VisitExpr(expr);
    return memo_;
  }

 private:
  void VisitExpr_(const CallNode* n) final {
    if (n->op == dense_op_) {
      const auto weight = n->args[1].as<CallNode>();
      if (weight) {
        if (weight->op == transpose_op_) {
          if (weight->args[0].as<VarNode>()) {
            const auto arg = weight->args[0].as<VarNode>();
            memo_.push_back(arg->name_hint());
          }
        }
      }
    }
    for (const auto& arg : n->args) {
      VisitExpr(arg);
    }
  }

  const Op& dense_op_;
  const Op& transpose_op_;
  Array<String> memo_;
};  // SearchDenseOpWeight

Array<String> SearchFCTranspose(const Expr& e) { return FCTransposeVisitor().Search(e); }

TVM_REGISTER_GLOBAL("relay.analysis.search_fc_transpose").set_body_typed(SearchFCTranspose);

// Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
class FCTransposeMutator : public ExprRewriter {
 public:
  explicit FCTransposeMutator(const Array<ObjectRef>& target_weights)
      : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {
    for (size_t i = 0; i < target_weights.size(); ++i) {
      ICHECK(target_weights[i]->IsInstance<runtime::StringObj>());
      std::string k = target_weights[i].as<runtime::StringObj>()->data;
      target_weights_.emplace(k);
    }
  }

  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
    if (pre->op == dense_op_) {
      const auto data = post.as<CallNode>()->args[0];
      const auto weight = pre->args[1].as<CallNode>();
      if (weight) {
        if (weight->op == transpose_op_) {
          const auto arg = weight->args[0];
          if (arg.as<VarNode>()) {
            const auto& arg_node = arg.as<VarNode>();
            ICHECK_GT(target_weights_.count(arg_node->name_hint()), 0);
            const auto& tt = arg_node->type_annotation.as<TensorTypeNode>();
            auto wt_type = TensorType({tt->shape[1], tt->shape[0]}, tt->dtype);
            Var wt(arg_node->name_hint() + ".T", wt_type);
            return Call(dense_op_, {data, wt}, pre->attrs, pre->type_args);
          }
        }
      }
    }
    return post;
  }

 private:
  // Cached op
  const Op& dense_op_;
  const Op& transpose_op_;
  std::unordered_set<std::string> target_weights_;
};  // class DenseToSparseDenseAlter

Expr SimplifyFCTranspose(const Expr& e, const Array<ObjectRef>& target_weights) {
  auto rewriter = FCTransposeMutator(target_weights);
  return PostOrderRewrite(e, &rewriter);
}

namespace transform {

Pass SimplifyFCTranspose(const Array<ObjectRef>& target_weights) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        // Remove FreeVar warning
        auto f0 = Downcast<Function>(SimplifyFCTranspose(f, target_weights));
        Array<Var> wt_params = FreeVars(f0);
        auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
        Array<Var> params = FreeVars(f1);
        for (const auto& var : wt_params) {
          params.push_back(var);
        }
        return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs);
      };
  return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"});
}

TVM_REGISTER_GLOBAL("relay._transform.SimplifyFCTranspose").set_body_typed(SimplifyFCTranspose);

}  // namespace transform

}  // namespace relay
}  // namespace tvm