Skip to content

Commit 3b96f41

Browse files
committed
ref PULSEDEV-36792 openml-lightgbm: Improvement
1 parent bd3fe9f commit 3b96f41

File tree

1 file changed

+39
-10
lines changed
  • openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm

1 file changed

+39
-10
lines changed

openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGResources.java

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ class SWIGResources implements AutoCloseable {
105105
*/
106106
private Integer boosterNumFeatures;
107107

108+
/**
109+
* Number of classes in the trained LGBM.
110+
*/
111+
private Integer boosterNumClasses = null;
112+
108113
/**
109114
* Names of features in the trained LightGBM boosting model.
110115
* Whilst not a swig resource, it is automatically retrieved during model loading,
@@ -252,19 +257,21 @@ private void initBoosterFastContributionsHandle(final String LightGBMParameters)
252257
* Assumes the model was already loaded from file.
253258
* Initializes the remaining SWIG resources needed to use the model.
254259
*
260+
* The size of {@link #swigOutContributionsPtr} is computed accoring to
261+
* https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow
262+
*
255263
* @throws LightGBMException in case there's an error in the C++ core library.
256264
*/
257265
private void initAuxiliaryModelResources() throws LightGBMException {
258-
259-
this.boosterNumFeatures = computeBoosterNumFeaturesFromModel();
266+
computeBoosterNumFeaturesFromModel();
260267
logger.debug("Loaded LightGBM Model has {} features.", this.boosterNumFeatures);
261268

262269
this.boosterFeatureNames = computeBoosterFeatureNamesFromModel();
263270

264271
this.swigOutLengthInt64Ptr = lightgbmlibJNI.new_int64_tp();
265272
this.swigInstancePtr = lightgbmlibJNI.new_doubleArray(getBoosterNumFeatures());
266-
this.swigOutScoresPtr = lightgbmlibJNI.new_doubleArray(BINARY_LGBM_NUM_CLASSES);
267-
this.swigOutContributionsPtr = lightgbmlibJNI.new_doubleArray(this.boosterNumFeatures + 1);
273+
this.swigOutScoresPtr = lightgbmlibJNI.new_doubleArray(this.boosterNumClasses);
274+
this.swigOutContributionsPtr = lightgbmlibJNI.new_doubleArray((long) this.boosterNumClasses * (this.boosterNumFeatures + 1));
268275
}
269276

270277
/**
@@ -302,6 +309,14 @@ private void releaseInitializedSWIGResources() throws LightGBMException {
302309
lightgbmlibJNI.delete_intp(this.swigOutIntPtr);
303310
this.swigOutIntPtr = null;
304311
}
312+
if (this.boosterNumFeatures != null) {
313+
lightgbmlibJNI.delete_intp(this.boosterNumFeatures);
314+
this.boosterNumFeatures = null;
315+
}
316+
if (this.boosterNumClasses != null) {
317+
lightgbmlibJNI.delete_intp(this.boosterNumClasses);
318+
this.boosterNumClasses = null;
319+
}
305320
if (this.swigOutContributionsPtr != null) {
306321
lightgbmlibJNI.delete_doubleArray(this.swigOutContributionsPtr);
307322
this.swigOutContributionsPtr = null;
@@ -373,17 +388,31 @@ public String[] getBoosterFeatureNames() {
373388
* Computes the number of features in the model and returns it.
374389
*
375390
* @throws LightGBMException when there is a LightGBM C++ error.
376-
* @returns int with the number of Booster features.
377391
*/
378-
private Integer computeBoosterNumFeaturesFromModel() throws LightGBMException {
379-
380-
final int returnCodeLGBM = lightgbmlibJNI.LGBM_BoosterGetNumFeature(
392+
private void computeBoosterNumFeaturesFromModel() throws LightGBMException {
393+
final int returnCodeNumFeatsLGBM = lightgbmlibJNI.LGBM_BoosterGetNumFeature(
381394
this.swigBoosterHandle,
382395
this.swigOutIntPtr);
383-
if (returnCodeLGBM == -1)
396+
if (returnCodeNumFeatsLGBM == -1)
384397
throw new LightGBMException();
385398

386-
return lightgbmlibJNI.intp_value(this.swigOutIntPtr);
399+
400+
if (this.boosterNumFeatures != null) {
401+
lightgbmlibJNI.delete_intp(this.boosterNumFeatures);
402+
this.boosterNumFeatures = null;
403+
}
404+
this.boosterNumFeatures = lightgbmlibJNI.intp_value(this.swigOutIntPtr);
405+
406+
final int returnCodeNumClassesLGBM = lightgbmlibJNI.LGBM_BoosterGetNumClasses(
407+
this.swigBoosterHandle,
408+
this.swigOutIntPtr);
409+
if (returnCodeNumClassesLGBM == -1)
410+
throw new LightGBMException();
411+
if (this.boosterNumClasses != null) {
412+
lightgbmlibJNI.delete_intp(this.boosterNumClasses);
413+
this.boosterNumClasses = null;
414+
}
415+
this.boosterNumClasses = lightgbmlibJNI.intp_value(this.swigOutIntPtr);
387416
}
388417

389418
/**

0 commit comments

Comments
 (0)