#include #include "dlr.h" #include "dlr_tvm.h" #include "test_utils.hpp" class TVMElemTest : public ::testing::Test { protected: std::vector img; const int batch_size = 1; size_t img_size = 224 * 224 * 3; const int64_t input_shape[4] = {1, 224, 224, 3}; const int input_dim = 4; const std::string graph_file = "./resnet_v1_5_50/compiled_model.json"; const std::string params_file = "./resnet_v1_5_50/compiled.params"; const std::string so_file = "./resnet_v1_5_50/compiled.so"; const std::string meta_file = "./resnet_v1_5_50/compiled.meta"; int device_type = 1; int device_id = 0; DLDevice dev = {static_cast(device_type), device_id}; dlr::TVMModel* model; TVMElemTest() { std::string graph_str = dlr::LoadFileToString(graph_file); std::string params_str = dlr::LoadFileToString(params_file, std::ios::in | std::ios::binary); std::string meta_str = dlr::LoadFileToString(meta_file); std::vector model_elems = { {DLRModelElemType::TVM_GRAPH, nullptr, graph_str.c_str(), 0}, {DLRModelElemType::TVM_PARAMS, nullptr, params_str.data(), params_str.size()}, {DLRModelElemType::TVM_LIB, so_file.c_str(), nullptr, 0}, {DLRModelElemType::NEO_METADATA, nullptr, meta_str.c_str(), 0}}; // Setup input data img = LoadImageAndPreprocess("cat224-3.txt", img_size, batch_size); // Instantiate model model = new dlr::TVMModel(model_elems, dev); } ~TVMElemTest() { delete model; } }; TEST_F(TVMElemTest, TestCreateModel_LibTvmIsPointer) { std::string so_data = dlr::LoadFileToString(so_file, std::ios::in | std::ios::binary); std::string graph_str = dlr::LoadFileToString(graph_file); std::string params_str = dlr::LoadFileToString(params_file, std::ios::in | std::ios::binary); std::vector model_elems = { {DLRModelElemType::TVM_GRAPH, nullptr, graph_str.c_str(), 0}, {DLRModelElemType::TVM_PARAMS, nullptr, params_str.data(), params_str.size()}, {DLRModelElemType::TVM_LIB, nullptr, so_file.data(), so_file.size()}}; EXPECT_THROW( { try { new dlr::TVMModel(model_elems, dev); } catch (const dmlc::Error& e) { EXPECT_STREQ(e.what(), "Invalid TVM model element TVM_LIB. TVM_LIB must be a file path."); throw; } }, dmlc::Error); } TEST_F(TVMElemTest, TestCreateModel_GraphIsMissing) { std::string params_str = dlr::LoadFileToString(params_file, std::ios::in | std::ios::binary); std::vector model_elems = { {DLRModelElemType::TVM_PARAMS, nullptr, params_str.data(), params_str.size()}, {DLRModelElemType::TVM_LIB, so_file.c_str(), nullptr, 0}}; EXPECT_THROW( { try { new dlr::TVMModel(model_elems, dev); } catch (const dmlc::Error& e) { EXPECT_STREQ(e.what(), "Invalid TVM model. Must have TVM_GRAPH, TVM_PARAMS and TVM_LIB elements"); throw; } }, dmlc::Error); } TEST_F(TVMElemTest, TestGetNumInputs) { EXPECT_EQ(model->GetNumInputs(), 1); } TEST_F(TVMElemTest, TestGetInput) { EXPECT_NO_THROW(model->SetInput("input_tensor", input_shape, img.data(), input_dim)); std::vector observed_input_data(img_size); EXPECT_NO_THROW(model->GetInput("input_tensor", observed_input_data.data())); EXPECT_EQ(img, observed_input_data); } TEST_F(TVMElemTest, TestGetInputShape) { std::vector in_shape(std::begin(input_shape), std::end(input_shape)); EXPECT_EQ(model->GetInputShape(0), in_shape); } TEST_F(TVMElemTest, TestGetInputSize) { EXPECT_EQ(model->GetInputSize(0), 1 * 224 * 224 * 3); } TEST_F(TVMElemTest, TestGetInputDim) { EXPECT_EQ(model->GetInputDim(0), 4); } int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); #ifndef _WIN32 testing::FLAGS_gtest_death_test_style = "threadsafe"; #endif // _WIN32 return RUN_ALL_TESTS(); }