From de77e2706180f565e414772fd3bcbe6113212c56 Mon Sep 17 00:00:00 2001 From: Scott Lowe Date: Sat, 19 Apr 2025 13:32:33 +0100 Subject: [PATCH] BUG: Support mapping NaN -> -1 in label2index We need to support this because the value returned by the dataset for missing values in all categorical columns is NaN, even for columns that contain taxonomic strings. --- bioscan_dataset/bioscan1m.py | 15 +++++++++------ bioscan_dataset/bioscan5m.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/bioscan_dataset/bioscan1m.py b/bioscan_dataset/bioscan1m.py index ac3311e..0d50dcc 100644 --- a/bioscan_dataset/bioscan1m.py +++ b/bioscan_dataset/bioscan1m.py @@ -891,8 +891,8 @@ def label2index( int or numpy.array[int] The integer index or indices corresponding to the text label or labels in the specified column. - Entries containing missing values, indicated by empty strings, are mapped - to ``-1``. + Entries containing missing values, indicated by empty strings or NaN values, + are mapped to ``-1``. """ if column is not None: pass @@ -900,10 +900,10 @@ def label2index( column = self.target_type[0] else: raise ValueError("column must be specified if there isn't a single target_type") - if isinstance(label, str): + if pandas.isna(label) or label == "": # Single index - if label == "": - return -1 + return -1 + if isinstance(label, str): try: return self.metadata[column].cat.categories.get_loc(label) except KeyError: @@ -915,7 +915,10 @@ def label2index( ) labels = label try: - out = [-1 if lab == "" else self.metadata[column].cat.categories.get_loc(lab) for lab in labels] + out = [ + -1 if lab == "" or pandas.isna(lab) else self.metadata[column].cat.categories.get_loc(lab) + for lab in labels + ] except KeyError: raise KeyError(f"Label {repr(label)} not found in metadata column {repr(column)}") from None out = np.asarray(out) diff --git a/bioscan_dataset/bioscan5m.py b/bioscan_dataset/bioscan5m.py index 556a676..0577857 100644 --- a/bioscan_dataset/bioscan5m.py +++ b/bioscan_dataset/bioscan5m.py @@ -612,8 +612,8 @@ def label2index( int or numpy.array[int] The integer index or indices corresponding to the text label or labels in the specified column. - Entries containing missing values, indicated by empty strings, are mapped - to ``-1``. + Entries containing missing values, indicated by empty strings or NaN values, + are mapped to ``-1``. Examples -------- @@ -628,10 +628,10 @@ def label2index( column = self.target_type[0] else: raise ValueError("column must be specified if there isn't a single target_type") - if isinstance(label, str): + if pandas.isna(label) or label == "": # Single index - if label == "": - return -1 + return -1 + if isinstance(label, str): try: return self.metadata[column].cat.categories.get_loc(label) except KeyError: @@ -643,7 +643,10 @@ def label2index( ) labels = label try: - out = [-1 if lab == "" else self.metadata[column].cat.categories.get_loc(lab) for lab in labels] + out = [ + -1 if lab == "" or pandas.isna(lab) else self.metadata[column].cat.categories.get_loc(lab) + for lab in labels + ] except KeyError: raise KeyError(f"Label {repr(label)} not found in metadata column {repr(column)}") from None out = np.asarray(out)