diff --git a/openml-lightgbm/lightgbm-builder/make-lightgbm b/openml-lightgbm/lightgbm-builder/make-lightgbm index 6928602c..319698ae 160000 --- a/openml-lightgbm/lightgbm-builder/make-lightgbm +++ b/openml-lightgbm/lightgbm-builder/make-lightgbm @@ -1 +1 @@ -Subproject commit 6928602c37bae318fe7dfd3226893ced596a112f +Subproject commit 319698ae745652cb6d33a22651aa889be48d311e diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java index 69eb7b70..222b7f59 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java @@ -22,27 +22,21 @@ import com.feedzai.openml.data.schema.CategoricalValueSchema; import com.feedzai.openml.data.schema.DatasetSchema; import com.feedzai.openml.data.schema.FieldSchema; +import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil; +import com.feedzai.openml.provider.lightgbm.schema.TrainSchemaUtil; import com.google.common.collect.ImmutableSet; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_float; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_int; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_void; -import com.microsoft.ml.lightgbm.lightgbmlib; -import com.microsoft.ml.lightgbm.lightgbmlibConstants; -import java.util.HashMap; -import java.util.Optional; -import java.util.Set; +import com.microsoft.ml.lightgbm.*; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.file.Path; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import static com.feedzai.openml.provider.lightgbm.FairGBMDescriptorUtil.CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME; import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME; +import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME; +import static com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil.getSoftLabelFieldName; import static java.lang.Integer.parseInt; import static java.util.stream.Collectors.toList; @@ -64,18 +58,20 @@ final class LightGBMBinaryClassificationModelTrainer { * Train data is copied from the input stream into an array of chunks. * Each chunk will have this many instances. Must be set before `fit()`. *

+ * * @implNote Performance overhead notes: * - Too small? Performance overhead - excessive in-memory data fragmentation. * - Too large? RAM overhead - in the worst case the last chunk has only 1 instance. - * Each instance might have upwards of 400 features. Each costs 8 bytes. - * E.g.: 100k instances of 400 features => 320MB / chunk + * Each instance might have upwards of 400 features. Each costs 8 bytes. + * E.g.: 100k instances of 400 features => 320MB / chunk */ static final long DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE = 200000; /** * This class is not meant to be instantiated. */ - private LightGBMBinaryClassificationModelTrainer() {} + private LightGBMBinaryClassificationModelTrainer() { + } /** * See LightGBMBinaryClassificationModelTrainer#fit overload below. @@ -89,10 +85,12 @@ static void fit(final Dataset dataset, final Map params, final Path outputModelFilePath) { - fit(dataset, - params, - outputModelFilePath, - DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE); + fit( + dataset, + params, + outputModelFilePath, + DEFAULT_TRAIN_DATA_CHUNK_INSTANCES_SIZE + ); } /** @@ -129,7 +127,9 @@ static void fit(final Dataset dataset, final long instancesPerChunk) { final DatasetSchema schema = dataset.getSchema(); - final int numFeatures = schema.getPredictiveFields().size(); + final Optional softLabelColumnIndex = SoftLabelParamParserUtil.getSoftLabelColumnIndex(params, schema); + final int numActualFeatures = TrainSchemaUtil.getNumActualFeatures(schema, params); + // Parse train parameters to LightGBM format final String trainParams = getLightGBMTrainParamsString(params, schema); @@ -138,7 +138,7 @@ static void fit(final Dataset dataset, logger.debug("LightGBM model trainParams: {}", trainParams); final SWIGTrainData swigTrainData = new SWIGTrainData( - numFeatures, + numActualFeatures, instancesPerChunk, FairGBMParamParserUtil.isFairnessConstrained(params)); final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster(); @@ -146,7 +146,21 @@ static void fit(final Dataset dataset, /// Create LightGBM dataset final int constraintGroupColIndex = FairGBMParamParserUtil.getConstraintGroupColumnIndex(params, schema).orElse( FairGBMParamParserUtil.NO_SPECIFIC); - createTrainDataset(dataset, numFeatures, trainParams, constraintGroupColIndex, swigTrainData); + + + if (softLabelColumnIndex.isPresent()) { + logger.debug("Replacing hard label by soft label for training."); + } + + createTrainDataset( + dataset, + numActualFeatures, + trainParams, + constraintGroupColIndex, + getSoftLabelFieldName(params), + softLabelColumnIndex, + swigTrainData + ); /// Create Booster from dataset createBoosterStructure(swigTrainBooster, swigTrainData, trainParams); @@ -193,16 +207,18 @@ private static String[] getFieldNames(final List fields) { * - Initializing the label data in the dataset + releasing the label array; * - Setting the feature names in the dataset. * - * @param dataset Dataset - * @param numFeatures Number of features - * @param trainParams LightGBM-formatted params string ("key1=value1 key2=value2 ...") - * @param constraintGroupColIndex The index of the constraint group column. - * @param swigTrainData SWIGTrainData object + * @param dataset Dataset + * @param numActualFeatures Number of features excluding soft label if used + * @param trainParams LightGBM-formatted params string ("key1=value1 key2=value2 ...") + * @param constraintGroupColIndex The index of the constraint group column. + * @param swigTrainData SWIGTrainData object */ private static void createTrainDataset(final Dataset dataset, - final int numFeatures, + final int numActualFeatures, final String trainParams, final int constraintGroupColIndex, + final Optional softLabelFieldName, + final Optional softLabelColumnIndex, final SWIGTrainData swigTrainData) { logger.info("Creating LightGBM dataset"); @@ -211,12 +227,13 @@ private static void createTrainDataset(final Dataset dataset, copyTrainDataToSWIGArrays( dataset, swigTrainData, - constraintGroupColIndex + constraintGroupColIndex, + softLabelColumnIndex ); initializeLightGBMTrainDatasetFeatures( swigTrainData, - numFeatures, + numActualFeatures, trainParams ); @@ -230,7 +247,12 @@ private static void createTrainDataset(final Dataset dataset, ); } - setLightGBMDatasetFeatureNames(swigTrainData.swigDatasetHandle, dataset.getSchema()); + setLightGBMDatasetFeatureNames( + swigTrainData.swigDatasetHandle, + dataset.getSchema(), + softLabelFieldName, + numActualFeatures + ); logger.info("Created LightGBM dataset."); } @@ -239,17 +261,17 @@ private static void createTrainDataset(final Dataset dataset, * Initializes the LightGBM dataset structure and copies the feature data. * * @param swigTrainData SWIGTrainData object. - * @param numFeatures Number of features used to predict. + * @param numActualFeatures Number of features used to predict. * @param trainParams LightGBM string with the train params ("key1=value1 key2=value2 ..."). */ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData swigTrainData, - final int numFeatures, + final int numActualFeatures, final String trainParams) { logger.debug("Initializing LightGBM in-memory structure and setting feature data."); /// First generate the array that has the chunk sizes for `LGBM_DatasetCreateFromMats`. - final SWIGTYPE_p_int swigChunkSizesArray = genSWIGFeatureChunkSizesArray(swigTrainData, numFeatures); + final SWIGTYPE_p_int swigChunkSizesArray = genSWIGFeatureChunkSizesArray(swigTrainData, numActualFeatures); /// Now create the LightGBM Dataset itself from the chunks: logger.debug("Creating LGBM_Dataset from chunked data..."); @@ -258,7 +280,7 @@ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData s swigTrainData.swigFeaturesChunkedArray.data_as_void(), // input data (void**) lightgbmlibConstants.C_API_DTYPE_FLOAT64, swigChunkSizesArray, - numFeatures, + numActualFeatures, 1, // rowMajor. trainParams, // parameters. null, // No alignment with other datasets. @@ -278,11 +300,11 @@ private static void initializeLightGBMTrainDatasetFeatures(final SWIGTrainData s * Generates a SWIG array of ints with the size of each train chunk (partition). * * @param swigTrainData SWIGTrainData object. - * @param numFeatures Number of features used to predict. + * @param numActualFeatures Number of features used to predict (excludes soft label field if used). * @return SWIG (int*) array of the train chunks' sizes. */ private static SWIGTYPE_p_int genSWIGFeatureChunkSizesArray(final SWIGTrainData swigTrainData, - final int numFeatures) { + final int numActualFeatures) { logger.debug("Retrieving chunked data block sizes..."); @@ -299,11 +321,11 @@ private static SWIGTYPE_p_int genSWIGFeatureChunkSizesArray(final SWIGTrainData lightgbmlib.intArray_setitem( swigChunkSizesArray, numChunks - 1, - (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numFeatures + (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numActualFeatures ); logger.debug("FTL: chunk-size report: chunk #{} is partial-chunk of size {}", numChunks - 1, - (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numFeatures); + (int) swigTrainData.swigFeaturesChunkedArray.get_last_chunk_add_count() / numActualFeatures); return swigChunkSizesArray; } @@ -370,14 +392,20 @@ private static void setLightGBMDatasetConstraintGroupData(final SWIGTrainData sw * @param swigDatasetHandle SWIG dataset handle * @param schema Dataset schema */ - private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDatasetHandle, final DatasetSchema schema) { + private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDatasetHandle, + final DatasetSchema schema, + final Optional softLabelFieldName, + final int numActualFeatures) { - final int numFeatures = schema.getPredictiveFields().size(); - final String[] featureNames = getFieldNames(schema.getPredictiveFields()); - logger.debug("featureNames {}", Arrays.toString(featureNames)); + final String[] actualFeatureNames = TrainSchemaUtil.getActualFeatureNames(schema, softLabelFieldName); + logger.debug("featureNames {} (numFeatures = {})", Arrays.toString(actualFeatureNames), numActualFeatures); - final int returnCodeLGBM = lightgbmlib.LGBM_DatasetSetFeatureNames(swigDatasetHandle, featureNames, numFeatures); + final int returnCodeLGBM = lightgbmlib.LGBM_DatasetSetFeatureNames( + swigDatasetHandle, + actualFeatureNames, + numActualFeatures + ); if (returnCodeLGBM == -1) { logger.error("Could not set feature names."); throw new LightGBMException(); @@ -388,8 +416,8 @@ private static void setLightGBMDatasetFeatureNames(final SWIGTYPE_p_void swigDat * Creates the booster structure with all the parameters and training dataset resources (but doesn't train). * * @param swigTrainBooster An object with the training resources already initialized. - * @param swigTrainData SWIGTrainData object. - * @param trainParams the LightGBM string-formatted string with properties in the form "key1=value1 key2=value2 ...". + * @param swigTrainData SWIGTrainData object. + * @param trainParams the LightGBM string-formatted string with properties in the form "key1=value1 key2=value2 ...". * @see LightGBMBinaryClassificationModelTrainer#trainBooster(SWIGTYPE_p_void, int) . */ static void createBoosterStructure(final SWIGTrainBooster swigTrainBooster, @@ -474,20 +502,11 @@ static void saveModelFileToDisk(final SWIGTYPE_p_void swigBoosterHandle, final P logger.info("Saved model to disk"); } - /** - * Takes the data in dataset and copies it into the features and label C++ arrays through SWIG. - * - * @param dataset Input train dataset (with target label) - * @param swigTrainData SWIGTrainData object. - */ - private static void copyTrainDataToSWIGArrays(final Dataset dataset, - final SWIGTrainData swigTrainData) { - copyTrainDataToSWIGArrays(dataset, swigTrainData, FairGBMParamParserUtil.NO_SPECIFIC); - } private static void copyTrainDataToSWIGArrays(final Dataset dataset, final SWIGTrainData swigTrainData, - final int constraintGroupIndex) { + final int constraintGroupIndex, + final Optional softLabelIndex) { final DatasetSchema datasetSchema = dataset.getSchema(); final int numFields = datasetSchema.getFieldSchemas().size(); @@ -495,27 +514,36 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset, This is ensured in validateForFit, by using the ValidationUtils' validateCategoricalSchema: */ - final int targetIndex = datasetSchema.getTargetIndex().get(); + final int hardLabelIndex = datasetSchema.getTargetIndex().get(); + + /* + * If a soft label is used, it replaces the hard label for training. + * In such case, the soft label cannot be used as a feature. + */ + final int labelFieldIndex = softLabelIndex.orElse(hardLabelIndex); final Iterator iterator = dataset.getInstances(); while (iterator.hasNext()) { final Instance instance = iterator.next(); - swigTrainData.addLabelValue((float) instance.getValue(targetIndex)); + swigTrainData.addLabelValue((float) instance.getValue(labelFieldIndex)); if (constraintGroupIndex != FairGBMParamParserUtil.NO_SPECIFIC) { swigTrainData.addConstraintGroupValue((int) instance.getValue(constraintGroupIndex)); } for (int colIdx = 0; colIdx < numFields; ++colIdx) { - if (colIdx != targetIndex) { - swigTrainData.addFeatureValue(instance.getValue(colIdx)); + if (colIdx != hardLabelIndex) { + // If we're using the soft label, remove all information from its feature (keeps schema and avoids label leakage): + swigTrainData.addFeatureValue( + colIdx != labelFieldIndex ? instance.getValue(colIdx) : 0 + ); } } } assert swigTrainData.swigLabelsChunkedArray.get_add_count() == - (swigTrainData.swigFeaturesChunkedArray.get_add_count() / swigTrainData.numFeatures); + (swigTrainData.swigFeaturesChunkedArray.get_add_count() / swigTrainData.numActualFeatures); if (swigTrainData.fairnessConstrained) { assert swigTrainData.swigConstraintGroupChunkedArray.get_add_count() == @@ -523,8 +551,8 @@ private static void copyTrainDataToSWIGArrays(final Dataset dataset, } logger.debug("Copied train data of size {} into {} chunks.", - swigTrainData.swigLabelsChunkedArray.get_add_count(), - swigTrainData.swigLabelsChunkedArray.get_chunks_count() + swigTrainData.swigLabelsChunkedArray.get_add_count(), + swigTrainData.swigLabelsChunkedArray.get_chunks_count() ); if (swigTrainData.swigLabelsChunkedArray.get_add_count() == 0) { @@ -547,11 +575,13 @@ private static String getLightGBMTrainParamsString(final Map map preprocessedMapParams .put("categorical_feature", StringUtils.join(categoricalFeatureIndices, ",")); - // Set default objective parameter - final Optional objective = getLightGBMObjective(mapParams); - if (! objective.isPresent()) { - // Default to objective=binary - preprocessedMapParams.put("objective", "binary"); + // Set default objective parameter - by default "binary": + final Optional customObjective = getLightGBMObjective(mapParams); + if (!customObjective.isPresent()) { + final String objective = SoftLabelParamParserUtil.getSoftLabelFieldName(mapParams).isPresent() ? + "cross_entropy" : "binary"; + + preprocessedMapParams.put("objective", objective); } // Add constraint_group_column parameter @@ -560,7 +590,12 @@ private static String getLightGBMTrainParamsString(final Map map integer -> preprocessedMapParams.put(CONSTRAINT_GROUP_COLUMN_PARAMETER_NAME, Integer.toString(integer))); // Add all **other** parameters - mapParams.forEach(preprocessedMapParams::putIfAbsent); + final Set lgbmUnknownParams = ImmutableSet.of(SOFT_LABEL_PARAMETER_NAME); + mapParams + .entrySet() + .stream() + .filter(entry -> !lgbmUnknownParams.contains(entry.getKey())) + .forEach(entry -> preprocessedMapParams.putIfAbsent(entry.getKey(), entry.getValue())); // Build string containing params in LightGBM format final StringBuilder paramsBuilder = new StringBuilder(); diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java index cf03abe2..1225c1bf 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMDescriptorUtil.java @@ -20,6 +20,7 @@ import com.feedzai.openml.provider.descriptor.ModelParameter; import com.feedzai.openml.provider.descriptor.fieldtype.BooleanFieldType; import com.feedzai.openml.provider.descriptor.fieldtype.ChoiceFieldType; +import com.feedzai.openml.provider.descriptor.fieldtype.FreeTextFieldType; import com.google.common.collect.ImmutableSet; import java.util.Set; @@ -72,6 +73,7 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil { */ public static final String FEATURE_FRACTION_PARAMETER_DESCRIPTION = "Feature fraction by tree"; + public static final String SOFT_LABEL_PARAMETER_NAME = "soft_label"; /** * Defines the set of model parameters accepted by the LightGBM model. */ @@ -99,6 +101,17 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil { MANDATORY, intRange(0, Integer.MAX_VALUE, 100) ), + new ModelParameter( + SOFT_LABEL_PARAMETER_NAME, + "(Soft classification) Soft Label", + "Train the classifier with a soft label instead. \n" + + "These values should be between 0.0 and 1.0. This field is ignored for validations. \n" + + "Simply pass the name of the field.\n" + + "If this field is selected, it is automatically dropped from the selected features.", + NOT_MANDATORY, + new FreeTextFieldType("") + ) + , new ModelParameter( "learning_rate", "Learning rate", @@ -314,7 +327,7 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil { ), new ModelParameter( "min_data_in_bin", - "Minimium bin data", + "Minimum bin data", "Minimum number of samples inside one bin. Limit over-fitting. E.g., not using 1 point per bin.", NOT_MANDATORY, intRange(1, Integer.MAX_VALUE, 3) diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java index f492f8cf..6d4c9374 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMModelCreator.java @@ -18,13 +18,10 @@ package com.feedzai.openml.provider.lightgbm; import com.feedzai.openml.data.Dataset; -import com.feedzai.openml.data.schema.AbstractValueSchema; -import com.feedzai.openml.data.schema.CategoricalValueSchema; -import com.feedzai.openml.data.schema.DatasetSchema; -import com.feedzai.openml.data.schema.FieldSchema; -import com.feedzai.openml.data.schema.StringValueSchema; +import com.feedzai.openml.data.schema.*; import com.feedzai.openml.provider.descriptor.fieldtype.ParamValidationError; import com.feedzai.openml.provider.exception.ModelLoadingException; +import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil; import com.feedzai.openml.provider.model.MachineLearningModelTrainer; import com.feedzai.openml.util.load.LoadSchemaUtils; import com.google.common.collect.ImmutableList; @@ -39,13 +36,8 @@ import java.util.Optional; import java.util.Random; -import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BAGGING_FRACTION_PARAMETER_NAME; -import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BAGGING_FREQUENCY_PARAMETER_NAME; -import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BOOSTING_TYPE_PARAMETER_NAME; -import static com.feedzai.openml.util.validate.ValidationUtils.baseLoadValidations; -import static com.feedzai.openml.util.validate.ValidationUtils.checkParams; -import static com.feedzai.openml.util.validate.ValidationUtils.validateCategoricalSchema; -import static com.feedzai.openml.util.validate.ValidationUtils.validateModelPathToTrain; +import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.*; +import static com.feedzai.openml.util.validate.ValidationUtils.*; import static java.nio.file.Files.createTempFile; /** @@ -145,7 +137,6 @@ public LightGBMBinaryClassificationModel fit(final Dataset dataset, } - @Override public List validateForFit(final Path pathToPersist, final DatasetSchema schema, @@ -155,7 +146,8 @@ public List validateForFit(final Path pathToPersist, errorsBuilder .addAll(validateModelPathToTrain(pathToPersist)) .addAll(validateSchema(schema)) - .addAll(validateFitParams(params)); + .addAll(validateFitParams(params)) + .addAll(validateSoftLabelParam(schema, params)); return errorsBuilder.build(); } @@ -183,6 +175,40 @@ private List validateSchema(final DatasetSchema schema) { return errorsBuilder.build(); } + /** + * Ensure that if the soft label parameter is passed, it uses a floating point field. + * + * @param schema Dataset schema. + * @param params Model fit parameters. + * @return list of validation errors. + */ + private List validateSoftLabelParam(final DatasetSchema schema, + final Map params) { + // Don't test anything if the parameter is not set: + final Optional softLabelFieldName = SoftLabelParamParserUtil.getSoftLabelFieldName(params); + if (!softLabelFieldName.isPresent()) { + return ImmutableList.of(); + } + + // Check if the field exists in the dataset: + final Optional softLabelIndex = SoftLabelParamParserUtil.getSoftLabelColumnIndex(params, schema); + if (!softLabelIndex.isPresent()) { + return ImmutableList.of(new ParamValidationError( + String.format("Soft label field '%s' doesn't exist in the dataset. Please select it in the features.", softLabelFieldName.get()))); // TODO: bad formatting (optional%$%WER) + } + + // Ensure the soft label is numerical: + final FieldSchema softLabelSchema = schema + .getFieldSchemas() + .get(softLabelIndex.get()); + final AbstractValueSchema valueSchema = softLabelSchema.getValueSchema(); + if (!(softLabelSchema.getValueSchema() instanceof NumericValueSchema)) { + return ImmutableList.of(new ParamValidationError("Soft label must be a numeric field!")); + } + + return ImmutableList.of(); + } + /** * Validate model fit parameters. * @@ -221,7 +247,7 @@ private boolean baggingDisabled(final Map params) { @Override public LightGBMBinaryClassificationModel loadModel(final Path modelPath, - final DatasetSchema schema) throws ModelLoadingException { + final DatasetSchema schema) throws ModelLoadingException { final Path modelFilePath = getPath(modelPath); @@ -340,11 +366,10 @@ private static Optional getNumTargetClasses(final DatasetSchema schema) /** * Gets the feature names from schema. * - * @implNote The space character is replaced with underscore - * to comply with LightGBM's model features representation. - * * @param schema Schema * @return Feature names from the schema. + * @implNote The space character is replaced with underscore + * to comply with LightGBM's model features representation. * @since 1.0.18 */ private static String[] getFeatureNamesFrom(final DatasetSchema schema) { @@ -360,7 +385,7 @@ private static String[] getFeatureNamesFrom(final DatasetSchema schema) { * {@link DatasetSchema} and an array of feature names. * This way the first mismatch is logged, improving debug. * - * @param schema Schema + * @param schema Schema * @param featureNames Feature names to validate. * @return {@code true} if the schema predictive field names * match the provided array, {@code false} otherwise. diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java index 1036e753..f4004a83 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/SWIGTrainData.java @@ -17,28 +17,21 @@ package com.feedzai.openml.provider.lightgbm; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_float; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_int; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void; -import com.microsoft.ml.lightgbm.SWIGTYPE_p_void; -import com.microsoft.ml.lightgbm.doubleChunkedArray; -import com.microsoft.ml.lightgbm.floatChunkedArray; -import com.microsoft.ml.lightgbm.int32ChunkedArray; -import com.microsoft.ml.lightgbm.lightgbmlib; +import com.microsoft.ml.lightgbm.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Handles train data resources and provides basic operations to manipulate train data. + *

+ * This class is responsible for initializing, managing and releasing all + * Handles train data resources and provides basic operations to manipulate train data. + * LightGBM SWIG train resources and resource handlers in a memory-safe manner. + *

+ * Whatever happens, it guarantees that no memory leaks are left behind. * - * This class is responsible for initializing, managing and releasing all - * Handles train data resources and provides basic operations to manipulate train data. - * LightGBM SWIG train resources and resource handlers in a memory-safe manner. - * - * Whatever happens, it guarantees that no memory leaks are left behind. - * - * @author Alberto Ferreira (alberto.ferreira@feedzai.com) - * @since 1.0.10 + * @author Alberto Ferreira (alberto.ferreira@feedzai.com) + * @since 1.0.10 */ public class SWIGTrainData implements AutoCloseable { @@ -56,7 +49,7 @@ public class SWIGTrainData implements AutoCloseable { * SWIG Features chunked data array. * This objects stores elements in float64 format * as a list of chunks that it manages automatically. - * + *

* In the current implementation, features are stored in row-major order, i.e., * each instance is stored contiguously. */ @@ -80,7 +73,7 @@ public class SWIGTrainData implements AutoCloseable { /** * Number of features per instance. */ - public final int numFeatures; + public final int numActualFeatures; /** * Number of instances to store in each chunk. @@ -104,42 +97,42 @@ public class SWIGTrainData implements AutoCloseable { /** * Constructor. - * + *

* Allocates all the initial handles necessary to bootstrap (but not use) the * in-memory LightGBM dataset + booster structures. - * + *

* After that the BoosterHandle and the DatasetHandle will still need to be initialized at the proper times: - * @see SWIGTrainData#initSwigDatasetHandle() * - * @param numFeatures The number of features. + * @param numActualFeatures The number of features (excludes soft label if used). * @param numInstancesChunk The number of instances per chunk of data. + * @see SWIGTrainData#initSwigDatasetHandle() */ - public SWIGTrainData(final int numFeatures, final long numInstancesChunk) { - this(numFeatures, numInstancesChunk, false); + public SWIGTrainData(final int numActualFeatures, final long numInstancesChunk) { + this(numActualFeatures, numInstancesChunk, false); } /** * Constructor. - * + *

* Allocates al the initial ahndles necessary to bootstrap (but not use) the * in-memory LightGBM dataset, plus booster structures. - * + *

* If fairnessConstrained=true, this will also include data on which sensitive * group each instance belongs to. * - * @param numFeatures The number of features. - * @param numInstancesChunk The number of instances per chunk of data. - * @param fairnessConstrained Whether this data will be used for a model with fairness (group) constraints. + * @param numActualFeatures The number of features (excludes soft label if used). + * @param numInstancesChunk The number of instances per chunk of data. + * @param fairnessConstrained Whether this data will be used for a model with fairness (group) constraints. */ - public SWIGTrainData(final int numFeatures, final long numInstancesChunk, final boolean fairnessConstrained) { - this.numFeatures = numFeatures; + public SWIGTrainData(final int numActualFeatures, final long numInstancesChunk, final boolean fairnessConstrained) { + this.numActualFeatures = numActualFeatures; this.numInstancesChunk = numInstancesChunk; this.swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle(); this.fairnessConstrained = fairnessConstrained; logger.debug("Intermediate SWIG train buffers will be allocated in chunks of {} instances.", numInstancesChunk); // 1-D Array in row-major-order that stores only the features (excludes label) in double format by chunks: - this.swigFeaturesChunkedArray = new doubleChunkedArray(numFeatures * numInstancesChunk); + this.swigFeaturesChunkedArray = new doubleChunkedArray(numActualFeatures * numInstancesChunk); // 1-D Array with the labels (float32): this.swigLabelsChunkedArray = new floatChunkedArray(numInstancesChunk); @@ -152,6 +145,7 @@ public SWIGTrainData(final int numFeatures, final long numInstancesChunk, final /** * Adds a value to the features' ChunkedArray. + * * @param value value to insert. */ public void addFeatureValue(double value) { @@ -167,6 +161,7 @@ public void addLabelValue(float value) { /** * Adds a value to the constraint group ChunkedArray. + * * @param value the value to add. */ public void addConstraintGroupValue(int value) { @@ -239,7 +234,7 @@ void destroySwigConstraintGroupDataArray() { /** * Release the memory of the chunked features array. * This can be called after instantiating the dataset. - * + *

* Although this simply calls `release()`. * After this that object becomes unusable. * To cleanup and reuse call `clear()` instead. diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SchemaFieldsUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SchemaFieldsUtil.java new file mode 100644 index 00000000..4881d9e4 --- /dev/null +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SchemaFieldsUtil.java @@ -0,0 +1,72 @@ +package com.feedzai.openml.provider.lightgbm.parameters; + +import com.feedzai.openml.data.schema.DatasetSchema; +import com.feedzai.openml.data.schema.FieldSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Optional; + +/** + * Utilities to compute fields/column indices. + * + * @author alberto.ferreira + */ +public class SchemaFieldsUtil { + + /** + * Logger for this class. + */ + private static final Logger logger = LoggerFactory.getLogger(SchemaFieldsUtil.class); + + /** + * Placeholder to use when an integer argument is not provided. + * E.g., when running standard unconstrained LightGBM the constrained_group_column parameter will take this value. + */ + public static final int NO_SPECIFIC = -1; + + public static Optional getColumnIndex(final String fieldName, + final DatasetSchema schema) { + + final List featureFields = schema.getPredictiveFields(); + Optional field = featureFields + .stream() + .filter(field_ -> field_.getFieldName().equalsIgnoreCase(fieldName)) + .findFirst(); + + // Check if column exists + if (!field.isPresent()) { + logger.error(String.format( + "Column '%s' was not found in the dataset.", + fieldName)); + return Optional.empty(); + } + + return Optional.of(field.get().getFieldIndex()); + } + + /** + * Gets the index of the soft label column without the label column (LightGBM-specific format) + * + * @param fieldName Name of the field in the dataset. + * @param schema Schema of the dataset. + * @return the index of the constraint group column without the label if the constraint_group_column parameter + * was provided, else returns an empty Optional. + */ + public static Optional getFieldIndexWithoutLabel(final String fieldName, + final DatasetSchema schema) { + + final Optional fieldIndex = getColumnIndex(fieldName, schema); + if (!fieldIndex.isPresent()) { + return Optional.empty(); + } + + // Compute column index in LightGBM-specific format (disregarding target column) + final int targetIndex = schema.getTargetIndex() + .orElseThrow(RuntimeException::new); // Our model is supervised. It needs the target. + + final int offsetIfFieldIsAfterLabel = fieldIndex.get() > targetIndex ? -1 : 0; + return Optional.of(fieldIndex.get() + offsetIfFieldIsAfterLabel); + } +} diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SoftLabelParamParserUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SoftLabelParamParserUtil.java new file mode 100644 index 00000000..cfe8132c --- /dev/null +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/parameters/SoftLabelParamParserUtil.java @@ -0,0 +1,63 @@ +package com.feedzai.openml.provider.lightgbm.parameters; + +import com.feedzai.openml.data.schema.DatasetSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.Optional; + +import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME; + +/** + * Utility to parse parameters specific to the FairGBM model. + * + * @author Alberto Ferreira (alberto.ferreira@feedzai.com) + */ +public class SoftLabelParamParserUtil { + + /** + * Logger for this class. + */ + private static final Logger logger = LoggerFactory.getLogger(SoftLabelParamParserUtil.class); + + /** + * This class is not meant to be instantiated. + */ + private SoftLabelParamParserUtil() { + } + + public static Optional getSoftLabelFieldName(final Map params) { + final String softLabelFieldName = params.get(SOFT_LABEL_PARAMETER_NAME); + + return softLabelFieldName == null || softLabelFieldName.equals("") ? + Optional.empty() : Optional.of(softLabelFieldName.trim()); + } + + public static boolean useSoftLabel(final Map mapParams) { + return getSoftLabelFieldName(mapParams).isPresent(); + } + + + /** + * Gets the (canonical) index of the constraint group column. + * NOTE: the soft label column must be part of the features in the Dataset, but it may be ignored for training + * + * @param params LightGBM train parameters. + * @param schema Schema of the dataset. + * @return the index of the soft label column if one was provided, else returns an empty Optional. + */ + public static Optional getSoftLabelColumnIndex(final Map params, + final DatasetSchema schema) { + final Optional softLabelFieldName = getSoftLabelFieldName(params); + return softLabelFieldName.isPresent() ? + SchemaFieldsUtil.getColumnIndex(softLabelFieldName.get(), schema) : Optional.empty(); + } + + public static Optional getSoftLabelFieldIndexWithoutLabel(final Map params, + final DatasetSchema schema) { + final Optional softLabelFieldName = getSoftLabelFieldName(params); + return softLabelFieldName.isPresent() ? + SchemaFieldsUtil.getFieldIndexWithoutLabel(softLabelFieldName.get(), schema) : Optional.empty(); + } +} diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/schema/TrainSchemaUtil.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/schema/TrainSchemaUtil.java new file mode 100644 index 00000000..191465d9 --- /dev/null +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/schema/TrainSchemaUtil.java @@ -0,0 +1,91 @@ +package com.feedzai.openml.provider.lightgbm.schema; + +import com.feedzai.openml.data.schema.DatasetSchema; +import com.feedzai.openml.data.schema.FieldSchema; +import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static java.util.stream.Collectors.toList; + +/** + * Utils to account for features + * + * @author alberto.ferreira + */ +public class TrainSchemaUtil { + + private TrainSchemaUtil() { + + } + + /** + * Gets the actual number of features to use for train. + * If the soft label is used, it must be passed in the predictiveFields (features screen), but then excluded + * from the features, or there would be label leakage, as it is the label during fit. + * + * @param schema train dataset Schema + * @param params train parameters + * @return the actual number of features (excluding soft label in case it is used to train - in which case it's removed from the features) + */ + static public int getNumActualFeatures(final DatasetSchema schema, final Map params) { + final int rawNumPredictiveFields = schema.getPredictiveFields().size(); + return rawNumPredictiveFields; + /* + + final Optional softLabelFieldName = SoftLabelParamParserUtil.getSoftLabelFieldName(params); + + if (!softLabelFieldName.isPresent()) { + return rawNumPredictiveFields; + } + + final boolean isSoftLabelSelectedAsFeature = schema + .getPredictiveFields() + .stream() + .anyMatch( + field -> field.getFieldName() + .equals(softLabelFieldName.get())); + + return rawNumPredictiveFields - (isSoftLabelSelectedAsFeature ? 1 : 0); + */ + } + + static public String[] getActualFeatureNames(final DatasetSchema schema, + final Optional softLabelFieldName) { + + return getFieldNamesArray(schema.getPredictiveFields()); + /* + if (!softLabelFieldName.isPresent()) { + return getFieldNamesArray(schema.getPredictiveFields()); + } + + return getFieldNamesArray( + schema.getPredictiveFields() + .stream() + .filter(field -> !field.getFieldName().equals(softLabelFieldName.orElse(""))) + .collect(toList()) + ); + + */ + } + + + /** + * @param fields List of FieldSchema fields. + * @return Names of the fields in the input list. + */ + private static String[] getFieldNamesArray(final List fields) { + return fields.stream().map(FieldSchema::getFieldName).toArray(String[]::new); + } + + /** + * @param fields List of FieldSchema fields. + * @return Names of the fields in the input list. + */ + private static List getFieldNames(final List fields) { + return Arrays.asList(getFieldNamesArray(fields)); + } +} diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerSoftTest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerSoftTest.java new file mode 100644 index 00000000..19fddbf7 --- /dev/null +++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerSoftTest.java @@ -0,0 +1,119 @@ +package com.feedzai.openml.provider.lightgbm; + +import com.feedzai.openml.provider.exception.ModelLoadingException; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.feedzai.openml.provider.lightgbm.LightGBMBinaryClassificationModelTrainerTest.average; +import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME; +import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME; +import static com.feedzai.openml.provider.lightgbm.resources.schemas.SoftSchemas.SOFT_SCHEMA; +import static java.nio.file.Files.createTempDirectory; +import static org.assertj.core.api.Assertions.assertThat; + +public class LightGBMBinaryClassificationModelTrainerSoftTest { + + /** + * Parameters for model train. + */ + private static final Map MODEL_PARAMS = TestParameters.getDefaultLightGBMParameters(); + + /** + * Maximum number of instances to train (to speed up tests). + */ + private static final int MAX_NUMBER_OF_INSTANCES_TO_TRAIN = 5000; + + /** + * Maximum number of instances to score (to speed up tests). + */ + private static final int MAX_NUMBER_OF_INSTANCES_TO_SCORE = 300; + + /** + * Dataset resource name to use for both fit and validation stages during tests. + */ + static final String DATASET_RESOURCE_NAME = "test_data/soft.csv"; + + /** + * The number of iterations used to test model fit. + * The smaller, the faster the tests go, but make sure you're testing anything still. + */ + private static final String NUM_ITERATIONS_FOR_FAST_TESTS = "2"; + + /** + * For unit tests, as train data is smaller, using the default chunk sizes to train + * would mean that performing fits with multiple chunks would not be tested. + * Hence, all model score tests are done with smaller chunk sizes to ensure + * fitting with chunked data works. + */ + public static final int SMALL_TRAIN_DATA_CHUNK_INSTANCES_SIZE = 221; + + /** + * Load the LightGBM utils or nothing will work. + * Also changes parameters. + */ + @BeforeClass + public static void setupFixture() { + LightGBMUtils.loadLibs(); + + // Override number of iterations in fit tests for faster tests: + MODEL_PARAMS.replace(NUM_ITERATIONS_PARAMETER_NAME, NUM_ITERATIONS_FOR_FAST_TESTS); + } + + @Test + public void fitWithSoftLabels() throws ModelLoadingException, URISyntaxException, IOException { + + final Map trainParams = modelParamsWith(MODEL_PARAMS, SOFT_LABEL_PARAMETER_NAME, "soft"); + + assertThat(new LightGBMModelCreator().validateForFit(createTempDirectory("lixo"), SOFT_SCHEMA, trainParams)).isEmpty(); + + final ArrayList> scoresPerClass = LightGBMBinaryClassificationModelTrainerTest.fitModelAndGetFirstScoresPerClass( + "test_data/soft.csv", + SOFT_SCHEMA, + 9999, + 9999, + 100, + trainParams + ); + + assertThat(average(scoresPerClass.get(1)) - average(scoresPerClass.get(0))) + .isGreaterThan(0.1); + } + + @Test + public void fitWithSoftLabelsUninformativeHasNoDistinction() throws ModelLoadingException, URISyntaxException, IOException { + + final Map trainParams = modelParamsWith(MODEL_PARAMS, SOFT_LABEL_PARAMETER_NAME, "soft_uninformative"); + + assertThat(new LightGBMModelCreator().validateForFit(createTempDirectory("lixo"), SOFT_SCHEMA, trainParams)).isEmpty(); + + final ArrayList> scoresPerClass = LightGBMBinaryClassificationModelTrainerTest.fitModelAndGetFirstScoresPerClass( + "test_data/soft.csv", + SOFT_SCHEMA, + 9999, + 9999, + 100, + trainParams + ); + + assertThat(average(scoresPerClass.get(1)) - average(scoresPerClass.get(0))) + .isLessThan(0.01); + } + + public static Map modelParamsWith(final Map params, final String key, final String value) { + final Map trainParams = new HashMap<>(); + params + .entrySet() + .stream() + .filter(entry -> !entry.getKey().equals(key)) + .forEach(entry -> trainParams.put(entry.getKey(), entry.getValue())); + trainParams.put(key, value); + return trainParams; + } +} diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java index 2b075b99..de8afbab 100644 --- a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java +++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainerTest.java @@ -19,9 +19,16 @@ import com.feedzai.openml.data.Dataset; import com.feedzai.openml.data.Instance; +import com.feedzai.openml.data.schema.CategoricalValueSchema; import com.feedzai.openml.data.schema.DatasetSchema; +import com.feedzai.openml.data.schema.FieldSchema; +import com.feedzai.openml.data.schema.NumericValueSchema; import com.feedzai.openml.mocks.MockDataset; import com.feedzai.openml.provider.exception.ModelLoadingException; +import com.feedzai.openml.provider.lightgbm.resources.schemas.SoftSchemas; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.junit.BeforeClass; import org.junit.Test; @@ -29,15 +36,14 @@ import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; +import java.util.stream.Collectors; import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.NUM_ITERATIONS_PARAMETER_NAME; +import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.SOFT_LABEL_PARAMETER_NAME; +import static com.feedzai.openml.provider.lightgbm.resources.schemas.SoftSchemas.SOFT_SCHEMA; +import static com.google.common.escape.Escapers.builder; +import static java.nio.file.Files.createTempDirectory; import static java.nio.file.Files.createTempFile; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -239,12 +245,12 @@ public void fitWithNoInstances() { final Dataset emptyDataset = new MockDataset(TestSchemas.NUMERICALS_SCHEMA_WITH_LABEL_AT_END, noInstances); assertThatThrownBy(() -> - new LightGBMModelCreator().fit( - emptyDataset, - new Random(), - MODEL_PARAMS - ) + new LightGBMModelCreator().fit( + emptyDataset, + new Random(), + MODEL_PARAMS ) + ) .isInstanceOf(RuntimeException.class); } @@ -301,6 +307,8 @@ public void fitWithNonASCIIFeatureNameIsPossible() { .isBetween(0.0, 1.0); } + + /** * Test Feature Contributions with target at end. * @@ -461,8 +469,8 @@ static ArrayList> fitModelAndGetFirstScoresPerClass( * @return Array with arrays of class scores. */ static ArrayList> getClassScores(final Dataset dataset, - final LightGBMBinaryClassificationModel model, - final int maxInstances) { + final LightGBMBinaryClassificationModel model, + final int maxInstances) { final int targetIndex = dataset.getSchema().getTargetIndex().get(); // We need a label for the tests. diff --git a/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/resources/schemas/SoftSchemas.java b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/resources/schemas/SoftSchemas.java new file mode 100644 index 00000000..9d6a44ff --- /dev/null +++ b/openml-lightgbm/lightgbm-provider/src/test/java/com/feedzai/openml/provider/lightgbm/resources/schemas/SoftSchemas.java @@ -0,0 +1,36 @@ +package com.feedzai.openml.provider.lightgbm.resources.schemas; + +import com.feedzai.openml.data.schema.CategoricalValueSchema; +import com.feedzai.openml.data.schema.DatasetSchema; +import com.feedzai.openml.data.schema.FieldSchema; +import com.feedzai.openml.data.schema.NumericValueSchema; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; + +/** + * Schemas with soft labels. + * + * @author alberto.ferreira + */ +public class SoftSchemas { + + public static final List SOFT_SCHEMA_FIELDS = ImmutableList.builder() + .add(new FieldSchema("f_float", 0, new NumericValueSchema(false))) + .add(new FieldSchema("f_cat", 1, new CategoricalValueSchema(false, ImmutableSet.of("a", "b")))) + .add(new FieldSchema("f_int", 2, new NumericValueSchema(false))) + .add(new FieldSchema("soft", 3, new NumericValueSchema(false))) + .add(new FieldSchema("soft_uninformative", 4, new NumericValueSchema(false))) + .add(new FieldSchema("hard", 5, new CategoricalValueSchema(true, ImmutableSet.of("0", "1")))) + .add(new FieldSchema("tempo_ms", 6, new NumericValueSchema(false))) + .build(); + + public static final DatasetSchema SOFT_SCHEMA = new DatasetSchema( + 5, + ImmutableList.builder() + .addAll(SOFT_SCHEMA_FIELDS) + .build() + ); + +} diff --git a/openml-lightgbm/lightgbm-provider/src/test/resources/test_data/soft.csv b/openml-lightgbm/lightgbm-provider/src/test/resources/test_data/soft.csv new file mode 100644 index 00000000..910a7fc1 --- /dev/null +++ b/openml-lightgbm/lightgbm-provider/src/test/resources/test_data/soft.csv @@ -0,0 +1,113 @@ +f_float,f_cat,f_int,soft,soft_uninformative,hard,tempo_ms +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8 +0.1,a,10,0.1,0.1,0,0 +0.2,a,10,0.2,0.9,0,1 +0.3,a,10,0.3,0.1,0,2 +0.2,a,10,0.2,0.9,0,3 +0.9,b,90,0.9,0.1,1,4 +0.8,b,90,0.8,0.9,1,5 +0.6,b,90,0.6,0.1,1,6 +0.9,b,90,0.9,0.9,1,8