8
8
# pyre-strict
9
9
import logging as logger
10
10
from collections import Counter , OrderedDict
11
- from typing import Dict , Iterable , List , Optional , Union
11
+ from typing import Callable , Dict , Iterable , List , Optional , Tuple , Union
12
12
13
13
import torch
14
14
15
15
from torch import nn
16
16
from torchrec .distributed .embedding import ShardedEmbeddingCollection
17
17
from torchrec .distributed .embeddingbag import ShardedEmbeddingBagCollection
18
+ from torchrec .distributed .model_tracker .delta_store import DeltaStore
18
19
from torchrec .distributed .model_tracker .types import (
19
20
DeltaRows ,
20
21
EmbdUpdateMode ,
@@ -41,15 +42,16 @@ class ModelDeltaTracker:
41
42
ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
42
43
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
43
44
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
44
- online training. Unique IDs or states can be retrieved by calling the get_unique () method.
45
+ online training. Unique IDs or states can be retrieved by calling the get_delta () method.
45
46
46
47
Args:
47
48
model (nn.Module): the model to track.
48
49
consumers (List[str], optional): list of consumers to track. Each consumer will
49
- have its own batch offset index. Every get_unique_ids invocation will
50
- only return the new ids for the given consumer since last get_unique_ids
51
- call.
50
+ have its own batch offset index. Every get_delta and get_delta_ids invocation will
51
+ only return the new values for the given consumer since last call.
52
52
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
53
+ auto_compact (bool, optional): Trigger compaction automatically during communication at each train cycle.
54
+ When set false, compaction is triggered at get_delta() call. Default: False.
53
55
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54
56
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55
57
@@ -62,36 +64,179 @@ def __init__(
62
64
model : nn .Module ,
63
65
consumers : Optional [List [str ]] = None ,
64
66
delete_on_read : bool = True ,
67
+ auto_compact : bool = False ,
65
68
mode : TrackingMode = TrackingMode .ID_ONLY ,
66
69
fqns_to_skip : Iterable [str ] = (),
67
70
) -> None :
68
71
self ._model = model
69
72
self ._consumers : List [str ] = consumers or [self .DEFAULT_CONSUMER ]
70
73
self ._delete_on_read = delete_on_read
74
+ self ._auto_compact = auto_compact
71
75
self ._mode = mode
72
76
self ._fqn_to_feature_map : Dict [str , List [str ]] = {}
73
77
self ._fqns_to_skip : Iterable [str ] = fqns_to_skip
78
+
79
+ # per_consumer_batch_idx is used to track the batch index for each consumer.
80
+ # This is used to retrieve the delta values for a given consumer as well as
81
+ # start_ids for compaction window.
82
+ self .per_consumer_batch_idx : Dict [str , int ] = {
83
+ c : - 1 for c in (consumers or [self .DEFAULT_CONSUMER ])
84
+ }
85
+ self .curr_batch_idx : int = 0
86
+ self .curr_compact_index : int = 0
87
+
88
+ self .store : DeltaStore = DeltaStore (UPDATE_MODE_MAP [self ._mode ])
89
+
90
+ # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
91
+ self .tracked_modules : Dict [str , nn .Module ] = {}
92
+ self .feature_to_fqn : Dict [str , str ] = {}
93
+ # Generate the mapping from FQN to feature names.
74
94
self .fqn_to_feature_names ()
75
- pass
95
+ # Validate the mode is supported for the given module
96
+ self ._validate_mode ()
97
+
98
+ # Mapping feature name to corresponding FQNs. This is used for retrieving
99
+ # the FQN associated with a given feature name in record_lookup().
100
+ for fqn , feature_names in self ._fqn_to_feature_map .items ():
101
+ for feature_name in feature_names :
102
+ if feature_name in self .feature_to_fqn :
103
+ logger .warn (f"Duplicate feature name: { feature_name } in fqn { fqn } " )
104
+ continue
105
+ self .feature_to_fqn [feature_name ] = fqn
106
+ logger .info (f"feature_to_fqn: { self .feature_to_fqn } " )
107
+
108
+ def record_lookup (
109
+ self , emb_module : nn .Module , kjt : KeyedJaggedTensor , states : torch .Tensor
110
+ ) -> None :
111
+ """
112
+ Records lookups (IDs and optionally embeddings) based on the current tracking mode.
113
+ This method is run post lookup after the embedding lookup has been performed because it needs
114
+ access to both the input IDs and the resulting embeddings.
115
+
116
+ This function processes the input KeyedJaggedTensor and records either just the IDs
117
+ (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode).
118
+
119
+ Args:
120
+ emb_module (nn.Module): The embedding module in which the lookup was performed.
121
+ kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
122
+ states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
123
+ """
124
+
125
+ # In ID_ONLY mode, we only track feature IDs received in the current batch.
126
+ if self ._mode == TrackingMode .ID_ONLY :
127
+ self .record_ids (kjt )
128
+ # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
129
+ elif self ._mode == TrackingMode .EMBEDDING :
130
+ self .record_embeddings (kjt , states )
131
+
132
+ else :
133
+ raise NotImplementedError (f"Tracking mode { self ._mode } is not supported" )
134
+
135
+ def record_ids (self , kjt : KeyedJaggedTensor ) -> None :
136
+ """
137
+ Record Ids from a given KeyedJaggedTensor.
138
+
139
+ Args:
140
+ kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
141
+ """
142
+ per_table_ids : Dict [str , List [torch .Tensor ]] = {}
143
+ for key in kjt .keys ():
144
+ table_fqn = self .feature_to_fqn [key ]
145
+ ids_list : List [torch .Tensor ] = per_table_ids .get (table_fqn , [])
146
+ ids_list .append (kjt [key ].values ())
147
+ per_table_ids [table_fqn ] = ids_list
148
+
149
+ for table_fqn , ids_list in per_table_ids .items ():
150
+ self .store .append (
151
+ batch_idx = self .curr_batch_idx ,
152
+ table_fqn = table_fqn ,
153
+ ids = torch .cat (ids_list ),
154
+ embeddings = None ,
155
+ )
76
156
77
- def record_lookup (self , kjt : KeyedJaggedTensor , states : torch .Tensor ) -> None :
157
+ def record_embeddings (
158
+ self , kjt : KeyedJaggedTensor , embeddings : torch .Tensor
159
+ ) -> None :
78
160
"""
79
- Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states .
161
+ Record Ids along with Embeddings from a given KeyedJaggedTensor and embeddings.
80
162
81
163
Args:
82
164
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
83
- states (torch.Tensor): the states to record.
165
+ embeddings (torch.Tensor): the embeddings to record.
166
+ """
167
+ per_table_ids : Dict [str , List [torch .Tensor ]] = {}
168
+ per_table_emb : Dict [str , List [torch .Tensor ]] = {}
169
+ assert embeddings .numel () % kjt .values ().numel () == 0 , (
170
+ f"ids and embeddings size mismatch, expect [{ kjt .values ().numel ()} * emb_dim], "
171
+ f"but got { embeddings .numel ()} "
172
+ )
173
+ embeddings_2d = embeddings .view (kjt .values ().numel (), - 1 )
174
+
175
+ offset : int = 0
176
+ for key in kjt .keys ():
177
+ table_fqn = self .feature_to_fqn [key ]
178
+ ids_list : List [torch .Tensor ] = per_table_ids .get (table_fqn , [])
179
+ emb_list : List [torch .Tensor ] = per_table_emb .get (table_fqn , [])
180
+
181
+ ids = kjt [key ].values ()
182
+ ids_list .append (ids )
183
+ emb_list .append (embeddings_2d [offset : offset + ids .numel ()])
184
+ offset += ids .numel ()
185
+
186
+ per_table_ids [table_fqn ] = ids_list
187
+ per_table_emb [table_fqn ] = emb_list
188
+
189
+ for table_fqn , ids_list in per_table_ids .items ():
190
+ self .store .append (
191
+ batch_idx = self .curr_batch_idx ,
192
+ table_fqn = table_fqn ,
193
+ ids = torch .cat (ids_list ),
194
+ embeddings = torch .cat (per_table_emb [table_fqn ]),
195
+ )
196
+
197
+ def get_delta_ids (self , consumer : Optional [str ] = None ) -> Dict [str , torch .Tensor ]:
198
+ """
199
+ Return a dictionary of hit local IDs for each sparse feature. Ids are
200
+ first keyed by submodule FQN.
201
+
202
+ Args:
203
+ consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer.
84
204
"""
85
- pass
205
+ per_table_delta_rows = self .get_delta (consumer )
206
+ return {fqn : delta_rows .ids for fqn , delta_rows in per_table_delta_rows .items ()}
86
207
87
208
def get_delta (self , consumer : Optional [str ] = None ) -> Dict [str , DeltaRows ]:
88
209
"""
89
- Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN.
210
+ Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
90
211
91
212
Args:
92
- consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer.
213
+ consumer (str, optional): The consumer to retrieve delta values for. If not specified, "default" is used as the default consumer.
93
214
"""
94
- return {}
215
+ consumer = consumer or self .DEFAULT_CONSUMER
216
+ assert (
217
+ consumer in self .per_consumer_batch_idx
218
+ ), f"consumer { consumer } not present in { self .per_consumer_batch_idx .values ()} "
219
+
220
+ index_end : int = self .curr_batch_idx + 1
221
+ index_start = max (self .per_consumer_batch_idx .values ())
222
+
223
+ # In case of multiple consumers, it is possible that the previous consumer has already compact these indices
224
+ # and index_start could be equal to index_end, in which case we should not compact again.
225
+ if index_start < index_end :
226
+ self .compact (index_start , index_end )
227
+ tracker_rows = self .store .get_delta (
228
+ from_idx = self .per_consumer_batch_idx [consumer ]
229
+ )
230
+ self .per_consumer_batch_idx [consumer ] = index_end
231
+ if self ._delete_on_read :
232
+ self .store .delete (up_to_idx = min (self .per_consumer_batch_idx .values ()))
233
+ return tracker_rows
234
+
235
+ def get_tracked_modules (self ) -> Dict [str , nn .Module ]:
236
+ """
237
+ Returns a dictionary of tracked modules.
238
+ """
239
+ return self .tracked_modules
95
240
96
241
def fqn_to_feature_names (self ) -> Dict [str , List [str ]]:
97
242
"""
@@ -114,19 +259,19 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
114
259
break
115
260
if should_skip :
116
261
continue
117
-
118
262
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119
263
if isinstance (named_module , SUPPORTED_MODULES ):
120
264
for table_name , config in named_module ._table_name_to_config .items ():
121
265
logger .info (
122
266
f"Found { table_name } for { fqn } with features { config .feature_names } "
123
267
)
124
268
table_to_feature_names [table_name ] = config .feature_names
269
+ self .tracked_modules [self ._clean_fqn_fn (fqn )] = named_module
125
270
for table_name in table_to_feature_names :
126
271
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127
272
# will incorrectly match fqn with all the table names that have the same prefix
128
273
if table_name in split_fqn :
129
- embedding_fqn = fqn . replace ( "_dmp_wrapped_module.module." , "" )
274
+ embedding_fqn = self . _clean_fqn_fn ( fqn )
130
275
if table_name in table_to_fqn :
131
276
# Sanity check for validating that we don't have more then one table mapping to same fqn.
132
277
logger .warning (
@@ -165,7 +310,19 @@ def clear(self, consumer: Optional[str] = None) -> None:
165
310
Args:
166
311
consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer.
167
312
"""
168
- pass
313
+ # 1. If consumer is None, delete globally.
314
+ if consumer is None :
315
+ self .store .delete ()
316
+ return
317
+
318
+ assert (
319
+ consumer in self .per_consumer_batch_idx
320
+ ), f"consumer { consumer } not found in { self .per_consumer_batch_idx .values ()} "
321
+
322
+ # 2. For single consumer, we can just delete all ids
323
+ if len (self .per_consumer_batch_idx ) == 1 :
324
+ self .store .delete ()
325
+ return
169
326
170
327
def compact (self , start_idx : int , end_idx : int ) -> None :
171
328
"""
@@ -175,4 +332,16 @@ def compact(self, start_idx: int, end_idx: int) -> None:
175
332
start_idx (int): Starting index for compaction.
176
333
end_idx (int): Ending index for compaction.
177
334
"""
178
- pass
335
+ self .store .compact (start_idx , end_idx )
336
+
337
+ def _clean_fqn_fn (self , fqn : str ) -> str :
338
+ # strip DMP internal module FQN prefix to match state dict FQN
339
+ return fqn .replace ("_dmp_wrapped_module.module." , "" )
340
+
341
+ def _validate_mode (self ) -> None :
342
+ "To validate the mode is supported for the given module"
343
+ for module in self .tracked_modules .values ():
344
+ assert not (
345
+ isinstance (module , ShardedEmbeddingBagCollection )
346
+ and self ._mode == TrackingMode .EMBEDDING
347
+ ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."
0 commit comments