#ifndef DLR_RELAYVM_H_ #define DLR_RELAYVM_H_ #include #include #include #include #include #include #include #include #include #include "dlr_common.h" #ifdef ENABLE_DATATRANSFORM #include "dlr_data_transform.h" #endif #ifdef _WIN32 #define LIBEXT ".dll" #define LIBDLR "dlr.dll" #elif __APPLE__ #define LIBEXT ".dylib" #define LIBDLR "libdlr.dylib" #else #define LIBEXT ".so" #define LIBDLR "libdlr.so" #endif #if defined(_MSC_VER) || defined(_WIN32) #define DLR_DLL __declspec(dllexport) #else #define DLR_DLL #endif // defined(_MSC_VER) || defined(_WIN32) namespace dlr { class DLR_DLL RelayVMModel : public DLRModel { private: static const std::string ENTRY_FUNCTION; std::vector output_names_; std::vector output_types_; std::shared_ptr vm_module_; std::shared_ptr vm_executable_; std::vector inputs_; tvm::runtime::ObjectRef output_ref_; std::vector outputs_; std::vector> output_shapes_; const tvm::runtime::NDArray empty_; tvm::runtime::vm::AllocatorType allocator_type_; #ifdef ENABLE_DATATRANSFORM DataTransform data_transform_; #endif void SetupVMModule(const std::vector& paths); void SetupVMModule(const std::vector& model_elems); void FetchInputNodesData(); void FetchOutputNodesData(); void UpdateOutputs(); void UpdateInputs(); DLDataType GetInputDLDataType(int index); public: explicit RelayVMModel(const std::vector& files, const DLDevice& dev) : DLRModel(dev, DLRBackend::kRELAYVM), allocator_type_(tvm::runtime::vm::AllocatorType::kPooled) { SetupVMModule(files); FetchInputNodesData(); FetchOutputNodesData(); } explicit RelayVMModel(std::vector model_elems, const DLDevice& dev) : DLRModel(dev, DLRBackend::kRELAYVM), allocator_type_(tvm::runtime::vm::AllocatorType::kPooled) { SetupVMModule(model_elems); FetchInputNodesData(); FetchOutputNodesData(); } int GetInputIndex(const char* name) const; virtual const int GetInputDim(int index) const override; virtual const int64_t GetInputSize(int index) const override; virtual const char* GetInputName(int index) const override; virtual const char* GetInputType(int index) const override; virtual const char* GetWeightName(int index) const override; virtual std::vector GetWeightNames() const override; virtual void GetInput(const char* name, void* input) override; virtual void SetInput(const char* name, const int64_t* shape, const void* input, int dim) override; void SetInputTensor(const char* name, DLTensor* tensor); virtual int GetNumInputs() const override; virtual void Run() override; tvm::runtime::NDArray GetOutput(int index); virtual void GetOutput(int index, void* out) override; void GetOutputManagedTensorPtr(int index, const DLManagedTensor** out); virtual const void* GetOutputPtr(int index) const override; virtual void GetOutputShape(int index, int64_t* shape) const override; virtual void GetOutputSizeDim(int index, int64_t* size, int* dim) override; virtual const char* GetOutputType(int index) const override; void GetOutputTensor(int index, DLTensor* out); virtual void SetNumThreads(int threads) override; virtual void UseCPUAffinity(bool use) override; tvm::runtime::vm::AllocatorType GetAllocatorType(); /* Following methods use metadata file to lookup input and output names. */ virtual const char* GetOutputName(const int index) const override; virtual int GetOutputIndex(const char* name) const override; virtual void GetOutputByName(const char* name, void* out) override; }; } // namespace dlr #endif // DLR_RELAYVM_H_