11"""Module for managing supported modalities in the library.""" 
22
33import  re 
4- from  typing  import  TYPE_CHECKING , Any , Optional 
4+ import  warnings 
5+ from  dataclasses  import  dataclass , field 
6+ from  typing  import  Any , ClassVar , Optional 
57
68from  typing_extensions  import  Self 
79
810
9- _default_supported_modalities  =  ["rgb" , "depth" , "thermal" , "text" , "audio" , "video" ]
11+ _DEFAULT_SUPPORTED_MODALITIES  =  ["rgb" , "depth" , "thermal" , "text" , "audio" , "video" ]
1012
1113
12- class  Modality (str ):
14+ @dataclass  
15+ class  Modality :
1316    """Class to represent a modality in the library. 
1417
1518    This class is used to represent a modality in the library. It contains the name of 
@@ -24,61 +27,46 @@ class Modality(str):
2427    modality_specific_properties : Optional[dict[str, str]], optional, default=None 
2528        Additional properties specific to the modality, by default None 
2629
27-     Attributes 
28-     ---------- 
29-     value : str 
30-         The name of the modality. 
31-     properties : dict[str, str] 
32-         The properties associated with the modality. By default, the properties are 
33-         `target`, `mask`, `embedding`, `masked_embedding`, and `ema_embedding`. 
34-         These default properties apply to all newly created modality types 
35-         automatically. Modality-specific properties can be added using the 
36-         `add_property` method or by passing them as a dictionary to the constructor. 
30+     Raises 
31+     ------ 
32+     ValueError 
33+         If the property already exists for the modality or if the format string is 
34+         invalid. 
3735    """ 
3836
39-     _default_properties  =  {
40-         "target" : "{}_target" ,
41-         "attention_mask" : "{}_attention_mask" ,
42-         "mask" : "{}_mask" ,
43-         "embedding" : "{}_embedding" ,
44-         "masked_embedding" : "{}_masked_embedding" ,
45-         "ema_embedding" : "{}_ema_embedding" ,
46-     }
47- 
48-     if  TYPE_CHECKING :
49- 
50-         def  __getattr__ (self , attr : str ) ->  Any :
51-             """Get the value of the attribute.""" 
52-             ...
53- 
54-         def  __setattr__ (self , attr : str , value : Any ) ->  None :
55-             """Set the value of the attribute.""" 
56-             ...
57- 
58-     def  __new__ (
59-         cls , name : str , modality_specific_properties : Optional [dict [str , str ]] =  None 
60-     ) ->  Self :
37+     name : str 
38+     target : str  =  field (init = False , repr = False )
39+     attention_mask : str  =  field (init = False , repr = False )
40+     mask : str  =  field (init = False , repr = False )
41+     embedding : str  =  field (init = False , repr = False )
42+     masked_embedding : str  =  field (init = False , repr = False )
43+     ema_embedding : str  =  field (init = False , repr = False )
44+     modality_specific_properties : Optional [dict [str , str ]] =  field (
45+         default = None , repr = False 
46+     )
47+ 
48+     def  __post_init__ (self ) ->  None :
6149        """Initialize the modality with the name and properties.""" 
62-         instance  =  super ( Modality ,  cls ). __new__ ( cls ,  name .lower () )
63-         properties  =  cls . _default_properties . copy () 
64-          if   modality_specific_properties   is   not   None : 
65-              properties . update ( modality_specific_properties ) 
66-         instance . _properties   =   properties 
67- 
68-         for   property_name ,  format_string   in   instance ._properties . items (): 
69-             instance . _set_property_as_attr ( property_name ,  format_string )
70- 
71-         return   instance 
72- 
73-     @ property 
74-     def   value ( self )  ->   str : 
75-         """Return the name of the modality.""" 
76-         return   self .__str__ ( )
50+         self . name  =  self . name .lower ()
51+         self . _properties  =  {} 
52+ 
53+         for   field_name   in   self . __dataclass_fields__ : 
54+              if   field_name   not   in  ( "name" ,  "modality_specific_properties" ): 
55+                  field_value   =   f" { self . name } _ { field_name } " 
56+                  self ._properties [ field_name ]  =   field_value 
57+                  setattr ( self ,  field_name ,  field_value )
58+ 
59+         if   self . modality_specific_properties   is   not   None : 
60+              for  ( 
61+                  property_name , 
62+                  format_string , 
63+             )  in   self . modality_specific_properties . items (): 
64+                  self .add_property ( property_name ,  format_string )
7765
7866    @property  
7967    def  properties (self ) ->  dict [str , str ]:
8068        """Return the properties associated with the modality.""" 
81-         return  { name :  getattr ( self ,  name )  for   name   in   self ._properties } 
69+         return  self ._properties 
8270
8371    def  add_property (self , name : str , format_string : str ) ->  None :
8472        """Add a new property to the modality. 
@@ -92,49 +80,38 @@ def add_property(self, name: str, format_string: str) -> None:
9280            placeholder that will be replaced with the name of the modality when the 
9381            property is accessed. 
9482
83+         Warns 
84+         ----- 
85+         UserWarning 
86+             If the property already exists for the modality. It will overwrite the 
87+             existing property. 
88+ 
9589        Raises 
9690        ------ 
9791        ValueError 
98-             If the property already exists for the modality or if the format string is  
99-             invalid . 
92+             If `format_string` is invalid. A valid format string contains at least one  
93+             placeholder enclosed in curly braces . 
10094        """ 
10195        if  name  in  self ._properties :
102-             raise   ValueError (
96+             warnings . warn (
10397                f"Property '{ name } { super ().__str__ ()}  
98+                 "Will overwrite the existing property." ,
99+                 category = UserWarning ,
100+                 stacklevel = 2 ,
104101            )
105-         self ._properties [name ] =  format_string 
106-         self ._set_property_as_attr (name , format_string )
107102
108-     def  _set_property_as_attr (self , name : str , format_string : str ) ->  None :
109-         """Set the property as an attribute of the modality.""" 
110103        if  not  _is_format_string (format_string ):
111104            raise  ValueError (
112105                f"Invalid format string '{ format_string }  
113106                f"'{ name } { super ().__str__ ()}  
114107            )
115-         setattr (self , name , format_string .format (self .value ))
108+ 
109+         self ._properties [name ] =  format_string .format (self .name )
110+         setattr (self , name , self ._properties [name ])
116111
117112    def  __str__ (self ) ->  str :
118113        """Return the object as a string.""" 
119-         return  self .lower ()
120- 
121-     def  __repr__ (self ) ->  str :
122-         """Return the string representation of the modality.""" 
123-         return  f"<Modality: { self .upper ()}  
124- 
125-     def  __hash__ (self ) ->  int :
126-         """Return the hash of the modality name and properties.""" 
127-         return  hash ((self .value , tuple (self ._properties .items ())))
128- 
129-     def  __eq__ (self , other : object ) ->  bool :
130-         """Check if two modality types are equal. 
131- 
132-         Two modality types are equal if they have the same name and properties. 
133-         """ 
134-         return  isinstance (other , Modality ) and  (
135-             (self .__str__ () ==  other .__str__ ())
136-             and  (self ._properties  ==  other ._properties )
137-         )
114+         return  self .name .lower ()
138115
139116
140117class  ModalityRegistry :
@@ -146,16 +123,15 @@ class ModalityRegistry:
146123    ensure that there is only one instance of the registry in the library. 
147124    """ 
148125
149-     _instance  =  None 
126+     _instance : ClassVar [Any ] =  None 
127+     _modality_registry : dict [str , Modality ] =  {}
150128
151129    def  __new__ (cls ) ->  Self :
152130        """Create a new instance of the class if it does not exist.""" 
153131        if  cls ._instance  is  None :
154-             cls ._instance  =  super (ModalityRegistry , cls ).__new__ (cls )
155-             cls ._instance ._modality_registry  =  {}  # type: ignore[attr-defined] 
156-             for  modality  in  _default_supported_modalities :
157-                 cls ._instance .register_modality (modality )
158-         return  cls ._instance 
132+             cls ._instance  =  super ().__new__ (cls )
133+             cls ._instance ._modality_registry  =  {}
134+         return  cls ._instance   # type: ignore[no-any-return] 
159135
160136    def  register_modality (
161137        self , name : str , modality_specific_properties : Optional [dict [str , str ]] =  None 
@@ -169,13 +145,19 @@ def register_modality(
169145        modality_specific_properties : Optional[dict[str, str]], optional, default=None 
170146            Additional properties specific to the modality. 
171147
172-         Raises 
173-         ------ 
174-         ValueError 
175-             If the modality already exists in the registry. 
148+         Warns 
149+         ----- 
150+         UserWarning 
151+             If the modality already exists in the registry. It will overwrite the 
152+             existing modality. 
153+ 
176154        """ 
177155        if  name .lower () in  self ._modality_registry :
178-             raise  ValueError (f"Modality '{ name }  )
156+             warnings .warn (
157+                 f"Modality '{ name }  ,
158+                 category = UserWarning ,
159+                 stacklevel = 2 ,
160+             )
179161
180162        name  =  name .lower ()
181163        modality  =  Modality (name , modality_specific_properties )
@@ -194,18 +176,21 @@ def add_default_property(self, name: str, format_string: str) -> None:
194176            placeholder that will be replaced with the name of the modality when the 
195177            property is accessed. 
196178
179+         Warns 
180+         ----- 
181+         UserWarning 
182+             If the property already exists for the default properties. It will 
183+             overwrite the existing property. 
184+ 
197185        Raises 
198186        ------ 
199187        ValueError 
200-             If the property already exists for the default properties or if the format  
201-             string is invalid . 
188+             If the format string is invalid. A valid format string contains at least one  
189+             placeholder enclosed in curly braces . 
202190        """ 
203191        for  modality  in  self ._modality_registry .values ():
204192            modality .add_property (name , format_string )
205193
206-         # add the property to the default properties for new modalities 
207-         Modality ._default_properties [name .lower ()] =  format_string 
208- 
209194    def  has_modality (self , name : str ) ->  bool :
210195        """Check if the modality exists in the registry. 
211196
@@ -234,7 +219,7 @@ def get_modality(self, name: str) -> Modality:
234219        Modality 
235220            The modality object from the registry. 
236221        """ 
237-         return  self ._modality_registry [name .lower ()]   # type: ignore[index,return-value] 
222+         return  self ._modality_registry [name .lower ()]
238223
239224    def  get_modality_properties (self , name : str ) ->  dict [str , str ]:
240225        """Get the properties of a modality from the registry. 
@@ -264,7 +249,7 @@ def list_modalities(self) -> list[Modality]:
264249    def  __getattr__ (self , name : str ) ->  Modality :
265250        """Access a modality as an attribute by its name.""" 
266251        if  name .lower () in  self ._modality_registry :
267-             return  self ._modality_registry [name .lower ()]   # type: ignore[index,return-value] 
252+             return  self ._modality_registry [name .lower ()]
268253        raise  AttributeError (
269254            f"'{ self .__class__ .__name__ } { name }  
270255        )
@@ -292,3 +277,6 @@ def _is_format_string(string: str) -> bool:
292277
293278
294279Modalities  =  ModalityRegistry ()
280+ 
281+ for  modality  in  _DEFAULT_SUPPORTED_MODALITIES :
282+     Modalities .register_modality (modality )
0 commit comments