@@ -749,7 +749,13 @@ def __init__(
749
749
to save memory resources
750
750
"""
751
751
752
- T , F = original_estimator .input_shape
752
+ T = original_estimator .input_shape [0 ]
753
+ if len (original_estimator .input_shape ) == 2 :
754
+ F = original_estimator .input_shape [1 ]
755
+ self .is_features_dim_missing = False
756
+ else :
757
+ F = 1
758
+ self .is_features_dim_missing = True
753
759
754
760
# Input checks for fixed_condition_mask
755
761
if fixed_condition_mask .dim () != 1 :
@@ -786,10 +792,10 @@ def __init__(
786
792
"for all entries."
787
793
)
788
794
795
+ # Count number of latent and observed nodes
789
796
num_latent = int (torch .sum (fixed_condition_mask == 0 ).item ())
790
797
num_observed = int (torch .sum (fixed_condition_mask == 1 ).item ())
791
798
792
- # Count number of latent and observed nodes
793
799
self ._new_input_shape = torch .Size ((num_latent * F ,))
794
800
self ._new_condition_shape = torch .Size ((num_observed * F ,))
795
801
@@ -828,12 +834,20 @@ def __init__(
828
834
self ._observed_idx = (self ._fixed_condition_mask == 1 ).nonzero (as_tuple = True )[0 ]
829
835
830
836
# Get the mean/std for the latent nodes from the original estimator
831
- latent_mean_base_unflattened = original_estimator .mean_base [
832
- :, self ._latent_idx , :
833
- ]
834
- latent_std_base_unflattened = original_estimator .std_base [
835
- :, self ._latent_idx , :
836
- ]
837
+ if len (original_estimator .input_shape ) == 1 :
838
+ latent_mean_base_unflattened = original_estimator .mean_base [
839
+ :, self ._latent_idx
840
+ ]
841
+ latent_std_base_unflattened = original_estimator .std_base [
842
+ :, self ._latent_idx
843
+ ]
844
+ else :
845
+ latent_mean_base_unflattened = original_estimator .mean_base [
846
+ :, self ._latent_idx , :
847
+ ]
848
+ latent_std_base_unflattened = original_estimator .std_base [
849
+ :, self ._latent_idx , :
850
+ ]
837
851
838
852
latent_mean_base_flattened = latent_mean_base_unflattened .flatten (start_dim = 1 )
839
853
latent_std_base_flattened = latent_std_base_unflattened .flatten (start_dim = 1 )
@@ -1000,26 +1014,47 @@ def _assemble_full_inputs(self, input, condition):
1000
1014
B = int (torch .prod (torch .tensor (input .shape [:- 1 ])).item ())
1001
1015
C = int (torch .prod (torch .tensor (condition .shape [:- 1 ])).item ())
1002
1016
1003
- input_part_unflattened = input .reshape (B , self ._num_latent , self ._original_F )
1004
- condition_part_unflattened = condition .reshape (
1005
- - 1 , self ._num_observed , self ._original_F
1006
- ).repeat (B // C , 1 , 1 )
1007
-
1008
- full_inputs = torch .zeros (
1009
- B ,
1010
- self ._original_T ,
1011
- self ._original_F ,
1012
- dtype = input .dtype ,
1013
- device = input .device ,
1014
- )
1015
- # Place unflattened parts into the correct positions
1016
- full_inputs [:, self ._latent_idx , :] = input_part_unflattened
1017
- full_inputs [:, self ._observed_idx , :] = condition_part_unflattened
1017
+ if self .is_features_dim_missing :
1018
+ input_part_unflattened = input .reshape (B , self ._num_latent )
1019
+ condition_part_unflattened = condition .reshape (
1020
+ - 1 , self ._num_observed
1021
+ ).repeat (B // C , 1 , 1 )
1022
+
1023
+ full_inputs = torch .zeros (
1024
+ B ,
1025
+ self ._original_T ,
1026
+ dtype = input .dtype ,
1027
+ device = input .device ,
1028
+ )
1029
+ # Place unflattened parts into the correct positions
1030
+ full_inputs [:, self ._latent_idx ] = input_part_unflattened
1031
+ full_inputs [:, self ._observed_idx ] = condition_part_unflattened
1032
+ else :
1033
+ input_part_unflattened = input .reshape (
1034
+ B , self ._num_latent , self ._original_F
1035
+ )
1036
+ condition_part_unflattened = condition .reshape (
1037
+ - 1 , self ._num_observed , self ._original_F
1038
+ ).repeat (B // C , 1 , 1 )
1039
+
1040
+ full_inputs = torch .zeros (
1041
+ B ,
1042
+ self ._original_T ,
1043
+ self ._original_F ,
1044
+ dtype = input .dtype ,
1045
+ device = input .device ,
1046
+ )
1047
+ # Place unflattened parts into the correct positions
1048
+ full_inputs [:, self ._latent_idx , :] = input_part_unflattened
1049
+ full_inputs [:, self ._observed_idx , :] = condition_part_unflattened
1018
1050
1019
1051
return full_inputs
1020
1052
1021
1053
def _disassemble_full_outputs (self , full_outputs , original_latent_tensor ):
1022
- latent_part = full_outputs [:, self ._latent_idx , :] # (B, num_latent, F)
1054
+ if self .is_features_dim_missing :
1055
+ latent_part = full_outputs [:, self ._latent_idx ] # (B, num_latent)
1056
+ else :
1057
+ latent_part = full_outputs [:, self ._latent_idx , :] # (B, num_latent, F)
1023
1058
1024
1059
return latent_part .reshape_as (original_latent_tensor )
1025
1060
0 commit comments