#include "dlr_treelite.h" #include int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); #ifndef _WIN32 testing::FLAGS_gtest_death_test_style = "threadsafe"; #endif // _WIN32 return RUN_ALL_TESTS(); } class TreeliteTest : public ::testing::Test { protected: const int64_t in_size = 69; std::vector data{std::vector(in_size)}; const int in_dim = 2; const int64_t in_shape[2] = {1, 69}; const int64_t out_size = 1; const int out_dim = 2; dlr::TreeliteModel* model; TreeliteTest() { // Setup input data for (auto i = 0; i < in_size; i++) { data[i] = static_cast(rand()) / static_cast(RAND_MAX); } // Instantiate model const int device_type = 1; const int device_id = 0; DLDevice dev = {static_cast(device_type), device_id}; std::vector paths = {"./xgboost_test"}; std::vector files = dlr::FindFiles(paths); model = new dlr::TreeliteModel(files, dev); } ~TreeliteTest() { delete model; } }; TEST_F(TreeliteTest, TestGetNumInputs) { EXPECT_EQ(model->GetNumInputs(), 1); } TEST_F(TreeliteTest, TestGetInputName) { EXPECT_STREQ(model->GetInputName(0), "data"); } TEST_F(TreeliteTest, TestGetInputType) { EXPECT_STREQ(model->GetInputType(0), "float32"); } TEST_F(TreeliteTest, TestGetInputSize) { EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); EXPECT_EQ(model->GetInputSize(0), in_size); } TEST_F(TreeliteTest, TestGetInputDim) { EXPECT_EQ(model->GetInputDim(0), in_dim); } TEST_F(TreeliteTest, TestGetInputShape) { std::vector shape{-1, in_size}; EXPECT_EQ(model->GetInputShape(0), shape); EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); shape.assign(std::begin(in_shape), std::end(in_shape)); EXPECT_EQ(model->GetInputShape(0), shape); } TEST_F(TreeliteTest, TestGetInput) { EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); try { std::vector observed_input_data(in_size); model->GetInput("data", observed_input_data.data()); EXPECT_EQ(observed_input_data, data); } catch (const dmlc::Error& e) { EXPECT_STREQ(e.what(), "GetInput is not supported by Treelite backend."); } } TEST_F(TreeliteTest, TestGetNumOutputs) { EXPECT_EQ(model->GetNumOutputs(), 1); } TEST_F(TreeliteTest, TestGetOutputType) { EXPECT_STREQ(model->GetOutputType(0), "float32"); } TEST_F(TreeliteTest, TestGetOutputShape) { int64_t out_shape[2]; // input is not set - output batch and size are unknown. EXPECT_NO_THROW(model->GetOutputShape(0, out_shape)); EXPECT_EQ(out_shape[0], -1); EXPECT_EQ(out_shape[1], 1); // Set input EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); // Check OutputShape again EXPECT_NO_THROW(model->GetOutputShape(0, out_shape)); EXPECT_EQ(out_shape[0], 1); EXPECT_EQ(out_shape[1], 1); } TEST_F(TreeliteTest, TestGetOutputSizeDim) { int64_t output_size; int output_dim; EXPECT_NO_THROW(model->GetOutputSizeDim(0, &output_size, &output_dim)); // input is not set - output batch and size are unknown. EXPECT_EQ(output_size, -1); EXPECT_EQ(output_dim, out_dim); // Set input EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); // Check OutputSizeDim again EXPECT_NO_THROW(model->GetOutputSizeDim(0, &output_size, &output_dim)); EXPECT_EQ(output_size, out_size); EXPECT_EQ(output_dim, out_dim); } TEST_F(TreeliteTest, TestGetOutput) { EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); EXPECT_NO_THROW(model->Run()); float output[1]; EXPECT_NO_THROW(model->GetOutput(0, output)); float* output_p; EXPECT_NO_THROW(output_p = (float*)model->GetOutputPtr(0)); EXPECT_EQ(output_p[0], output[0]); } TEST_F(TreeliteTest, TestScoreOutput) { EXPECT_NO_THROW(model->SetInput("data", in_shape, data.data(), in_dim)); EXPECT_NO_THROW(model->SetPredMargin(true)); EXPECT_NO_THROW(model->Run()); float output[1]; EXPECT_NO_THROW(model->GetOutput(0, output)); float* output_p; EXPECT_NO_THROW(output_p = (float*)model->GetOutputPtr(0)); EXPECT_EQ(output_p[0], output[0]); }