diff --git a/openml-lightgbm/lightgbm-builder/make-lightgbm b/openml-lightgbm/lightgbm-builder/make-lightgbm
index 6928602c..319698ae 160000
--- a/openml-lightgbm/lightgbm-builder/make-lightgbm
+++ b/openml-lightgbm/lightgbm-builder/make-lightgbm
@@ -1 +1 @@
-Subproject commit 6928602c37bae318fe7dfd3226893ced596a112f
+Subproject commit 319698ae745652cb6d33a22651aa889be48d311e
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java
index 69eb7b70..222b7f59 100644
--- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java
@@ -22,27 +22,21 @@
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
+import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil;
+import com.feedzai.openml.provider.lightgbm.schema.TrainSchemaUtil;
import com.google.common.collect.ImmutableSet;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_float;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_int;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_void;
-import com.microsoft.ml.lightgbm.lightgbmlib;
-import com.microsoft.ml.lightgbm.lightgbmlibConstants;
-import java.util.HashMap;
-import java.util.Optional;
-import java.util.Set;
+import com.microsoft.ml.lightgbm.*;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.file.Path;
-import java.util.Arrays;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
import static com.feedzai.openml.provider.lightgbm.FairGBMDescriptorUtil.CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME;
+import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME;
+import static com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil.getSoftLabelFieldName;
import static java.lang.Integer.parseInt;
import static java.util.stream.Collectors.toList;
@@ -64,18 +58,20 @@ final class LightGBMBinaryClassificationModelTrainer {
* Train data is copied from the input stream into an array of chunks.
* Each chunk will have this many instances. Must be set before `fit()`.
*
+ *
* @implNote Performance overhead notes:
* - Too small? Performance overhead - excessive in-memory data fragmentation.
* - Too large? RAM overhead - in the worst case the last chunk has only 1 instance.
- * Each instance might have upwards of 400 features. Each costs 8 bytes.
- * E.g.: 100k instances of 400 features => 320MB / chunk
+ * Each instance might have upwards of 400 features. Each costs 8 bytes.
+ * E.g.: 100k instances of 400 features => 320MB / chunk
*/
static final long DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE = 200000;
/**
* This class is not meant to be instantiated.
*/
- private LightGBMBinaryClassificationModelTrainer() {}
+ private LightGBMBinaryClassificationModelTrainer() {
+ }
/**
* See LightGBMBinaryClassificationModelTrainer#fit overload below.
@@ -89,10 +85,12 @@ static void fit(final Dataset dataset,
final Map params,
final Path outputModelFilePath) {
- fit(dataset,
- params,
- outputModelFilePath,
- DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE);
+ fit(
+ dataset,
+ params,
+ outputModelFilePath,
+ DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE
+ );
}
/**
@@ -129,7 +127,9 @@ static void fit(final Dataset dataset,
final long instancesPerChunk) {
final DatasetSchema schema = dataset.getSchema();
- final int numFeatures = schema.getPredictiveFields().size();
+ final Optional softLabelColumnIndex = SoftLabelParamParserUtil.getSoftLabelColumnIndex(params, schema);
+ final int numActualFeatures = TrainSchemaUtil.getNumActualFeatures(schema, params);
+
// Parse train parameters to LightGBM format
final String trainParams = getLightGBMTrainParamsString(params, schema);
@@ -138,7 +138,7 @@ static void fit(final Dataset dataset,
logger.debug("LightGBM model trainParams: {}", trainParams);
final SWIGTrainData swigTrainData = new SWIGTrainData(
- numFeatures,
+ numActualFeatures,
instancesPerChunk,
FairGBMParamParserUtil.isFairnessConstrained(params));
final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster();
@@ -146,7 +146,21 @@ static void fit(final Dataset dataset,
/// Create LightGBM dataset
final int constraintGroupColIndex = FairGBMParamParserUtil.getConstraintGroupColumnIndex(params, schema).orElse(
FairGBMParamParserUtil.NO_SPECIFIC);
- createTrainDataset(dataset, numFeatures, trainParams, constraintGroupColIndex, swigTrainData);
+
+
+ if (softLabelColumnIndex.isPresent()) {
+ logger.debug("Replacing hard label by soft label for training.");
+ }
+
+ createTrainDataset(
+ dataset,
+ numActualFeatures,
+ trainParams,
+ constraintGroupColIndex,
+ getSoftLabelFieldName(params),
+ softLabelColumnIndex,
+ swigTrainData
+ );
/// Create Booster from dataset
createBoosterStructure(swigTrainBooster, swigTrainData, trainParams);
@@ -193,16 +207,18 @@ private static String[] getFieldNames(final List fields) {
* - Initializing the label data in the dataset + releasing the label array;
* - Setting the feature names in the dataset.
*
- * @param dataset Dataset
- * @param numFeatures Number of features
- * @param trainParams LightGBM-formatted params string ("key1=value1 key2=value2 ...")
- * @param constraintGroupColIndex The index of the constraint group column.
- * @param swigTrainData SWIGTrainData object
+ * @param dataset Dataset
+ * @param numActualFeatures Number of features excluding soft label if used
+ * @param trainParams LightGBM-formatted params string ("key1=value1 key2=value2 ...")
+ * @param constraintGroupColIndex The index of the constraint group column.
+ * @param swigTrainData SWIGTrainData object
*/
private static void createTrainDataset(final Dataset dataset,
- final int numFeatures,
+ final int numActualFeatures,
final String trainParams,
final int constraintGroupColIndex,
+ final Optional softLabelFieldName,
+ final Optional softLabelColumnIndex,
final SWIGTrainData swigTrainData) {
logger.info("Creating LightGBM dataset");
@@ -211,12 +227,13 @@ private static void createTrainDataset(final Dataset dataset,
copyTrainDataToSWIGArrays(
dataset,
swigTrainData,
- constraintGroupColIndex
+ constraintGroupColIndex,
+ softLabelColumnIndex
);
initializeLightGBMTrainDatasetFeatures(
swigTrainData,
- numFeatures,
+ numActualFeatures,
trainParams
);
@@ -230,7 +247,12 @@ private static void createTrainDataset(final Dataset dataset,
);
}
- setLightGBMDatasetFeatureNames(swigTrainData.swigDatasetHandle, dataset.getSchema());
+ setLightGBMDatasetFeatureNames(
+ swigTrainData.swigDatasetHandle,
+ dataset.getSchema(),
+ softLabelFieldName,
+ numActualFeatures
+ );
logger.info("Created LightGBM dataset.");
}
@@ -239,17 +261,17 @@ private static void createTrainDataset(final Dataset dataset,
* Initializes the LightGBM dataset structure and copies the feature data.
*
* @param swigTrainData SWIGTrainData object.
- * @param numFeatures Number of features used to predict.
+ * @param numActualFeatures Number of features used to predict.
* @param trainParams LightGBM string with the train params ("key1=value1 key2=value2 ...").
*/
private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData swigTrainData,
- final int numFeatures,
+ final int numActualFeatures,
final String trainParams) {
logger.debug("Initializing LightGBM in-memory structure and setting feature data.");
/// First generate the array that has the chunk sizes for `LGBM_DatasetCreateFromMats`.
- final SWIGTYPE_p_int swigChunkSizesArray = genSWIGFeatureChunkSizesArray(swigTrainData, numFeatures);
+ final SWIGTYPE_p_int swigChunkSizesArray = genSWIGFeatureChunkSizesArray(swigTrainData, numActualFeatures);
/// Now create the LightGBM Dataset itself from the chunks:
logger.debug("Creating LGBM_Dataset from chunked data...");
@@ -258,7 +280,7 @@ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData s
swigTrainData.swigFeaturesChunkedArray.data_as_void(), // input data (void**)
lightgbmlibConstants.C_API_DTYPE_FLOAT64,
swigChunkSizesArray,
- numFeatures,
+ numActualFeatures,
1, // rowMajor.
trainParams, // parameters.
null, // No alignment with other datasets.
@@ -278,11 +300,11 @@ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData s
* Generates a SWIG array of ints with the size of each train chunk (partition).
*
* @param swigTrainData SWIGTrainData object.
- * @param numFeatures Number of features used to predict.
+ * @param numActualFeatures Number of features used to predict (excludes soft label field if used).
* @return SWIG (int*) array of the train chunks' sizes.
*/
private static SWIGTYPE_p_int genSWIGFeatureChunkSizesArray(final SWIGTrainData swigTrainData,
- final int numFeatures) {
+ final int numActualFeatures) {
logger.debug("Retrieving chunked data block sizes...");
@@ -299,11 +321,11 @@ private static SWIGTYPE_p_int genSWIGFeatureChunkSizesArray(final SWIGTrainData
lightgbmlib.intArray_setitem(
swigChunkSizesArray,
numChunks - 1,
- (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numFeatures
+ (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numActualFeatures
);
logger.debug("FTL: chunk-size report: chunk #{} is partial-chunk of size {}",
numChunks - 1,
- (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numFeatures);
+ (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numActualFeatures);
return swigChunkSizesArray;
}
@@ -370,14 +392,20 @@ private static void setLightGBMDatasetConstraintGroupData(final SWIGTrainData sw
* @param swigDatasetHandle SWIG dataset handle
* @param schema Dataset schema
*/
- private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDatasetHandle, final DatasetSchema schema) {
+ private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDatasetHandle,
+ final DatasetSchema schema,
+ final Optional softLabelFieldName,
+ final int numActualFeatures) {
- final int numFeatures = schema.getPredictiveFields().size();
- final String[] featureNames = getFieldNames(schema.getPredictiveFields());
- logger.debug("featureNames {}", Arrays.toString(featureNames));
+ final String[] actualFeatureNames = TrainSchemaUtil.getActualFeatureNames(schema, softLabelFieldName);
+ logger.debug("featureNames {} (numFeatures = {})", Arrays.toString(actualFeatureNames), numActualFeatures);
- final int returnCodeLGBM = lightgbmlib.LGBM_DatasetSetFeatureNames(swigDatasetHandle, featureNames, numFeatures);
+ final int returnCodeLGBM = lightgbmlib.LGBM_DatasetSetFeatureNames(
+ swigDatasetHandle,
+ actualFeatureNames,
+ numActualFeatures
+ );
if (returnCodeLGBM == -1) {
logger.error("Could not set feature names.");
throw new LightGBMException();
@@ -388,8 +416,8 @@ private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDat
* Creates the booster structure with all the parameters and training dataset resources (but doesn't train).
*
* @param swigTrainBooster An object with the training resources already initialized.
- * @param swigTrainData SWIGTrainData object.
- * @param trainParams the LightGBM string-formatted string with properties in the form "key1=value1 key2=value2 ...".
+ * @param swigTrainData SWIGTrainData object.
+ * @param trainParams the LightGBM string-formatted string with properties in the form "key1=value1 key2=value2 ...".
* @see LightGBMBinaryClassificationModelTrainer#trainBooster(SWIGTYPE_p_void, int) .
*/
static void createBoosterStructure(final SWIGTrainBooster swigTrainBooster,
@@ -474,20 +502,11 @@ static void saveModelFileToDisk(final SWIGTYPE_p_void swigBoosterHandle, final P
logger.info("Saved model to disk");
}
- /**
- * Takes the data in dataset and copies it into the features and label C++ arrays through SWIG.
- *
- * @param dataset Input train dataset (with target label)
- * @param swigTrainData SWIGTrainData object.
- */
- private static void copyTrainDataToSWIGArrays(final Dataset dataset,
- final SWIGTrainData swigTrainData) {
- copyTrainDataToSWIGArrays(dataset, swigTrainData, FairGBMParamParserUtil.NO_SPECIFIC);
- }
private static void copyTrainDataToSWIGArrays(final Dataset dataset,
final SWIGTrainData swigTrainData,
- final int constraintGroupIndex) {
+ final int constraintGroupIndex,
+ final Optional softLabelIndex) {
final DatasetSchema datasetSchema = dataset.getSchema();
final int numFields = datasetSchema.getFieldSchemas().size();
@@ -495,27 +514,36 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset,
This is ensured in validateForFit, by using the
ValidationUtils' validateCategoricalSchema:
*/
- final int targetIndex = datasetSchema.getTargetIndex().get();
+ final int hardLabelIndex = datasetSchema.getTargetIndex().get();
+
+ /*
+ * If a soft label is used, it replaces the hard label for training.
+ * In such case, the soft label cannot be used as a feature.
+ */
+ final int labelFieldIndex = softLabelIndex.orElse(hardLabelIndex);
final Iterator iterator = dataset.getInstances();
while (iterator.hasNext()) {
final Instance instance = iterator.next();
- swigTrainData.addLabelValue((float) instance.getValue(targetIndex));
+ swigTrainData.addLabelValue((float) instance.getValue(labelFieldIndex));
if (constraintGroupIndex != FairGBMParamParserUtil.NO_SPECIFIC) {
swigTrainData.addConstraintGroupValue((int) instance.getValue(constraintGroupIndex));
}
for (int colIdx = 0; colIdx < numFields; ++colIdx) {
- if (colIdx != targetIndex) {
- swigTrainData.addFeatureValue(instance.getValue(colIdx));
+ if (colIdx != hardLabelIndex) {
+ // If we're using the soft label, remove all information from its feature (keeps schema and avoids label leakage):
+ swigTrainData.addFeatureValue(
+ colIdx != labelFieldIndex ? instance.getValue(colIdx) : 0
+ );
}
}
}
assert swigTrainData.swigLabelsChunkedArray.get_add_count() ==
- (swigTrainData.swigFeaturesChunkedArray.get_add_count() / swigTrainData.numFeatures);
+ (swigTrainData.swigFeaturesChunkedArray.get_add_count() / swigTrainData.numActualFeatures);
if (swigTrainData.fairnessConstrained) {
assert swigTrainData.swigConstraintGroupChunkedArray.get_add_count() ==
@@ -523,8 +551,8 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset,
}
logger.debug("Copied train data of size {} into {} chunks.",
- swigTrainData.swigLabelsChunkedArray.get_add_count(),
- swigTrainData.swigLabelsChunkedArray.get_chunks_count()
+ swigTrainData.swigLabelsChunkedArray.get_add_count(),
+ swigTrainData.swigLabelsChunkedArray.get_chunks_count()
);
if (swigTrainData.swigLabelsChunkedArray.get_add_count() == 0) {
@@ -547,11 +575,13 @@ private static String getLightGBMTrainParamsString(final Map map
preprocessedMapParams
.put("categorical_feature", StringUtils.join(categoricalFeatureIndices, ","));
- // Set default objective parameter
- final Optional objective = getLightGBMObjective(mapParams);
- if (! objective.isPresent()) {
- // Default to objective=binary
- preprocessedMapParams.put("objective", "binary");
+ // Set default objective parameter - by default "binary":
+ final Optional customObjective = getLightGBMObjective(mapParams);
+ if (!customObjective.isPresent()) {
+ final String objective = SoftLabelParamParserUtil.getSoftLabelFieldName(mapParams).isPresent() ?
+ "cross_entropy" : "binary";
+
+ preprocessedMapParams.put("objective", objective);
}
// Add constraint_group_column parameter
@@ -560,7 +590,12 @@ private static String getLightGBMTrainParamsString(final Map map
integer -> preprocessedMapParams.put(CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME, Integer.toString(integer)));
// Add all **other** parameters
- mapParams.forEach(preprocessedMapParams::putIfAbsent);
+ final Set lgbmUnknownParams = ImmutableSet.of(SOFT_LABEL_PARAMETER_NAME);
+ mapParams
+ .entrySet()
+ .stream()
+ .filter(entry -> !lgbmUnknownParams.contains(entry.getKey()))
+ .forEach(entry -> preprocessedMapParams.putIfAbsent(entry.getKey(), entry.getValue()));
// Build string containing params in LightGBM format
final StringBuilder paramsBuilder = new StringBuilder();
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java
index cf03abe2..1225c1bf 100644
--- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java
@@ -20,6 +20,7 @@
import com.feedzai.openml.provider.descriptor.ModelParameter;
import com.feedzai.openml.provider.descriptor.fieldtype.BooleanFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.ChoiceFieldType;
+import com.feedzai.openml.provider.descriptor.fieldtype.FreeTextFieldType;
import com.google.common.collect.ImmutableSet;
import java.util.Set;
@@ -72,6 +73,7 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
*/
public static final String FEATURE_FRACTION_PARAMETER_DESCRIPTION = "Feature fraction by tree";
+ public static final String SOFT_LABEL_PARAMETER_NAME = "soft_label";
/**
* Defines the set of model parameters accepted by the LightGBM model.
*/
@@ -99,6 +101,17 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
MANDATORY,
intRange(0, Integer.MAX_VALUE, 100)
),
+ new ModelParameter(
+ SOFT_LABEL_PARAMETER_NAME,
+ "(Soft classification) Soft Label",
+ "Train the classifier with a soft label instead. \n"
+ + "These values should be between 0.0 and 1.0. This field is ignored for validations. \n"
+ + "Simply pass the name of the field.\n"
+ + "If this field is selected, it is automatically dropped from the selected features.",
+ NOT_MANDATORY,
+ new FreeTextFieldType("")
+ )
+ ,
new ModelParameter(
"learning_rate",
"Learning rate",
@@ -314,7 +327,7 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
),
new ModelParameter(
"min_data_in_bin",
- "Minimium bin data",
+ "Minimum bin data",
"Minimum number of samples inside one bin. Limit over-fitting. E.g., not using 1 point per bin.",
NOT_MANDATORY,
intRange(1, Integer.MAX_VALUE, 3)
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java
index f492f8cf..6d4c9374 100644
--- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java
@@ -18,13 +18,10 @@
package com.feedzai.openml.provider.lightgbm;
import com.feedzai.openml.data.Dataset;
-import com.feedzai.openml.data.schema.AbstractValueSchema;
-import com.feedzai.openml.data.schema.CategoricalValueSchema;
-import com.feedzai.openml.data.schema.DatasetSchema;
-import com.feedzai.openml.data.schema.FieldSchema;
-import com.feedzai.openml.data.schema.StringValueSchema;
+import com.feedzai.openml.data.schema.*;
import com.feedzai.openml.provider.descriptor.fieldtype.ParamValidationError;
import com.feedzai.openml.provider.exception.ModelLoadingException;
+import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil;
import com.feedzai.openml.provider.model.MachineLearningModelTrainer;
import com.feedzai.openml.util.load.LoadSchemaUtils;
import com.google.common.collect.ImmutableList;
@@ -39,13 +36,8 @@
import java.util.Optional;
import java.util.Random;
-import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BAGGING_FRACTION_PARAMETER_NAME;
-import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BAGGING_FREQUENCY_PARAMETER_NAME;
-import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BOOSTING_TYPE_PARAMETER_NAME;
-import static com.feedzai.openml.util.validate.ValidationUtils.baseLoadValidations;
-import static com.feedzai.openml.util.validate.ValidationUtils.checkParams;
-import static com.feedzai.openml.util.validate.ValidationUtils.validateCategoricalSchema;
-import static com.feedzai.openml.util.validate.ValidationUtils.validateModelPathToTrain;
+import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.*;
+import static com.feedzai.openml.util.validate.ValidationUtils.*;
import static java.nio.file.Files.createTempFile;
/**
@@ -145,7 +137,6 @@ public LightGBMBinaryClassificationModel fit(final Dataset dataset,
}
-
@Override
public List validateForFit(final Path pathToPersist,
final DatasetSchema schema,
@@ -155,7 +146,8 @@ public List validateForFit(final Path pathToPersist,
errorsBuilder
.addAll(validateModelPathToTrain(pathToPersist))
.addAll(validateSchema(schema))
- .addAll(validateFitParams(params));
+ .addAll(validateFitParams(params))
+ .addAll(validateSoftLabelParam(schema, params));
return errorsBuilder.build();
}
@@ -183,6 +175,40 @@ private List validateSchema(final DatasetSchema schema) {
return errorsBuilder.build();
}
+ /**
+ * Ensure that if the soft label parameter is passed, it uses a floating point field.
+ *
+ * @param schema Dataset schema.
+ * @param params Model fit parameters.
+ * @return list of validation errors.
+ */
+ private List validateSoftLabelParam(final DatasetSchema schema,
+ final Map params) {
+ // Don't test anything if the parameter is not set:
+ final Optional softLabelFieldName = SoftLabelParamParserUtil.getSoftLabelFieldName(params);
+ if (!softLabelFieldName.isPresent()) {
+ return ImmutableList.of();
+ }
+
+ // Check if the field exists in the dataset:
+ final Optional softLabelIndex = SoftLabelParamParserUtil.getSoftLabelColumnIndex(params, schema);
+ if (!softLabelIndex.isPresent()) {
+ return ImmutableList.of(new ParamValidationError(
+ String.format("Soft label field '%s' doesn't exist in the dataset. Please select it in the features.", softLabelFieldName.get()))); // TODO: bad formatting (optional%$%WER)
+ }
+
+ // Ensure the soft label is numerical:
+ final FieldSchema softLabelSchema = schema
+ .getFieldSchemas()
+ .get(softLabelIndex.get());
+ final AbstractValueSchema valueSchema = softLabelSchema.getValueSchema();
+ if (!(softLabelSchema.getValueSchema() instanceof NumericValueSchema)) {
+ return ImmutableList.of(new ParamValidationError("Soft label must be a numeric field!"));
+ }
+
+ return ImmutableList.of();
+ }
+
/**
* Validate model fit parameters.
*
@@ -221,7 +247,7 @@ private boolean baggingDisabled(final Map params) {
@Override
public LightGBMBinaryClassificationModel loadModel(final Path modelPath,
- final DatasetSchema schema) throws ModelLoadingException {
+ final DatasetSchema schema) throws ModelLoadingException {
final Path modelFilePath = getPath(modelPath);
@@ -340,11 +366,10 @@ private static Optional getNumTargetClasses(final DatasetSchema schema)
/**
* Gets the feature names from schema.
*
- * @implNote The space character is replaced with underscore
- * to comply with LightGBM's model features representation.
- *
* @param schema Schema
* @return Feature names from the schema.
+ * @implNote The space character is replaced with underscore
+ * to comply with LightGBM's model features representation.
* @since 1.0.18
*/
private static String[] getFeatureNamesFrom(final DatasetSchema schema) {
@@ -360,7 +385,7 @@ private static String[] getFeatureNamesFrom(final DatasetSchema schema) {
* {@link DatasetSchema} and an array of feature names.
* This way the first mismatch is logged, improving debug.
*
- * @param schema Schema
+ * @param schema Schema
* @param featureNames Feature names to validate.
* @return {@code true} if the schema predictive field names
* match the provided array, {@code false} otherwise.
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java
index 1036e753..f4004a83 100644
--- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java
@@ -17,28 +17,21 @@
package com.feedzai.openml.provider.lightgbm;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_float;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_int;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void;
-import com.microsoft.ml.lightgbm.SWIGTYPE_p_void;
-import com.microsoft.ml.lightgbm.doubleChunkedArray;
-import com.microsoft.ml.lightgbm.floatChunkedArray;
-import com.microsoft.ml.lightgbm.int32ChunkedArray;
-import com.microsoft.ml.lightgbm.lightgbmlib;
+import com.microsoft.ml.lightgbm.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Handles train data resources and provides basic operations to manipulate train data.
+ *
+ * This class is responsible for initializing, managing and releasing all
+ * Handles train data resources and provides basic operations to manipulate train data.
+ * LightGBM SWIG train resources and resource handlers in a memory-safe manner.
+ *
+ * Whatever happens, it guarantees that no memory leaks are left behind.
*
- * This class is responsible for initializing, managing and releasing all
- * Handles train data resources and provides basic operations to manipulate train data.
- * LightGBM SWIG train resources and resource handlers in a memory-safe manner.
- *
- * Whatever happens, it guarantees that no memory leaks are left behind.
- *
- * @author Alberto Ferreira (alberto.ferreira@feedzai.com)
- * @since 1.0.10
+ * @author Alberto Ferreira (alberto.ferreira@feedzai.com)
+ * @since 1.0.10
*/
public class SWIGTrainData implements AutoCloseable {
@@ -56,7 +49,7 @@ public class SWIGTrainData implements AutoCloseable {
* SWIG Features chunked data array.
* This objects stores elements in float64 format
* as a list of chunks that it manages automatically.
- *
+ *
* In the current implementation, features are stored in row-major order, i.e.,
* each instance is stored contiguously.
*/
@@ -80,7 +73,7 @@ public class SWIGTrainData implements AutoCloseable {
/**
* Number of features per instance.
*/
- public final int numFeatures;
+ public final int numActualFeatures;
/**
* Number of instances to store in each chunk.
@@ -104,42 +97,42 @@ public class SWIGTrainData implements AutoCloseable {
/**
* Constructor.
- *
+ *
* Allocates all the initial handles necessary to bootstrap (but not use) the
* in-memory LightGBM dataset + booster structures.
- *
+ *
* After that the BoosterHandle and the DatasetHandle will still need to be initialized at the proper times:
- * @see SWIGTrainData#initSwigDatasetHandle()
*
- * @param numFeatures The number of features.
+ * @param numActualFeatures The number of features (excludes soft label if used).
* @param numInstancesChunk The number of instances per chunk of data.
+ * @see SWIGTrainData#initSwigDatasetHandle()
*/
- public SWIGTrainData(final int numFeatures, final long numInstancesChunk) {
- this(numFeatures, numInstancesChunk, false);
+ public SWIGTrainData(final int numActualFeatures, final long numInstancesChunk) {
+ this(numActualFeatures, numInstancesChunk, false);
}
/**
* Constructor.
- *
+ *
* Allocates al the initial ahndles necessary to bootstrap (but not use) the
* in-memory LightGBM dataset, plus booster structures.
- *
+ *
* If fairnessConstrained=true, this will also include data on which sensitive
* group each instance belongs to.
*
- * @param numFeatures The number of features.
- * @param numInstancesChunk The number of instances per chunk of data.
- * @param fairnessConstrained Whether this data will be used for a model with fairness (group) constraints.
+ * @param numActualFeatures The number of features (excludes soft label if used).
+ * @param numInstancesChunk The number of instances per chunk of data.
+ * @param fairnessConstrained Whether this data will be used for a model with fairness (group) constraints.
*/
- public SWIGTrainData(final int numFeatures, final long numInstancesChunk, final boolean fairnessConstrained) {
- this.numFeatures = numFeatures;
+ public SWIGTrainData(final int numActualFeatures, final long numInstancesChunk, final boolean fairnessConstrained) {
+ this.numActualFeatures = numActualFeatures;
this.numInstancesChunk = numInstancesChunk;
this.swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle();
this.fairnessConstrained = fairnessConstrained;
logger.debug("Intermediate SWIG train buffers will be allocated in chunks of {} instances.", numInstancesChunk);
// 1-D Array in row-major-order that stores only the features (excludes label) in double format by chunks:
- this.swigFeaturesChunkedArray = new doubleChunkedArray(numFeatures * numInstancesChunk);
+ this.swigFeaturesChunkedArray = new doubleChunkedArray(numActualFeatures * numInstancesChunk);
// 1-D Array with the labels (float32):
this.swigLabelsChunkedArray = new floatChunkedArray(numInstancesChunk);
@@ -152,6 +145,7 @@ public SWIGTrainData(final int numFeatures, final long numInstancesChunk, final
/**
* Adds a value to the features' ChunkedArray.
+ *
* @param value value to insert.
*/
public void addFeatureValue(double value) {
@@ -167,6 +161,7 @@ public void addLabelValue(float value) {
/**
* Adds a value to the constraint group ChunkedArray.
+ *
* @param value the value to add.
*/
public void addConstraintGroupValue(int value) {
@@ -239,7 +234,7 @@ void destroySwigConstraintGroupDataArray() {
/**
* Release the memory of the chunked features array.
* This can be called after instantiating the dataset.
- *
+ *
* Although this simply calls `release()`.
* After this that object becomes unusable.
* To cleanup and reuse call `clear()` instead.
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SchemaFieldsUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SchemaFieldsUtil.java
new file mode 100644
index 00000000..4881d9e4
--- /dev/null
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SchemaFieldsUtil.java
@@ -0,0 +1,72 @@
+package com.feedzai.openml.provider.lightgbm.parameters;
+
+import com.feedzai.openml.data.schema.DatasetSchema;
+import com.feedzai.openml.data.schema.FieldSchema;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Utilities to compute fields/column indices.
+ *
+ * @author alberto.ferreira
+ */
+public class SchemaFieldsUtil {
+
+ /**
+ * Logger for this class.
+ */
+ private static final Logger logger = LoggerFactory.getLogger(SchemaFieldsUtil.class);
+
+ /**
+ * Placeholder to use when an integer argument is not provided.
+ * E.g., when running standard unconstrained LightGBM the constrained_group_column parameter will take this value.
+ */
+ public static final int NO_SPECIFIC = -1;
+
+ public static Optional getColumnIndex(final String fieldName,
+ final DatasetSchema schema) {
+
+ final List featureFields = schema.getPredictiveFields();
+ Optional field = featureFields
+ .stream()
+ .filter(field_ -> field_.getFieldName().equalsIgnoreCase(fieldName))
+ .findFirst();
+
+ // Check if column exists
+ if (!field.isPresent()) {
+ logger.error(String.format(
+ "Column '%s' was not found in the dataset.",
+ fieldName));
+ return Optional.empty();
+ }
+
+ return Optional.of(field.get().getFieldIndex());
+ }
+
+ /**
+ * Gets the index of the soft label column without the label column (LightGBM-specific format)
+ *
+ * @param fieldName Name of the field in the dataset.
+ * @param schema Schema of the dataset.
+ * @return the index of the constraint group column without the label if the constraint_group_column parameter
+ * was provided, else returns an empty Optional.
+ */
+ public static Optional getFieldIndexWithoutLabel(final String fieldName,
+ final DatasetSchema schema) {
+
+ final Optional fieldIndex = getColumnIndex(fieldName, schema);
+ if (!fieldIndex.isPresent()) {
+ return Optional.empty();
+ }
+
+ // Compute column index in LightGBM-specific format (disregarding target column)
+ final int targetIndex = schema.getTargetIndex()
+ .orElseThrow(RuntimeException::new); // Our model is supervised. It needs the target.
+
+ final int offsetIfFieldIsAfterLabel = fieldIndex.get() > targetIndex ? -1 : 0;
+ return Optional.of(fieldIndex.get() + offsetIfFieldIsAfterLabel);
+ }
+}
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SoftLabelParamParserUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SoftLabelParamParserUtil.java
new file mode 100644
index 00000000..cfe8132c
--- /dev/null
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SoftLabelParamParserUtil.java
@@ -0,0 +1,63 @@
+package com.feedzai.openml.provider.lightgbm.parameters;
+
+import com.feedzai.openml.data.schema.DatasetSchema;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Map;
+import java.util.Optional;
+
+import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME;
+
+/**
+ * Utility to parse parameters specific to the FairGBM model.
+ *
+ * @author Alberto Ferreira (alberto.ferreira@feedzai.com)
+ */
+public class SoftLabelParamParserUtil {
+
+ /**
+ * Logger for this class.
+ */
+ private static final Logger logger = LoggerFactory.getLogger(SoftLabelParamParserUtil.class);
+
+ /**
+ * This class is not meant to be instantiated.
+ */
+ private SoftLabelParamParserUtil() {
+ }
+
+ public static Optional getSoftLabelFieldName(final Map params) {
+ final String softLabelFieldName = params.get(SOFT_LABEL_PARAMETER_NAME);
+
+ return softLabelFieldName == null || softLabelFieldName.equals("") ?
+ Optional.empty() : Optional.of(softLabelFieldName.trim());
+ }
+
+ public static boolean useSoftLabel(final Map mapParams) {
+ return getSoftLabelFieldName(mapParams).isPresent();
+ }
+
+
+ /**
+ * Gets the (canonical) index of the constraint group column.
+ * NOTE: the soft label column must be part of the features in the Dataset, but it may be ignored for training
+ *
+ * @param params LightGBM train parameters.
+ * @param schema Schema of the dataset.
+ * @return the index of the soft label column if one was provided, else returns an empty Optional.
+ */
+ public static Optional getSoftLabelColumnIndex(final Map params,
+ final DatasetSchema schema) {
+ final Optional softLabelFieldName = getSoftLabelFieldName(params);
+ return softLabelFieldName.isPresent() ?
+ SchemaFieldsUtil.getColumnIndex(softLabelFieldName.get(), schema) : Optional.empty();
+ }
+
+ public static Optional getSoftLabelFieldIndexWithoutLabel(final Map params,
+ final DatasetSchema schema) {
+ final Optional softLabelFieldName = getSoftLabelFieldName(params);
+ return softLabelFieldName.isPresent() ?
+ SchemaFieldsUtil.getFieldIndexWithoutLabel(softLabelFieldName.get(), schema) : Optional.empty();
+ }
+}
diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/schema/TrainSchemaUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/schema/TrainSchemaUtil.java
new file mode 100644
index 00000000..191465d9
--- /dev/null
+++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/schema/TrainSchemaUtil.java
@@ -0,0 +1,91 @@
+package com.feedzai.openml.provider.lightgbm.schema;
+
+import com.feedzai.openml.data.schema.DatasetSchema;
+import com.feedzai.openml.data.schema.FieldSchema;
+import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static java.util.stream.Collectors.toList;
+
+/**
+ * Utils to account for features
+ *
+ * @author alberto.ferreira
+ */
+public class TrainSchemaUtil {
+
+ private TrainSchemaUtil() {
+
+ }
+
+ /**
+ * Gets the actual number of features to use for train.
+ * If the soft label is used, it must be passed in the predictiveFields (features screen), but then excluded
+ * from the features, or there would be label leakage, as it is the label during fit.
+ *
+ * @param schema train dataset Schema
+ * @param params train parameters
+ * @return the actual number of features (excluding soft label in case it is used to train - in which case it's removed from the features)
+ */
+ static public int getNumActualFeatures(final DatasetSchema schema, final Map params) {
+ final int rawNumPredictiveFields = schema.getPredictiveFields().size();
+ return rawNumPredictiveFields;
+ /*
+
+ final Optional softLabelFieldName = SoftLabelParamParserUtil.getSoftLabelFieldName(params);
+
+ if (!softLabelFieldName.isPresent()) {
+ return rawNumPredictiveFields;
+ }
+
+ final boolean isSoftLabelSelectedAsFeature = schema
+ .getPredictiveFields()
+ .stream()
+ .anyMatch(
+ field -> field.getFieldName()
+ .equals(softLabelFieldName.get()));
+
+ return rawNumPredictiveFields - (isSoftLabelSelectedAsFeature ? 1 : 0);
+ */
+ }
+
+ static public String[] getActualFeatureNames(final DatasetSchema schema,
+ final Optional softLabelFieldName) {
+
+ return getFieldNamesArray(schema.getPredictiveFields());
+ /*
+ if (!softLabelFieldName.isPresent()) {
+ return getFieldNamesArray(schema.getPredictiveFields());
+ }
+
+ return getFieldNamesArray(
+ schema.getPredictiveFields()
+ .stream()
+ .filter(field -> !field.getFieldName().equals(softLabelFieldName.orElse("")))
+ .collect(toList())
+ );
+
+ */
+ }
+
+
+ /**
+ * @param fields List of FieldSchema fields.
+ * @return Names of the fields in the input list.
+ */
+ private static String[] getFieldNamesArray(final List fields) {
+ return fields.stream().map(FieldSchema::getFieldName).toArray(String[]::new);
+ }
+
+ /**
+ * @param fields List of FieldSchema fields.
+ * @return Names of the fields in the input list.
+ */
+ private static List getFieldNames(final List fields) {
+ return Arrays.asList(getFieldNamesArray(fields));
+ }
+}
diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerSoftTest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerSoftTest.java
new file mode 100644
index 00000000..19fddbf7
--- /dev/null
+++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerSoftTest.java
@@ -0,0 +1,119 @@
+package com.feedzai.openml.provider.lightgbm;
+
+import com.feedzai.openml.provider.exception.ModelLoadingException;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static com.feedzai.openml.provider.lightgbm.LightGBMBinaryClassificationModelTrainerTest.average;
+import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME;
+import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME;
+import static com.feedzai.openml.provider.lightgbm.resources.schemas.SoftSchemas.SOFT_SCHEMA;
+import static java.nio.file.Files.createTempDirectory;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class LightGBMBinaryClassificationModelTrainerSoftTest {
+
+ /**
+ * Parameters for model train.
+ */
+ private static final Map MODEL_PARAMS = TestParameters.getDefaultLightGBMParameters();
+
+ /**
+ * Maximum number of instances to train (to speed up tests).
+ */
+ private static final int MAX_NUMBER_OF_INSTANCES_TO_TRAIN = 5000;
+
+ /**
+ * Maximum number of instances to score (to speed up tests).
+ */
+ private static final int MAX_NUMBER_OF_INSTANCES_TO_SCORE = 300;
+
+ /**
+ * Dataset resource name to use for both fit and validation stages during tests.
+ */
+ static final String DATASET_RESOURCE_NAME = "test_data/soft.csv";
+
+ /**
+ * The number of iterations used to test model fit.
+ * The smaller, the faster the tests go, but make sure you're testing anything still.
+ */
+ private static final String NUM_ITERATIONS_FOR_FAST_TESTS = "2";
+
+ /**
+ * For unit tests, as train data is smaller, using the default chunk sizes to train
+ * would mean that performing fits with multiple chunks would not be tested.
+ * Hence, all model score tests are done with smaller chunk sizes to ensure
+ * fitting with chunked data works.
+ */
+ public static final int SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE = 221;
+
+ /**
+ * Load the LightGBM utils or nothing will work.
+ * Also changes parameters.
+ */
+ @BeforeClass
+ public static void setupFixture() {
+ LightGBMUtils.loadLibs();
+
+ // Override number of iterations in fit tests for faster tests:
+ MODEL_PARAMS.replace(NUM_ITERATIONS_PARAMETER_NAME, NUM_ITERATIONS_FOR_FAST_TESTS);
+ }
+
+ @Test
+ public void fitWithSoftLabels() throws ModelLoadingException, URISyntaxException, IOException {
+
+ final Map trainParams = modelParamsWith(MODEL_PARAMS, SOFT_LABEL_PARAMETER_NAME, "soft");
+
+ assertThat(new LightGBMModelCreator().validateForFit(createTempDirectory("lixo"), SOFT_SCHEMA, trainParams)).isEmpty();
+
+ final ArrayList> scoresPerClass = LightGBMBinaryClassificationModelTrainerTest.fitModelAndGetFirstScoresPerClass(
+ "test_data/soft.csv",
+ SOFT_SCHEMA,
+ 9999,
+ 9999,
+ 100,
+ trainParams
+ );
+
+ assertThat(average(scoresPerClass.get(1)) - average(scoresPerClass.get(0)))
+ .isGreaterThan(0.1);
+ }
+
+ @Test
+ public void fitWithSoftLabelsUninformativeHasNoDistinction() throws ModelLoadingException, URISyntaxException, IOException {
+
+ final Map trainParams = modelParamsWith(MODEL_PARAMS, SOFT_LABEL_PARAMETER_NAME, "soft_uninformative");
+
+ assertThat(new LightGBMModelCreator().validateForFit(createTempDirectory("lixo"), SOFT_SCHEMA, trainParams)).isEmpty();
+
+ final ArrayList> scoresPerClass = LightGBMBinaryClassificationModelTrainerTest.fitModelAndGetFirstScoresPerClass(
+ "test_data/soft.csv",
+ SOFT_SCHEMA,
+ 9999,
+ 9999,
+ 100,
+ trainParams
+ );
+
+ assertThat(average(scoresPerClass.get(1)) - average(scoresPerClass.get(0)))
+ .isLessThan(0.01);
+ }
+
+ public static Map modelParamsWith(final Map params, final String key, final String value) {
+ final Map trainParams = new HashMap<>();
+ params
+ .entrySet()
+ .stream()
+ .filter(entry -> !entry.getKey().equals(key))
+ .forEach(entry -> trainParams.put(entry.getKey(), entry.getValue()));
+ trainParams.put(key, value);
+ return trainParams;
+ }
+}
diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java
index 2b075b99..de8afbab 100644
--- a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java
+++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java
@@ -19,9 +19,16 @@
import com.feedzai.openml.data.Dataset;
import com.feedzai.openml.data.Instance;
+import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
+import com.feedzai.openml.data.schema.FieldSchema;
+import com.feedzai.openml.data.schema.NumericValueSchema;
import com.feedzai.openml.mocks.MockDataset;
import com.feedzai.openml.provider.exception.ModelLoadingException;
+import com.feedzai.openml.provider.lightgbm.resources.schemas.SoftSchemas;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -29,15 +36,14 @@
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
+import java.util.*;
+import java.util.stream.Collectors;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME;
+import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME;
+import static com.feedzai.openml.provider.lightgbm.resources.schemas.SoftSchemas.SOFT_SCHEMA;
+import static com.google.common.escape.Escapers.builder;
+import static java.nio.file.Files.createTempDirectory;
import static java.nio.file.Files.createTempFile;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -239,12 +245,12 @@ public void fitWithNoInstances() {
final Dataset emptyDataset = new MockDataset(TestSchemas.NUMERICALS_SCHEMA_WITH_LABEL_AT_END, noInstances);
assertThatThrownBy(() ->
- new LightGBMModelCreator().fit(
- emptyDataset,
- new Random(),
- MODEL_PARAMS
- )
+ new LightGBMModelCreator().fit(
+ emptyDataset,
+ new Random(),
+ MODEL_PARAMS
)
+ )
.isInstanceOf(RuntimeException.class);
}
@@ -301,6 +307,8 @@ public void fitWithNonASCIIFeatureNameIsPossible() {
.isBetween(0.0, 1.0);
}
+
+
/**
* Test Feature Contributions with target at end.
*
@@ -461,8 +469,8 @@ static ArrayList> fitModelAndGetFirstScoresPerClass(
* @return Array with arrays of class scores.
*/
static ArrayList> getClassScores(final Dataset dataset,
- final LightGBMBinaryClassificationModel model,
- final int maxInstances) {
+ final LightGBMBinaryClassificationModel model,
+ final int maxInstances) {
final int targetIndex = dataset.getSchema().getTargetIndex().get(); // We need a label for the tests.
diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/resources/schemas/SoftSchemas.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/resources/schemas/SoftSchemas.java
new file mode 100644
index 00000000..9d6a44ff
--- /dev/null
+++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/resources/schemas/SoftSchemas.java
@@ -0,0 +1,36 @@
+package com.feedzai.openml.provider.lightgbm.resources.schemas;
+
+import com.feedzai.openml.data.schema.CategoricalValueSchema;
+import com.feedzai.openml.data.schema.DatasetSchema;
+import com.feedzai.openml.data.schema.FieldSchema;
+import com.feedzai.openml.data.schema.NumericValueSchema;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.List;
+
+/**
+ * Schemas with soft labels.
+ *
+ * @author alberto.ferreira
+ */
+public class SoftSchemas {
+
+ public static final List SOFT_SCHEMA_FIELDS = ImmutableList.builder()
+ .add(new FieldSchema("f_float", 0, new NumericValueSchema(false)))
+ .add(new FieldSchema("f_cat", 1, new CategoricalValueSchema(false, ImmutableSet.of("a", "b"))))
+ .add(new FieldSchema("f_int", 2, new NumericValueSchema(false)))
+ .add(new FieldSchema("soft", 3, new NumericValueSchema(false)))
+ .add(new FieldSchema("soft_uninformative", 4, new NumericValueSchema(false)))
+ .add(new FieldSchema("hard", 5, new CategoricalValueSchema(true, ImmutableSet.of("0", "1"))))
+ .add(new FieldSchema("tempo_ms", 6, new NumericValueSchema(false)))
+ .build();
+
+ public static final DatasetSchema SOFT_SCHEMA = new DatasetSchema(
+ 5,
+ ImmutableList.builder()
+ .addAll(SOFT_SCHEMA_FIELDS)
+ .build()
+ );
+
+}
diff --git a/openml-lightgbm/lightgbm-provider/src/test/resources/test_data/soft.csv b/openml-lightgbm/lightgbm-provider/src/test/resources/test_data/soft.csv
new file mode 100644
index 00000000..910a7fc1
--- /dev/null
+++ b/openml-lightgbm/lightgbm-provider/src/test/resources/test_data/soft.csv
@@ -0,0 +1,113 @@
+f_float,f_cat,f_int,soft,soft_uninformative,hard,tempo_ms
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8
+0.1,a,10,0.1,0.1,0,0
+0.2,a,10,0.2,0.9,0,1
+0.3,a,10,0.3,0.1,0,2
+0.2,a,10,0.2,0.9,0,3
+0.9,b,90,0.9,0.1,1,4
+0.8,b,90,0.8,0.9,1,5
+0.6,b,90,0.6,0.1,1,6
+0.9,b,90,0.9,0.9,1,8