@@ -248,6 +248,12 @@ def __init__(
248
248
self .total_hash_size_bits : int = 0
249
249
else :
250
250
self .total_hash_size_bits : int = int (log2 (float (hash_size_cumsum [- 1 ])) + 1 )
251
+ self .register_buffer (
252
+ "table_hash_size_cumsum" ,
253
+ torch .tensor (
254
+ hash_size_cumsum , device = self .current_device , dtype = torch .int64
255
+ ),
256
+ )
251
257
# The last element is to easily access # of rows of each table by
252
258
self .total_hash_size_bits = int (log2 (float (hash_size_cumsum [- 1 ])) + 1 )
253
259
self .total_hash_size : int = hash_size_cumsum [- 1 ]
@@ -288,6 +294,10 @@ def __init__(
288
294
"feature_dims" ,
289
295
torch .tensor (feature_dims , device = "cpu" , dtype = torch .int64 ),
290
296
)
297
+ self .register_buffer (
298
+ "table_dims" ,
299
+ torch .tensor (dims , device = "cpu" , dtype = torch .int64 ),
300
+ )
291
301
292
302
(info_B_num_bits_ , info_B_mask_ ) = torch .ops .fbgemm .get_infos_metadata (
293
303
self .D_offsets , # unused tensor
@@ -518,6 +528,7 @@ def __init__(
518
528
logging .warning ("dist is not initialized, treating as single gpu cases" )
519
529
tbe_unique_id = SSDTableBatchedEmbeddingBags ._local_instance_index
520
530
self .tbe_unique_id = tbe_unique_id
531
+ self .l2_cache_size = l2_cache_size
521
532
logging .info (f"tbe_unique_id: { tbe_unique_id } " )
522
533
if self .backend_type == BackendType .SSD :
523
534
logging .info (
@@ -564,12 +575,12 @@ def __init__(
564
575
self .res_params .table_offsets ,
565
576
self .res_params .table_sizes ,
566
577
(
567
- tensor_pad4 (self .feature_dims . cpu () )
578
+ tensor_pad4 (self .table_dims )
568
579
if self .enable_optimizer_offloading
569
580
else None
570
581
),
571
582
(
572
- self .hash_size_cumsum .cpu ()
583
+ self .table_hash_size_cumsum .cpu ()
573
584
if self .enable_optimizer_offloading
574
585
else None
575
586
),
@@ -607,74 +618,35 @@ def __init__(
607
618
f"self.cache_row_dim={ self .cache_row_dim } ,"
608
619
f"enable_optimizer_offloading={ self .enable_optimizer_offloading } ,"
609
620
f"feature_dims={ self .feature_dims } ,"
610
- f"hash_size_cumsum={ self .hash_size_cumsum } , "
611
- f"eviction_policy={ self .kv_zch_params .eviction_policy } , "
612
- )
613
- # prepare eviction policy parameters
614
- counter_eviction_threshold_tensor = None
615
- ttls_in_mins_tensor = None
616
- counter_decay_rates_tensor = None
617
- l2_weight_thresholds_tensor = None
618
- if self .kv_zch_params .eviction_policy .eviction_trigger_mode != 0 :
619
- counter_eviction_threshold = [
620
- self .kv_zch_params .eviction_policy .counter_thresholds [t ]
621
- for t in self .feature_table_map
622
- ]
623
- counter_eviction_threshold_tensor = torch .tensor (
624
- counter_eviction_threshold ,
625
- device = torch .device ("cpu" ),
626
- dtype = torch .uint32 ,
627
- )
628
- ttls_in_mins = [
629
- self .kv_zch_params .eviction_policy .ttls_in_mins [t ]
630
- for t in self .feature_table_map
631
- ]
632
- ttls_in_mins_tensor = torch .tensor (
633
- ttls_in_mins ,
634
- device = torch .device ("cpu" ),
635
- dtype = torch .uint32 ,
636
- )
637
- counter_decay_rates = [
638
- self .kv_zch_params .eviction_policy .counter_decay_rates [t ]
639
- for t in self .feature_table_map
640
- ]
641
- counter_decay_rates_tensor = torch .tensor (
642
- counter_decay_rates ,
643
- device = torch .device ("cpu" ),
644
- dtype = torch .float32 ,
645
- )
646
- l2_weight_thresholds = [
647
- self .kv_zch_params .eviction_policy .l2_weight_thresholds [t ]
648
- for t in self .feature_table_map
649
- ]
650
- l2_weight_thresholds_tensor = torch .tensor (
651
- l2_weight_thresholds ,
652
- device = torch .device ("cpu" ),
653
- dtype = torch .float32 ,
654
- )
655
-
621
+ f"hash_size_cumsum={ self .hash_size_cumsum } "
622
+ )
623
+ table_dims = (
624
+ tensor_pad4 (self .table_dims )
625
+ if self .enable_optimizer_offloading
626
+ else None
627
+ ) # table_dims
628
+ eviction_config = torch .classes .fbgemm .FeatureEvictConfig (
629
+ self .kv_zch_params .eviction_policy .eviction_trigger_mode , # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
630
+ self .kv_zch_params .eviction_policy .eviction_strategy , # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
631
+ self .kv_zch_params .eviction_policy .eviction_step_intervals , # trigger_step_interval if trigger mode is iteration
632
+ self .l2_cache_size , # mem_util_threshold_in_GB if trigger mode is mem_util
633
+ self .kv_zch_params .eviction_policy .ttls_in_mins , # ttls_in_mins for each table if eviction strategy is timestamp
634
+ self .kv_zch_params .eviction_policy .counter_thresholds , # counter_thresholds for each table if eviction strategy is feature score
635
+ self .kv_zch_params .eviction_policy .counter_decay_rates , # counter_decay_rates for each table if eviction strategy is feature score
636
+ self .kv_zch_params .eviction_policy .l2_weight_thresholds , # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
637
+ table_dims .tolist () if table_dims else None ,
638
+ )
656
639
self ._ssd_db = torch .classes .fbgemm .DramKVEmbeddingCacheWrapper (
657
640
self .cache_row_dim ,
658
641
ssd_uniform_init_lower ,
659
642
ssd_uniform_init_upper ,
660
- self .kv_zch_params .eviction_policy .eviction_trigger_mode , # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
661
- self .kv_zch_params .eviction_policy .eviction_step_intervals , # trigger_step_interval if trigger mode is iteration
662
- l2_cache_size , # mem_util_threshold_in_GB if trigger mode is mem_util
663
- self .kv_zch_params .eviction_policy .eviction_strategy , # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
664
- counter_eviction_threshold_tensor , # counter_thresholds for each table if eviction strategy is feature score
665
- ttls_in_mins_tensor , # ttls_in_mins for each table if eviction strategy is timestamp
666
- counter_decay_rates_tensor , # counter_decay_rates for each table if eviction strategy is feature score
667
- l2_weight_thresholds_tensor , # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
643
+ eviction_config ,
668
644
ssd_rocksdb_shards , # num_shards
669
645
ssd_rocksdb_shards , # num_threads
670
646
weights_precision .bit_rate (), # row_storage_bitwidth
647
+ table_dims ,
671
648
(
672
- tensor_pad4 (self .feature_dims .cpu ())
673
- if self .enable_optimizer_offloading
674
- else None
675
- ), # table_dims
676
- (
677
- self .hash_size_cumsum .cpu ()
649
+ self .table_hash_size_cumsum .cpu ()
678
650
if self .enable_optimizer_offloading
679
651
else None
680
652
), # hash_size_cumsum
@@ -2478,6 +2450,13 @@ def _may_create_snapshot_for_state_dict(
2478
2450
f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2479
2451
)
2480
2452
elif self .backend_type == BackendType .DRAM :
2453
+ # if there is any ongoing eviction, lets wait until eviction is finished before state_dict
2454
+ # so that we can reach consistent model state before/after state_dict
2455
+ evict_wait_start_time = time .time ()
2456
+ self .ssd_db .wait_until_eviction_done ()
2457
+ logging .info (
2458
+ f"state_dict wait for ongoing eviction: { time .time () - evict_wait_start_time } s"
2459
+ )
2481
2460
self .flush (force = should_flush )
2482
2461
return snapshot_handle , checkpoint_handle
2483
2462
0 commit comments