Skip to content

Commit 28494cb

Browse files
committed
Extended Simformer and tests to handle non 3-dimensional data
1 parent f5208ed commit 28494cb

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

sbi/neural_nets/net_builders/vector_field_nets.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,12 @@ def forward(
12691269
Vector field evaluation at the provided points
12701270
"""
12711271

1272+
if self.in_features == 1 and inputs.dim() == 2:
1273+
inputs = inputs.unsqueeze(-1) # [B, T] -> [B, T, 1]
1274+
to_squeeze = True
1275+
else:
1276+
to_squeeze = False
1277+
12721278
B, T, _ = inputs.shape
12731279
device = inputs.device
12741280

@@ -1317,6 +1323,10 @@ def forward(
13171323

13181324
# Output projection
13191325
out = self.out_linear(h) # [B, T, F]
1326+
1327+
if to_squeeze:
1328+
out = out.squeeze(-1)
1329+
13201330
return out
13211331

13221332

@@ -1585,7 +1595,7 @@ def build_simformer_network(
15851595
del batch_y # Unused
15861596
del embedding_net # Unused
15871597

1588-
in_features = batch_x.shape[-1]
1598+
in_features = 1 if batch_x.dim() == 2 else batch_x.shape[-1]
15891599
num_nodes = batch_x.shape[1]
15901600

15911601
# Create the vector field network (Simformer)

tests/vector_field_nets_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_vector_field_builders_shape_and_build(
160160
@pytest.mark.parametrize("time_emb_type", ["sinusoidal", "random_fourier"])
161161
@pytest.mark.parametrize("time_embedding_dim", [8, 16])
162162
@pytest.mark.parametrize("batch_dim", [1, 3])
163-
@pytest.mark.parametrize("input_dim", [(1, 1), (2, 1), (3, 5)])
163+
@pytest.mark.parametrize("input_dim", [(1,), (3,), (1, 1), (2, 1), (3, 5)])
164164
def test_simformer_builder_shape_and_build(
165165
builder,
166166
builder_kwargs,

0 commit comments

Comments
 (0)