#include <treelite/c_api_common.h> #include <treelite/c_api_runtime.h> #include <treelite/predictor.h> #include <dmlc/endian.h> #include <dmlc/logging.h> #include <dmlc/memory_io.h> #include <algorithm> #include <vector> #include "./treelite4j.h" namespace { // set handle void setHandle(JNIEnv* jenv, jlongArray jhandle, void* handle) { #ifdef __APPLE__ jlong out = (long)handle; #else int64_t out = (int64_t)handle; #endif jenv->SetLongArrayRegion(jhandle, 0, 1, &out); } } // namespace anonymous /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreeliteGetLastError * Signature: ()Ljava/lang/String; */ JNIEXPORT jstring JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreeliteGetLastError( JNIEnv* jenv, jclass jcls) { jstring jresult = 0; const char* result = TreeliteGetLastError(); if (result) { jresult = jenv->NewStringUTF(result); } return jresult; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreeliteAssembleSparseBatch * Signature: ([F[I[JJJ[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreeliteAssembleSparseBatch( JNIEnv* jenv, jclass jcls, jfloatArray jdata, jintArray jcol_ind, jlongArray jrow_ptr, jlong jnum_row, jlong jnum_col, jlongArray jout) { jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jint* col_ind = jenv->GetIntArrayElements(jcol_ind, 0); jlong* row_ptr = jenv->GetLongArrayElements(jrow_ptr, 0); CSRBatchHandle out; jint ret; if (sizeof(size_t) == sizeof(uint64_t)) { ret = (jint)TreeliteAssembleSparseBatch((const float*)data, (const uint32_t*)col_ind, (const size_t*)row_ptr, (size_t)jnum_row, (size_t)jnum_col, &out); } else { LOG(FATAL) << "32-bit platform not supported yet"; } setHandle(jenv, jout, out); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreeliteDeleteSparseBatch * Signature: (J[F[I[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreeliteDeleteSparseBatch( JNIEnv* jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jintArray jcol_ind, jlongArray jrow_ptr) { treelite::CSRBatch* batch = (treelite::CSRBatch*)jhandle; jenv->ReleaseFloatArrayElements(jdata, (jfloat*)batch->data, 0); jenv->ReleaseIntArrayElements(jcol_ind, (jint*)batch->col_ind, 0); jenv->ReleaseLongArrayElements(jrow_ptr, (jlong*)batch->row_ptr, 0); return (jint)TreeliteDeleteSparseBatch((CSRBatchHandle)batch); } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreeliteAssembleDenseBatch * Signature: ([FFJJ[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreeliteAssembleDenseBatch( JNIEnv* jenv, jclass jcls, jfloatArray jdata, jfloat jmissing_value, jlong jnum_row, jlong jnum_col, jlongArray jout) { jfloat* data = jenv->GetFloatArrayElements(jdata, 0); DenseBatchHandle out; const jint ret = (jint)TreeliteAssembleDenseBatch((const float*)data, (float)jmissing_value, (size_t)jnum_row, (size_t)jnum_col, &out); setHandle(jenv, jout, out); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreeliteDeleteDenseBatch * Signature: (J[F)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreeliteDeleteDenseBatch( JNIEnv* jenv, jclass jcls, jlong jhandle, jfloatArray jdata) { treelite::DenseBatch* batch = (treelite::DenseBatch*)jhandle; jenv->ReleaseFloatArrayElements(jdata, (jfloat*)batch->data, 0); return (jint)TreeliteDeleteDenseBatch((DenseBatchHandle)batch); } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreeliteBatchGetDimension * Signature: (JZ[J[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreeliteBatchGetDimension( JNIEnv* jenv, jclass jcls, jlong jhandle, jboolean jbatch_sparse, jlongArray jout_num_row, jlongArray jout_num_col) { size_t num_row, num_col; const jint ret = (jint)TreeliteBatchGetDimension((void*)jhandle, (jbatch_sparse == JNI_TRUE ? 1 : 0), &num_row, &num_col); // save dimensions jlong* out_num_row = jenv->GetLongArrayElements(jout_num_row, 0); jlong* out_num_col = jenv->GetLongArrayElements(jout_num_col, 0); out_num_row[0] = (jlong)num_row; out_num_col[0] = (jlong)num_col; jenv->ReleaseLongArrayElements(jout_num_row, out_num_row, 0); jenv->ReleaseLongArrayElements(jout_num_col, out_num_col, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorLoad * Signature: (Ljava/lang/String;I[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorLoad( JNIEnv* jenv, jclass jcls, jstring jlibrary_path, jint jnum_worker_thread, jlongArray jout) { const char* library_path = jenv->GetStringUTFChars(jlibrary_path, 0); PredictorHandle out; const jint ret = (jint)TreelitePredictorLoad(library_path, (int)jnum_worker_thread, &out); setHandle(jenv, jout, out); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorPredictBatch * Signature: (JJZZZ[F)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorPredictBatch( JNIEnv* jenv, jclass jcls, jlong jhandle, jlong jbatch, jboolean jbatch_sparse, jboolean jverbose, jboolean jpred_margin, jfloatArray jout_result, jlongArray jout_result_size) { jfloat* out_result = jenv->GetFloatArrayElements(jout_result, 0); jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, 0); size_t out_result_size_tmp; const jint ret = (jint)TreelitePredictorPredictBatch( (PredictorHandle)jhandle, (void*)jbatch, (jbatch_sparse == JNI_TRUE ? 1 : 0), (jverbose == JNI_TRUE ? 1 : 0), (jpred_margin == JNI_TRUE ? 1 : 0), (float*)out_result, &out_result_size_tmp); out_result_size[0] = (jlong)out_result_size_tmp; // release arrays jenv->ReleaseFloatArrayElements(jout_result, out_result, 0); jenv->ReleaseLongArrayElements(jout_result_size, out_result_size, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorPredictInst * Signature: (J[BZ[F[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorPredictInst( JNIEnv* jenv, jclass jcls, jlong jhandle, jbyteArray jinst, jboolean jpred_margin, jfloatArray jout_result, jlongArray jout_result_size) { // read Entry[] array from bytes jbyte* inst_bytes = jenv->GetByteArrayElements(jinst, 0); const size_t nbytes = jenv->GetArrayLength(jinst); CHECK_EQ(nbytes % sizeof(TreelitePredictorEntry), 0); const size_t num_elem = nbytes / sizeof(TreelitePredictorEntry); if (!DMLC_LITTLE_ENDIAN) { // re-order bytes on big-endian machines dmlc::ByteSwap((void*)inst_bytes, nbytes, num_elem); } dmlc::MemoryFixedSizeStream fs((void*)inst_bytes, nbytes); std::vector<TreelitePredictorEntry> inst(num_elem, {-1}); for (int i = 0; i < num_elem; ++i) { fs.Read(&inst[i], sizeof(TreelitePredictorEntry)); } jfloat* out_result = jenv->GetFloatArrayElements(jout_result, 0); jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, 0); size_t out_result_size_tmp; const jint ret = (jint)TreelitePredictorPredictInst((PredictorHandle)jhandle, inst.data(), (jpred_margin == JNI_TRUE ? 1 : 0), (float*)out_result, &out_result_size_tmp); out_result_size[0] = (jlong)out_result_size_tmp; // release arrays jenv->ReleaseByteArrayElements(jinst, inst_bytes, 0); jenv->ReleaseFloatArrayElements(jout_result, out_result, 0); jenv->ReleaseLongArrayElements(jout_result_size, out_result_size, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorQueryResultSize * Signature: (JJZ[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorQueryResultSize( JNIEnv* jenv, jclass jcls, jlong jhandle, jlong jbatch, jboolean jbatch_sparse, jlongArray jout) { size_t result_size; const jint ret = (jint)TreelitePredictorQueryResultSize( (PredictorHandle)jhandle, (void*)jbatch, (jbatch_sparse == JNI_TRUE ? 1 : 0), &result_size); // store dimension jlong* out = jenv->GetLongArrayElements(jout, 0); out[0] = (jlong)result_size; jenv->ReleaseLongArrayElements(jout, out, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorQueryResultSizeSingleInst * Signature: (J[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorQueryResultSizeSingleInst( JNIEnv* jenv, jclass jcls, jlong jhandle, jlongArray jout) { size_t result_size; const jint ret = (jint)TreelitePredictorQueryResultSizeSingleInst( (PredictorHandle)jhandle, &result_size); // store dimension jlong* out = jenv->GetLongArrayElements(jout, 0); out[0] = (jlong)result_size; jenv->ReleaseLongArrayElements(jout, out, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorQueryNumOutputGroup * Signature: (J[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorQueryNumOutputGroup( JNIEnv* jenv, jclass jcls, jlong jhandle, jlongArray jout) { size_t num_output_group; const jint ret = (jint)TreelitePredictorQueryNumOutputGroup( (PredictorHandle)jhandle, &num_output_group); // store dimension jlong* out = jenv->GetLongArrayElements(jout, 0); out[0] = (jlong)num_output_group; jenv->ReleaseLongArrayElements(jout, out, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorQueryNumFeature * Signature: (J[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorQueryNumFeature( JNIEnv* jenv, jclass jcls, jlong jhandle, jlongArray jout) { size_t num_feature; const jint ret = (jint)TreelitePredictorQueryNumFeature( (PredictorHandle)jhandle, &num_feature); // store dimension jlong* out = jenv->GetLongArrayElements(jout, 0); out[0] = (jlong)num_feature; jenv->ReleaseLongArrayElements(jout, out, 0); return ret; } /* * Class: ml_dmlc_treelite4j_TreeliteJNI * Method: TreelitePredictorFree * Signature: (J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_TreeliteJNI_TreelitePredictorFree( JNIEnv* jenv, jclass jcls, jlong jhandle) { return (jint)TreelitePredictorFree((PredictorHandle)jhandle); }