Skip to content

zaickathon: LightGBM: Add soft labels #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.feedzai.openml.provider.descriptor.ModelParameter;
import com.feedzai.openml.provider.descriptor.fieldtype.BooleanFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.ChoiceFieldType;
import com.feedzai.openml.provider.descriptor.fieldtype.FreeTextFieldType;
import com.google.common.collect.ImmutableSet;

import java.util.Set;
Expand Down Expand Up @@ -72,6 +73,7 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
*/
public static final String FEATURE_FRACTION_PARAMETER_DESCRIPTION = "Feature fraction by tree";

public static final String SOFT_LABEL_PARAMETER_NAME = "soft_label";
/**
* Defines the set of model parameters accepted by the LightGBM model.
*/
Expand Down Expand Up @@ -99,6 +101,17 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
MANDATORY,
intRange(0, Integer.MAX_VALUE, 100)
),
new ModelParameter(
SOFT_LABEL_PARAMETER_NAME,
"(Soft classification) Soft Label",
"Train the classifier with a soft label instead. \n"
+ "These values should be between 0.0 and 1.0. This field is ignored for validations. \n"
+ "Simply pass the name of the field.\n"
+ "If this field is selected, it is automatically dropped from the selected features.",
NOT_MANDATORY,
new FreeTextFieldType("")
)
,
new ModelParameter(
"learning_rate",
"Learning rate",
Expand Down Expand Up @@ -314,7 +327,7 @@ public class LightGBMDescriptorUtil extends AlgoDescriptorUtil {
),
new ModelParameter(
"min_data_in_bin",
"Minimium bin data",
"Minimum bin data",
"Minimum number of samples inside one bin. Limit over-fitting. E.g., not using 1 point per bin.",
NOT_MANDATORY,
intRange(1, Integer.MAX_VALUE, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
package com.feedzai.openml.provider.lightgbm;

import com.feedzai.openml.data.Dataset;
import com.feedzai.openml.data.schema.AbstractValueSchema;
import com.feedzai.openml.data.schema.CategoricalValueSchema;
import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import com.feedzai.openml.data.schema.StringValueSchema;
import com.feedzai.openml.data.schema.*;
import com.feedzai.openml.provider.descriptor.fieldtype.ParamValidationError;
import com.feedzai.openml.provider.exception.ModelLoadingException;
import com.feedzai.openml.provider.lightgbm.parameters.SoftLabelParamParserUtil;
import com.feedzai.openml.provider.model.MachineLearningModelTrainer;
import com.feedzai.openml.util.load.LoadSchemaUtils;
import com.google.common.collect.ImmutableList;
Expand All @@ -39,13 +36,8 @@
import java.util.Optional;
import java.util.Random;

import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BAGGING_FRACTION_PARAMETER_NAME;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BAGGING_FREQUENCY_PARAMETER_NAME;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.BOOSTING_TYPE_PARAMETER_NAME;
import static com.feedzai.openml.util.validate.ValidationUtils.baseLoadValidations;
import static com.feedzai.openml.util.validate.ValidationUtils.checkParams;
import static com.feedzai.openml.util.validate.ValidationUtils.validateCategoricalSchema;
import static com.feedzai.openml.util.validate.ValidationUtils.validateModelPathToTrain;
import static com.feedzai.openml.provider.lightgbm.LightGBMDescriptorUtil.*;
import static com.feedzai.openml.util.validate.ValidationUtils.*;
import static java.nio.file.Files.createTempFile;

/**
Expand Down Expand Up @@ -145,7 +137,6 @@ public LightGBMBinaryClassificationModel fit(final Dataset dataset,
}



@Override
public List<ParamValidationError> validateForFit(final Path pathToPersist,
final DatasetSchema schema,
Expand All @@ -155,7 +146,8 @@ public List<ParamValidationError> validateForFit(final Path pathToPersist,
errorsBuilder
.addAll(validateModelPathToTrain(pathToPersist))
.addAll(validateSchema(schema))
.addAll(validateFitParams(params));
.addAll(validateFitParams(params))
.addAll(validateSoftLabelParam(schema, params));

return errorsBuilder.build();
}
Expand Down Expand Up @@ -183,6 +175,40 @@ private List<ParamValidationError> validateSchema(final DatasetSchema schema) {
return errorsBuilder.build();
}

/**
* Ensure that if the soft label parameter is passed, it uses a floating point field.
*
* @param schema Dataset schema.
* @param params Model fit parameters.
* @return list of validation errors.
*/
private List<ParamValidationError> validateSoftLabelParam(final DatasetSchema schema,
final Map<String, String> params) {
// Don't test anything if the parameter is not set:
final Optional<String> softLabelFieldName = SoftLabelParamParserUtil.getSoftLabelFieldName(params);
if (!softLabelFieldName.isPresent()) {
return ImmutableList.of();
}

// Check if the field exists in the dataset:
final Optional<Integer> softLabelIndex = SoftLabelParamParserUtil.getSoftLabelColumnIndex(params, schema);
if (!softLabelIndex.isPresent()) {
return ImmutableList.of(new ParamValidationError(
String.format("Soft label field '%s' doesn't exist in the dataset. Please select it in the features.", softLabelFieldName.get()))); // TODO: bad formatting (optional%$%WER)
}

// Ensure the soft label is numerical:
final FieldSchema softLabelSchema = schema
.getFieldSchemas()
.get(softLabelIndex.get());
final AbstractValueSchema valueSchema = softLabelSchema.getValueSchema();
if (!(softLabelSchema.getValueSchema() instanceof NumericValueSchema)) {
return ImmutableList.of(new ParamValidationError("Soft label must be a numeric field!"));
}

return ImmutableList.of();
}

/**
* Validate model fit parameters.
*
Expand Down Expand Up @@ -221,7 +247,7 @@ private boolean baggingDisabled(final Map<String, String> params) {

@Override
public LightGBMBinaryClassificationModel loadModel(final Path modelPath,
final DatasetSchema schema) throws ModelLoadingException {
final DatasetSchema schema) throws ModelLoadingException {

final Path modelFilePath = getPath(modelPath);

Expand Down Expand Up @@ -340,11 +366,10 @@ private static Optional<Integer> getNumTargetClasses(final DatasetSchema schema)
/**
* Gets the feature names from schema.
*
* @implNote The space character is replaced with underscore
* to comply with LightGBM's model features representation.
*
* @param schema Schema
* @return Feature names from the schema.
* @implNote The space character is replaced with underscore
* to comply with LightGBM's model features representation.
* @since 1.0.18
*/
private static String[] getFeatureNamesFrom(final DatasetSchema schema) {
Expand All @@ -360,7 +385,7 @@ private static String[] getFeatureNamesFrom(final DatasetSchema schema) {
* {@link DatasetSchema} and an array of feature names.
* This way the first mismatch is logged, improving debug.
*
* @param schema Schema
* @param schema Schema
* @param featureNames Feature names to validate.
* @return {@code true} if the schema predictive field names
* match the provided array, {@code false} otherwise.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,21 @@

package com.feedzai.openml.provider.lightgbm;

import com.microsoft.ml.lightgbm.SWIGTYPE_p_float;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_int;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_p_void;
import com.microsoft.ml.lightgbm.SWIGTYPE_p_void;
import com.microsoft.ml.lightgbm.doubleChunkedArray;
import com.microsoft.ml.lightgbm.floatChunkedArray;
import com.microsoft.ml.lightgbm.int32ChunkedArray;
import com.microsoft.ml.lightgbm.lightgbmlib;
import com.microsoft.ml.lightgbm.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Handles train data resources and provides basic operations to manipulate train data.
* <p>
* This class is responsible for initializing, managing and releasing all
* Handles train data resources and provides basic operations to manipulate train data.
* LightGBM SWIG train resources and resource handlers in a memory-safe manner.
* <p>
* Whatever happens, it guarantees that no memory leaks are left behind.
*
* This class is responsible for initializing, managing and releasing all
* Handles train data resources and provides basic operations to manipulate train data.
* LightGBM SWIG train resources and resource handlers in a memory-safe manner.
*
* Whatever happens, it guarantees that no memory leaks are left behind.
*
* @author Alberto Ferreira (alberto.ferreira@feedzai.com)
* @since 1.0.10
* @author Alberto Ferreira (alberto.ferreira@feedzai.com)
* @since 1.0.10
*/
public class SWIGTrainData implements AutoCloseable {

Expand All @@ -56,7 +49,7 @@ public class SWIGTrainData implements AutoCloseable {
* SWIG Features chunked data array.
* This objects stores elements in float64 format
* as a list of chunks that it manages automatically.
*
* <p>
* In the current implementation, features are stored in row-major order, i.e.,
* each instance is stored contiguously.
*/
Expand All @@ -80,7 +73,7 @@ public class SWIGTrainData implements AutoCloseable {
/**
* Number of features per instance.
*/
public final int numFeatures;
public final int numActualFeatures;

/**
* Number of instances to store in each chunk.
Expand All @@ -104,42 +97,42 @@ public class SWIGTrainData implements AutoCloseable {

/**
* Constructor.
*
* <p>
* Allocates all the initial handles necessary to bootstrap (but not use) the
* in-memory LightGBM dataset + booster structures.
*
* <p>
* After that the BoosterHandle and the DatasetHandle will still need to be initialized at the proper times:
* @see SWIGTrainData#initSwigDatasetHandle()
*
* @param numFeatures The number of features.
* @param numActualFeatures The number of features (excludes soft label if used).
* @param numInstancesChunk The number of instances per chunk of data.
* @see SWIGTrainData#initSwigDatasetHandle()
*/
public SWIGTrainData(final int numFeatures, final long numInstancesChunk) {
this(numFeatures, numInstancesChunk, false);
public SWIGTrainData(final int numActualFeatures, final long numInstancesChunk) {
this(numActualFeatures, numInstancesChunk, false);
}

/**
* Constructor.
*
* <p>
* Allocates al the initial ahndles necessary to bootstrap (but not use) the
* in-memory LightGBM dataset, plus booster structures.
*
* <p>
* If fairnessConstrained=true, this will also include data on which sensitive
* group each instance belongs to.
*
* @param numFeatures The number of features.
* @param numInstancesChunk The number of instances per chunk of data.
* @param fairnessConstrained Whether this data will be used for a model with fairness (group) constraints.
* @param numActualFeatures The number of features (excludes soft label if used).
* @param numInstancesChunk The number of instances per chunk of data.
* @param fairnessConstrained Whether this data will be used for a model with fairness (group) constraints.
*/
public SWIGTrainData(final int numFeatures, final long numInstancesChunk, final boolean fairnessConstrained) {
this.numFeatures = numFeatures;
public SWIGTrainData(final int numActualFeatures, final long numInstancesChunk, final boolean fairnessConstrained) {
this.numActualFeatures = numActualFeatures;
this.numInstancesChunk = numInstancesChunk;
this.swigOutDatasetHandlePtr = lightgbmlib.voidpp_handle();
this.fairnessConstrained = fairnessConstrained;

logger.debug("Intermediate SWIG train buffers will be allocated in chunks of {} instances.", numInstancesChunk);
// 1-D Array in row-major-order that stores only the features (excludes label) in double format by chunks:
this.swigFeaturesChunkedArray = new doubleChunkedArray(numFeatures * numInstancesChunk);
this.swigFeaturesChunkedArray = new doubleChunkedArray(numActualFeatures * numInstancesChunk);

// 1-D Array with the labels (float32):
this.swigLabelsChunkedArray = new floatChunkedArray(numInstancesChunk);
Expand All @@ -152,6 +145,7 @@ public SWIGTrainData(final int numFeatures, final long numInstancesChunk, final

/**
* Adds a value to the features' ChunkedArray.
*
* @param value value to insert.
*/
public void addFeatureValue(double value) {
Expand All @@ -167,6 +161,7 @@ public void addLabelValue(float value) {

/**
* Adds a value to the constraint group ChunkedArray.
*
* @param value the value to add.
*/
public void addConstraintGroupValue(int value) {
Expand Down Expand Up @@ -239,7 +234,7 @@ void destroySwigConstraintGroupDataArray() {
/**
* Release the memory of the chunked features array.
* This can be called after instantiating the dataset.
*
* <p>
* Although this simply calls `release()`.
* After this that object becomes unusable.
* To cleanup and reuse call `clear()` instead.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package com.feedzai.openml.provider.lightgbm.parameters;

import com.feedzai.openml.data.schema.DatasetSchema;
import com.feedzai.openml.data.schema.FieldSchema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Optional;

/**
* Utilities to compute fields/column indices.
*
* @author alberto.ferreira
*/
public class SchemaFieldsUtil {

/**
* Logger for this class.
*/
private static final Logger logger = LoggerFactory.getLogger(SchemaFieldsUtil.class);

/**
* Placeholder to use when an integer argument is not provided.
* E.g., when running standard unconstrained LightGBM the constrained_group_column parameter will take this value.
*/
public static final int NO_SPECIFIC = -1;

public static Optional<Integer> getColumnIndex(final String fieldName,
final DatasetSchema schema) {

final List<FieldSchema> featureFields = schema.getPredictiveFields();
Optional<FieldSchema> field = featureFields
.stream()
.filter(field_ -> field_.getFieldName().equalsIgnoreCase(fieldName))
.findFirst();

// Check if column exists
if (!field.isPresent()) {
logger.error(String.format(
"Column '%s' was not found in the dataset.",
fieldName));
return Optional.empty();
}

return Optional.of(field.get().getFieldIndex());
}

/**
* Gets the index of the soft label column without the label column (LightGBM-specific format)
*
* @param fieldName Name of the field in the dataset.
* @param schema Schema of the dataset.
* @return the index of the constraint group column without the label if the constraint_group_column parameter
* was provided, else returns an empty Optional.
*/
public static Optional<Integer> getFieldIndexWithoutLabel(final String fieldName,
final DatasetSchema schema) {

final Optional<Integer> fieldIndex = getColumnIndex(fieldName, schema);
if (!fieldIndex.isPresent()) {
return Optional.empty();
}

// Compute column index in LightGBM-specific format (disregarding target column)
final int targetIndex = schema.getTargetIndex()
.orElseThrow(RuntimeException::new); // Our model is supervised. It needs the target.

final int offsetIfFieldIsAfterLabel = fieldIndex.get() > targetIndex ? -1 : 0;
return Optional.of(fieldIndex.get() + offsetIfFieldIsAfterLabel);
}
}
Loading