@@ -788,7 +788,7 @@ def execute_ssd_backward_(
788
788
def split_optimizer_states_ (
789
789
self , emb : SSDTableBatchedEmbeddingBags
790
790
) -> List [List [torch .Tensor ]]:
791
- _ , bucket_asc_ids_list , _ = emb .split_embedding_weights (
791
+ _ , bucket_asc_ids_list , _ , _ = emb .split_embedding_weights (
792
792
no_snapshot = False , should_flush = True
793
793
)
794
794
@@ -1113,7 +1113,7 @@ def test_ssd_emb_state_dict(
1113
1113
split_optimizer_states = self .split_optimizer_states_ (emb )
1114
1114
1115
1115
# Compare emb state dict with expected values from nn.EmbeddingBag
1116
- emb_state_dict , _ , _ = emb .split_embedding_weights (no_snapshot = False )
1116
+ emb_state_dict , _ , _ , _ = emb .split_embedding_weights (no_snapshot = False )
1117
1117
for feature_index , table_index in self .get_physical_table_arg_indices_ (
1118
1118
emb .feature_table_map
1119
1119
):
@@ -1728,9 +1728,12 @@ def test_kv_emb_state_dict(
1728
1728
split_optimizer_states = []
1729
1729
1730
1730
# Compare emb state dict with expected values from nn.EmbeddingBag
1731
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
1732
- emb .split_embedding_weights (no_snapshot = False , should_flush = True )
1733
- )
1731
+ (
1732
+ emb_state_dict_list ,
1733
+ bucket_asc_ids_list ,
1734
+ num_active_id_per_bucket_list ,
1735
+ metadata_list ,
1736
+ ) = emb .split_embedding_weights (no_snapshot = False , should_flush = True )
1734
1737
1735
1738
for s in emb .split_optimizer_states (
1736
1739
bucket_asc_ids_list , no_snapshot = False , should_flush = True
@@ -1797,6 +1800,7 @@ def test_kv_emb_state_dict(
1797
1800
)
1798
1801
self .assertLess (table_index , len (emb_state_dict_list ))
1799
1802
assert len (split_optimizer_states [table_index ][0 ]) == num_ids
1803
+ assert len (metadata_list [table_index ]) == num_ids
1800
1804
# NOTE: The [0] index is a hack since the test is fixed to use
1801
1805
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1802
1806
# be upgraded in the future to support multiple optimizers
@@ -1943,7 +1947,7 @@ def test_kv_opt_state_w_offloading(
1943
1947
)
1944
1948
1945
1949
# Compare emb state dict with expected values from nn.EmbeddingBag
1946
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
1950
+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list , _ = (
1947
1951
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
1948
1952
)
1949
1953
split_optimizer_states = emb .split_optimizer_states (
@@ -2172,7 +2176,7 @@ def test_kv_state_dict_w_backend_return_whole_row(
2172
2176
)
2173
2177
2174
2178
# Compare emb state dict with expected values from nn.EmbeddingBag
2175
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
2179
+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list , _ = (
2176
2180
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2177
2181
)
2178
2182
split_optimizer_states = emb .split_optimizer_states (
@@ -2440,7 +2444,7 @@ def test_apply_kv_state_dict(
2440
2444
)
2441
2445
2442
2446
# Compare emb state dict with expected values from nn.EmbeddingBag
2443
- emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list = (
2447
+ emb_state_dict_list , bucket_asc_ids_list , num_active_id_per_bucket_list , _ = (
2444
2448
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2445
2449
)
2446
2450
split_optimizer_states = emb .split_optimizer_states (
@@ -2508,6 +2512,7 @@ def test_apply_kv_state_dict(
2508
2512
emb_state_dict_list2 ,
2509
2513
bucket_asc_ids_list2 ,
2510
2514
num_active_id_per_bucket_list2 ,
2515
+ _ ,
2511
2516
) = emb2 .split_embedding_weights (no_snapshot = False , should_flush = True )
2512
2517
split_optimizer_states2 = emb2 .split_optimizer_states (
2513
2518
bucket_asc_ids_list2 , no_snapshot = False , should_flush = True
@@ -2963,7 +2968,7 @@ def copy_opt_states_hook(
2963
2968
emb .flush ()
2964
2969
2965
2970
# Compare emb state dict with expected values from nn.EmbeddingBag
2966
- _emb_state_dict_list , bucket_asc_ids_list , _num_active_id_per_bucket_list = (
2971
+ _emb_state_dict_list , bucket_asc_ids_list , _num_active_id_per_bucket_list , _ = (
2967
2972
emb .split_embedding_weights (no_snapshot = False , should_flush = True )
2968
2973
)
2969
2974
assert bucket_asc_ids_list is not None
0 commit comments