Skip to content

Commit 975b6ed

Browse files
maliafzalfacebook-github-bot
authored andcommitted
Add logic for fqn_to_feature_names (#3059)
Summary: Pull Request resolved: #3059 # This Diff Added implementation for fqn_to_feature_names method along with initial testing framework and UTs for fqn_to_feature_names # ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training Differential Revision: D75908963
1 parent 3de503f commit 975b6ed

File tree

3 files changed

+592
-7
lines changed

3 files changed

+592
-7
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
from typing import Dict, List, Optional, Union
9+
import logging as logger
10+
from collections import Counter, OrderedDict
11+
from typing import Dict, Iterable, List, Optional, Union
1012

1113
import torch
1214

@@ -49,6 +51,8 @@ class ModelDeltaTracker:
4951
call.
5052
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
5153
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
54+
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.
55+
5256
"""
5357

5458
DEFAULT_CONSUMER: str = "default"
@@ -59,11 +63,15 @@ def __init__(
5963
consumers: Optional[List[str]] = None,
6064
delete_on_read: bool = True,
6165
mode: TrackingMode = TrackingMode.ID_ONLY,
66+
fqns_to_skip: Iterable[str] = (),
6267
) -> None:
6368
self._model = model
6469
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
6570
self._delete_on_read = delete_on_read
6671
self._mode = mode
72+
self._fqn_to_feature_map: Dict[str, List[str]] = {}
73+
self._fqns_to_skip: Iterable[str] = fqns_to_skip
74+
self.fqn_to_feature_names()
6775
pass
6876

6977
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
@@ -85,14 +93,69 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
8593
"""
8694
return {}
8795

88-
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
96+
def fqn_to_feature_names(self) -> Dict[str, List[str]]:
8997
"""
90-
Returns a mapping from FQN to feature names for a given module.
91-
92-
Args:
93-
module (nn.Module): the module to retrieve feature names for.
98+
Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
9499
"""
95-
return {}
100+
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
101+
return self._fqn_to_feature_map
102+
103+
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
104+
table_to_fqn: Dict[str, str] = OrderedDict()
105+
for fqn, named_module in self._model.named_modules():
106+
split_fqn = fqn.split(".")
107+
108+
should_skip = False
109+
for fqn_to_skip in self._fqns_to_skip:
110+
if fqn_to_skip in split_fqn:
111+
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
112+
should_skip = True
113+
break
114+
if should_skip:
115+
continue
116+
117+
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
118+
if isinstance(named_module, SUPPORTED_MODULES):
119+
for table_name, config in named_module._table_name_to_config.items():
120+
logger.info(
121+
f"Found {table_name} for {fqn} with features {config.feature_names}"
122+
)
123+
table_to_feature_names[table_name] = config.feature_names
124+
for table_name in table_to_feature_names:
125+
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
126+
# will incorrectly match fqn with all the table names that have the same prefix
127+
if table_name in split_fqn:
128+
embedding_fqn = fqn.replace("_dmp_wrapped_module.module.", "")
129+
if table_name in table_to_fqn:
130+
# Sanity check for validating that we don't have more then one table mapping to same fqn.
131+
logger.warning(
132+
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
133+
)
134+
table_to_fqn[table_name] = embedding_fqn
135+
logger.info(f"Table to fqn: {table_to_fqn}")
136+
flatten_names = [
137+
name for names in table_to_feature_names.values() for name in names
138+
]
139+
# TODO: Validate if there is a better way to handle duplicate feature names.
140+
# Logging a warning if duplicate feature names are found across tables, but continue execution as this is allowed in some models
141+
if len(set(flatten_names)) != len(flatten_names):
142+
counts = Counter(flatten_names)
143+
duplicates = [item for item, count in counts.items() if count > 1]
144+
logger.warning(f"duplicate feature names found: {duplicates}")
145+
146+
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
147+
for table_name in table_to_feature_names:
148+
if table_name not in table_to_fqn:
149+
# This is likely unexpected, where we can't locate the FQN associated with this table.
150+
logger.warning(
151+
f"Table {table_name} not found in {table_to_fqn}, skipping"
152+
)
153+
continue
154+
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
155+
table_name
156+
]
157+
self._fqn_to_feature_map = fqn_to_feature_names
158+
return fqn_to_feature_names
96159

97160
def clear(self, consumer: Optional[str] = None) -> None:
98161
"""

0 commit comments

Comments
 (0)