/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/******** Workload ********/

Workload::Workload(IRModule mod) {
  ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
  n->shash = tvm::StructuralHash()(mod);
  n->mod = mod;
  data_ = std::move(n);
}

Workload::Workload(IRModule mod, Workload::THashCode shash) {
  ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
  n->mod = mod;
  n->shash = shash;
  data_ = std::move(n);
}

ObjectRef WorkloadNode::AsJSON() const {
  // Convert `this->mod` to JSON
  std::string json_mod = tvm::SaveJSON(this->mod);
  // Dump the JSON string to base64
  std::string b64_mod = Base64Encode(json_mod);
  // Output
  return Array<ObjectRef>{SHash2Str(this->shash), String(b64_mod)};
}

Workload Workload::FromJSON(const ObjectRef& json_obj) {
  IRModule mod{nullptr};
  THashCode shash = 0;
  try {
    const ArrayNode* json_array = json_obj.as<ArrayNode>();
    CHECK(json_array && json_array->size() == 2);
    // Load json[0] => shash
    String str_shash = Downcast<String>(json_array->at(0));
    // Load json[1] => mod
    {
      String b64_mod = Downcast<String>(json_array->at(1));
      std::string json_mod = Base64Decode(b64_mod);
      mod = Downcast<IRModule>(LoadJSON(json_mod));
    }
    // Verify SHash(mod) == shash
    shash = tvm::StructuralHash()(mod);
    String recalc_shash = SHash2Str(shash);
    CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash
                                      << "; Recalculated: " << recalc_shash;
  } catch (const std::runtime_error& e) {  // includes tvm::Error and dmlc::Error
    LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
               << "\nThe error is: " << e.what();
  }
  return Workload(mod, shash);
}

/******** TuningRecord ********/

TuningRecord::TuningRecord(tir::Trace trace, Array<FloatImm> run_secs, Workload workload,
                           Target target, Array<ArgInfo> args_info) {
  ObjectPtr<TuningRecordNode> n = make_object<TuningRecordNode>();
  n->trace = trace;
  n->run_secs = run_secs;
  n->workload = workload;
  n->target = target;
  n->args_info = args_info;
  this->data_ = n;
}

ObjectRef TuningRecordNode::AsJSON() const {
  Array<ObjectRef> json_args_info;
  json_args_info.reserve(args_info.size());
  for (const ArgInfo& arg_info : args_info) {
    json_args_info.push_back(arg_info->AsJSON());
  }
  return Array<ObjectRef>{trace->AsJSON(false),  //
                          run_secs,              //
                          target->Export(),      //
                          json_args_info};
}

TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) {
  tir::Trace trace{nullptr};
  Array<FloatImm> run_secs{nullptr};
  Target target{nullptr};
  Array<ArgInfo> args_info;
  try {
    const ArrayNode* json_array = json_obj.as<ArrayNode>();
    CHECK(json_array && json_array->size() == 4);
    // Load json[1] => run_secs
    run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
    // Load json[2] => target
    target = Target(Downcast<Map<String, ObjectRef>>(json_array->at(2)));
    // Load json[3] => args_info
    {
      const ArrayNode* json_args_info = json_array->at(3).as<ArrayNode>();
      args_info.reserve(json_args_info->size());
      for (const ObjectRef& json_arg_info : *json_args_info) {
        args_info.push_back(ArgInfo::FromJSON(json_arg_info));
      }
    }
    // Load json[0] => trace
    {
      const ObjectRef& json_trace = json_array->at(0);
      tir::Schedule sch =
          tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0,
                                /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
      tir::Trace::ApplyJSONToSchedule(json_trace, sch);
      trace = sch->trace().value();
    }
  } catch (const std::runtime_error& e) {  // includes tvm::Error and dmlc::Error
    LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
               << "\nThe error is: " << e.what();
  }
  return TuningRecord(trace, run_secs, workload, target, args_info);
}

/******** PyDatabase ********/

Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
                              PyDatabaseNode::FCommitWorkload f_commit_workload,
                              PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
                              PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) {
  ObjectPtr<PyDatabaseNode> n = make_object<PyDatabaseNode>();
  n->f_has_workload = f_has_workload;
  n->f_commit_workload = f_commit_workload;
  n->f_commit_tuning_record = f_commit_tuning_record;
  n->f_get_top_k = f_get_top_k;
  n->f_size = f_size;
  return Database(n);
}

/******** FFI ********/

TVM_REGISTER_NODE_TYPE(WorkloadNode);
TVM_REGISTER_NODE_TYPE(TuningRecordNode);
TVM_REGISTER_OBJECT_TYPE(DatabaseNode);
TVM_REGISTER_NODE_TYPE(PyDatabaseNode);
TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) {
  return Workload(mod);
});
TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON")
    .set_body_method<Workload>(&WorkloadNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord")
    .set_body_typed([](tir::Trace trace, Array<FloatImm> run_secs, Workload workload, Target target,
                       Array<ArgInfo> args_info) {
      return TuningRecord(trace, run_secs, workload, target, args_info);
    });
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON")
    .set_body_method<TuningRecord>(&TuningRecordNode::AsJSON);
TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload")
    .set_body_method<Database>(&DatabaseNode::HasWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload")
    .set_body_method<Database>(&DatabaseNode::CommitWorkload);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord")
    .set_body_method<Database>(&DatabaseNode::CommitTuningRecord);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK")
    .set_body_method<Database>(&DatabaseNode::GetTopK);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method<Database>(&DatabaseNode::Size);
TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase);

}  // namespace meta_schedule
}  // namespace tvm