Skip to content

Commit 4c2fc4a

Browse files
authored
Merge pull request #12 from bioscan-ml/enh_index2label
ENH: Add index2label and label2index methods
2 parents a5a1456 + 55a66b5 commit 4c2fc4a

File tree

5 files changed

+154
-4
lines changed

5 files changed

+154
-4
lines changed

README.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ If this is set to ``target_format="text"``, the output will instead be the raw l
200200
The default setting is ``target_format="index"``.
201201
Note that if multiple targets types are given, each label will be returned in the same format.
202202

203+
To map target indices back to text labels, the dataset class provides the ``index2label`` method.
204+
Similarly, the ``label2index`` method can be used to map text labels to indices.
205+
203206

204207
Data transforms
205208
~~~~~~~~~~~~~~~

bioscan_dataset/bioscan1m.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
import os
1212
from enum import Enum
13-
from typing import Any, Tuple
13+
from typing import Any, Iterable, Tuple, Union
1414

15+
import numpy as np
16+
import numpy.typing as npt
1517
import pandas
1618
import PIL
1719
import torch
@@ -339,11 +341,74 @@ def __init__(
339341

340342
self._load_metadata()
341343

344+
def index2label(self, column: str, index: Union[int, Iterable[int]]) -> Union[str, npt.NDArray[np.str_]]:
345+
r"""
346+
Convert target's integer index to text label.
347+
348+
.. versionadded:: 1.1.0
349+
350+
Parameters
351+
----------
352+
column : str
353+
The dataset column name to map. This is the same as the ``target_type``.
354+
index : int or Iterable[int]
355+
The integer index or indices to map to labels.
356+
357+
Returns
358+
-------
359+
str or numpy.array[str]
360+
The text label or labels corresponding to the integer index or indices
361+
in the specified column.
362+
Entries containing missing values, indicated by negative indices, are mapped
363+
to an empty string.
364+
"""
365+
if not hasattr(index, "__len__"):
366+
# Single index
367+
if index < 0:
368+
return ""
369+
return self.metadata[column].cat.categories[index]
370+
index = np.asarray(index)
371+
out = self.metadata[column].cat.categories[index]
372+
out = np.asarray(out)
373+
out[index < 0] = ""
374+
return out
375+
376+
def label2index(self, column: str, label: Union[str, Iterable[str]]) -> Union[int, npt.NDArray[np.int_]]:
377+
r"""
378+
Convert target's text label to integer index.
379+
380+
.. versionadded:: 1.1.0
381+
382+
Parameters
383+
----------
384+
column : str
385+
The dataset column name to map. This is the same as the ``target_type``.
386+
label : str or Iterable[str]
387+
The text label or labels to map to integer indices.
388+
389+
Returns
390+
-------
391+
int or numpy.array[int]
392+
The integer index or indices corresponding to the text label or labels
393+
in the specified column.
394+
Entries containing missing values, indicated by empty strings, are mapped
395+
to ``-1``.
396+
"""
397+
if isinstance(label, str):
398+
# Single index
399+
if label == "":
400+
return -1
401+
return self.metadata[column].cat.categories.get_loc(label)
402+
labels = label
403+
out = [-1 if lab == "" else self.metadata[column].cat.categories.get_loc(lab) for lab in labels]
404+
out = np.asarray(out)
405+
return out
406+
342407
def __len__(self):
343408
return len(self.metadata)
344409

345410
def __getitem__(self, index: int) -> Tuple[Any, ...]:
346-
"""
411+
r"""
347412
Get a sample from the dataset.
348413
349414
Parameters

bioscan_dataset/bioscan5m.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
import os
1212
from enum import Enum
13-
from typing import Any, Tuple
13+
from typing import Any, Iterable, Tuple, Union
1414

15+
import numpy as np
16+
import numpy.typing as npt
1517
import pandas
1618
import PIL
1719
import torch
@@ -406,11 +408,88 @@ def __init__(
406408

407409
self._load_metadata()
408410

411+
def index2label(self, column: str, index: Union[int, Iterable[int]]) -> Union[str, npt.NDArray[np.str_]]:
412+
r"""
413+
Convert target's integer index to text label.
414+
415+
.. versionadded:: 1.1.0
416+
417+
Parameters
418+
----------
419+
column : str
420+
The dataset column name to map. This is the same as the ``target_type``.
421+
index : int or Iterable[int]
422+
The integer index or indices to map to labels.
423+
424+
Returns
425+
-------
426+
str or numpy.array[str]
427+
The text label or labels corresponding to the integer index or indices
428+
in the specified column.
429+
Entries containing missing values, indicated by negative indices, are mapped
430+
to an empty string.
431+
432+
Examples
433+
--------
434+
>>> dataset.index2label("order", [4])
435+
'Diptera'
436+
>>> dataset.index2label("order", [4, 9, -1, 4])
437+
array(['Diptera', 'Lepidoptera', '', 'Diptera'], dtype=object)
438+
"""
439+
if not hasattr(index, "__len__"):
440+
# Single index
441+
if index < 0:
442+
return ""
443+
return self.metadata[column].cat.categories[index]
444+
index = np.asarray(index)
445+
out = self.metadata[column].cat.categories[index]
446+
out = np.asarray(out)
447+
out[index < 0] = ""
448+
return out
449+
450+
def label2index(self, column: str, label: Union[str, Iterable[str]]) -> Union[int, npt.NDArray[np.int_]]:
451+
r"""
452+
Convert target's text label to integer index.
453+
454+
.. versionadded:: 1.1.0
455+
456+
Parameters
457+
----------
458+
column : str
459+
The dataset column name to map. This is the same as the ``target_type``.
460+
label : str or Iterable[str]
461+
The text label or labels to map to integer indices.
462+
463+
Returns
464+
-------
465+
int or numpy.array[int]
466+
The integer index or indices corresponding to the text label or labels
467+
in the specified column.
468+
Entries containing missing values, indicated by empty strings, are mapped
469+
to ``-1``.
470+
471+
Examples
472+
--------
473+
>>> dataset.label2index("order", "Diptera")
474+
4
475+
>>> dataset.label2index("order", ["Diptera", "Lepidoptera", "", "Diptera"])
476+
array([4, 9, -1, 4])
477+
"""
478+
if isinstance(label, str):
479+
# Single index
480+
if label == "":
481+
return -1
482+
return self.metadata[column].cat.categories.get_loc(label)
483+
labels = label
484+
out = [-1 if lab == "" else self.metadata[column].cat.categories.get_loc(lab) for lab in labels]
485+
out = np.asarray(out)
486+
return out
487+
409488
def __len__(self):
410489
return len(self.metadata)
411490

412491
def __getitem__(self, index: int) -> Tuple[Any, ...]:
413-
"""
492+
r"""
414493
Get a sample from the dataset.
415494
416495
Parameters

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def auto_convert_readme(_):
6868
# Change hard API URLs to dynamically generated class links
6969
readme_rst = readme_rst.replace("`BIOSCAN1M <BS1M-class_>`_", ":class:`~.bioscan_dataset.BIOSCAN1M`")
7070
readme_rst = readme_rst.replace("`BIOSCAN5M <BS5M-class_>`_", ":class:`~.bioscan_dataset.BIOSCAN5M`")
71+
readme_rst = readme_rst.replace("``index2label``", ":meth:`~.bioscan_dataset.BIOSCAN5M.index2label`")
72+
readme_rst = readme_rst.replace("``label2index``", ":meth:`~.bioscan_dataset.BIOSCAN5M.label2index`")
7173
print(f"Writing {readme_path_output}")
7274
with open(readme_path_output, "w") as f:
7375
f.write(readme_rst)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
numpy>=1.21.0
12
pandas>=1.0.0
23
Pillow>=4.1.1
34
torch>=1.4.0

0 commit comments

Comments
 (0)