Skip to content

Commit 81ef02f

Browse files
committed
fix typing error in mol.py
1 parent d5bcc81 commit 81ef02f

File tree

1 file changed

+8
-4
lines changed
  • kgcnn/data/transform/scaler

1 file changed

+8
-4
lines changed

kgcnn/data/transform/scaler/mol.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,11 @@ def __init__(self, scaler: list):
807807
raise ValueError("Unsupported scaler type '%s'." % x)
808808

809809
# noinspection PyPep8Naming
810-
def fit_transform(self, y=None, *, X=None, copy=True, sample_weight=None, atomic_number=None):
810+
def fit_transform(self, y: Union[np.ndarray, List[np.ndarray]] = None,
811+
X: Union[np.ndarray, List[np.ndarray], None] = None,
812+
atomic_number: Union[np.ndarray, List[np.ndarray], None] = None,
813+
copy: bool = True,
814+
sample_weight=None):
811815
r"""Fit and transform all target labels for QM.
812816
813817
Args:
@@ -826,7 +830,7 @@ def fit_transform(self, y=None, *, X=None, copy=True, sample_weight=None, atomic
826830
# noinspection PyPep8Naming
827831
def transform(self, y: Union[np.ndarray, List[np.ndarray]] = None,
828832
X: Union[np.ndarray, List[np.ndarray], None] = None,
829-
atomic_number: List[np.ndarray, None] = None,
833+
atomic_number: Union[np.ndarray, List[np.ndarray], None] = None,
830834
copy=True):
831835
r"""Transform all target labels for QM. Requires :obj:`fit()` called previously.
832836
@@ -855,7 +859,7 @@ def transform(self, y: Union[np.ndarray, List[np.ndarray]] = None,
855859
# noinspection PyPep8Naming
856860
def fit(self, y: Union[np.ndarray, List[np.ndarray]] = None,
857861
X: Union[np.ndarray, List[np.ndarray], None] = None,
858-
atomic_number: List[np.ndarray, None] = None,
862+
atomic_number: Union[np.ndarray, List[np.ndarray], None] = None,
859863
sample_weight=None):
860864
r"""Fit scaling of QM graph labels or targets.
861865
@@ -878,7 +882,7 @@ def fit(self, y: Union[np.ndarray, List[np.ndarray]] = None,
878882
# noinspection PyPep8Naming
879883
def inverse_transform(self, y: Union[np.ndarray, List[np.ndarray]] = None,
880884
X: Union[np.ndarray, List[np.ndarray], None] = None,
881-
atomic_number: List[np.ndarray, None] = None,
885+
atomic_number: Union[np.ndarray, List[np.ndarray], None] = None,
882886
copy: bool = True):
883887
r"""Back-transform all target labels for QM.
884888

0 commit comments

Comments
 (0)