|
34 | 34 | def build_made(
|
35 | 35 | batch_x: Tensor,
|
36 | 36 | batch_y: Tensor,
|
37 |
| - z_score_x: Optional[str] = "independent", |
38 |
| - z_score_y: Optional[str] = "independent", |
| 37 | + z_score_x: Literal[ |
| 38 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 39 | + ], |
| 40 | + z_score_y: Literal[ |
| 41 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 42 | + ], |
39 | 43 | hidden_features: int = 50,
|
40 | 44 | num_mixture_components: int = 10,
|
41 | 45 | embedding_net: nn.Module = nn.Identity(),
|
@@ -105,8 +109,12 @@ def build_made(
|
105 | 109 | def build_maf(
|
106 | 110 | batch_x: Tensor,
|
107 | 111 | batch_y: Tensor,
|
108 |
| - z_score_x: Optional[str] = "independent", |
109 |
| - z_score_y: Optional[str] = "independent", |
| 112 | + z_score_x: Literal[ |
| 113 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 114 | + ], |
| 115 | + z_score_y: Literal[ |
| 116 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 117 | + ], |
110 | 118 | hidden_features: int = 50,
|
111 | 119 | num_transforms: int = 5,
|
112 | 120 | embedding_net: nn.Module = nn.Identity(),
|
@@ -193,8 +201,12 @@ def build_maf(
|
193 | 201 | def build_maf_rqs(
|
194 | 202 | batch_x: Tensor,
|
195 | 203 | batch_y: Tensor,
|
196 |
| - z_score_x: Optional[str] = "independent", |
197 |
| - z_score_y: Optional[str] = "independent", |
| 204 | + z_score_x: Literal[ |
| 205 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 206 | + ], |
| 207 | + z_score_y: Literal[ |
| 208 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 209 | + ], |
198 | 210 | hidden_features: int = 50,
|
199 | 211 | num_transforms: int = 5,
|
200 | 212 | embedding_net: nn.Module = nn.Identity(),
|
@@ -305,8 +317,12 @@ def build_maf_rqs(
|
305 | 317 | def build_nsf(
|
306 | 318 | batch_x: Tensor,
|
307 | 319 | batch_y: Tensor,
|
308 |
| - z_score_x: Optional[str] = "independent", |
309 |
| - z_score_y: Optional[str] = "independent", |
| 320 | + z_score_x: Literal[ |
| 321 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 322 | + ], |
| 323 | + z_score_y: Literal[ |
| 324 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 325 | + ], |
310 | 326 | hidden_features: int = 50,
|
311 | 327 | num_transforms: int = 5,
|
312 | 328 | num_bins: int = 10,
|
@@ -427,8 +443,12 @@ def mask_in_layer(i):
|
427 | 443 | def build_zuko_nice(
|
428 | 444 | batch_x: Tensor,
|
429 | 445 | batch_y: Tensor,
|
430 |
| - z_score_x: Optional[str] = "independent", |
431 |
| - z_score_y: Optional[str] = "independent", |
| 446 | + z_score_x: Literal[ |
| 447 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 448 | + ], |
| 449 | + z_score_y: Literal[ |
| 450 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 451 | + ], |
432 | 452 | hidden_features: Union[Sequence[int], int] = 50,
|
433 | 453 | num_transforms: int = 5,
|
434 | 454 | embedding_net: nn.Module = nn.Identity(),
|
@@ -482,8 +502,12 @@ def build_zuko_nice(
|
482 | 502 | def build_zuko_maf(
|
483 | 503 | batch_x: Tensor,
|
484 | 504 | batch_y: Tensor,
|
485 |
| - z_score_x: Optional[str] = "independent", |
486 |
| - z_score_y: Optional[str] = "independent", |
| 505 | + z_score_x: Literal[ |
| 506 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 507 | + ], |
| 508 | + z_score_y: Literal[ |
| 509 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 510 | + ], |
487 | 511 | hidden_features: Union[Sequence[int], int] = 50,
|
488 | 512 | num_transforms: int = 5,
|
489 | 513 | embedding_net: nn.Module = nn.Identity(),
|
@@ -534,8 +558,12 @@ def build_zuko_maf(
|
534 | 558 | def build_zuko_nsf(
|
535 | 559 | batch_x: Tensor,
|
536 | 560 | batch_y: Tensor,
|
537 |
| - z_score_x: Optional[str] = "independent", |
538 |
| - z_score_y: Optional[str] = "independent", |
| 561 | + z_score_x: Literal[ |
| 562 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 563 | + ], |
| 564 | + z_score_y: Literal[ |
| 565 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 566 | + ], |
539 | 567 | hidden_features: Union[Sequence[int], int] = 50,
|
540 | 568 | num_transforms: int = 5,
|
541 | 569 | embedding_net: nn.Module = nn.Identity(),
|
@@ -595,8 +623,12 @@ def build_zuko_nsf(
|
595 | 623 | def build_zuko_ncsf(
|
596 | 624 | batch_x: Tensor,
|
597 | 625 | batch_y: Tensor,
|
598 |
| - z_score_x: Optional[str] = "independent", |
599 |
| - z_score_y: Optional[str] = "independent", |
| 626 | + z_score_x: Literal[ |
| 627 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 628 | + ], |
| 629 | + z_score_y: Literal[ |
| 630 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 631 | + ], |
600 | 632 | hidden_features: Union[Sequence[int], int] = 50,
|
601 | 633 | num_transforms: int = 5,
|
602 | 634 | embedding_net: nn.Module = nn.Identity(),
|
@@ -651,8 +683,12 @@ def build_zuko_ncsf(
|
651 | 683 | def build_zuko_sospf(
|
652 | 684 | batch_x: Tensor,
|
653 | 685 | batch_y: Tensor,
|
654 |
| - z_score_x: Optional[str] = "independent", |
655 |
| - z_score_y: Optional[str] = "independent", |
| 686 | + z_score_x: Literal[ |
| 687 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 688 | + ], |
| 689 | + z_score_y: Literal[ |
| 690 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 691 | + ], |
656 | 692 | hidden_features: Union[Sequence[int], int] = 50,
|
657 | 693 | num_transforms: int = 5,
|
658 | 694 | embedding_net: nn.Module = nn.Identity(),
|
@@ -705,8 +741,12 @@ def build_zuko_sospf(
|
705 | 741 | def build_zuko_naf(
|
706 | 742 | batch_x: Tensor,
|
707 | 743 | batch_y: Tensor,
|
708 |
| - z_score_x: Optional[str] = "independent", |
709 |
| - z_score_y: Optional[str] = "independent", |
| 744 | + z_score_x: Literal[ |
| 745 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 746 | + ], |
| 747 | + z_score_y: Literal[ |
| 748 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 749 | + ], |
710 | 750 | hidden_features: Union[Sequence[int], int] = 50,
|
711 | 751 | num_transforms: int = 5,
|
712 | 752 | embedding_net: nn.Module = nn.Identity(),
|
@@ -771,8 +811,12 @@ def build_zuko_naf(
|
771 | 811 | def build_zuko_unaf(
|
772 | 812 | batch_x: Tensor,
|
773 | 813 | batch_y: Tensor,
|
774 |
| - z_score_x: Optional[str] = "independent", |
775 |
| - z_score_y: Optional[str] = "independent", |
| 814 | + z_score_x: Literal[ |
| 815 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 816 | + ], |
| 817 | + z_score_y: Literal[ |
| 818 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 819 | + ], |
776 | 820 | hidden_features: Union[Sequence[int], int] = 50,
|
777 | 821 | num_transforms: int = 5,
|
778 | 822 | embedding_net: nn.Module = nn.Identity(),
|
@@ -837,8 +881,12 @@ def build_zuko_unaf(
|
837 | 881 | def build_zuko_cnf(
|
838 | 882 | batch_x: Tensor,
|
839 | 883 | batch_y: Tensor,
|
840 |
| - z_score_x: Optional[str] = "independent", |
841 |
| - z_score_y: Optional[str] = "independent", |
| 884 | + z_score_x: Literal[ |
| 885 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 886 | + ], |
| 887 | + z_score_y: Literal[ |
| 888 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 889 | + ], |
842 | 890 | hidden_features: Union[Sequence[int], int] = 50,
|
843 | 891 | num_transforms: int = 5,
|
844 | 892 | embedding_net: nn.Module = nn.Identity(),
|
@@ -891,8 +939,12 @@ def build_zuko_cnf(
|
891 | 939 | def build_zuko_gf(
|
892 | 940 | batch_x: Tensor,
|
893 | 941 | batch_y: Tensor,
|
894 |
| - z_score_x: Optional[str] = "independent", |
895 |
| - z_score_y: Optional[str] = "independent", |
| 942 | + z_score_x: Literal[ |
| 943 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 944 | + ], |
| 945 | + z_score_y: Literal[ |
| 946 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 947 | + ], |
896 | 948 | hidden_features: Union[Sequence[int], int] = 50,
|
897 | 949 | num_transforms: int = 3,
|
898 | 950 | embedding_net: nn.Module = nn.Identity(),
|
@@ -948,8 +1000,12 @@ def build_zuko_gf(
|
948 | 1000 | def build_zuko_bpf(
|
949 | 1001 | batch_x: Tensor,
|
950 | 1002 | batch_y: Tensor,
|
951 |
| - z_score_x: Optional[str] = "independent", |
952 |
| - z_score_y: Optional[str] = "independent", |
| 1003 | + z_score_x: Literal[ |
| 1004 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 1005 | + ], |
| 1006 | + z_score_y: Literal[ |
| 1007 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 1008 | + ], |
953 | 1009 | hidden_features: Union[Sequence[int], int] = 50,
|
954 | 1010 | num_transforms: int = 3,
|
955 | 1011 | embedding_net: nn.Module = nn.Identity(),
|
@@ -1007,10 +1063,12 @@ def build_zuko_flow(
|
1007 | 1063 | which_nf: str,
|
1008 | 1064 | batch_x: Tensor,
|
1009 | 1065 | batch_y: Tensor,
|
1010 |
| - z_score_x: Literal["none", "independent", |
1011 |
| - "structured", "transform_to_unconstrained"], |
1012 |
| - z_score_y: Literal["none", "independent", |
1013 |
| - "structured", "transform_to_unconstrained"], |
| 1066 | + z_score_x: Literal[ |
| 1067 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 1068 | + ], |
| 1069 | + z_score_y: Literal[ |
| 1070 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 1071 | + ], |
1014 | 1072 | hidden_features: Union[Sequence[int], int] = 50,
|
1015 | 1073 | num_transforms: int = 5,
|
1016 | 1074 | embedding_net: nn.Module = nn.Identity(),
|
@@ -1163,7 +1221,9 @@ def get_transform_to_unconstrained(
|
1163 | 1221 | def build_zuko_unconditional_flow(
|
1164 | 1222 | which_nf: str,
|
1165 | 1223 | batch_x: Tensor,
|
1166 |
| - z_score_x: Optional[str] = "independent", |
| 1224 | + z_score_x: Literal[ |
| 1225 | + "none", "independent", "structured", "transform_to_unconstrained" |
| 1226 | + ], |
1167 | 1227 | hidden_features: Union[Sequence[int], int] = 50,
|
1168 | 1228 | num_transforms: int = 5,
|
1169 | 1229 | **kwargs,
|
|
0 commit comments