Skip to content

Commit dd20e10

Browse files
salizade22facebook-github-bot
authored andcommitted
Adjust _pre_load_state_dict_hook to skip excluded tensors (#3208)
Summary: Pull Request resolved: #3208 For more context please refer to S539698 We see key errors when calling `_pre_load_state_dict_hook`. Upon code inspection, we see the following pattern, a dictionary key is manually constructed (https://fburl.com/code/kg6b1y97), and then code directly calls `state_dict[key]`, without checking if a key is in the state dict. With this change we skip keys that are not in state dict, which avoid keyerror Reviewed By: iamzainhuda Differential Revision: D78511161 fbshipit-source-id: 2fa5cc6ef04ecfad15ecfec6731ef3ed5579e3ec
1 parent 45d5c4d commit dd20e10

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,11 @@ def _pre_load_state_dict_hook(
836836
continue
837837

838838
key = f"{prefix}embedding_bags.{table_name}.weight"
839+
840+
# If key not in state dict, continue
841+
if key not in state_dict:
842+
continue
843+
839844
# gather model shards from both DTensor and ShardedTensor maps
840845
model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[
841846
table_name

0 commit comments

Comments
 (0)