/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include "./te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/topi/tags.h>

#include <functional>
#include <limits>
#include <mutex>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../op/memory/device_copy.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler_cache.h"
#include "./utils.h"

namespace tvm {
namespace relay {
// TODO(@jroesch, @csullivan): declare directly elsewhere
backend::StaticMemoryPlan GraphPlanMemory(const Function& func);

namespace tec {

using namespace tvm::relay::transform;

TVM_REGISTER_OBJECT_TYPE(TECompilerNode);

class TECompilerImpl : public TECompilerNode {
 public:
  explicit TECompilerImpl(Optional<IRModule> opt_mod) {
    // Make sure we don't collide with any existing globals in the module.
    if (opt_mod) {
      for (const auto& kv : opt_mod.value()->functions) {
        name_map_[kv.first->name_hint] = 1;
      }
    }
  }

  // Lower the function.
  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
    return LowerInternal(key, mangle_fn)->cached_func;
  }

  CachedFunc Lower(const CCacheKey& key, const String mod_name) {
    auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); };

    return Lower(key, mangle_fn);
  }

  // For now, build one module per function.
  PackedFunc JIT(const CCacheKey& key) final {
    auto mangle_fn = [](String name) { return name; };
    CCacheValue value = LowerInternal(key, mangle_fn);
    if (value->packed_func != nullptr) {
      return value->packed_func;
    }
    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
    return value->packed_func;
  }

  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
    return LowerShapeFuncInternal(key)->cached_func;
  }

  IRModule GetLoweredFunctions() {
    IRModule mod;
    // Extract lowered functions from the cache
    for (const auto& it : cache_) {
      auto source_func = it.first;
      auto lowered_func = it.second;

      IRModule lowered_mod = lowered_func->cached_func->funcs;

      // Annotate functions with their target and put them in the return module
      for (const auto& kv : lowered_mod->functions) {
        const GlobalVar& var = kv.first;
        const BaseFunc& func = kv.second;

        // Only add functions that are not external functions
        if (!func->GetAttr<String>(attr::kCompiler).defined()) {
          ICHECK(func->IsInstance<tir::PrimFuncNode>())
              << "Expected all functions that are not external to be PrimFuncs, but found:"
              << std::endl
              << PrettyPrint(func);
          const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
          mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
        }
      }
    }

    // Extract lowered dynamic shape functions from the shape cache
    for (const auto& it : shape_func_cache_) {
      auto source_func = it.first;
      auto lowered_func = it.second;
      auto target = source_func->target;
      IRModule lowered_mod = lowered_func->cached_func->funcs;

      // Annotate functions with their target and put them in the return module
      for (auto kv : lowered_mod->functions) {
        const GlobalVar& var = kv.first;
        const BaseFunc& func = kv.second;
        const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
        mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
      }
    }

    return mod;
  }

  void AddExterns(IRModule module) {
    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
    std::vector<GlobalVar> to_be_deleted;
    for (const auto& kv : module->functions) {
      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
        to_be_deleted.push_back(kv.first);
      }
    }
    for (const auto& global_var : to_be_deleted) {
      module->Remove(global_var);
    }
    // HOWEVER we still need a Relay definition to go with those now external functions, so
    // retrieve them from the cache and mark them with "ExternalSymbol".
    for (const auto& kv1 : cache_) {
      auto src_func = kv1.first->source_func;
      ICHECK(src_func.defined());
      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
            // Abandon the existing function annotations.
            Function function(function_node->params, function_node->body, function_node->ret_type,
                              function_node->type_params, /*attrs=*/{}, function_node->span);
            // Mark function as 'extern' using the "ExternalSymbol" attribute.
            function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
            module->Add(kv2.first, function);
          }
        }
      }
    }
  }

  Array<tvm::runtime::Module> LowerExternalFunctions() {
    Array<tvm::runtime::Module> ret;
    std::vector<CCacheKey> cached_ext_funcs;

    for (const auto& it : cache_) {
      auto src_func = it.first->source_func;
      ICHECK(src_func.defined());
      Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
      if (opt_compiler.defined()) {
        Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
        ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
                                          << PrettyPrint(src_func);
        VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
                << opt_symbol_name.value() << "' and function:" << std::endl
                << PrettyPrint(src_func);
        cached_ext_funcs.push_back(it.first);

        std::string ext_name = "relay.ext." + opt_compiler.value();
        auto pf = tvm::runtime::Registry::Get(ext_name);
        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
        // No need to keep compiler attribute at this point, functions have been
        // extracted for specific codegen.
        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
        VLOG_CONTEXT << ext_name;
        runtime::Module ext_mod = (*pf)(src_func);
        if (ext_mod.defined()) {
          if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
            // It's possible the codegen yielded C or C++ tracked separately and thus the
            // returned runtime module can be empty.
            VLOG(1) << "Unable to find definition for the external function '"
                    << opt_symbol_name.value()
                    << "' in the runtime module generated by external codegen '"
                    << opt_compiler.value() << "'";
          }
          ret.push_back(ext_mod);
        } else {
          // A warning only so that we can write unit tests which can return an empty runtime
          // module.
          LOG(WARNING) << "No external runtime module was generated by external codegen '"
                       << opt_compiler.value() << "'";
        }
      }
    }

    // No need to cache external functions as we collected them all to create
    // external runtime modules.
    for (const auto& it : cached_ext_funcs) {
      cache_.erase(it);
    }
    return ret;
  }

  Map<GlobalVar, String> GetDeviceContexts() { return device_contexts_; }
  void SetDeviceContexts(const Map<GlobalVar, String>& device_contexts) {
    device_contexts_ = device_contexts;
  }

  void Clear() final { cache_.clear(); }

  // List all items in the cache.
  Array<ObjectRef> ListItems() {
    std::lock_guard<std::mutex> lock(mutex_);
    Array<ObjectRef> items;
    for (auto& kv : cache_) {
      items.push_back(kv.first);
      items.push_back(kv.second);
    }
    return items;
  }

  /*!
   * \brief Get the cache key of the function that is being lowered currently
   * \return the cache key
   */
  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }

 private:
  // implement lowered func
  CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
    VLOG(1) << "lowering:" << std::endl
            << PrettyPrint(key->source_func) << std::endl
            << "for target:" << std::endl
            << key->target->ToDebugString();
    std::lock_guard<std::mutex> lock(mutex_);
    CCacheValue value;
    auto it = cache_.find(key);
    if (it != cache_.end()) {
      VLOG(1) << "already lowered to name:" << std::endl
              << PrettyPrint(it->second->cached_func->prim_fn_var);
      it->second->use_count += 1;
      if (it->second->cached_func.defined()) return it->second;
      value = it->second;
    } else {
      value = CCacheValue(make_object<CCacheValueNode>());
      value->use_count = 1;
      cache_[key] = value;
    }
    cur_ccache_key_ = key;

    Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler);
    if (opt_compiler.defined()) {
      // Don't compile now since we don't have anywhere to put the resulting runtime module.
      // Instead place the original definition in the cache and wait for LowerExternalFunctions.
      IRModule ir_module;
      Optional<String> opt_global_symbol =
          key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
      ICHECK(opt_global_symbol.defined()) << "External function has not been attached a name yet.";
      // Note that the source_func may already be bound to a global function in the module
      // we are compiling, in which case we should not attempt to make its name unique w.r.t.
      // the module's globals. Furthermore, the external codegen tool must bind the compiled
      // function to the "global_symbol" attribute on the source_func. So do not use GetUniqueName
      // here.
      auto target = Target("ext_dev");
      auto global_var = GlobalVar(opt_global_symbol.value());
      global_var->checked_type_ = key->source_func->checked_type();
      ir_module->Add(global_var, key->source_func);
      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr},
                                      tir::PrimFunc{nullptr}, {}, ir_module);
      // Collect these here as it's removed in LowerExternalFunctions()
      device_contexts_.Set(value->cached_func->prim_fn_var, opt_compiler.value());
      VLOG(1) << "preparing to use external codegen '" << opt_compiler.value()
              << "' with name:" << std::endl
              << PrettyPrint(value->cached_func->prim_fn_var) << std::endl
              << "and definitions:" << std::endl
              << PrettyPrint(value->cached_func->funcs);
      return value;
    }

    // Enforce use the target.
    With<Target> target_scope(key->target);

    ICHECK(!value->cached_func.defined());
    value->cached_func = PrimFuncFor(key->source_func, key->target, [&](std::string name) {
      auto mangled = mangle_fn(name);
      return GetUniqueName(mangled, &name_map_);
    });

    if (value->cached_func->prim_func.defined()) {
      VLOG(1) << "already have PrimFunc";
      value->cached_func->funcs->Add(value->cached_func->prim_fn_var,
                                     value->cached_func->prim_func.value());
    } else {
      // NOTE: array will copy on write.
      Array<te::Tensor> all_args = Array<te::Tensor>(value->cached_func->inputs);
      for (te::Tensor arg : value->cached_func->outputs) {
        all_args.push_back(arg);
      }
      // lower the function
      std::unordered_map<te::Tensor, tir::Buffer> binds;
      auto func_name = value->cached_func->prim_fn_var->name_hint;
      VLOG(1) << "scheduling";
      IRModule scheduled_module =
          tvm::LowerSchedule(value->cached_func->schedule, all_args, func_name, binds);
      // Unfortunately the above machinery creates its own GlobalVars instead of using *the*
      // GlobalVar we established above. Fix this before the confusion spreads any further.
      // TODO(mbs): LowerSchedule should be given prim_fn_gvar instead of func_name.
      for (const auto& kv : scheduled_module->functions) {
        GlobalVar global_var = kv.first->name_hint == value->cached_func->prim_fn_var->name_hint
                                   ? value->cached_func->prim_fn_var
                                   : kv.first;
        value->cached_func->funcs->Add(global_var, kv.second);
      }
      ICHECK(value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var)
                 .as<tir::PrimFuncNode>());
    }
    VLOG(1) << "lowered to name:" << std::endl
            << PrettyPrint(value->cached_func->prim_fn_var) << std::endl
            << "with definitions:" << std::endl
            << PrettyPrint(value->cached_func->funcs);

    return value;
  }

  // implement lowered shape func
  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
    VLOG(1) << "lowering dynamic shape function for:" << std::endl
            << PrettyPrint(key->source_func) << std::endl
            << "for target:" << std::endl
            << key->target->ToDebugString();
    std::lock_guard<std::mutex> lock(mutex_);
    CCacheValue value;
    auto it = shape_func_cache_.find(key);
    if (it != shape_func_cache_.end()) {
      it->second->use_count += 1;
      if (it->second->cached_func.defined()) return it->second;
      value = it->second;
    } else {
      value = CCacheValue(make_object<CCacheValueNode>());
      value->use_count = 0;
      shape_func_cache_[key] = value;
    }
    // Enforce use the target.
    With<Target> target_scope(key->target);

    ICHECK(!value->cached_func.defined());

    using tvm::transform::PassContext;
    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
    value->cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
      return GetUniqueName(name, &name_map_);
    });

    ICHECK(
        value->cached_func->funcs->Lookup(value->cached_func->prim_fn_var).as<tir::PrimFuncNode>());

    VLOG(1) << "lowered to name:" << std::endl
            << PrettyPrint(value->cached_func->prim_fn_var) << std::endl
            << "with definitions:" << std::endl
            << PrettyPrint(value->cached_func->funcs);
    return value;
  }

  Map<String, Integer> GetOpWeights() const {
    Map<String, Integer> weights;
    for (const auto& kv : cache_) {
      auto value = kv.second;
      auto name = value->cached_func->prim_fn_var->name_hint;
      weights.Set(name, value->use_count);
    }
    return weights;
  }

  // TODO(mbs): Hold the output module here and reduce the cache_ to just be from
  // Function to GlobalVar.

  /*! \brief compiler cache lock*/
  std::mutex mutex_;
  /*! \brief internal name map to get an unique name */
  std::unordered_map<std::string, int> name_map_;
  /*! \brief internal compiler cache */
  std::unordered_map<CCacheKey, CCacheValue> cache_;
  /*! \brief internal compiler cache for shape funcs */
  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
  /*! \brief the cache key of the function that is being lowered currently*/
  CCacheKey cur_ccache_key_;
  /*! \brief Map of GlobalVar to C Device API context names */
  Map<GlobalVar, String> device_contexts_;
};

TECompiler::TECompiler(Optional<IRModule> opt_mod) {
  auto object = make_object<TECompilerImpl>(std::move(opt_mod));
  data_ = object;
}

/*! \brief The global TE compiler */
// TODO(mbs): To be terminated with extreme prejudice.
TECompiler& TECompiler::Global() {
  static TECompiler* inst = new TECompiler(make_object<TECompilerImpl>(Optional<IRModule>()));
  return *inst;
}
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool);

TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() {
  return TECompiler::Global();
});

TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
    .set_body_typed([](Function source_func, Target target) {
      return CCacheKey(source_func, target);
    });

TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
    .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
      return LoweredOutput(outputs, impl);
    });

TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) {
  self->Clear();
});

TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower")
    .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) {
      return self->Lower(key, mod_name);
    });

TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT")
    .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); });

TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) {
  TECompilerImpl* ptr = dynamic_cast<TECompilerImpl*>(self.operator->());
  ICHECK(ptr != nullptr);
  return ptr->ListItems();
});

using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;

/*!
 * \brief Rewrites call expressions to Relay Functions marked as "primitive"
 * to calls to the corresponding TIR PrimFunc for the appropriate target.
 *
 * \code
 * %0 = fn(...) { prim_op(...) }     OR   let %p = fn(...) { prim_op(...) }
 * ... %0(...) ...                        ... %p(...) ...
 * ==>
 * def @q(..., target=<target>) { <tir body> }
 * ... @q(...) ...
 * \endcode
 *
 * Requires FuseOps, ToANormalForm, EtaExpand and InferType to have run.
 *
 * FuseOps is needed to identify and lift all prim op calls:
 * \code
 * ... prim_op(...) ...
 * ==>
 * %0 = fn(...) { prim_op(...) }
 * ... %0(...) ...
 * \endcode
 *
 * ToANormalForm is needed so we only need to consider vars and function literals as the call
 * target.
 *
 * EtaExpand is needed to ensures all calls to primitives are direct:
 * \code
 * let %p1 = fn(...) { prim_op1(...) }
 * let %p2 = fn(...) { prim_op2(...) }
 * let %p = if (...) { %p1 } else { %p2 }
 * ... %p(...) ...
 * ==>
 * let %p1 = fn(...) { prim_op1(...) }
 * let %p2 = fn(...) { prim_op2(...) }
 * let %p = fn(...) { if (...) { %p1(...) } else { %p2(...) } }
 * ... %p(...) ...
 * \endcode
 */
class LowerTensorExprMutator : public DeviceAwareExprMutator {
 public:
  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
                         TECompiler compiler, VirtualDevice host_virtual_device)
      : DeviceAwareExprMutator(module),
        module_(module),
        process_fn_(std::move(process_fn)),
        module_name_(std::move(module_name)),
        compiler_(std::move(compiler)),
        host_virtual_device_(std::move(host_virtual_device)),
        debug_op_(Op::Get("debug")) {}

  /*!
   *  \brief Returns the primitive function associated with \p expr, or nullptr if none.
   */
  BaseFunc ResolveToPrimitive(const Expr& expr) {
    // NOTE: We can't assume expr->checked_type_ is defined, so can't early exit for first-order
    // expressions.
    if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
      if (!module_->ContainGlobalVar(global_var_node->name_hint)) {
        // TODO(mbs): extern function cleanup
        // Assume the function is extern and thus no longer in the IRModule.
        return {};
      } else {
        BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(global_var_node));
        return ResolveToPrimitive(base_func);
      }
    } else if (const auto* prim_func_node = expr.as<tir::PrimFuncNode>()) {
      return GetRef<tir::PrimFunc>(prim_func_node);
    } else if (const auto* var_node = expr.as<VarNode>()) {
      auto itr = primitive_functions_.find(var_node);
      if (itr == primitive_functions_.end()) {
        // Not bound to a primitive function.
        return {};
      } else {
        return itr->second;
      }
    } else if (const auto* function_node = expr.as<FunctionNode>()) {
      if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
        // Not marked as primitive by FuseOps.
        return {};
      }
      if (const auto* call_node = function_node->body.as<CallNode>()) {
        if (call_node->op == debug_op_) {
          // Debug 'primitives' are not lowered.
          return {};
        }
      }
      return GetRef<Function>(function_node);
    } else {
      return {};
    }
  }

  /*!
   * \brief Lowers the primitive function \p func to TIR for ultimate execution
   * on a device with configuration \p target. Returns the global var bound
   * to the TIR implementation, and attributes to attach to the call to identify it as
   * a TIR call.
   */
  Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span, Target target) {
    CCacheKey key = CCacheKey(func, target);
    CachedFunc cfunc = compiler_->Lower(key, module_name_);
    ICHECK(cfunc.defined());

    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);

    // Add some metadata on top of the *original function* and invoke the callback so it can
    // be captured.
    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
    Map<GlobalVar, tir::PrimFunc> prim_fns;
    Array<GlobalVar> all_prim_fn_vars;
    for (const auto& kv : cfunc->funcs->functions) {
      if (opt_compiler) {
        // We expect just the original func but with just the ExternalSymbol attribute signaling
        // the function (will be) compiled externally.
        ICHECK(kv.second.as<FunctionNode>())
            << PrettyPrint(kv.first) << " must be bound to an (external) Function";
      } else {
        // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive
        // (and the rest in support of that via tir::Calls).
        ICHECK(kv.second.as<tir::PrimFuncNode>())
            << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
        prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
        all_prim_fn_vars.push_back(kv.first);
      }
    }
    Function func_with_metadata = func;
    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var);
    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
    func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target);
    this->process_fn_(func_with_metadata);

    CallLoweredAttrs call_lowered_attrs;

    // Non-External Relay Function
    // TODO(mbs): "reshape" cleanup.
    if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
      call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
    }

    call_lowered_attrs.metadata.Set("relay_attrs", func->attrs);
    call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars);

    if (IsDynamic(func->ret_type)) {
      // Also lower the companion dynamic shape function.
      // Shape function keys use the underlying primitive function as their 'function',
      // but the generic 'cpu' target as the target since all shape functions run
      // on the host cpu irrespective of where the primitive runs.
      CCacheKey shape_key(func, host_virtual_device_->target);
      CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);

      // Capture the shape function's global var and parameters 'states' in call
      // annotations so calling convention can be recovered.
      // TODO(mbs): Shape cleanup.
      call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
      call_lowered_attrs.metadata.Set("prim_shape_fn_states",
                                      lowered_shape_func->shape_func_param_states);
      call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs",
                                      Integer(static_cast<int>(lowered_shape_func->inputs.size())));
      call_lowered_attrs.metadata.Set(
          "prim_shape_fn_num_outputs",
          Integer(static_cast<int>(lowered_shape_func->outputs.size())));
      Array<GlobalVar> all_prim_shape_fn_vars;
      for (const auto& kv : lowered_shape_func->funcs->functions) {
        CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
        all_prim_shape_fn_vars.push_back(kv.first);
      }
      call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
    }

    return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs),
                       std::move(span));
  }

  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
    Var new_var = Downcast<Var>(Mutate(var));
    Expr new_value = Mutate(value);
    BaseFunc prim_func = ResolveToPrimitive(new_value);

    if (prim_func.defined()) {
      // Remember let var is bound (possibly indirectly) to a primitive function.
      primitive_functions_.emplace(var.get(), prim_func);
    }
    return {new_var, new_value};
  }

  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final {
    BaseFunc prim_func = ResolveToPrimitive(post_let_node->value);
    if (prim_func.defined()) {
      // Leaving let var scope
      primitive_functions_.erase(pre_let_node->var.get());
    }
    return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
  }

  Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
    if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
        function_node->GetAttr<String>(attr::kExternalSymbol)) {
      // Nothing to lower inside primitive/external functions.
      return GetRef<Function>(function_node);
    } else {
      return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
    }
  }

  Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
    // We can see five forms of calls:
    //  1. A 'normal' Relay call to a Function with the "primitive" attribute. We will need
    //     to lower that to a global PrimFunc and rewrite the call to:
    //       call_lowered(@new_global, (arg1, ..., argn), <attributes>)
    //     However there are a few special forms which are excluded from this treatment, see
    //     below.
    //  2. A 'normal' Relay call to a Function with the "compiler" attribute. We will need
    //     to invoke the appropriate BYOC toolchain function to yield a runtime module and
    //     rewrite the call to the same form as above.
    //  3. A 'normal' Relay call to a PrimFunc which has already been supplied via a global
    //     definition. We rewrite to use the call_lowered form, but otherwise nothing else
    //     needs to be done.
    //  4. A 'normal' Relay call to a Relay Function without any special attribute. These
    //     calls are not changed.
    //  5. A call_lowered call from an earlier invocation of this pass.
    // Note that ResolveToPrimitive will yield non-null only for cases 1-3.

    // Look for (possibly indirect) calls to primitives.
    BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
    if (!primitive_func.defined()) {
      // Not a call to a primitive function we need to rewrite.
      if (const auto* function_node = call_node->op.as<FunctionNode>()) {
        process_fn_(GetRef<Function>(function_node));
      }
      return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
    }

    // Prepare the arguments.
    Array<Expr> new_args;
    for (const auto& arg : call_node->args) {
      new_args.push_back(VisitExpr(arg));
    }

    // Special case: device_copies are left as calls to primitive operators
    // (thus undoing FuseOps) so that each backend can handle them directly.
    // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone.
    if (const auto* function_node = primitive_func.as<FunctionNode>()) {
      DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body);
      if (device_copy_props.body.defined()) {
        ICHECK_EQ(new_args.size(), 1);
        return DeviceCopy(new_args[0], device_copy_props.src_virtual_device,
                          device_copy_props.dst_virtual_device);
      }
    }

    // Special case: If already lowered by other means then so we don't need to mutate
    // the call but we do need to mutate the arguments
    if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
      // Function should already be Target annotated by this point
      // but the TE Compiler metadata is still needed for the callback
      // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks
      GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
      tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);

      Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
      tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
                                                                  {"prim_fn_var", prim_func_var},
                                                                  {"prim_funcs", prim_fns},
                                                              });

      ICHECK(!IsDynamic(call_node->checked_type()));
      CallLoweredAttrs call_lowered_attrs;
      call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs);

      process_fn_(func_with_metadata);
      ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
      return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs),
                         call_node->span);
    }

    // Typical case: call to fused primitive Relay Function.
    // Find the desired target device.
    Target target;
    if (primitive_func->GetAttr<String>(attr::kCompiler).defined()) {
      // The generic 'external device' target.
      // TODO(mbs): Retire once replaced unified BYOC compiler and target machinery
      target = Target("ext_dev");
    } else {
      // The target corresponding to the call_node expression's annotation.
      VirtualDevice virtual_device = GetVirtualDevice(GetRef<Call>(call_node));
      ICHECK(!virtual_device->IsFullyUnconstrained());
      target = virtual_device->target;
      ICHECK(target.defined());
    }

    // Lower the primitive function for that target.
    Function function = Downcast<Function>(primitive_func);
    ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
    return MakeLoweredCall(function, std::move(new_args), call_node->span, target);
  }

  IRModule module_;
  ProcessFn process_fn_;
  // Map from in-scope let-bound variables to Functions known to be primitive, or PrimFuncs which
  // have already been lowered. We'll rewrite these to the fresh global vars bound to the lowered
  // primitive function as we go. Those vars will be bound in the target device-type specific
  // module we'll ultimately emit for each required device-type. Note that a primitive may be
  // lowered for multiple device types, each which will be assigned a fresh var.
  std::unordered_map<const VarNode*, BaseFunc> primitive_functions_;
  String module_name_;
  TECompiler compiler_;
  /*!
   * \brief The \p VirtualDevice for the host, which is where all shape-related data and computation
   * must live.
   */
  VirtualDevice host_virtual_device_;
  // Cache ops that need to be frequently used later to reduce lookup overhead.
  const Op& debug_op_;
};

Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
  if (targets.size() == 1) {
    // The homogeneous execution case, return the only target.
    const auto& it = targets.begin();
    return (*it).second;
  } else {
    // The heterogeneous execution case, return the target associated with the
    // given device type.
    // If "dev_type" equals to 0, the device name only can be got from
    // "targets", and it may not be "llvm", so here just set it to "unknown".
    std::string dev_name = "unknown";
    if (dev_type != 0) {
      dev_name = runtime::DeviceName(dev_type);
    }

    if (targets.count(dev_type) == 0) {
      std::stringstream msg;
      msg << "No target is specified for provided device name: `" << dev_name << "`\n\n"
          << dev_name << " mapped to device type (" << dev_type
          << ") which was not found in the target map.\n"
          << "Availible targets: \n";
      for (auto target : targets) {
        msg << "  " << target.first << "-> " << target.second << "\n";
      }
      LOG(FATAL) << msg.str();
    }
    return targets[dev_type];
  }
}

Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn,
                     VirtualDevice host_virtual_device) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function func, IRModule module, PassContext ctx) {
        LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler,
                                        host_virtual_device);
        return Downcast<Function>(lower_te.Mutate(func));
      };
  return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
                                              Map<Expr, backend::StorageInfo> storage_info_map) {
  Function func = Downcast<Function>(mod->Lookup("main"));

  VLOG_CONTEXT << "UpdateMainWorkspaceSize";
  VLOG(1) << "calculating FunctionInfo for main:" << std::endl << PrettyPrint(func);
  for (const auto& kv : targets) {
    VLOG(1) << "  target " << kv.first << " = " << kv.second->str();
  }

  // This is a Map<device,Map<storage_id, size>>
  // TODO(mbs): Collapsing VirtualDevices to just device type.
  std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
      sid_workspace;
  // This is a Map<device, size_of_inputs_and_outputs>
  std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io;
  // This is a Map<device, size_of_constants>
  std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts;

  // Initialize the mapping from all storage identifiers to workspace sizes,
  // the amount of device io, and the device constants.
  for (const auto& kv : storage_info_map) {
    const backend::StorageInfo& storage_info = kv.second;
    const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
    const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
    CHECK_EQ(storage_ids.size(), virtual_devices.size());
    for (uint32_t i = 0; i < virtual_devices.size(); i++) {
      DLDeviceType device_type = virtual_devices[i]->device_type();
      sid_workspace[device_type][storage_ids[i]] = 0;
      device_io[device_type] = 0;
      device_consts[device_type] = 0;
    }
  }

  // Iterate the storage map to compute all the tensor sizes in the program.
  // There are 3 cases in this code:
  //
  // First we need to compute the sizes of all
  // inline constants.
  //
  // Second we compute the size of any bound variable as these are input and output
  // sizes of the program.
  //
  // Finally for all other expressions we check which storage identifier they have
  // been assigned and we compute the maximal size of the storage, as tensors can
  // share storage with other tensors which are the same size or larger.
  //
  // In this final case there is only one allocation for all tensors which share storage
  // which will be the maximal size of all tensors which were assigned to it.
  for (const auto& kv : storage_info_map) {
    const Expr& expr = kv.first;
    const backend::StorageInfo& storage_info = kv.second;
    int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type());
    VLOG(1) << "expression:" << std::endl
            << PrettyPrint(expr) << std::endl
            << "of type:" << std::endl
            << PrettyPrint(expr->checked_type()) << std::endl
            << "has size " << size_bytes << " and storage info:" << std::endl
            << storage_info;
    const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
    const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;

    if (expr->IsInstance<ConstantNode>()) {
      for (const auto& virtual_device : virtual_devices) {
        DLDeviceType device_type = virtual_device->device_type();
        ICHECK_EQ(device_consts.count(device_type), 1);
        device_consts[device_type] += size_bytes;
      }
    } else if (expr->IsInstance<VarNode>() || expr.same_as(func->body)) {
      CHECK_GE(virtual_devices.size(), 1) << "must be at least one device";
      for (const auto& virtual_device : virtual_devices) {
        DLDeviceType device_type = virtual_device->device_type();
        device_io[device_type] += size_bytes;
      }
    } else {
      // TODO(@electriclilies): This code is never being called which means sid_workspace is not
      // updated.. This means that storage info is probably not being created correctly. Or is not
      // equivalent to what was here previously
      for (uint32_t i = 0; i < storage_ids.size(); i++) {
        // Here we record the largest size of the tensor
        // that share the same storage id, because storage_id will
        // be shared between multiple tensors that are not live simultaneously.
        DLDeviceType device_type = virtual_devices[i]->device_type();
        if (size_bytes > sid_workspace[device_type][storage_ids[i]]) {
          sid_workspace[device_type][storage_ids[i]] = size_bytes;
        }
      }
    }
  }

  // This is a Map<device, workspace_size>
  std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace;
  // Once we know the sizes of sids, we need to accumulate per device
  for (const auto& dev_sid_size : sid_workspace) {
    auto dev = dev_sid_size.first;
    device_workspace[dev] = 0;
    for (const auto& sid_size : dev_sid_size.second) {
      device_workspace[dev] += sid_size.second;
    }
  }

  Map<Target, Integer> workspace_sizes;
  Map<Target, Integer> io_sizes;
  Map<Target, Integer> constant_sizes;
  Map<Target, tir::PrimFunc> tir_primfuncs;
  Map<Target, Function> relay_primfuncs;

  // Initialize all target workspaces to zero
  for (const auto& kv : targets) {
    auto tgt = kv.second;
    workspace_sizes.Set(tgt, 0);
  }

  for (const auto& dev_and_size : device_workspace) {
    auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
    workspace_sizes.Set(tgt, dev_and_size.second);
    relay_primfuncs.Set(tgt, func);
  }
  for (const auto& dev_and_size : device_io) {
    auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
    io_sizes.Set(tgt, dev_and_size.second);
  }

  for (const auto& dev_and_size : device_consts) {
    auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
    ICHECK_EQ(constant_sizes.count(tgt), 0);
    constant_sizes.Set(tgt, dev_and_size.second);
  }

  backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes),
                                  std::move(constant_sizes), std::move(tir_primfuncs),
                                  std::move(relay_primfuncs));
  VLOG(1) << "func_info: " << func_info;
  return std::move(func_info);
}

/*!
 * \brief A function to create the function metadata for an input function (ie calculate buffer
 * input/output sizes)
 * \param func The function to calculate function metadata for
 * \param function_metadata The map that stores all the function metadatas
 */
void UpdateFunctionMetadata(BaseFunc func,
                            Map<String, backend::FunctionInfo>& function_metadata,  // NOLINT(*)
                            Integer workspace_byte_alignment) {
  VLOG_CONTEXT << "UpdateFunctionMetadata";
  VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func);
  // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored
  // there Now the goal is to take only one func because process_fn should be controlling the
  // iteration However, to do the workspace calculations we need the primfuncs. So process_fn
  // needs to either access the cached funcs or be directly passed primfuncs This is bad and
  // ideally we don't want process_fn to look at primfuncs There's also the question now of what
  // the function metadatas are and how they are used if we can do something else to replicate the
  // behavior of the function metadatas that might be good (ie annotating functions or something).
  Map<Target, Integer> workspace_sizes;
  Map<Target, Integer> io_sizes;
  Map<Target, Integer> constant_sizes;
  Map<Target, tir::PrimFunc> tir_primfuncs;
  Map<Target, Function> relay_primfuncs;

  Optional<Map<GlobalVar, tir::PrimFunc>> prim_fns =
      func->GetAttr<Map<GlobalVar, tir::PrimFunc>>("prim_funcs");
  CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler.";

  Optional<GlobalVar> prim_fn_var = func->GetAttr<GlobalVar>("prim_fn_var");
  CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler.";

  Optional<Target> relay_target = func->GetAttr<Target>(tvm::attr::kTarget);
  CHECK(relay_target) << "target must be set on Relay functions by the TECompiler.";

  for (const auto& kv : prim_fns.value()) {
    auto prim_fn = Downcast<tir::PrimFunc>(kv.second);
    CHECK(prim_fn.defined()) << "the primitive function must be defined";

    Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment);

    // Workspace sizes
    Target prim_fn_target;
    if (prim_fn->attrs->dict.count(tvm::attr::kTarget)) {
      prim_fn_target = Downcast<Target>(prim_fn->attrs->dict[tvm::attr::kTarget]);
    } else {
      prim_fn_target = relay_target.value();
    }

    workspace_sizes.Set(prim_fn_target, workspace_size);

    // Calculating size for I/O
    // TODO(mbs): See also the other three utils for calculating tensor bytesize.
    for (auto const& param : prim_fn->params) {
      bool not_a_buffer = prim_fn->buffer_map.count(param) == 0;
      if (not_a_buffer) {
        io_sizes.Set(prim_fn_target, 0);
        continue;
      }

      auto p_shape = prim_fn->buffer_map[param]->shape;
      int num_of_elements = 1;
      for (const auto& dim_index_expr : p_shape) {
        if (dim_index_expr->IsInstance<IntImmNode>()) {
          num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
        } else {
          // If shape is dynamic, we cannot calculate workspace in compile time.
          num_of_elements = 0;
        }
      }
      int element_size = prim_fn->buffer_map[param]->dtype.bytes();
      io_sizes.Set(prim_fn_target, element_size * num_of_elements);
    }

    constant_sizes.Set(prim_fn_target, 0);
    tir_primfuncs.Set(prim_fn_target, prim_fn);
    if (func->IsInstance<FunctionNode>()) {
      relay_primfuncs.Set(prim_fn_target, Downcast<Function>(func));
    }
  }

  backend::FunctionInfo fi = backend::FunctionInfo(
      std::move(workspace_sizes), std::move(io_sizes), std::move(constant_sizes),
      std::move(tir_primfuncs), std::move(relay_primfuncs));

  VLOG(1) << "FunctionInfo: " << PrettyPrint(prim_fn_var.value()) << " = " << PrettyPrint(fi);

  // The primitive function name here corresponds to the string we will use to generate
  // this Relay function at the low level.
  function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn,
                 VirtualDevice host_virtual_device) {
  TECompiler compiler(module);

  // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten
  // module as we go (including rewritten Functions, lowered primitives, and runtime modules
  // generated by external toolchains), and use a pair of maps over vars and global vars
  // to global vars to remember which functions have already been lowered.

  // Lower all the callees in module:
  //  - Functions tagged with "Compiler" are unchanged (checked by CreateFunctionPass)
  //  - Functions tagged with "Primitive" are unchanged (checked by LowerTensorExprMutator)
  //  - Called functions tagged with "Compiler" are copied into the compiler cache with a fresh
  //    GlobalVar, and calls updated (sticking with regular Relay Call).
  //  - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated
  //    (using call_lowered convention).
  IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn),
                                            std::move(host_virtual_device))(module);

  // The Functions tagged with "Compiler" are now residing in the cache ready to be
  // compiled by LowerExternalFunctions. However we still need a record of them in the
  // IRModule so that the various executors can see which function names need to be
  // retrieved. They may, however, have been renamed.
  compiler->AddExterns(updated_module);

  // Add the lowered functions.
  IRModule lowered_module = compiler->GetLoweredFunctions();
  VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered functions";
  for (const auto& kv : lowered_module->functions) {
    if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
      LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
                 << "'. Existing is:" << std::endl
                 << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) << std::endl
                 << "while new is:" << std::endl
                 << PrettyPrint(kv.second);
    }
    updated_module->Add(kv.first, kv.second);
  }

  // Invoke external codegen for all Functions in the cache tagged with "Compiler", and
  // annotate the module with the resulting runtime modules.
  // TODO(mbs): runtime modules should be first class rather than attributes.
  Array<runtime::Module> external_mods =
      module->GetAttr<Array<runtime::Module>>("external_mods", Array<runtime::Module>()).value();
  Array<runtime::Module> new_external_mods = compiler->LowerExternalFunctions();
  VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size()
          << " new external modules";
  for (const auto& mod : new_external_mods) {
    external_mods.push_back(mod);  // copy-on-write.
  }

  // Annotate the module with C Device API context mapping (this is until we have Targets
  // annotated for the C Device API)
  // TODO(Mousius) - Remove "device_contexts" as soon as we have the graph annotated properly with
  // Targets
  Map<GlobalVar, String> device_contexts =
      module->GetAttr<Map<GlobalVar, String>>("device_contexts", Map<GlobalVar, String>()).value();
  Map<GlobalVar, String> new_device_contexts = compiler->GetDeviceContexts();
  VLOG(1) << "capturing " << device_contexts.size() << " existing and "
          << new_device_contexts.size() << " new device contexts for external functions";
  for (const auto& kv : new_device_contexts) {
    ICHECK_EQ(device_contexts.count(kv.first), 0);
    device_contexts.Set(kv.first, kv.second);  // copy-on-write.
  }

  updated_module = WithAttrs(updated_module, {{"external_mods", std::move(external_mods)},
                                              {"device_contexts", std::move(device_contexts)}});

  if (backend::IsAutoSchedulerEnabled()) {
    // Capture all the 'operator weights', ie usage counts for each PrimFunc.
    Map<String, Integer> op_weights =
        module->GetAttr<Map<String, Integer>>("op_weights", Map<String, Integer>()).value();
    Map<String, Integer> new_op_weights = compiler->GetOpWeights();
    VLOG(1) << "capturing " << op_weights.size() << " existing and " << new_op_weights.size()
            << " new operator weights for PrimFuncs";
    for (const auto& kv : new_op_weights) {
      ICHECK_EQ(op_weights.count(kv.first), 0);
      op_weights.Set(kv.first, kv.second);  // copy-on-write.
    }
    updated_module = WithAttr(updated_module, "op_weights", std::move(op_weights));
  }

  return updated_module;
}

Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
  std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
      per_target_modules;
  for (const auto& kv : mod->functions) {
    const GlobalVar& var = kv.first;
    const BaseFunc& func = kv.second;
    if (func->IsInstance<tir::PrimFuncNode>()) {
      // Extract target
      Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
      ICHECK(target) << "Target should be set at this point";

      // Put the function in per_target_modules
      if (!per_target_modules.count(target.value())) {
        // Initialize the IRModule for this target with the attributes from the input IRModule
        IRModule target_module = IRModule({}, {}, {}, {}, mod->attrs);
        // Add the function to the IRModule
        target_module->Add(var, func);
        per_target_modules[target.value()] = target_module;
      } else {
        // The IRModule for this target is initialized, so just add the function.
        IRModule target_module = per_target_modules.at(target.value());
        target_module->Add(var, func);
      }
    } else if (!func->IsInstance<relay::FunctionNode>()) {
      LOG(FATAL)
          << "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
          << func->GetTypeKey();
    }
  }
  return per_target_modules;
}

Pass LowerTEPass(const String& module_name, ProcessFn process_fn,
                 VirtualDevice host_virtual_device) {
  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
                                                                            PassContext ctx) {
    return LowerTE(module, module_name, process_fn, host_virtual_device);
  };

  return tvm::transform::Sequential(
      {tvm::relay::transform::RelayToTIRTargetHook(),
       tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType()});
}
}  // namespace tec
}  // namespace relay
}  // namespace tvm