Skip to content

Commit 5083f92

Browse files
aliafzalfacebook-github-bot
authored andcommitted
ModelDeltaTracer implementation for tracking logic (#3060)
Summary: ### Diff Summary This diff introduces implementation of tracking logic for ID and Embedding mode 1. **Record Functions** ```record_lookup():``` Handles recording of IDs and embeddings based on the tracking mode. ```record_ids():``` Records IDs from a KeyedJaggedTensor. ```record_embeddings():``` Records IDs along with embeddings, ensuring size compatibility between IDs and embeddings. 2. **Delta Retrieval** ```get_delta():``` Retrieves per FQN local IDs for each sparse feature. 3. **Tracked Modules Access** ```get_tracked_modules():``` Returns a dictionary of tracked modules. ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Differential Revision: D76094097
1 parent 31ef926 commit 5083f92

File tree

1 file changed

+180
-16
lines changed

1 file changed

+180
-16
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 180 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
# pyre-strict
99
import logging as logger
1010
from collections import Counter, OrderedDict
11-
from typing import Dict, Iterable, List, Optional, Union
11+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1212

1313
import torch
1414

1515
from torch import nn
1616
from torchrec.distributed.embedding import ShardedEmbeddingCollection
1717
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
18+
from torchrec.distributed.model_tracker.delta_store import DeltaStore
1819
from torchrec.distributed.model_tracker.types import (
1920
DeltaRows,
2021
EmbdUpdateMode,
@@ -41,16 +42,17 @@ class ModelDeltaTracker:
4142
ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
4243
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
4344
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
44-
online training. Unique IDs or states can be retrieved by calling the get_unique() method.
45+
online training. Unique IDs or states can be retrieved by calling the get_delta() method.
4546
4647
Args:
4748
model (nn.Module): the model to track.
4849
consumers (List[str], optional): list of consumers to track. Each consumer will
49-
have its own batch offset index. Every get_unique_ids invocation will
50-
only return the new ids for the given consumer since last get_unique_ids
51-
call.
50+
have its own batch offset index. Every get_delta and get_delta_ids invocation will
51+
only return the new values for the given consumer since last call.
5252
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
53+
auto_compact (bool, optional):Overlap compaction with communication at each train cycle.
5354
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
55+
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
5456
"""
5557

5658
DEFAULT_CONSUMER: str = "default"
@@ -60,40 +62,178 @@ def __init__(
6062
model: nn.Module,
6163
consumers: Optional[List[str]] = None,
6264
delete_on_read: bool = True,
65+
auto_compact: bool = False,
6366
mode: TrackingMode = TrackingMode.ID_ONLY,
6467
fqns_to_skip: Iterable[str] = (),
6568
) -> None:
6669
self._model = model
6770
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6871
self._delete_on_read = delete_on_read
72+
self._auto_compact = auto_compact
6973
self._mode = mode
7074
self._fqn_to_feature_map: Dict[str, List[str]] = {}
7175
self._fqns_to_skip: Iterable[str] = fqns_to_skip
76+
self.per_consumer_batch_idx: Dict[str, int] = {
77+
c: -1 for c in (consumers or [self.DEFAULT_CONSUMER])
78+
}
79+
self.curr_batch_idx: int = 0
80+
self.curr_compact_index: int = 0
81+
82+
self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode])
83+
84+
# preprocess_fn is used to preprocess the module inputs before tracking.
85+
self.preprocess_fn: Optional[
86+
Callable[..., Tuple[KeyedJaggedTensor, torch.Tensor]]
87+
] = None
88+
89+
# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
90+
self.tracked_modules: Dict[str, nn.Module] = {}
91+
self.feature_to_fqn: Dict[str, str] = {}
92+
# Generate the mapping from FQN to feature names.
7293
self.fqn_to_feature_names()
73-
pass
94+
# Validate the mode is supported for the given module
95+
self._validate_mode()
7496

75-
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
97+
# Mapping feature name to corresponding FQNs. This is used for retrieving
98+
# the FQN associated with a given feature name in record_lookup().
99+
for fqn, feature_names in self._fqn_to_feature_map.items():
100+
for feature_name in feature_names:
101+
if feature_name in self.feature_to_fqn:
102+
logger.warn(f"Duplicate feature name: {feature_name} in fqn {fqn}")
103+
continue
104+
self.feature_to_fqn[feature_name] = fqn
105+
logger.info(f"feature_to_fqn: {self.feature_to_fqn}")
106+
107+
def record_lookup(
108+
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
109+
) -> None:
76110
"""
77111
Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states.
78112
79113
Args:
80114
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
81115
states (torch.Tensor): the states to record.
82116
"""
83-
pass
117+
118+
# In ID_ONLY mode, we only track feature IDs received in the current batch.
119+
if self._mode == TrackingMode.ID_ONLY:
120+
self.record_ids(kjt)
121+
# In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
122+
elif self._mode == TrackingMode.EMBEDDING:
123+
self.record_embeddings(kjt, states)
124+
125+
else:
126+
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
127+
128+
def record_ids(self, kjt: KeyedJaggedTensor) -> None:
129+
"""
130+
Record Ids from a given KeyedJaggedTensor.
131+
132+
Args:
133+
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
134+
"""
135+
per_table_ids: Dict[str, List[torch.Tensor]] = {}
136+
for key in kjt.keys():
137+
table_fqn = self.feature_to_fqn[key]
138+
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
139+
ids_list.append(kjt[key].values())
140+
per_table_ids[table_fqn] = ids_list
141+
142+
for table_fqn, ids_list in per_table_ids.items():
143+
self.store.append(
144+
batch_idx=self.curr_batch_idx,
145+
table_fqn=table_fqn,
146+
ids=torch.cat(ids_list),
147+
embeddings=None,
148+
)
149+
150+
def record_embeddings(
151+
self, kjt: KeyedJaggedTensor, embeddings: torch.Tensor
152+
) -> None:
153+
"""
154+
Record Ids along with Embeddings from a given KeyedJaggedTensor and embeddings.
155+
156+
Args:
157+
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
158+
embeddings (torch.Tensor): the embeddings to record.
159+
"""
160+
per_table_ids: Dict[str, List[torch.Tensor]] = {}
161+
per_table_emb: Dict[str, List[torch.Tensor]] = {}
162+
assert embeddings.numel() % kjt.values().numel() == 0, (
163+
f"ids and embeddings size mismatch, expect [{kjt.values().numel()} * emb_dim], "
164+
f"but got {embeddings.numel()}"
165+
)
166+
embeddings_2d = embeddings.view(kjt.values().numel(), -1)
167+
168+
offset: int = 0
169+
for key in kjt.keys():
170+
table_fqn = self.feature_to_fqn[key]
171+
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
172+
emb_list: List[torch.Tensor] = per_table_emb.get(table_fqn, [])
173+
174+
ids = kjt[key].values()
175+
ids_list.append(ids)
176+
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
177+
offset += ids.numel()
178+
179+
per_table_ids[table_fqn] = ids_list
180+
per_table_emb[table_fqn] = emb_list
181+
182+
for table_fqn, ids_list in per_table_ids.items():
183+
self.store.append(
184+
batch_idx=self.curr_batch_idx,
185+
table_fqn=table_fqn,
186+
ids=torch.cat(ids_list),
187+
embeddings=torch.cat(per_table_emb[table_fqn]),
188+
)
189+
190+
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
191+
"""
192+
Return a dictionary of hit local IDs for each sparse feature. Ids are
193+
first keyed by submodule FQN.
194+
195+
Args:
196+
consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer.
197+
"""
198+
per_table_delta_rows = self.get_delta(consumer)
199+
return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()}
84200

85201
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
86202
"""
87-
Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN.
203+
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
88204
89205
Args:
90-
consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer.
206+
consumer (str, optional): The consumer to retrieve delta values for. If not specified, "default" is used as the default consumer.
91207
"""
92-
return {}
208+
consumer = consumer or self.DEFAULT_CONSUMER
209+
assert (
210+
consumer in self.per_consumer_batch_idx
211+
), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}"
212+
213+
index_end: int = self.curr_batch_idx + 1
214+
index_start = max(self.per_consumer_batch_idx.values())
215+
216+
# In case of multiple consumers, it is possible that the previous consumer has already compact these indices
217+
# and index_start could be equal to index_end, in which case we should not compact again.
218+
if index_start < index_end:
219+
self.compact(index_start, index_end)
220+
tracker_rows = self.store.get_delta(
221+
from_idx=self.per_consumer_batch_idx[consumer]
222+
)
223+
self.per_consumer_batch_idx[consumer] = index_end
224+
if self._delete_on_read:
225+
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
226+
return tracker_rows
227+
228+
def get_tracked_modules(self) -> Dict[str, nn.Module]:
229+
"""
230+
Returns a dictionary of tracked modules.
231+
"""
232+
return self.tracked_modules
93233

94234
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
95235
"""
96-
Returns a mapping from FQN to feature names for a given module.
236+
Returns a mapping from FQN to feature names and updates the tracked_modules dict from the given model with supported modules.
97237
98238
Args:
99239
module (nn.Module): the module to retrieve feature names for.
@@ -114,19 +254,19 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
114254
break
115255
if should_skip:
116256
continue
117-
118257
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
119258
if isinstance(named_module, SUPPORTED_MODULES):
120259
for table_name, config in named_module._table_name_to_config.items():
121260
logger.info(
122261
f"Found {table_name} for {fqn} with features {config.feature_names}"
123262
)
124263
table_to_feature_names[table_name] = config.feature_names
264+
self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module
125265
for table_name in table_to_feature_names:
126266
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
127267
# will incorrectly match fqn with all the table names that have the same prefix
128268
if table_name in split_fqn:
129-
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
269+
embedding_fqn = self._clean_fqn_fn(fqn)
130270
if table_name in table_to_fqn:
131271
# Sanity check for validating that we don't have more then one tbale mapping to same fqn.
132272
logger.warning(
@@ -164,7 +304,19 @@ def clear(self, consumer: Optional[str] = None) -> None:
164304
Args:
165305
consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer.
166306
"""
167-
pass
307+
# 1. If consumer is None, delete globally.
308+
if consumer is None:
309+
self.store.delete()
310+
return
311+
312+
assert (
313+
consumer in self.per_consumer_batch_idx
314+
), f"consumer {consumer} not found in {self.per_consumer_batch_idx.values()}"
315+
316+
# 2. For single consumer, we can just delete all ids
317+
if len(self.per_consumer_batch_idx) == 1:
318+
self.store.delete()
319+
return
168320

169321
def compact(self, start_idx: int, end_idx: int) -> None:
170322
"""
@@ -174,4 +326,16 @@ def compact(self, start_idx: int, end_idx: int) -> None:
174326
start_idx (int): Starting index for compaction.
175327
end_idx (int): Ending index for compaction.
176328
"""
177-
pass
329+
self.store.compact(start_idx, end_idx)
330+
331+
def _clean_fqn_fn(self, fqn: str) -> str:
332+
# strip DMP internal module FQN prefix to match state dict FQN
333+
return fqn.replace("_dmp_wrapped_module.module.", "")
334+
335+
def _validate_mode(self) -> None:
336+
"To validate the mode is supported for the given module"
337+
for module in self.tracked_modules.values():
338+
assert not (
339+
isinstance(module, ShardedEmbeddingBagCollection)
340+
and self._mode == TrackingMode.EMBEDDING
341+
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."

0 commit comments

Comments
 (0)