/*! * Copyright (c) 2017 by Contributors * \file predictor.cc * \author Philip Cho * \brief Load prediction function exported as a shared library */ #include <treelite/predictor.h> #include <treelite/common.h> #include <dmlc/logging.h> #include <dmlc/io.h> #include <dmlc/timer.h> #include <cstdint> #include <algorithm> #include <memory> #include <fstream> #include <limits> #include <functional> #include <type_traits> #include "common/math.h" #include "common/filesystem.h" #include "thread_pool/thread_pool.h" #ifdef _WIN32 #define NOMINMAX #include <windows.h> #else #include <dlfcn.h> #endif namespace { enum class InputType : uint8_t { kSparseBatch = 0, kDenseBatch = 1 }; struct InputToken { InputType input_type; const void* data; // pointer to input data bool pred_margin; // whether to store raw margin or transformed scores size_t num_feature; // # features (columns) accepted by the tree ensemble model size_t num_output_group; // size of output per instance (row) treelite::Predictor::PredFuncHandle pred_func_handle; size_t rbegin, rend; // range of instances (rows) assigned to each worker float* out_pred; // buffer to store output from each worker }; struct OutputToken { size_t query_result_size; }; inline std::string GetProtocol(const char* name) { const char *p = std::strstr(name, "://"); if (p == NULL) { return ""; } else { return std::string(name, p - name + 3); } } using PredThreadPool = treelite::ThreadPool<InputToken, OutputToken, treelite::Predictor>; inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) { #ifdef _WIN32 HMODULE handle = LoadLibraryA(name); #else void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL); #endif return static_cast<treelite::Predictor::LibraryHandle>(handle); } inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) { #ifdef _WIN32 FreeLibrary(static_cast<HMODULE>(handle)); #else dlclose(static_cast<void*>(handle)); #endif } template <typename HandleType> inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle, const char* name) { #ifdef _WIN32 FARPROC func_handle = GetProcAddress(static_cast<HMODULE>(lib_handle), name); #else void* func_handle = dlsym(static_cast<void*>(lib_handle), name); #endif return static_cast<HandleType>(func_handle); } template <typename PredFunc> inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature, size_t rbegin, size_t rend, float* out_pred, PredFunc func) { CHECK_LE(batch->num_col, num_feature); std::vector<TreelitePredictorEntry> inst( std::max(batch->num_col, num_feature), {-1}); CHECK(rbegin < rend && rend <= batch->num_row); CHECK(sizeof(size_t) < sizeof(int64_t) || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max()) && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max()))); const int64_t rbegin_ = static_cast<int64_t>(rbegin); const int64_t rend_ = static_cast<int64_t>(rend); const size_t num_col = batch->num_col; const float* data = batch->data; const uint32_t* col_ind = batch->col_ind; const size_t* row_ptr = batch->row_ptr; size_t total_output_size = 0; for (int64_t rid = rbegin_; rid < rend_; ++rid) { const size_t ibegin = row_ptr[rid]; const size_t iend = row_ptr[rid + 1]; for (size_t i = ibegin; i < iend; ++i) { inst[col_ind[i]].fvalue = data[i]; } total_output_size += func(rid, &inst[0], out_pred); for (size_t i = ibegin; i < iend; ++i) { inst[col_ind[i]].missing = -1; } } return total_output_size; } template <typename PredFunc> inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature, size_t rbegin, size_t rend, float* out_pred, PredFunc func) { const bool nan_missing = treelite::common::math::CheckNAN(batch->missing_value); CHECK_LE(batch->num_col, num_feature); std::vector<TreelitePredictorEntry> inst( std::max(batch->num_col, num_feature), {-1}); CHECK(rbegin < rend && rend <= batch->num_row); CHECK(sizeof(size_t) < sizeof(int64_t) || (rbegin <= static_cast<size_t>(std::numeric_limits<int64_t>::max()) && rend <= static_cast<size_t>(std::numeric_limits<int64_t>::max()))); const int64_t rbegin_ = static_cast<int64_t>(rbegin); const int64_t rend_ = static_cast<int64_t>(rend); const size_t num_col = batch->num_col; const float missing_value = batch->missing_value; const float* data = batch->data; const float* row; size_t total_output_size = 0; for (int64_t rid = rbegin_; rid < rend_; ++rid) { row = &data[rid * num_col]; for (size_t j = 0; j < num_col; ++j) { if (treelite::common::math::CheckNAN(row[j])) { CHECK(nan_missing) << "The missing_value argument must be set to NaN if there is any " << "NaN in the matrix."; } else if (nan_missing || row[j] != missing_value) { inst[j].fvalue = row[j]; } } total_output_size += func(rid, &inst[0], out_pred); for (size_t j = 0; j < num_col; ++j) { inst[j].missing = -1; } } return total_output_size; } template <typename BatchType> inline size_t PredictBatch_(const BatchType* batch, bool pred_margin, size_t num_feature, size_t num_output_group, treelite::Predictor::PredFuncHandle pred_func_handle, size_t rbegin, size_t rend, size_t expected_query_result_size, float* out_pred) { CHECK(pred_func_handle != nullptr) << "A shared library needs to be loaded first using Load()"; /* Pass the correct prediction function to PredLoop. We also need to specify how the function should be called. */ size_t query_result_size; // Dimension of output vector: // can be either [num_data] or [num_class]*[num_data]. // Note that size of prediction may be smaller than out_pred (this occurs // when pred_function is set to "max_index"). if (num_output_group > 1) { // multi-class classification task using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle); query_result_size = PredLoop(batch, num_feature, rbegin, rend, out_pred, [pred_func, num_output_group, pred_margin] (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t { return pred_func(inst, static_cast<int>(pred_margin), &out_pred[rid * num_output_group]); }); } else { // every other task using PredFunc = float (*)(TreelitePredictorEntry*, int); PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle); query_result_size = PredLoop(batch, num_feature, rbegin, rend, out_pred, [pred_func, pred_margin] (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t { out_pred[rid] = pred_func(inst, static_cast<int>(pred_margin)); return 1; }); } return query_result_size; } inline size_t PredictInst_(TreelitePredictorEntry* inst, bool pred_margin, size_t num_output_group, treelite::Predictor::PredFuncHandle pred_func_handle, size_t expected_query_result_size, float* out_pred) { CHECK(pred_func_handle != nullptr) << "A shared library needs to be loaded first using Load()"; size_t query_result_size; // Dimention of output vector if (num_output_group > 1) { // multi-class classification task using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle); query_result_size = pred_func(inst, (int)pred_margin, out_pred); } else { // every other task using PredFunc = float (*)(TreelitePredictorEntry*, int); PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle); out_pred[0] = pred_func(inst, (int)pred_margin); query_result_size = 1; } return query_result_size; } } // anonymous namespace namespace treelite { Predictor::Predictor(int num_worker_thread) : lib_handle_(nullptr), num_output_group_query_func_handle_(nullptr), num_feature_query_func_handle_(nullptr), pred_func_handle_(nullptr), thread_pool_handle_(nullptr), num_worker_thread_(num_worker_thread), tempdir_(nullptr) {} Predictor::~Predictor() { Free(); } void Predictor::Load(const char* name) { const std::string protocol = GetProtocol(name); if (protocol == "file://" || protocol.empty()) { // local file lib_handle_ = OpenLibrary(name); } else { // remote file tempdir_.reset(new common::filesystem::TemporaryDirectory()); temp_libfile_ = tempdir_->AddFile(common::filesystem::GetBasename(name)); { std::unique_ptr<dmlc::Stream> strm(dmlc::Stream::Create(name, "r")); dmlc::istream is(strm.get()); std::ofstream of(temp_libfile_); of << is.rdbuf(); } lib_handle_ = OpenLibrary(temp_libfile_.c_str()); } if (lib_handle_ == nullptr) { LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'"; } /* 1. query # of output groups */ num_output_group_query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_output_group"); using QueryFunc = size_t (*)(void); QueryFunc query_func = reinterpret_cast<QueryFunc>(num_output_group_query_func_handle_); CHECK(query_func != nullptr) << "Dynamic shared library `" << name << "' does not contain valid get_num_output_group() function"; num_output_group_ = query_func(); /* 2. query # of features */ num_feature_query_func_handle_ = LoadFunction<QueryFuncHandle>(lib_handle_, "get_num_feature"); query_func = reinterpret_cast<QueryFunc>(num_feature_query_func_handle_); CHECK(query_func != nullptr) << "Dynamic shared library `" << name << "' does not contain valid get_num_feature() function"; num_feature_ = query_func(); CHECK_GT(num_feature_, 0) << "num_feature cannot be zero"; /* 3. load appropriate function for margin prediction */ CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero"; if (num_output_group_ > 1) { // multi-class classification pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict_multiclass"); using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_); CHECK(pred_func != nullptr) << "Dynamic shared library `" << name << "' does not contain valid predict_multiclass() function"; } else { // everything else pred_func_handle_ = LoadFunction<PredFuncHandle>(lib_handle_, "predict"); using PredFunc = float (*)(TreelitePredictorEntry*, int); PredFunc pred_func = reinterpret_cast<PredFunc>(pred_func_handle_); CHECK(pred_func != nullptr) << "Dynamic shared library `" << name << "' does not contain valid predict() function"; } if (num_worker_thread_ == -1) { num_worker_thread_ = std::thread::hardware_concurrency(); } thread_pool_handle_ = static_cast<ThreadPoolHandle>( new PredThreadPool(num_worker_thread_ - 1, this, [](SpscQueue<InputToken>* incoming_queue, SpscQueue<OutputToken>* outgoing_queue, const Predictor* predictor) { InputToken input; while (incoming_queue->Pop(&input)) { size_t query_result_size; const size_t rbegin = input.rbegin; const size_t rend = input.rend; switch (input.input_type) { case InputType::kSparseBatch: { const CSRBatch* batch = static_cast<const CSRBatch*>(input.data); query_result_size = PredictBatch_(batch, input.pred_margin, input.num_feature, input.num_output_group, input.pred_func_handle, rbegin, rend, predictor->QueryResultSize(batch, rbegin, rend), input.out_pred); } break; case InputType::kDenseBatch: { const DenseBatch* batch = static_cast<const DenseBatch*>(input.data); query_result_size = PredictBatch_(batch, input.pred_margin, input.num_feature, input.num_output_group, input.pred_func_handle, rbegin, rend, predictor->QueryResultSize(batch, rbegin, rend), input.out_pred); } break; } outgoing_queue->Push(OutputToken{query_result_size}); } })); } void Predictor::Free() { CloseLibrary(lib_handle_); delete static_cast<PredThreadPool*>(thread_pool_handle_); } template <typename BatchType> static inline std::vector<size_t> SplitBatch(const BatchType* batch, size_t split_factor) { const size_t num_row = batch->num_row; CHECK_LE(split_factor, num_row); const size_t portion = num_row / split_factor; const size_t remainder = num_row % split_factor; std::vector<size_t> workload(split_factor, portion); std::vector<size_t> row_ptr(split_factor + 1, 0); for (size_t i = 0; i < remainder; ++i) { ++workload[i]; } size_t accum = 0; for (size_t i = 0; i < split_factor; ++i) { accum += workload[i]; row_ptr[i + 1] = accum; } return row_ptr; } template <typename BatchType> inline size_t Predictor::PredictBatchBase_(const BatchType* batch, int verbose, bool pred_margin, float* out_result) { static_assert(std::is_same<BatchType, DenseBatch>::value || std::is_same<BatchType, CSRBatch>::value, "PredictBatchBase_: unrecognized batch type"); const double tstart = dmlc::GetTime(); PredThreadPool* pool = static_cast<PredThreadPool*>(thread_pool_handle_); const InputType input_type = std::is_same<BatchType, CSRBatch>::value ? InputType::kSparseBatch : InputType::kDenseBatch; InputToken request{input_type, static_cast<const void*>(batch), pred_margin, num_feature_, num_output_group_, pred_func_handle_, 0, batch->num_row, out_result}; OutputToken response; CHECK_GT(batch->num_row, 0); const int nthread = std::min(num_worker_thread_, static_cast<int>(batch->num_row)); const std::vector<size_t> row_ptr = SplitBatch(batch, nthread); for (int tid = 0; tid < nthread - 1; ++tid) { request.rbegin = row_ptr[tid]; request.rend = row_ptr[tid + 1]; pool->SubmitTask(tid, request); } size_t total_size = 0; { // assign work to master const size_t rbegin = row_ptr[nthread - 1]; const size_t rend = row_ptr[nthread]; const size_t query_result_size = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_, pred_func_handle_, rbegin, rend, QueryResultSize(batch, rbegin, rend), out_result); total_size += query_result_size; } for (int tid = 0; tid < nthread - 1; ++tid) { if (pool->WaitForTask(tid, &response)) { total_size += response.query_result_size; } } // re-shape output if total_size < dimension of out_result if (total_size < QueryResultSize(batch, 0, batch->num_row)) { CHECK_GT(num_output_group_, 1); CHECK_EQ(total_size % batch->num_row, 0); const size_t query_size_per_instance = total_size / batch->num_row; CHECK_GT(query_size_per_instance, 0); CHECK_LT(query_size_per_instance, num_output_group_); for (size_t rid = 0; rid < batch->num_row; ++rid) { for (size_t k = 0; k < query_size_per_instance; ++k) { out_result[rid * query_size_per_instance + k] = out_result[rid * num_output_group_ + k]; } } } const double tend = dmlc::GetTime(); if (verbose > 0) { LOG(INFO) << "Treelite: Finished prediction in " << tend - tstart << " sec"; } return total_size; } size_t Predictor::PredictBatch(const CSRBatch* batch, int verbose, bool pred_margin, float* out_result) { return PredictBatchBase_(batch, verbose, pred_margin, out_result); } size_t Predictor::PredictBatch(const DenseBatch* batch, int verbose, bool pred_margin, float* out_result) { return PredictBatchBase_(batch, verbose, pred_margin, out_result); } size_t Predictor::PredictInst(TreelitePredictorEntry* inst, bool pred_margin, float* out_result) { size_t total_size; total_size = PredictInst_(inst, pred_margin, num_output_group_, pred_func_handle_, QueryResultSizeSingleInst(), out_result); return total_size; } } // namespace treelite