66from rdkit import DataStructs
77
88from datasail .reader .utils import DataSet
9- from datasail .settings import LOGGER
9+ from datasail .settings import DIST_OPTIONS , LOGGER , SIM_OPTIONS
1010
11- SIM_OPTIONS = Literal [
12- "allbit" , "asymmetric" , "braunblanquet" , "cosine" , "dice" , "kulczynski" , "onbit" , "rogotgoldberg" ,
13- "russel" , "sokal" , "tanimoto"
14- ]
1511
16- # unbounded: chebyshev, cityblock, euclidean, mahalanobis, manhattan, mcconnaughey, minkowski, sqeuclidean
17- # produces inf or nan: correlation, cosine, jensenshannon, seuclidean, braycurtis
18- # boolean only: dice, kulczynski1, russelrao, sokalsneath
19- # matching == hamming, manhattan == cityblock (inofficial)
20- DIST_OPTIONS = Literal [
21- "canberra" , "hamming" , "jaccard" , "matching" , "rogerstanimoto" , "sokalmichener" , "yule"
22- ]
23-
24-
25- def get_rdkit_fct (method : SIM_OPTIONS ) -> Callable [[Any , Any ], np .ndarray ]:
12+ def get_rdkit_fct (method : str ) -> Callable [[Any , Any ], np .ndarray ]:
2613 """
2714 Get the RDKit function for the given similarity measure.
2815
@@ -57,7 +44,7 @@ def get_rdkit_fct(method: SIM_OPTIONS) -> Callable[[Any, Any], np.ndarray]:
5744 raise ValueError (f"Unknown method { method } " )
5845
5946
60- def rdkit_sim (fps , method : SIM_OPTIONS ) -> np .ndarray :
47+ def rdkit_sim (fps , method : str ) -> np .ndarray :
6148 """
6249 Compute the similarity between elements of a list of rdkit vectors.
6350
@@ -108,7 +95,7 @@ def iterable2bitvect(it) -> DataStructs.ExplicitBitVect:
10895 return output
10996
11097
111- def run_vector (dataset : DataSet , method : SIM_OPTIONS = "tanimoto" ) -> None :
98+ def run_vector (dataset : DataSet , method : str = "tanimoto" ) -> None :
11299 """
113100 Compute pairwise Tanimoto-Scores of the given dataset.
114101
@@ -120,7 +107,7 @@ def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None:
120107 method = method .lower ()
121108
122109 embed = dataset .data [dataset .names [0 ]]
123- if method in get_args ( SIM_OPTIONS ) :
110+ if method in SIM_OPTIONS :
124111 if isinstance (embed , (list , tuple , np .ndarray )):
125112 if isinstance (embed [0 ], int ) or np .issubdtype (embed [0 ].dtype , int ):
126113 if method in ["allbit" , "asymmetric" , "braunblanquet" , "cosine" , "kulczynski" , "onbit" ,
@@ -137,7 +124,7 @@ def run_vector(dataset: DataSet, method: SIM_OPTIONS = "tanimoto") -> None:
137124 raise ValueError (
138125 f"Unsupported embedding type { type (embed )} . Please use either RDKit datastructures, lists, "
139126 f"tuples or one-dimensional numpy arrays." )
140- elif method in get_args ( DIST_OPTIONS ) :
127+ elif method in DIST_OPTIONS :
141128 dtype = np .bool_ if ["jaccard" , "rogerstanimoto" , "sokalmichener" , "yule" ] else np .float64
142129 if isinstance (embed , (
143130 list , tuple , DataStructs .ExplicitBitVect , DataStructs .LongSparseIntVect , DataStructs .IntSparseIntVect )):
@@ -159,7 +146,7 @@ def run(
159146 dataset : DataSet ,
160147 fps : Union [np .ndarray , DataStructs .ExplicitBitVect , DataStructs .LongSparseIntVect ,
161148 DataStructs .IntSparseIntVect ],
162- method : Union [ SIM_OPTIONS , DIST_OPTIONS ] ,
149+ method : str ,
163150) -> None :
164151 """
165152 Compute pairwise similarities of the given fingerprints.
@@ -169,11 +156,11 @@ def run(
169156 fps: The fingerprints to compute pairwise similarities for.
170157 method: The similarity measure to use.
171158 """
172- if method in get_args ( SIM_OPTIONS ) :
159+ if method in SIM_OPTIONS :
173160 dataset .cluster_similarity = rdkit_sim (fps , method )
174161 if method == "mcconnaughey" :
175162 dataset .cluster_similarity = dataset .cluster_similarity + 1 / 2
176- elif method in get_args ( DIST_OPTIONS ) :
163+ elif method in DIST_OPTIONS :
177164 if method == "mahalanobis" and len (fps ) <= len (fps [0 ]):
178165 raise ValueError (
179166 f"For clustering with the Mahalanobis method, you have to have more observations that dimensions in "
@@ -185,6 +172,8 @@ def run(
185172 dataset .cluster_distance = dataset .cluster_distance / len (fps [0 ])
186173 elif method == "yule" :
187174 dataset .cluster_distance /= 2
175+ else :
176+ raise ValueError (f"Unknown method to compare fingerprints. Found: { method } " )
188177
189178
190179if __name__ == '__main__' :
0 commit comments