Skip to content

Commit bcce5f7

Browse files
committed
Update AALogo().get_conservation
1 parent 3364bf0 commit bcce5f7

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

aaanalysis/_utils/check_data.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,19 @@ def check_superset_subset(subset=None, superset=None, name_subset=None, name_sup
234234

235235
# df checking functions
236236
def check_df(name="df", df=None, accept_none=False, accept_nan=True, check_all_positive=False,
237-
cols_requiered=None, cols_forbidden=None, cols_nan_check=None, str_add=None):
237+
check_series=False, cols_requiered=None, cols_forbidden=None, cols_nan_check=None,
238+
str_add=None):
238239
"""Check if the provided DataFrame meets various criteria such as NaN values, required/forbidden columns, etc."""
239240
# Check DataFrame and values
240241
if df is None:
241242
if not accept_none:
242243
raise ValueError(f"'{name}' should not be None")
243244
else:
244245
return None
245-
if not isinstance(df, pd.DataFrame):
246-
str_error = add_str(str_error=f"'{name}' ({type(df)}) should be DataFrame",
246+
_check_dtype = pd.DataFrame if not check_series else pd.Series
247+
_str_check_dtype = "DataFrame" if not check_series else "Series"
248+
if not isinstance(df, _check_dtype):
249+
str_error = add_str(str_error=f"'{name}' ({type(df)}) should be {_str_check_dtype}",
247250
str_add=str_add)
248251
raise ValueError(str_error)
249252
if not accept_nan and df.isna().any().any():
@@ -278,11 +281,12 @@ def check_df(name="df", df=None, accept_none=False, accept_nan=True, check_all_p
278281
str_error = add_str(str_error=f"NaN values are not allowed in '{cols_nan_check}'.",
279282
str_add=str_add)
280283
raise ValueError(str_error)
281-
columns = list(df)
282-
cols_duplicated = [x for x in columns if columns.count(x) > 1]
283-
if len(cols_duplicated) > 0:
284-
str_error = add_str(str_error=f"The following columns are duplicated '{cols_duplicated}'.", str_add=str_add)
285-
raise ValueError(str_error)
284+
if _check_dtype == pd.DataFrame:
285+
columns = list(df)
286+
cols_duplicated = [x for x in columns if columns.count(x) > 1]
287+
if len(cols_duplicated) > 0:
288+
str_error = add_str(str_error=f"The following columns are duplicated '{cols_duplicated}'.", str_add=str_add)
289+
raise ValueError(str_error)
286290

287291

288292
def check_warning_consecutive_index(name="df", df=None):

aaanalysis/seq_analysis/_aalogo.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99

1010

1111
# I Helper function
12+
def check_df_logo_info(df_logo_info=None):
13+
"""Check if df_logo_info has correct format"""
14+
ut.check_df(name="df_logo_info", df=df_logo_info,
15+
check_series=True, accept_none=False, accept_nan=False)
16+
# Additional check specific to logo info: index name
17+
if df_logo_info.index.name != "pos":
18+
raise ValueError("Index name must be 'pos'")
19+
20+
1221
def _adjust_tmd(df_parts=None, tmd_len=None, start_n=False, ):
1322
"""Adjust TMD to have similar length for df logo"""
1423
if ut.COL_TMD in list(df_parts):
@@ -92,22 +101,19 @@ def get_df_logo_info(self,
92101
df_logo_info = df_logo.sum(axis=1) # vals_sum_per_pos
93102
return df_logo_info
94103

95-
def get_conservation(self,
96-
df_logo=None,
104+
@staticmethod
105+
def get_conservation(df_logo_info=None,
97106
value_type: Literal["min", "mean", "median", "max"] = "mean"):
98107
"""Compute conservation scores from sequence logos, ranging from 0 (no conservation) to
99108
4.248 (completely conserved)."""
100-
101-
if self._logo_type != "information":
102-
raise ValueError("Conservation can only be computed for 'logo_type'='information'")
103-
vals_sum_per_pos = df_logo.sum(axis=1)
109+
check_df_logo_info(df_logo_info=df_logo_info)
104110
# Compute the statistic for each scale
105111
if value_type == "min":
106-
conservation = vals_sum_per_pos.min()
112+
cons_val = df_logo_info.min()
107113
elif value_type == "mean":
108-
conservation = vals_sum_per_pos.mean()
114+
cons_val = df_logo_info.mean()
109115
elif value_type == "median":
110-
conservation = vals_sum_per_pos.median()
116+
cons_val = df_logo_info.median()
111117
else:
112-
conservation = vals_sum_per_pos.max()
113-
return conservation
118+
cons_val = df_logo_info.max()
119+
return cons_val

0 commit comments

Comments
 (0)