-
Notifications
You must be signed in to change notification settings - Fork 429
NumericToCategoricalEncoding Input Transform. #2907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1625,6 +1625,122 @@ | |
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d | ||
|
||
|
||
class NumericToCategoricalEncoding(InputTransform): | ||
"""Transform categorical parameters from an integer representation | ||
to a vector based representation like one-hot encoding or a descriptor | ||
encoding. | ||
Comment on lines
+1629
to
+1631
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to have a description of how the columns in the output of the transform are organized. Ideally there would be a concrete example in the docstring. |
||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
categorical_features: dict[int, int], | ||
encoders: dict[int, Callable[[Tensor], Tensor]], | ||
transform_on_train: bool = True, | ||
transform_on_eval: bool = True, | ||
transform_on_fantasize: bool = True, | ||
): | ||
r"""Initialize. | ||
|
||
Args: | ||
dim: The dimension of the numerically encoded input. | ||
categorical_features: A dictionary mapping the index of each | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this arg name could be more descriptive, e.g. |
||
categorical feature to its cardinality. This assumes that categoricals | ||
are integer encoded. | ||
encoders: A dictionary mapping the index of each categorical feature to | ||
a callable that encodes the categorical feature into a vector | ||
representation. | ||
transform_on_train: A boolean indicating whether to apply the | ||
transforms in train() mode. Default: False. | ||
transform_on_eval: A boolean indicating whether to apply the | ||
transform in eval() mode. Default: True. | ||
transform_on_fantasize: A boolean indicating whether to apply the | ||
transform when called from within a `fantasize` call. Default: False. | ||
""" | ||
super().__init__() | ||
self.transform_on_train = transform_on_train | ||
self.transform_on_eval = transform_on_eval | ||
self.transform_on_fantasize = transform_on_fantasize | ||
|
||
self.encoders = encoders | ||
self.categorical_features = categorical_features | ||
|
||
if len(self.categorical_features) > dim: | ||
raise ValueError( | ||
"The number of categorical features exceeds the provided dimension." | ||
) | ||
|
||
# check that the encoders match the categorical features | ||
if set(self.encoders.keys()) != set(self.categorical_features.keys()): | ||
raise ValueError( | ||
"The keys of `encoders` must match the keys of `categorical_features`." | ||
) | ||
|
||
self.ordinal_idx = list( | ||
self.categorical_features.keys() | ||
) # indices of categorical features before encoding | ||
|
||
self.numerical_idx = list( | ||
set(range(dim)) - set(self.ordinal_idx) | ||
) # indices of numerical features before encoding | ||
|
||
self.new_numerical_idx = [] # indices of numerical features after encoding | ||
self.encoded_idx = [] # indices of categorical features after encoding | ||
|
||
offset = 0 | ||
for idx in range(dim): | ||
if idx in self.numerical_idx: | ||
self.new_numerical_idx.append(idx + offset) | ||
else: | ||
card = self.categorical_features[idx] | ||
self.encoded_idx.append( | ||
np.arange( | ||
idx + offset, idx + offset + card | ||
).tolist() # indices of categorical features after encoding | ||
) | ||
offset += card - 1 # adjust offset for next categorical feature | ||
|
||
def transform(self, X: Tensor) -> Tensor: | ||
r"""Transform the categorical inputs into a vector representation. | ||
|
||
Args: | ||
X: A `batch_shape x n x d`-dim tensor of inputs. | ||
|
||
Returns: | ||
A `batch_shape x n x d'`-dim tensor of where the integer encoded | ||
categoricals are transformed to a vector representation. | ||
""" | ||
if len(self.categorical_features) > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the whole transform would be a nullop if no |
||
s = list(X.shape) | ||
s[-1] = len(self.numerical_idx) + len(np.concatenate(self.encoded_idx)) | ||
X_encoded = torch.zeros(size=s).to(X) | ||
X_encoded[..., self.new_numerical_idx] = X[..., self.numerical_idx] | ||
for i, idx in enumerate(self.categorical_features.keys()): | ||
X_encoded[..., self.encoded_idx[i]] = self.encoders[idx]( | ||
X[..., idx].long(), | ||
).to(X_encoded) | ||
return X_encoded | ||
return X | ||
|
||
def equals(self, other: InputTransform) -> bool: | ||
r"""Check if another input transform is equivalent. | ||
|
||
Args: | ||
other: Another input transform. | ||
|
||
Returns: | ||
A boolean indicating if the other transform is equivalent. | ||
""" | ||
return ( | ||
type(self) is type(other) | ||
and (self.transform_on_train == other.transform_on_train) | ||
and (self.transform_on_eval == other.transform_on_eval) | ||
and (self.transform_on_fantasize == other.transform_on_fantasize) | ||
and self.categorical_features == other.categorical_features | ||
and self.encoders == other.encoders | ||
) | ||
|
||
|
||
class OneHotToNumeric(InputTransform): | ||
r"""Transform categorical parameters from a one-hot to a numeric representation.""" | ||
|
||
|
@@ -1649,10 +1765,6 @@ | |
transform in eval() mode. Default: True. | ||
transform_on_fantasize: A boolean indicating whether to apply the | ||
transform when called from within a `fantasize` call. Default: False. | ||
|
||
Returns: | ||
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded | ||
categoricals are transformed to integer representation. | ||
""" | ||
super().__init__() | ||
self.transform_on_train = transform_on_train | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
import itertools | ||
from abc import ABC | ||
from copy import deepcopy | ||
from functools import partial | ||
from itertools import product | ||
from random import randint | ||
|
||
|
@@ -25,6 +26,7 @@ | |
InteractionFeatures, | ||
Log10, | ||
Normalize, | ||
NumericToCategoricalEncoding, | ||
OneHotToNumeric, | ||
ReversibleInputTransform, | ||
Round, | ||
|
@@ -1207,6 +1209,159 @@ def test_warp_mro(self) -> None: | |
), | ||
) | ||
|
||
def test_numeric_to_categorical_encoding(self) -> None: | ||
# test exceptions | ||
with self.assertRaises( | ||
ValueError, | ||
msg="The number of categorical features exceeds the provided dimension.", | ||
): | ||
categorical_features = {0: 2, 1: 3} | ||
NumericToCategoricalEncoding( | ||
dim=1, | ||
categorical_features=categorical_features, | ||
encoders={ | ||
0: partial(one_hot, num_classes=2), | ||
1: partial(one_hot, num_classes=3), | ||
}, | ||
) | ||
with self.assertRaises( | ||
ValueError, | ||
msg="The keys of `encoders` must match the keys of `categorical_features`.", | ||
): | ||
categorical_features = {0: 2, 1: 3} | ||
NumericToCategoricalEncoding( | ||
dim=4, | ||
categorical_features=categorical_features, | ||
encoders={ | ||
0: partial(one_hot, num_classes=2), | ||
2: partial(one_hot, num_classes=3), | ||
}, | ||
) | ||
|
||
torch.manual_seed(randint(0, 1000)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to defy the point of locking the seed? |
||
for dtype in (torch.float, torch.double): | ||
# one categorical at start | ||
dim = 3 | ||
categorical_features = {0: 3} | ||
tf = NumericToCategoricalEncoding( | ||
dim=dim, | ||
categorical_features=categorical_features, | ||
encoders={0: partial(one_hot, num_classes=3)}, | ||
) | ||
tf.eval() | ||
cat_numeric = torch.randint(0, 3, (3,), device=self.device) | ||
cat_one_hot = one_hot(cat_numeric, num_classes=3) | ||
cont = torch.rand(3, 2, dtype=dtype, device=self.device) | ||
|
||
X_numeric = torch.cat( | ||
[cat_numeric.view(-1, 1).to(dtype=dtype, device=self.device), cont], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for the |
||
dim=-1, | ||
) | ||
|
||
expected = torch.cat( | ||
[cat_one_hot, cont], | ||
dim=-1, | ||
) | ||
X_one_hot = tf(X_numeric) | ||
self.assertTrue(torch.equal(X_one_hot, expected)) | ||
# two categoricals at end | ||
dim = 4 | ||
categorical_features = {2: 3, 3: 2} | ||
tf = NumericToCategoricalEncoding( | ||
dim=dim, | ||
categorical_features=categorical_features, | ||
encoders={ | ||
2: partial(one_hot, num_classes=3), | ||
3: partial(one_hot, num_classes=2), | ||
}, | ||
) | ||
tf.eval() | ||
cat_numeric1 = torch.randint(0, 3, (3,), device=self.device) | ||
cat_one_hot1 = one_hot(cat_numeric1, num_classes=3) | ||
cat_numeric2 = torch.randint(0, 2, (3,), device=self.device) | ||
cat_one_hot2 = one_hot(cat_numeric2, num_classes=2) | ||
cont = torch.rand(3, 2, dtype=dtype, device=self.device) | ||
|
||
X_numeric = torch.cat( | ||
[ | ||
cont, | ||
cat_numeric1.view(-1, 1).to(dtype=dtype, device=self.device), | ||
cat_numeric2.view(-1, 1).to(dtype=dtype, device=self.device), | ||
Comment on lines
+1288
to
+1289
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above about |
||
], | ||
dim=-1, | ||
) | ||
|
||
expected = torch.cat( | ||
[cont, cat_one_hot1, cat_one_hot2], | ||
dim=-1, | ||
) | ||
X_one_hot = tf(X_numeric) | ||
self.assertTrue(torch.equal(X_one_hot, expected)) | ||
# two categoricals, one at start, one at end | ||
dim = 4 | ||
categorical_features = {0: 3, 3: 2} | ||
tf = NumericToCategoricalEncoding( | ||
dim=dim, | ||
categorical_features=categorical_features, | ||
encoders={ | ||
0: partial(one_hot, num_classes=3), | ||
3: partial(one_hot, num_classes=2), | ||
}, | ||
) | ||
tf.eval() | ||
cat_numeric1 = torch.randint(0, 3, (3,), device=self.device) | ||
cat_one_hot1 = one_hot(cat_numeric1, num_classes=3) | ||
cat_numeric2 = torch.randint(0, 2, (3,), device=self.device) | ||
cat_one_hot2 = one_hot(cat_numeric2, num_classes=2) | ||
cont = torch.rand(3, 2, dtype=dtype, device=self.device) | ||
|
||
X_numeric = torch.cat( | ||
[ | ||
cat_numeric1.view(-1, 1).to(dtype=dtype, device=self.device), | ||
cont, | ||
cat_numeric2.view(-1, 1).to(dtype=dtype, device=self.device), | ||
], | ||
dim=-1, | ||
) | ||
|
||
expected = torch.cat( | ||
[cat_one_hot1, cont, cat_one_hot2], | ||
dim=-1, | ||
) | ||
X_one_hot = tf(X_numeric) | ||
self.assertTrue(torch.equal(X_one_hot, expected)) | ||
# only categoricals | ||
dim = 2 | ||
categorical_features = {0: 3, 1: 2} | ||
tf = NumericToCategoricalEncoding( | ||
dim=dim, | ||
categorical_features=categorical_features, | ||
encoders={ | ||
0: partial(one_hot, num_classes=3), | ||
1: partial(one_hot, num_classes=2), | ||
}, | ||
) | ||
tf.eval() | ||
cat_numeric1 = torch.randint(0, 3, (3,), device=self.device) | ||
cat_one_hot1 = one_hot(cat_numeric1, num_classes=3) | ||
cat_numeric2 = torch.randint(0, 2, (3,), device=self.device) | ||
cat_one_hot2 = one_hot(cat_numeric2, num_classes=2) | ||
|
||
X_numeric = torch.cat( | ||
[ | ||
cat_numeric1.view(-1, 1).to(dtype=dtype, device=self.device), | ||
cat_numeric2.view(-1, 1).to(dtype=dtype, device=self.device), | ||
], | ||
dim=-1, | ||
) | ||
|
||
expected = torch.cat( | ||
[cat_one_hot1, cat_one_hot2], | ||
dim=-1, | ||
) | ||
X_one_hot = tf(X_numeric) | ||
self.assertTrue(torch.equal(X_one_hot, expected)) | ||
|
||
def test_one_hot_to_numeric(self) -> None: | ||
dim = 8 | ||
# test exceptions | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, should this not be
CategoricalToNumeric
?