6
6
# LICENSE file in the root directory of this source tree.
7
7
8
8
# pyre-strict
9
- from typing import Dict , List , Optional , Union
9
+ import logging as logger
10
+ from collections import Counter , OrderedDict
11
+ from typing import Dict , Iterable , List , Optional , Union
10
12
11
13
import torch
12
14
@@ -49,6 +51,8 @@ class ModelDeltaTracker:
49
51
call.
50
52
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
51
53
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54
+ fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55
+
52
56
"""
53
57
54
58
DEFAULT_CONSUMER : str = "default"
@@ -59,11 +63,15 @@ def __init__(
59
63
consumers : Optional [List [str ]] = None ,
60
64
delete_on_read : bool = True ,
61
65
mode : TrackingMode = TrackingMode .ID_ONLY ,
66
+ fqns_to_skip : Iterable [str ] = (),
62
67
) -> None :
63
68
self ._model = model
64
69
self ._consumers : List [str ] = consumers or [self .DEFAULT_CONSUMER ]
65
70
self ._delete_on_read = delete_on_read
66
71
self ._mode = mode
72
+ self ._fqn_to_feature_map : Dict [str , List [str ]] = {}
73
+ self ._fqns_to_skip : Iterable [str ] = fqns_to_skip
74
+ self .fqn_to_feature_names ()
67
75
pass
68
76
69
77
def record_lookup (self , kjt : KeyedJaggedTensor , states : torch .Tensor ) -> None :
@@ -85,14 +93,69 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
85
93
"""
86
94
return {}
87
95
88
- def fqn_to_feature_names (self , module : nn . Module ) -> Dict [str , List [str ]]:
96
+ def fqn_to_feature_names (self ) -> Dict [str , List [str ]]:
89
97
"""
90
- Returns a mapping from FQN to feature names for a given module.
91
-
92
- Args:
93
- module (nn.Module): the module to retrieve feature names for.
98
+ Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
94
99
"""
95
- return {}
100
+ if (self ._fqn_to_feature_map is not None ) and len (self ._fqn_to_feature_map ) > 0 :
101
+ return self ._fqn_to_feature_map
102
+
103
+ table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
104
+ table_to_fqn : Dict [str , str ] = OrderedDict ()
105
+ for fqn , named_module in self ._model .named_modules ():
106
+ split_fqn = fqn .split ("." )
107
+
108
+ should_skip = False
109
+ for fqn_to_skip in self ._fqns_to_skip :
110
+ if fqn_to_skip in split_fqn :
111
+ logger .info (f"Skipping { fqn } because it is part of fqns_to_skip" )
112
+ should_skip = True
113
+ break
114
+ if should_skip :
115
+ continue
116
+
117
+ # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
118
+ if isinstance (named_module , SUPPORTED_MODULES ):
119
+ for table_name , config in named_module ._table_name_to_config .items ():
120
+ logger .info (
121
+ f"Found { table_name } for { fqn } with features { config .feature_names } "
122
+ )
123
+ table_to_feature_names [table_name ] = config .feature_names
124
+ for table_name in table_to_feature_names :
125
+ # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
126
+ # will incorrectly match fqn with all the table names that have the same prefix
127
+ if table_name in split_fqn :
128
+ embedding_fqn = fqn .replace ("_dmp_wrapped_module.module." , "" )
129
+ if table_name in table_to_fqn :
130
+ # Sanity check for validating that we don't have more then one table mapping to same fqn.
131
+ logger .warning (
132
+ f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
133
+ )
134
+ table_to_fqn [table_name ] = embedding_fqn
135
+ logger .info (f"Table to fqn: { table_to_fqn } " )
136
+ flatten_names = [
137
+ name for names in table_to_feature_names .values () for name in names
138
+ ]
139
+ # TODO: Validate if there is a better way to handle duplicate feature names.
140
+ # Logging a warning if duplicate feature names are found across tables, but continue execution as this is allowed in some models
141
+ if len (set (flatten_names )) != len (flatten_names ):
142
+ counts = Counter (flatten_names )
143
+ duplicates = [item for item , count in counts .items () if count > 1 ]
144
+ logger .warning (f"duplicate feature names found: { duplicates } " )
145
+
146
+ fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
147
+ for table_name in table_to_feature_names :
148
+ if table_name not in table_to_fqn :
149
+ # This is likely unexpected, where we can't locate the FQN associated with this table.
150
+ logger .warning (
151
+ f"Table { table_name } not found in { table_to_fqn } , skipping"
152
+ )
153
+ continue
154
+ fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
155
+ table_name
156
+ ]
157
+ self ._fqn_to_feature_map = fqn_to_feature_names
158
+ return fqn_to_feature_names
96
159
97
160
def clear (self , consumer : Optional [str ] = None ) -> None :
98
161
"""
0 commit comments