@@ -105,6 +105,11 @@ class SWIGResources implements AutoCloseable {
105
105
*/
106
106
private Integer boosterNumFeatures ;
107
107
108
+ /**
109
+ * Number of classes in the trained LGBM.
110
+ */
111
+ private Integer boosterNumClasses = null ;
112
+
108
113
/**
109
114
* Names of features in the trained LightGBM boosting model.
110
115
* Whilst not a swig resource, it is automatically retrieved during model loading,
@@ -252,19 +257,21 @@ private void initBoosterFastContributionsHandle(final String LightGBMParameters)
252
257
* Assumes the model was already loaded from file.
253
258
* Initializes the remaining SWIG resources needed to use the model.
254
259
*
260
+ * The size of {@link #swigOutContributionsPtr} is computed accoring to
261
+ * https://lightgbm.readthedocs.io/en/latest/C-API.html#c.LGBM_BoosterPredictForMatSingleRow
262
+ *
255
263
* @throws LightGBMException in case there's an error in the C++ core library.
256
264
*/
257
265
private void initAuxiliaryModelResources () throws LightGBMException {
258
-
259
- this .boosterNumFeatures = computeBoosterNumFeaturesFromModel ();
266
+ computeBoosterNumFeaturesFromModel ();
260
267
logger .debug ("Loaded LightGBM Model has {} features." , this .boosterNumFeatures );
261
268
262
269
this .boosterFeatureNames = computeBoosterFeatureNamesFromModel ();
263
270
264
271
this .swigOutLengthInt64Ptr = lightgbmlibJNI .new_int64_tp ();
265
272
this .swigInstancePtr = lightgbmlibJNI .new_doubleArray (getBoosterNumFeatures ());
266
- this .swigOutScoresPtr = lightgbmlibJNI .new_doubleArray (BINARY_LGBM_NUM_CLASSES );
267
- this .swigOutContributionsPtr = lightgbmlibJNI .new_doubleArray (this .boosterNumFeatures + 1 );
273
+ this .swigOutScoresPtr = lightgbmlibJNI .new_doubleArray (this . boosterNumClasses );
274
+ this .swigOutContributionsPtr = lightgbmlibJNI .new_doubleArray (( long ) this .boosterNumClasses * ( this . boosterNumFeatures + 1 ) );
268
275
}
269
276
270
277
/**
@@ -302,6 +309,14 @@ private void releaseInitializedSWIGResources() throws LightGBMException {
302
309
lightgbmlibJNI .delete_intp (this .swigOutIntPtr );
303
310
this .swigOutIntPtr = null ;
304
311
}
312
+ if (this .boosterNumFeatures != null ) {
313
+ lightgbmlibJNI .delete_intp (this .boosterNumFeatures );
314
+ this .boosterNumFeatures = null ;
315
+ }
316
+ if (this .boosterNumClasses != null ) {
317
+ lightgbmlibJNI .delete_intp (this .boosterNumClasses );
318
+ this .boosterNumClasses = null ;
319
+ }
305
320
if (this .swigOutContributionsPtr != null ) {
306
321
lightgbmlibJNI .delete_doubleArray (this .swigOutContributionsPtr );
307
322
this .swigOutContributionsPtr = null ;
@@ -373,17 +388,31 @@ public String[] getBoosterFeatureNames() {
373
388
* Computes the number of features in the model and returns it.
374
389
*
375
390
* @throws LightGBMException when there is a LightGBM C++ error.
376
- * @returns int with the number of Booster features.
377
391
*/
378
- private Integer computeBoosterNumFeaturesFromModel () throws LightGBMException {
379
-
380
- final int returnCodeLGBM = lightgbmlibJNI .LGBM_BoosterGetNumFeature (
392
+ private void computeBoosterNumFeaturesFromModel () throws LightGBMException {
393
+ final int returnCodeNumFeatsLGBM = lightgbmlibJNI .LGBM_BoosterGetNumFeature (
381
394
this .swigBoosterHandle ,
382
395
this .swigOutIntPtr );
383
- if (returnCodeLGBM == -1 )
396
+ if (returnCodeNumFeatsLGBM == -1 )
384
397
throw new LightGBMException ();
385
398
386
- return lightgbmlibJNI .intp_value (this .swigOutIntPtr );
399
+
400
+ if (this .boosterNumFeatures != null ) {
401
+ lightgbmlibJNI .delete_intp (this .boosterNumFeatures );
402
+ this .boosterNumFeatures = null ;
403
+ }
404
+ this .boosterNumFeatures = lightgbmlibJNI .intp_value (this .swigOutIntPtr );
405
+
406
+ final int returnCodeNumClassesLGBM = lightgbmlibJNI .LGBM_BoosterGetNumClasses (
407
+ this .swigBoosterHandle ,
408
+ this .swigOutIntPtr );
409
+ if (returnCodeNumClassesLGBM == -1 )
410
+ throw new LightGBMException ();
411
+ if (this .boosterNumClasses != null ) {
412
+ lightgbmlibJNI .delete_intp (this .boosterNumClasses );
413
+ this .boosterNumClasses = null ;
414
+ }
415
+ this .boosterNumClasses = lightgbmlibJNI .intp_value (this .swigOutIntPtr );
387
416
}
388
417
389
418
/**
0 commit comments