/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file relay/backend/util.cc
 * \brief Relay backend utilities.
 */

#include "utils.h"

#include <tvm/parser/parser.h>
#include <tvm/relay/qnn/transform.h>

#include "te_compiler.h"

namespace tvm {
namespace relay {
namespace backend {

TVM_REGISTER_NODE_TYPE(StorageInfoNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<StorageInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
      const auto* node = ref.as<StorageInfoNode>();
      p->stream << "StorageInfoNode("
                << "storage_ids=[";
      for (auto id : node->storage_ids) {
        p->stream << id << ",";
      }
      p->stream << "], virtual_devices=[";
      for (const auto& virtual_device : node->virtual_devices) {
        p->stream << virtual_device << ",";
      }
      p->stream << "], storage_size_in_bytes=[";
      for (auto bytes : node->storage_sizes_in_bytes) {
        p->stream << bytes << ",";
      }
      p->stream << "])";
    });

StorageInfo::StorageInfo(std::vector<int64_t> storage_ids,
                         std::vector<VirtualDevice> virtual_devices,
                         std::vector<int64_t> storage_sizes_in_bytes) {
  ICHECK_EQ(storage_ids.size(), virtual_devices.size());
  ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size());
  auto node = make_object<StorageInfoNode>();
  node->storage_ids = std::move(storage_ids);
  node->virtual_devices = std::move(virtual_devices);
  node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes);
  data_ = std::move(node);
}

// This is the legacy interface for devices as DLDeviceTypes (represented by integers)
TVM_REGISTER_GLOBAL("relay.ir.StorageInfo")
    .set_body_typed([](const Array<Integer>& sids, const Array<Integer>& device_types,
                       const Array<Integer>& sizes_in_bytes) {
      std::vector<int64_t> sids_v;
      sids_v.reserve(sids.size());
      for (auto s : sids) {
        sids_v.push_back(s);
      }
      std::vector<VirtualDevice> virtual_devices_v;
      virtual_devices_v.reserve(device_types.size());
      for (const auto& device_type : device_types) {
        virtual_devices_v.emplace_back(VirtualDevice::ForDeviceType(device_type));
      }
      std::vector<int64_t> size_in_bytes_v;
      size_in_bytes_v.reserve(sizes_in_bytes.size());
      for (auto s : sizes_in_bytes) {
        size_in_bytes_v.push_back(s);
      }
      return StorageInfo(std::move(sids_v), std::move(virtual_devices_v),
                         std::move(size_in_bytes_v));
    });

TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) {
  Array<tvm::Integer> ids;
  for (auto id : si->storage_ids) {
    ids.push_back(id);
  }
  return ids;
});

// This is the legacy interface for devices as DLDeviceTypes (represented by integers)
TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) {
  Array<tvm::Integer> device_types;
  for (const auto& virtual_device : si->virtual_devices) {
    device_types.push_back(virtual_device->device_type());
  }
  return device_types;
});

TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageSizes").set_body_typed([](StorageInfo si) {
  Array<tvm::Integer> storage_sizes_in_bytes;
  for (auto id : si->storage_sizes_in_bytes) {
    storage_sizes_in_bytes.push_back(id);
  }
  return storage_sizes_in_bytes;
});

TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode);

StaticMemoryPlan::StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info) {
  auto n = make_object<StaticMemoryPlanNode>();
  n->expr_to_storage_info = std::move(expr_to_storage_info);
  data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan")
    .set_body_typed([](const Map<Expr, StorageInfo>& expr_to_storage_info) {
      return StaticMemoryPlan(expr_to_storage_info);
    });

// TODO(mbs): Cf GetMemorySizeBytes in aot_executor_codegen.cc, GetMemorySize in
// graph_plan_memory.cc
int64_t CalculateRelayExprSizeBytes(const Type& expr_type) {
  if (expr_type->IsInstance<TupleTypeNode>()) {
    auto tuple_type = Downcast<TupleType>(expr_type);
    int64_t size = 0;
    for (const auto& field : tuple_type->fields) {
      size += CalculateRelayExprSizeBytes(field);
    }
    return size;
  }
  auto tensor_type = expr_type.as<TensorTypeNode>();
  auto shape = tensor_type->shape;
  int num_of_elements = 1;
  for (const auto& dim_index_expr : shape) {
    if (dim_index_expr->IsInstance<IntImmNode>()) {
      num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
    } else {
      // If shape is dynamic, we cannot calculate workspace in compile time.
      num_of_elements = 0;
    }
  }
  auto element_size = tensor_type->dtype.bytes();
  return element_size * num_of_elements;
}

TVM_REGISTER_NODE_TYPE(FunctionInfoNode);

FunctionInfo::FunctionInfo(Map<Target, Integer> workspace_sizes, Map<Target, Integer> io_sizes,
                           Map<Target, Integer> constant_sizes,
                           Map<Target, tir::PrimFunc> tir_primfuncs,
                           Map<Target, Function> relay_primfuncs) {
  ObjectPtr<FunctionInfoNode> n = make_object<FunctionInfoNode>();
  n->workspace_sizes = std::move(workspace_sizes);
  n->io_sizes = std::move(io_sizes);
  n->constant_sizes = std::move(constant_sizes);
  n->tir_primfuncs = std::move(tir_primfuncs);
  n->relay_primfuncs = std::move(relay_primfuncs);
  data_ = std::move(n);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<FunctionInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
      auto* node = static_cast<const FunctionInfoNode*>(ref.get());
      p->stream << "FunctionInfoNode(\n"
                << "workspace_sizes=" << node->workspace_sizes << ",\n  io_sizes=" << node->io_sizes
                << ",\n  constant_sizes=" << node->constant_sizes
                << ",\n  tir_primfuncs=" << node->tir_primfuncs
                << ",\n  relay_primfuncs=" << node->relay_primfuncs << ")";
    });

Array<Pass> GetPassPrefix(bool is_homegeneous, bool is_vm) {
  Array<Pass> pass_seqs;
  // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton
  // by most passes there's little utility in including this now. Plus we'd need to only do
  // this if there's no existing spans to work from.
  // pass_seqs.push_back(parser::AnnotateSpans());
  Array<runtime::String> entry_functions{"main"};
  pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
  pass_seqs.push_back(transform::ToBasicBlockNormalForm());
  // Run all dialect legalization passes.
  pass_seqs.push_back(relay::qnn::transform::Legalize());

  // Legalize pass is restricted to homogeneous execution for now.
  if (is_homegeneous) {
    pass_seqs.push_back(transform::Legalize());
  }

  pass_seqs.push_back(transform::SimplifyInference());

  if (is_vm) {
    // eta expand to support constructors in argument position
    pass_seqs.push_back(transform::EtaExpand(
        /* expand_constructor */ true, /* expand_global_var */ false));
  } else {
    // Convert Dynamic ops to static versions
    pass_seqs.push_back(transform::DynamicToStatic());
  }

  PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
    Expr expr = args[0];
    if (expr.as<CallNode>()) {
      auto call_node = expr.as<CallNode>();
      auto op_node = call_node->op.as<OpNode>();
      if (op_node->name == "cast") {
        auto attrs = call_node->attrs.as<CastAttrs>();
        if (attrs->dtype == DataType::Int(32)) {
          *rv = true;
        }
      }
    }
    *rv = false;
  });
  pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
  pass_seqs.push_back(transform::SimplifyExpr());
  pass_seqs.push_back(transform::CombineParallelConv2D(3));
  pass_seqs.push_back(transform::CombineParallelDense(3));
  pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
  pass_seqs.push_back(transform::FoldConstant());
  pass_seqs.push_back(transform::FoldScaleAxis());
  pass_seqs.push_back(transform::CanonicalizeCast());
  pass_seqs.push_back(transform::CanonicalizeOps());

  // Alter layout transformation is currently only applied to homogeneous execution.
  if (is_homegeneous) {
    if (!is_vm) {
      pass_seqs.push_back(transform::InferType());
    }
    pass_seqs.push_back(transform::AlterOpLayout());
  }

  // Fast math optimizations.
  pass_seqs.push_back(transform::FastMath());
  pass_seqs.push_back(transform::FoldConstant());
  return pass_seqs;
}

std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map) {
  std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> std_map;
  for (auto kv : input_map) {
    std_map[kv.first] = kv.second;
  }
  return std_map;
}

Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
    std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map) {
  Map<Target, IRModule> tvm_map;
  for (auto kv : input_map) {
    tvm_map.Set(kv.first, kv.second);
  }
  return tvm_map;
}

void UpdateAutoSchedulerOpWeights(const IRModule& module) {
  const auto* te_compiler_update_weights =
      runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights");

  ICHECK(te_compiler_update_weights != nullptr)
      << "auto_scheduler.relay_integration.te_compiler_update_weights";

  Map<String, Integer> weight_map =
      module->GetAttr<Map<String, Integer>>("op_weights", Map<String, Integer>()).value();

  (*te_compiler_update_weights)(weight_map);
}

}  // namespace backend
}  // namespace relay
}  // namespace tvm