From 69809b34ab9f44956b2a0642e7bb2d3b1fc9a67c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Tue, 1 Jul 2025 20:12:11 +0200 Subject: [PATCH] add initial implementation --- botorch/models/transforms/input.py | 120 ++++++++++++++++++++- test/models/transforms/test_input.py | 155 +++++++++++++++++++++++++++ 2 files changed, 271 insertions(+), 4 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 48f31d9eaf..ad6c9dfb2e 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -1625,6 +1625,122 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor: return p.transpose(-3, -2) # p is batch_shape x n_p x n x d +class NumericToCategoricalEncoding(InputTransform): + """Transform categorical parameters from an integer representation + to a vector based representation like one-hot encoding or a descriptor + encoding. + """ + + def __init__( + self, + dim: int, + categorical_features: dict[int, int], + encoders: dict[int, Callable[[Tensor], Tensor]], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + ): + r"""Initialize. + + Args: + dim: The dimension of the numerically encoded input. + categorical_features: A dictionary mapping the index of each + categorical feature to its cardinality. This assumes that categoricals + are integer encoded. + encoders: A dictionary mapping the index of each categorical feature to + a callable that encodes the categorical feature into a vector + representation. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: False. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: False. + """ + super().__init__() + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + + self.encoders = encoders + self.categorical_features = categorical_features + + if len(self.categorical_features) > dim: + raise ValueError( + "The number of categorical features exceeds the provided dimension." + ) + + # check that the encoders match the categorical features + if set(self.encoders.keys()) != set(self.categorical_features.keys()): + raise ValueError( + "The keys of `encoders` must match the keys of `categorical_features`." + ) + + self.ordinal_idx = list( + self.categorical_features.keys() + ) # indices of categorical features before encoding + + self.numerical_idx = list( + set(range(dim)) - set(self.ordinal_idx) + ) # indices of numerical features before encoding + + self.new_numerical_idx = [] # indices of numerical features after encoding + self.encoded_idx = [] # indices of categorical features after encoding + + offset = 0 + for idx in range(dim): + if idx in self.numerical_idx: + self.new_numerical_idx.append(idx + offset) + else: + card = self.categorical_features[idx] + self.encoded_idx.append( + np.arange( + idx + offset, idx + offset + card + ).tolist() # indices of categorical features after encoding + ) + offset += card - 1 # adjust offset for next categorical feature + + def transform(self, X: Tensor) -> Tensor: + r"""Transform the categorical inputs into a vector representation. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d'`-dim tensor of where the integer encoded + categoricals are transformed to a vector representation. + """ + if len(self.categorical_features) > 0: + s = list(X.shape) + s[-1] = len(self.numerical_idx) + len(np.concatenate(self.encoded_idx)) + X_encoded = torch.zeros(size=s).to(X) + X_encoded[..., self.new_numerical_idx] = X[..., self.numerical_idx] + for i, idx in enumerate(self.categorical_features.keys()): + X_encoded[..., self.encoded_idx[i]] = self.encoders[idx]( + X[..., idx].long(), + ).to(X_encoded) + return X_encoded + return X + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + type(self) is type(other) + and (self.transform_on_train == other.transform_on_train) + and (self.transform_on_eval == other.transform_on_eval) + and (self.transform_on_fantasize == other.transform_on_fantasize) + and self.categorical_features == other.categorical_features + and self.encoders == other.encoders + ) + + class OneHotToNumeric(InputTransform): r"""Transform categorical parameters from a one-hot to a numeric representation.""" @@ -1649,10 +1765,6 @@ def __init__( transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the transform when called from within a `fantasize` call. Default: False. - - Returns: - A `batch_shape x n x d'`-dim tensor of where the one-hot encoded - categoricals are transformed to integer representation. """ super().__init__() self.transform_on_train = transform_on_train diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 67586e83d4..29361e6f85 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -7,6 +7,7 @@ import itertools from abc import ABC from copy import deepcopy +from functools import partial from itertools import product from random import randint @@ -25,6 +26,7 @@ InteractionFeatures, Log10, Normalize, + NumericToCategoricalEncoding, OneHotToNumeric, ReversibleInputTransform, Round, @@ -1207,6 +1209,159 @@ def test_warp_mro(self) -> None: ), ) + def test_numeric_to_categorical_encoding(self) -> None: + # test exceptions + with self.assertRaises( + ValueError, + msg="The number of categorical features exceeds the provided dimension.", + ): + categorical_features = {0: 2, 1: 3} + NumericToCategoricalEncoding( + dim=1, + categorical_features=categorical_features, + encoders={ + 0: partial(one_hot, num_classes=2), + 1: partial(one_hot, num_classes=3), + }, + ) + with self.assertRaises( + ValueError, + msg="The keys of `encoders` must match the keys of `categorical_features`.", + ): + categorical_features = {0: 2, 1: 3} + NumericToCategoricalEncoding( + dim=4, + categorical_features=categorical_features, + encoders={ + 0: partial(one_hot, num_classes=2), + 2: partial(one_hot, num_classes=3), + }, + ) + + torch.manual_seed(randint(0, 1000)) + for dtype in (torch.float, torch.double): + # one categorical at start + dim = 3 + categorical_features = {0: 3} + tf = NumericToCategoricalEncoding( + dim=dim, + categorical_features=categorical_features, + encoders={0: partial(one_hot, num_classes=3)}, + ) + tf.eval() + cat_numeric = torch.randint(0, 3, (3,), device=self.device) + cat_one_hot = one_hot(cat_numeric, num_classes=3) + cont = torch.rand(3, 2, dtype=dtype, device=self.device) + + X_numeric = torch.cat( + [cat_numeric.view(-1, 1).to(dtype=dtype, device=self.device), cont], + dim=-1, + ) + + expected = torch.cat( + [cat_one_hot, cont], + dim=-1, + ) + X_one_hot = tf(X_numeric) + self.assertTrue(torch.equal(X_one_hot, expected)) + # two categoricals at end + dim = 4 + categorical_features = {2: 3, 3: 2} + tf = NumericToCategoricalEncoding( + dim=dim, + categorical_features=categorical_features, + encoders={ + 2: partial(one_hot, num_classes=3), + 3: partial(one_hot, num_classes=2), + }, + ) + tf.eval() + cat_numeric1 = torch.randint(0, 3, (3,), device=self.device) + cat_one_hot1 = one_hot(cat_numeric1, num_classes=3) + cat_numeric2 = torch.randint(0, 2, (3,), device=self.device) + cat_one_hot2 = one_hot(cat_numeric2, num_classes=2) + cont = torch.rand(3, 2, dtype=dtype, device=self.device) + + X_numeric = torch.cat( + [ + cont, + cat_numeric1.view(-1, 1).to(dtype=dtype, device=self.device), + cat_numeric2.view(-1, 1).to(dtype=dtype, device=self.device), + ], + dim=-1, + ) + + expected = torch.cat( + [cont, cat_one_hot1, cat_one_hot2], + dim=-1, + ) + X_one_hot = tf(X_numeric) + self.assertTrue(torch.equal(X_one_hot, expected)) + # two categoricals, one at start, one at end + dim = 4 + categorical_features = {0: 3, 3: 2} + tf = NumericToCategoricalEncoding( + dim=dim, + categorical_features=categorical_features, + encoders={ + 0: partial(one_hot, num_classes=3), + 3: partial(one_hot, num_classes=2), + }, + ) + tf.eval() + cat_numeric1 = torch.randint(0, 3, (3,), device=self.device) + cat_one_hot1 = one_hot(cat_numeric1, num_classes=3) + cat_numeric2 = torch.randint(0, 2, (3,), device=self.device) + cat_one_hot2 = one_hot(cat_numeric2, num_classes=2) + cont = torch.rand(3, 2, dtype=dtype, device=self.device) + + X_numeric = torch.cat( + [ + cat_numeric1.view(-1, 1).to(dtype=dtype, device=self.device), + cont, + cat_numeric2.view(-1, 1).to(dtype=dtype, device=self.device), + ], + dim=-1, + ) + + expected = torch.cat( + [cat_one_hot1, cont, cat_one_hot2], + dim=-1, + ) + X_one_hot = tf(X_numeric) + self.assertTrue(torch.equal(X_one_hot, expected)) + # only categoricals + dim = 2 + categorical_features = {0: 3, 1: 2} + tf = NumericToCategoricalEncoding( + dim=dim, + categorical_features=categorical_features, + encoders={ + 0: partial(one_hot, num_classes=3), + 1: partial(one_hot, num_classes=2), + }, + ) + tf.eval() + cat_numeric1 = torch.randint(0, 3, (3,), device=self.device) + cat_one_hot1 = one_hot(cat_numeric1, num_classes=3) + cat_numeric2 = torch.randint(0, 2, (3,), device=self.device) + cat_one_hot2 = one_hot(cat_numeric2, num_classes=2) + + X_numeric = torch.cat( + [ + cat_numeric1.view(-1, 1).to(dtype=dtype, device=self.device), + cat_numeric2.view(-1, 1).to(dtype=dtype, device=self.device), + ], + dim=-1, + ) + + expected = torch.cat( + [cat_one_hot1, cat_one_hot2], + dim=-1, + ) + X_one_hot = tf(X_numeric) + self.assertTrue(torch.equal(X_one_hot, expected)) + def test_one_hot_to_numeric(self) -> None: dim = 8 # test exceptions