Skip to content

Commit 580fd50

Browse files
Nastya KrouglovaNastya Krouglova
authored andcommitted
adjust literals and format
1 parent 33129d8 commit 580fd50

File tree

1 file changed

+93
-33
lines changed
  • sbi/neural_nets/net_builders

1 file changed

+93
-33
lines changed

sbi/neural_nets/net_builders/flow.py

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@
3434
def build_made(
3535
batch_x: Tensor,
3636
batch_y: Tensor,
37-
z_score_x: Optional[str] = "independent",
38-
z_score_y: Optional[str] = "independent",
37+
z_score_x: Literal[
38+
"none", "independent", "structured", "transform_to_unconstrained"
39+
],
40+
z_score_y: Literal[
41+
"none", "independent", "structured", "transform_to_unconstrained"
42+
],
3943
hidden_features: int = 50,
4044
num_mixture_components: int = 10,
4145
embedding_net: nn.Module = nn.Identity(),
@@ -105,8 +109,12 @@ def build_made(
105109
def build_maf(
106110
batch_x: Tensor,
107111
batch_y: Tensor,
108-
z_score_x: Optional[str] = "independent",
109-
z_score_y: Optional[str] = "independent",
112+
z_score_x: Literal[
113+
"none", "independent", "structured", "transform_to_unconstrained"
114+
],
115+
z_score_y: Literal[
116+
"none", "independent", "structured", "transform_to_unconstrained"
117+
],
110118
hidden_features: int = 50,
111119
num_transforms: int = 5,
112120
embedding_net: nn.Module = nn.Identity(),
@@ -193,8 +201,12 @@ def build_maf(
193201
def build_maf_rqs(
194202
batch_x: Tensor,
195203
batch_y: Tensor,
196-
z_score_x: Optional[str] = "independent",
197-
z_score_y: Optional[str] = "independent",
204+
z_score_x: Literal[
205+
"none", "independent", "structured", "transform_to_unconstrained"
206+
],
207+
z_score_y: Literal[
208+
"none", "independent", "structured", "transform_to_unconstrained"
209+
],
198210
hidden_features: int = 50,
199211
num_transforms: int = 5,
200212
embedding_net: nn.Module = nn.Identity(),
@@ -305,8 +317,12 @@ def build_maf_rqs(
305317
def build_nsf(
306318
batch_x: Tensor,
307319
batch_y: Tensor,
308-
z_score_x: Optional[str] = "independent",
309-
z_score_y: Optional[str] = "independent",
320+
z_score_x: Literal[
321+
"none", "independent", "structured", "transform_to_unconstrained"
322+
],
323+
z_score_y: Literal[
324+
"none", "independent", "structured", "transform_to_unconstrained"
325+
],
310326
hidden_features: int = 50,
311327
num_transforms: int = 5,
312328
num_bins: int = 10,
@@ -427,8 +443,12 @@ def mask_in_layer(i):
427443
def build_zuko_nice(
428444
batch_x: Tensor,
429445
batch_y: Tensor,
430-
z_score_x: Optional[str] = "independent",
431-
z_score_y: Optional[str] = "independent",
446+
z_score_x: Literal[
447+
"none", "independent", "structured", "transform_to_unconstrained"
448+
],
449+
z_score_y: Literal[
450+
"none", "independent", "structured", "transform_to_unconstrained"
451+
],
432452
hidden_features: Union[Sequence[int], int] = 50,
433453
num_transforms: int = 5,
434454
embedding_net: nn.Module = nn.Identity(),
@@ -482,8 +502,12 @@ def build_zuko_nice(
482502
def build_zuko_maf(
483503
batch_x: Tensor,
484504
batch_y: Tensor,
485-
z_score_x: Optional[str] = "independent",
486-
z_score_y: Optional[str] = "independent",
505+
z_score_x: Literal[
506+
"none", "independent", "structured", "transform_to_unconstrained"
507+
],
508+
z_score_y: Literal[
509+
"none", "independent", "structured", "transform_to_unconstrained"
510+
],
487511
hidden_features: Union[Sequence[int], int] = 50,
488512
num_transforms: int = 5,
489513
embedding_net: nn.Module = nn.Identity(),
@@ -534,8 +558,12 @@ def build_zuko_maf(
534558
def build_zuko_nsf(
535559
batch_x: Tensor,
536560
batch_y: Tensor,
537-
z_score_x: Optional[str] = "independent",
538-
z_score_y: Optional[str] = "independent",
561+
z_score_x: Literal[
562+
"none", "independent", "structured", "transform_to_unconstrained"
563+
],
564+
z_score_y: Literal[
565+
"none", "independent", "structured", "transform_to_unconstrained"
566+
],
539567
hidden_features: Union[Sequence[int], int] = 50,
540568
num_transforms: int = 5,
541569
embedding_net: nn.Module = nn.Identity(),
@@ -595,8 +623,12 @@ def build_zuko_nsf(
595623
def build_zuko_ncsf(
596624
batch_x: Tensor,
597625
batch_y: Tensor,
598-
z_score_x: Optional[str] = "independent",
599-
z_score_y: Optional[str] = "independent",
626+
z_score_x: Literal[
627+
"none", "independent", "structured", "transform_to_unconstrained"
628+
],
629+
z_score_y: Literal[
630+
"none", "independent", "structured", "transform_to_unconstrained"
631+
],
600632
hidden_features: Union[Sequence[int], int] = 50,
601633
num_transforms: int = 5,
602634
embedding_net: nn.Module = nn.Identity(),
@@ -651,8 +683,12 @@ def build_zuko_ncsf(
651683
def build_zuko_sospf(
652684
batch_x: Tensor,
653685
batch_y: Tensor,
654-
z_score_x: Optional[str] = "independent",
655-
z_score_y: Optional[str] = "independent",
686+
z_score_x: Literal[
687+
"none", "independent", "structured", "transform_to_unconstrained"
688+
],
689+
z_score_y: Literal[
690+
"none", "independent", "structured", "transform_to_unconstrained"
691+
],
656692
hidden_features: Union[Sequence[int], int] = 50,
657693
num_transforms: int = 5,
658694
embedding_net: nn.Module = nn.Identity(),
@@ -705,8 +741,12 @@ def build_zuko_sospf(
705741
def build_zuko_naf(
706742
batch_x: Tensor,
707743
batch_y: Tensor,
708-
z_score_x: Optional[str] = "independent",
709-
z_score_y: Optional[str] = "independent",
744+
z_score_x: Literal[
745+
"none", "independent", "structured", "transform_to_unconstrained"
746+
],
747+
z_score_y: Literal[
748+
"none", "independent", "structured", "transform_to_unconstrained"
749+
],
710750
hidden_features: Union[Sequence[int], int] = 50,
711751
num_transforms: int = 5,
712752
embedding_net: nn.Module = nn.Identity(),
@@ -771,8 +811,12 @@ def build_zuko_naf(
771811
def build_zuko_unaf(
772812
batch_x: Tensor,
773813
batch_y: Tensor,
774-
z_score_x: Optional[str] = "independent",
775-
z_score_y: Optional[str] = "independent",
814+
z_score_x: Literal[
815+
"none", "independent", "structured", "transform_to_unconstrained"
816+
],
817+
z_score_y: Literal[
818+
"none", "independent", "structured", "transform_to_unconstrained"
819+
],
776820
hidden_features: Union[Sequence[int], int] = 50,
777821
num_transforms: int = 5,
778822
embedding_net: nn.Module = nn.Identity(),
@@ -837,8 +881,12 @@ def build_zuko_unaf(
837881
def build_zuko_cnf(
838882
batch_x: Tensor,
839883
batch_y: Tensor,
840-
z_score_x: Optional[str] = "independent",
841-
z_score_y: Optional[str] = "independent",
884+
z_score_x: Literal[
885+
"none", "independent", "structured", "transform_to_unconstrained"
886+
],
887+
z_score_y: Literal[
888+
"none", "independent", "structured", "transform_to_unconstrained"
889+
],
842890
hidden_features: Union[Sequence[int], int] = 50,
843891
num_transforms: int = 5,
844892
embedding_net: nn.Module = nn.Identity(),
@@ -891,8 +939,12 @@ def build_zuko_cnf(
891939
def build_zuko_gf(
892940
batch_x: Tensor,
893941
batch_y: Tensor,
894-
z_score_x: Optional[str] = "independent",
895-
z_score_y: Optional[str] = "independent",
942+
z_score_x: Literal[
943+
"none", "independent", "structured", "transform_to_unconstrained"
944+
],
945+
z_score_y: Literal[
946+
"none", "independent", "structured", "transform_to_unconstrained"
947+
],
896948
hidden_features: Union[Sequence[int], int] = 50,
897949
num_transforms: int = 3,
898950
embedding_net: nn.Module = nn.Identity(),
@@ -948,8 +1000,12 @@ def build_zuko_gf(
9481000
def build_zuko_bpf(
9491001
batch_x: Tensor,
9501002
batch_y: Tensor,
951-
z_score_x: Optional[str] = "independent",
952-
z_score_y: Optional[str] = "independent",
1003+
z_score_x: Literal[
1004+
"none", "independent", "structured", "transform_to_unconstrained"
1005+
],
1006+
z_score_y: Literal[
1007+
"none", "independent", "structured", "transform_to_unconstrained"
1008+
],
9531009
hidden_features: Union[Sequence[int], int] = 50,
9541010
num_transforms: int = 3,
9551011
embedding_net: nn.Module = nn.Identity(),
@@ -1007,10 +1063,12 @@ def build_zuko_flow(
10071063
which_nf: str,
10081064
batch_x: Tensor,
10091065
batch_y: Tensor,
1010-
z_score_x: Literal["none", "independent",
1011-
"structured", "transform_to_unconstrained"],
1012-
z_score_y: Literal["none", "independent",
1013-
"structured", "transform_to_unconstrained"],
1066+
z_score_x: Literal[
1067+
"none", "independent", "structured", "transform_to_unconstrained"
1068+
],
1069+
z_score_y: Literal[
1070+
"none", "independent", "structured", "transform_to_unconstrained"
1071+
],
10141072
hidden_features: Union[Sequence[int], int] = 50,
10151073
num_transforms: int = 5,
10161074
embedding_net: nn.Module = nn.Identity(),
@@ -1163,7 +1221,9 @@ def get_transform_to_unconstrained(
11631221
def build_zuko_unconditional_flow(
11641222
which_nf: str,
11651223
batch_x: Tensor,
1166-
z_score_x: Optional[str] = "independent",
1224+
z_score_x: Literal[
1225+
"none", "independent", "structured", "transform_to_unconstrained"
1226+
],
11671227
hidden_features: Union[Sequence[int], int] = 50,
11681228
num_transforms: int = 5,
11691229
**kwargs,

0 commit comments

Comments
 (0)