Skip to content

Commit b291ea0

Browse files
authored
Refactor Modality object (#26)
1 parent 187fb48 commit b291ea0

File tree

33 files changed

+394
-309
lines changed

33 files changed

+394
-309
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ experimentation and research for new techniques.
99
## Quick Start
1010
### Installation
1111
#### Prerequisites
12-
The library requires Python 3.9 or later. We recommend using a virtual environment to manage dependencies. You can create
12+
The library requires Python 3.10 or later. We recommend using a virtual environment to manage dependencies. You can create
1313
a virtual environment using the following command:
1414
```bash
1515
python3 -m venv /path/to/new/virtual/environment

mmlearn/conf/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Hydra/Hydra-zen-based configurations."""
22

33
import functools
4+
import os
45
import warnings
56
from dataclasses import dataclass, field
67
from enum import Enum
@@ -29,7 +30,7 @@
2930

3031
def _get_default_ckpt_dir() -> Any:
3132
"""Get the default checkpoint directory."""
32-
return SI("/checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID}")
33+
return SI("${hydra:runtime.output_dir}/checkpoints")
3334

3435

3536
_DataLoaderConf = builds(

mmlearn/datasets/chexpert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __getitem__(self, idx: int) -> Example:
105105

106106
return Example(
107107
{
108-
Modalities.RGB: image,
108+
Modalities.RGB.name: image,
109109
Modalities.RGB.target: label,
110110
"qid": entry["qid"],
111111
EXAMPLE_INDEX_KEY: idx,

mmlearn/datasets/core/data_collator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from collections.abc import Mapping
44
from dataclasses import dataclass
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, Callable, Optional
66

77
from torch.utils.data import default_collate
88

99
from mmlearn.datasets.core.example import Example
10-
from mmlearn.datasets.core.modalities import Modalities, Modality
10+
from mmlearn.datasets.core.modalities import Modalities
1111

1212

1313
@dataclass
@@ -43,9 +43,9 @@ def __call__(self, examples: list[Example]) -> dict[str, Any]:
4343

4444
if self.batch_processors is not None:
4545
for key, processor in self.batch_processors.items():
46-
batch_key: Union[str, Modality] = key
46+
batch_key: str = key
4747
if Modalities.has_modality(key):
48-
batch_key = Modalities.get_modality(key)
48+
batch_key = Modalities.get_modality(key).name
4949

5050
if batch_key in batch:
5151
batch_processed = processor(batch[batch_key])

mmlearn/datasets/core/modalities.py

Lines changed: 84 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
"""Module for managing supported modalities in the library."""
22

33
import re
4-
from typing import TYPE_CHECKING, Any, Optional
4+
import warnings
5+
from dataclasses import dataclass, field
6+
from typing import Any, ClassVar, Optional
57

68
from typing_extensions import Self
79

810

9-
_default_supported_modalities = ["rgb", "depth", "thermal", "text", "audio", "video"]
11+
_DEFAULT_SUPPORTED_MODALITIES = ["rgb", "depth", "thermal", "text", "audio", "video"]
1012

1113

12-
class Modality(str):
14+
@dataclass
15+
class Modality:
1316
"""Class to represent a modality in the library.
1417
1518
This class is used to represent a modality in the library. It contains the name of
@@ -24,61 +27,46 @@ class Modality(str):
2427
modality_specific_properties : Optional[dict[str, str]], optional, default=None
2528
Additional properties specific to the modality, by default None
2629
27-
Attributes
28-
----------
29-
value : str
30-
The name of the modality.
31-
properties : dict[str, str]
32-
The properties associated with the modality. By default, the properties are
33-
`target`, `mask`, `embedding`, `masked_embedding`, and `ema_embedding`.
34-
These default properties apply to all newly created modality types
35-
automatically. Modality-specific properties can be added using the
36-
`add_property` method or by passing them as a dictionary to the constructor.
30+
Raises
31+
------
32+
ValueError
33+
If the property already exists for the modality or if the format string is
34+
invalid.
3735
"""
3836

39-
_default_properties = {
40-
"target": "{}_target",
41-
"attention_mask": "{}_attention_mask",
42-
"mask": "{}_mask",
43-
"embedding": "{}_embedding",
44-
"masked_embedding": "{}_masked_embedding",
45-
"ema_embedding": "{}_ema_embedding",
46-
}
47-
48-
if TYPE_CHECKING:
49-
50-
def __getattr__(self, attr: str) -> Any:
51-
"""Get the value of the attribute."""
52-
...
53-
54-
def __setattr__(self, attr: str, value: Any) -> None:
55-
"""Set the value of the attribute."""
56-
...
57-
58-
def __new__(
59-
cls, name: str, modality_specific_properties: Optional[dict[str, str]] = None
60-
) -> Self:
37+
name: str
38+
target: str = field(init=False, repr=False)
39+
attention_mask: str = field(init=False, repr=False)
40+
mask: str = field(init=False, repr=False)
41+
embedding: str = field(init=False, repr=False)
42+
masked_embedding: str = field(init=False, repr=False)
43+
ema_embedding: str = field(init=False, repr=False)
44+
modality_specific_properties: Optional[dict[str, str]] = field(
45+
default=None, repr=False
46+
)
47+
48+
def __post_init__(self) -> None:
6149
"""Initialize the modality with the name and properties."""
62-
instance = super(Modality, cls).__new__(cls, name.lower())
63-
properties = cls._default_properties.copy()
64-
if modality_specific_properties is not None:
65-
properties.update(modality_specific_properties)
66-
instance._properties = properties
67-
68-
for property_name, format_string in instance._properties.items():
69-
instance._set_property_as_attr(property_name, format_string)
70-
71-
return instance
72-
73-
@property
74-
def value(self) -> str:
75-
"""Return the name of the modality."""
76-
return self.__str__()
50+
self.name = self.name.lower()
51+
self._properties = {}
52+
53+
for field_name in self.__dataclass_fields__:
54+
if field_name not in ("name", "modality_specific_properties"):
55+
field_value = f"{self.name}_{field_name}"
56+
self._properties[field_name] = field_value
57+
setattr(self, field_name, field_value)
58+
59+
if self.modality_specific_properties is not None:
60+
for (
61+
property_name,
62+
format_string,
63+
) in self.modality_specific_properties.items():
64+
self.add_property(property_name, format_string)
7765

7866
@property
7967
def properties(self) -> dict[str, str]:
8068
"""Return the properties associated with the modality."""
81-
return {name: getattr(self, name) for name in self._properties}
69+
return self._properties
8270

8371
def add_property(self, name: str, format_string: str) -> None:
8472
"""Add a new property to the modality.
@@ -92,49 +80,38 @@ def add_property(self, name: str, format_string: str) -> None:
9280
placeholder that will be replaced with the name of the modality when the
9381
property is accessed.
9482
83+
Warns
84+
-----
85+
UserWarning
86+
If the property already exists for the modality. It will overwrite the
87+
existing property.
88+
9589
Raises
9690
------
9791
ValueError
98-
If the property already exists for the modality or if the format string is
99-
invalid.
92+
If `format_string` is invalid. A valid format string contains at least one
93+
placeholder enclosed in curly braces.
10094
"""
10195
if name in self._properties:
102-
raise ValueError(
96+
warnings.warn(
10397
f"Property '{name}' already exists for modality '{super().__str__()}'."
98+
"Will overwrite the existing property.",
99+
category=UserWarning,
100+
stacklevel=2,
104101
)
105-
self._properties[name] = format_string
106-
self._set_property_as_attr(name, format_string)
107102

108-
def _set_property_as_attr(self, name: str, format_string: str) -> None:
109-
"""Set the property as an attribute of the modality."""
110103
if not _is_format_string(format_string):
111104
raise ValueError(
112105
f"Invalid format string '{format_string}' for property "
113106
f"'{name}' of modality '{super().__str__()}'."
114107
)
115-
setattr(self, name, format_string.format(self.value))
108+
109+
self._properties[name] = format_string.format(self.name)
110+
setattr(self, name, self._properties[name])
116111

117112
def __str__(self) -> str:
118113
"""Return the object as a string."""
119-
return self.lower()
120-
121-
def __repr__(self) -> str:
122-
"""Return the string representation of the modality."""
123-
return f"<Modality: {self.upper()}>"
124-
125-
def __hash__(self) -> int:
126-
"""Return the hash of the modality name and properties."""
127-
return hash((self.value, tuple(self._properties.items())))
128-
129-
def __eq__(self, other: object) -> bool:
130-
"""Check if two modality types are equal.
131-
132-
Two modality types are equal if they have the same name and properties.
133-
"""
134-
return isinstance(other, Modality) and (
135-
(self.__str__() == other.__str__())
136-
and (self._properties == other._properties)
137-
)
114+
return self.name.lower()
138115

139116

140117
class ModalityRegistry:
@@ -146,16 +123,15 @@ class ModalityRegistry:
146123
ensure that there is only one instance of the registry in the library.
147124
"""
148125

149-
_instance = None
126+
_instance: ClassVar[Any] = None
127+
_modality_registry: dict[str, Modality] = {}
150128

151129
def __new__(cls) -> Self:
152130
"""Create a new instance of the class if it does not exist."""
153131
if cls._instance is None:
154-
cls._instance = super(ModalityRegistry, cls).__new__(cls)
155-
cls._instance._modality_registry = {} # type: ignore[attr-defined]
156-
for modality in _default_supported_modalities:
157-
cls._instance.register_modality(modality)
158-
return cls._instance
132+
cls._instance = super().__new__(cls)
133+
cls._instance._modality_registry = {}
134+
return cls._instance # type: ignore[no-any-return]
159135

160136
def register_modality(
161137
self, name: str, modality_specific_properties: Optional[dict[str, str]] = None
@@ -169,13 +145,19 @@ def register_modality(
169145
modality_specific_properties : Optional[dict[str, str]], optional, default=None
170146
Additional properties specific to the modality.
171147
172-
Raises
173-
------
174-
ValueError
175-
If the modality already exists in the registry.
148+
Warns
149+
-----
150+
UserWarning
151+
If the modality already exists in the registry. It will overwrite the
152+
existing modality.
153+
176154
"""
177155
if name.lower() in self._modality_registry:
178-
raise ValueError(f"Modality '{name}' already exists in the registry.")
156+
warnings.warn(
157+
f"Modality '{name}' already exists in the registry. Overwriting...",
158+
category=UserWarning,
159+
stacklevel=2,
160+
)
179161

180162
name = name.lower()
181163
modality = Modality(name, modality_specific_properties)
@@ -194,18 +176,21 @@ def add_default_property(self, name: str, format_string: str) -> None:
194176
placeholder that will be replaced with the name of the modality when the
195177
property is accessed.
196178
179+
Warns
180+
-----
181+
UserWarning
182+
If the property already exists for the default properties. It will
183+
overwrite the existing property.
184+
197185
Raises
198186
------
199187
ValueError
200-
If the property already exists for the default properties or if the format
201-
string is invalid.
188+
If the format string is invalid. A valid format string contains at least one
189+
placeholder enclosed in curly braces.
202190
"""
203191
for modality in self._modality_registry.values():
204192
modality.add_property(name, format_string)
205193

206-
# add the property to the default properties for new modalities
207-
Modality._default_properties[name.lower()] = format_string
208-
209194
def has_modality(self, name: str) -> bool:
210195
"""Check if the modality exists in the registry.
211196
@@ -234,7 +219,7 @@ def get_modality(self, name: str) -> Modality:
234219
Modality
235220
The modality object from the registry.
236221
"""
237-
return self._modality_registry[name.lower()] # type: ignore[index,return-value]
222+
return self._modality_registry[name.lower()]
238223

239224
def get_modality_properties(self, name: str) -> dict[str, str]:
240225
"""Get the properties of a modality from the registry.
@@ -264,7 +249,7 @@ def list_modalities(self) -> list[Modality]:
264249
def __getattr__(self, name: str) -> Modality:
265250
"""Access a modality as an attribute by its name."""
266251
if name.lower() in self._modality_registry:
267-
return self._modality_registry[name.lower()] # type: ignore[index,return-value]
252+
return self._modality_registry[name.lower()]
268253
raise AttributeError(
269254
f"'{self.__class__.__name__}' object has no attribute '{name}'"
270255
)
@@ -292,3 +277,6 @@ def _is_format_string(string: str) -> bool:
292277

293278

294279
Modalities = ModalityRegistry()
280+
281+
for modality in _DEFAULT_SUPPORTED_MODALITIES:
282+
Modalities.register_modality(modality)

mmlearn/datasets/imagenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __getitem__(self, index: int) -> Example:
5858
image, target = super().__getitem__(index)
5959
example = Example(
6060
{
61-
Modalities.RGB: image,
61+
Modalities.RGB.name: image,
6262
Modalities.RGB.target: target,
6363
EXAMPLE_INDEX_KEY: index,
6464
}

mmlearn/datasets/librispeech.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def __getitem__(self, idx: int) -> Example:
108108

109109
return Example(
110110
{
111-
Modalities.AUDIO: waveform,
112-
Modalities.TEXT: transcript,
111+
Modalities.AUDIO.name: waveform,
112+
Modalities.TEXT.name: transcript,
113113
EXAMPLE_INDEX_KEY: idx,
114114
},
115115
)

mmlearn/datasets/llvip.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ def __getitem__(self, idx: int) -> Example:
7070
rgb_image = PILImage.open(rgb_image_path).convert("RGB")
7171
ir_image = PILImage.open(ir_image_path).convert("L")
7272

73-
sample = Example(
73+
example = Example(
7474
{
75-
Modalities.RGB: self.transform(rgb_image),
76-
Modalities.THERMAL: self.transform(ir_image),
75+
Modalities.RGB.name: self.transform(rgb_image),
76+
Modalities.THERMAL.name: self.transform(ir_image),
7777
EXAMPLE_INDEX_KEY: idx,
7878
},
7979
)
@@ -85,11 +85,11 @@ def __getitem__(self, idx: int) -> Example:
8585
.replace("train", "")
8686
)
8787
annot = self._get_bbox(annot_path)
88-
sample["annotation"] = {
88+
example["annotation"] = {
8989
"bboxes": torch.from_numpy(annot["bboxes"]),
9090
"labels": torch.from_numpy(annot["labels"]),
9191
}
92-
return sample
92+
return example
9393

9494
def _get_bbox(self, filename: str) -> Dict[str, np.ndarray]:
9595
"""Parse the XML file to get bounding boxes and labels.

0 commit comments

Comments
 (0)