Skip to content

Commit 64e1581

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Adding eviction metadata tensor fqn (pytorch#4611)
Summary: X-link: pytorch/torchrec#3247 Pull Request resolved: pytorch#4611 X-link: facebookresearch/FBGEMM#1646 Adding a new metadata fqn in kvzch ckpt, which is needed for eviction filter in publishing. Differential Revision: D78768842
1 parent e8d708c commit 64e1581

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2783,6 +2783,7 @@ def split_embedding_weights(
27832783
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
27842784
Optional[List[torch.Tensor]],
27852785
Optional[List[torch.Tensor]],
2786+
Optional[List[torch.Tensor]],
27862787
]:
27872788
"""
27882789
This method is intended to be used by the checkpointing engine
@@ -2802,6 +2803,7 @@ def split_embedding_weights(
28022803
2nd arg: input id sorted in bucket id ascending order
28032804
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
28042805
where for the i th element, we have i + bucket_id_start = global bucket id
2806+
4th arg: kvzch eviction metadata for each input id sorted in bucket id ascending order
28052807
"""
28062808
snapshot_handle, checkpoint_handle = self._may_create_snapshot_for_state_dict(
28072809
no_snapshot=no_snapshot,
@@ -2818,16 +2820,19 @@ def split_embedding_weights(
28182820
self._cached_kvzch_data.cached_weight_tensor_per_table,
28192821
self._cached_kvzch_data.cached_id_tensor_per_table,
28202822
self._cached_kvzch_data.cached_bucket_splits,
2823+
[], # metadata tensor is not needed for checkpointing loading
28212824
)
28222825
start_time = time.time()
28232826
pmt_splits = []
28242827
bucket_sorted_id_splits = [] if self.kv_zch_params else None
28252828
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
2829+
metadata_splits = [] if self.kv_zch_params else None
28262830

28272831
table_offset = 0
28282832
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
28292833
bucket_ascending_id_tensor = None
28302834
bucket_t = None
2835+
metadata_tensor = None
28312836
row_offset = table_offset
28322837
metaheader_dim = 0
28332838
if self.kv_zch_params:
@@ -2859,6 +2864,12 @@ def split_embedding_weights(
28592864
bucket_size,
28602865
)
28612866
)
2867+
metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
2868+
bucket_ascending_id_tensor,
2869+
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
2870+
snapshot_handle,
2871+
)
2872+
28622873
# 3. convert local id back to global id
28632874
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
28642875

@@ -2874,11 +2885,17 @@ def split_embedding_weights(
28742885
device=torch.device("cpu"),
28752886
dtype=torch.int64,
28762887
)
2888+
metadata_tensor = torch.zeros(
2889+
(self.local_weight_counts[i], 1),
2890+
device=torch.device("cpu"),
2891+
dtype=torch.int64,
2892+
)
28772893
# self.local_weight_counts[i] = 0 # Reset the count
28782894

28792895
# pyre-ignore [16] bucket_sorted_id_splits is not None
28802896
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
28812897
active_id_cnt_per_bucket_split.append(bucket_t)
2898+
metadata_splits.append(metadata_tensor)
28822899

28832900
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
28842901
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
@@ -2934,7 +2951,12 @@ def split_embedding_weights(
29342951
f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
29352952
)
29362953

2937-
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
2954+
return (
2955+
pmt_splits,
2956+
bucket_sorted_id_splits,
2957+
active_id_cnt_per_bucket_split,
2958+
metadata_splits,
2959+
)
29382960

29392961
@torch.jit.ignore
29402962
def _apply_state_dict_w_offloading(self) -> None:

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def execute_ssd_backward_(
788788
def split_optimizer_states_(
789789
self, emb: SSDTableBatchedEmbeddingBags
790790
) -> List[List[torch.Tensor]]:
791-
_, bucket_asc_ids_list, _ = emb.split_embedding_weights(
791+
_, bucket_asc_ids_list, _, _ = emb.split_embedding_weights(
792792
no_snapshot=False, should_flush=True
793793
)
794794

@@ -1113,7 +1113,7 @@ def test_ssd_emb_state_dict(
11131113
split_optimizer_states = self.split_optimizer_states_(emb)
11141114

11151115
# Compare emb state dict with expected values from nn.EmbeddingBag
1116-
emb_state_dict, _, _ = emb.split_embedding_weights(no_snapshot=False)
1116+
emb_state_dict, _, _, _ = emb.split_embedding_weights(no_snapshot=False)
11171117
for feature_index, table_index in self.get_physical_table_arg_indices_(
11181118
emb.feature_table_map
11191119
):
@@ -1728,9 +1728,12 @@ def test_kv_emb_state_dict(
17281728
split_optimizer_states = []
17291729

17301730
# Compare emb state dict with expected values from nn.EmbeddingBag
1731-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
1732-
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
1733-
)
1731+
(
1732+
emb_state_dict_list,
1733+
bucket_asc_ids_list,
1734+
num_active_id_per_bucket_list,
1735+
metadata_list,
1736+
) = emb.split_embedding_weights(no_snapshot=False, should_flush=True)
17341737

17351738
for s in emb.split_optimizer_states(
17361739
bucket_asc_ids_list, no_snapshot=False, should_flush=True
@@ -1797,6 +1800,7 @@ def test_kv_emb_state_dict(
17971800
)
17981801
self.assertLess(table_index, len(emb_state_dict_list))
17991802
assert len(split_optimizer_states[table_index][0]) == num_ids
1803+
assert len(metadata_list[table_index]) == num_ids
18001804
# NOTE: The [0] index is a hack since the test is fixed to use
18011805
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
18021806
# be upgraded in the future to support multiple optimizers
@@ -1943,7 +1947,7 @@ def test_kv_opt_state_w_offloading(
19431947
)
19441948

19451949
# Compare emb state dict with expected values from nn.EmbeddingBag
1946-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
1950+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = (
19471951
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
19481952
)
19491953
split_optimizer_states = emb.split_optimizer_states(
@@ -2172,7 +2176,7 @@ def test_kv_state_dict_w_backend_return_whole_row(
21722176
)
21732177

21742178
# Compare emb state dict with expected values from nn.EmbeddingBag
2175-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
2179+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = (
21762180
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
21772181
)
21782182
split_optimizer_states = emb.split_optimizer_states(
@@ -2440,7 +2444,7 @@ def test_apply_kv_state_dict(
24402444
)
24412445

24422446
# Compare emb state dict with expected values from nn.EmbeddingBag
2443-
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
2447+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list, _ = (
24442448
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
24452449
)
24462450
split_optimizer_states = emb.split_optimizer_states(
@@ -2508,6 +2512,7 @@ def test_apply_kv_state_dict(
25082512
emb_state_dict_list2,
25092513
bucket_asc_ids_list2,
25102514
num_active_id_per_bucket_list2,
2515+
_,
25112516
) = emb2.split_embedding_weights(no_snapshot=False, should_flush=True)
25122517
split_optimizer_states2 = emb2.split_optimizer_states(
25132518
bucket_asc_ids_list2, no_snapshot=False, should_flush=True
@@ -2963,7 +2968,7 @@ def copy_opt_states_hook(
29632968
emb.flush()
29642969

29652970
# Compare emb state dict with expected values from nn.EmbeddingBag
2966-
_emb_state_dict_list, bucket_asc_ids_list, _num_active_id_per_bucket_list = (
2971+
_emb_state_dict_list, bucket_asc_ids_list, _num_active_id_per_bucket_list, _ = (
29672972
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
29682973
)
29692974
assert bucket_asc_ids_list is not None

0 commit comments

Comments
 (0)