@@ -807,7 +807,11 @@ def __init__(self, scaler: list):
807
807
raise ValueError ("Unsupported scaler type '%s'." % x )
808
808
809
809
# noinspection PyPep8Naming
810
- def fit_transform (self , y = None , * , X = None , copy = True , sample_weight = None , atomic_number = None ):
810
+ def fit_transform (self , y : Union [np .ndarray , List [np .ndarray ]] = None ,
811
+ X : Union [np .ndarray , List [np .ndarray ], None ] = None ,
812
+ atomic_number : Union [np .ndarray , List [np .ndarray ], None ] = None ,
813
+ copy : bool = True ,
814
+ sample_weight = None ):
811
815
r"""Fit and transform all target labels for QM.
812
816
813
817
Args:
@@ -826,7 +830,7 @@ def fit_transform(self, y=None, *, X=None, copy=True, sample_weight=None, atomic
826
830
# noinspection PyPep8Naming
827
831
def transform (self , y : Union [np .ndarray , List [np .ndarray ]] = None ,
828
832
X : Union [np .ndarray , List [np .ndarray ], None ] = None ,
829
- atomic_number : List [np .ndarray , None ] = None ,
833
+ atomic_number : Union [ np . ndarray , List [np .ndarray ] , None ] = None ,
830
834
copy = True ):
831
835
r"""Transform all target labels for QM. Requires :obj:`fit()` called previously.
832
836
@@ -855,7 +859,7 @@ def transform(self, y: Union[np.ndarray, List[np.ndarray]] = None,
855
859
# noinspection PyPep8Naming
856
860
def fit (self , y : Union [np .ndarray , List [np .ndarray ]] = None ,
857
861
X : Union [np .ndarray , List [np .ndarray ], None ] = None ,
858
- atomic_number : List [np .ndarray , None ] = None ,
862
+ atomic_number : Union [ np . ndarray , List [np .ndarray ] , None ] = None ,
859
863
sample_weight = None ):
860
864
r"""Fit scaling of QM graph labels or targets.
861
865
@@ -878,7 +882,7 @@ def fit(self, y: Union[np.ndarray, List[np.ndarray]] = None,
878
882
# noinspection PyPep8Naming
879
883
def inverse_transform (self , y : Union [np .ndarray , List [np .ndarray ]] = None ,
880
884
X : Union [np .ndarray , List [np .ndarray ], None ] = None ,
881
- atomic_number : List [np .ndarray , None ] = None ,
885
+ atomic_number : Union [ np . ndarray , List [np .ndarray ] , None ] = None ,
882
886
copy : bool = True ):
883
887
r"""Back-transform all target labels for QM.
884
888
0 commit comments