Skip to content

Commit 3c4e566

Browse files
committed
Adapted Masked Conditional VF Estimator Wrapper to handle 2-dim inputs natively
1 parent 22f43d9 commit 3c4e566

File tree

1 file changed

+59
-24
lines changed

1 file changed

+59
-24
lines changed

sbi/neural_nets/estimators/base.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,13 @@ def __init__(
749749
to save memory resources
750750
"""
751751

752-
T, F = original_estimator.input_shape
752+
T = original_estimator.input_shape[0]
753+
if len(original_estimator.input_shape) == 2:
754+
F = original_estimator.input_shape[1]
755+
self.is_features_dim_missing = False
756+
else:
757+
F = 1
758+
self.is_features_dim_missing = True
753759

754760
# Input checks for fixed_condition_mask
755761
if fixed_condition_mask.dim() != 1:
@@ -786,10 +792,10 @@ def __init__(
786792
"for all entries."
787793
)
788794

795+
# Count number of latent and observed nodes
789796
num_latent = int(torch.sum(fixed_condition_mask == 0).item())
790797
num_observed = int(torch.sum(fixed_condition_mask == 1).item())
791798

792-
# Count number of latent and observed nodes
793799
self._new_input_shape = torch.Size((num_latent * F,))
794800
self._new_condition_shape = torch.Size((num_observed * F,))
795801

@@ -828,12 +834,20 @@ def __init__(
828834
self._observed_idx = (self._fixed_condition_mask == 1).nonzero(as_tuple=True)[0]
829835

830836
# Get the mean/std for the latent nodes from the original estimator
831-
latent_mean_base_unflattened = original_estimator.mean_base[
832-
:, self._latent_idx, :
833-
]
834-
latent_std_base_unflattened = original_estimator.std_base[
835-
:, self._latent_idx, :
836-
]
837+
if len(original_estimator.input_shape) == 1:
838+
latent_mean_base_unflattened = original_estimator.mean_base[
839+
:, self._latent_idx
840+
]
841+
latent_std_base_unflattened = original_estimator.std_base[
842+
:, self._latent_idx
843+
]
844+
else:
845+
latent_mean_base_unflattened = original_estimator.mean_base[
846+
:, self._latent_idx, :
847+
]
848+
latent_std_base_unflattened = original_estimator.std_base[
849+
:, self._latent_idx, :
850+
]
837851

838852
latent_mean_base_flattened = latent_mean_base_unflattened.flatten(start_dim=1)
839853
latent_std_base_flattened = latent_std_base_unflattened.flatten(start_dim=1)
@@ -1000,26 +1014,47 @@ def _assemble_full_inputs(self, input, condition):
10001014
B = int(torch.prod(torch.tensor(input.shape[:-1])).item())
10011015
C = int(torch.prod(torch.tensor(condition.shape[:-1])).item())
10021016

1003-
input_part_unflattened = input.reshape(B, self._num_latent, self._original_F)
1004-
condition_part_unflattened = condition.reshape(
1005-
-1, self._num_observed, self._original_F
1006-
).repeat(B // C, 1, 1)
1007-
1008-
full_inputs = torch.zeros(
1009-
B,
1010-
self._original_T,
1011-
self._original_F,
1012-
dtype=input.dtype,
1013-
device=input.device,
1014-
)
1015-
# Place unflattened parts into the correct positions
1016-
full_inputs[:, self._latent_idx, :] = input_part_unflattened
1017-
full_inputs[:, self._observed_idx, :] = condition_part_unflattened
1017+
if self.is_features_dim_missing:
1018+
input_part_unflattened = input.reshape(B, self._num_latent)
1019+
condition_part_unflattened = condition.reshape(
1020+
-1, self._num_observed
1021+
).repeat(B // C, 1, 1)
1022+
1023+
full_inputs = torch.zeros(
1024+
B,
1025+
self._original_T,
1026+
dtype=input.dtype,
1027+
device=input.device,
1028+
)
1029+
# Place unflattened parts into the correct positions
1030+
full_inputs[:, self._latent_idx] = input_part_unflattened
1031+
full_inputs[:, self._observed_idx] = condition_part_unflattened
1032+
else:
1033+
input_part_unflattened = input.reshape(
1034+
B, self._num_latent, self._original_F
1035+
)
1036+
condition_part_unflattened = condition.reshape(
1037+
-1, self._num_observed, self._original_F
1038+
).repeat(B // C, 1, 1)
1039+
1040+
full_inputs = torch.zeros(
1041+
B,
1042+
self._original_T,
1043+
self._original_F,
1044+
dtype=input.dtype,
1045+
device=input.device,
1046+
)
1047+
# Place unflattened parts into the correct positions
1048+
full_inputs[:, self._latent_idx, :] = input_part_unflattened
1049+
full_inputs[:, self._observed_idx, :] = condition_part_unflattened
10181050

10191051
return full_inputs
10201052

10211053
def _disassemble_full_outputs(self, full_outputs, original_latent_tensor):
1022-
latent_part = full_outputs[:, self._latent_idx, :] # (B, num_latent, F)
1054+
if self.is_features_dim_missing:
1055+
latent_part = full_outputs[:, self._latent_idx] # (B, num_latent)
1056+
else:
1057+
latent_part = full_outputs[:, self._latent_idx, :] # (B, num_latent, F)
10231058

10241059
return latent_part.reshape_as(original_latent_tensor)
10251060

0 commit comments

Comments
 (0)