Skip to content

Commit 576f505

Browse files
committed
patch fixes for eviction (#4304)
Summary: Pull Request resolved: #4304 X-link: facebookresearch/FBGEMM#1380 tt Differential Revision: D76244371
1 parent abb5272 commit 576f505

12 files changed

+980
-341
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,16 @@ class EvictionPolicy(NamedTuple):
6060
0 # trigger_step_interval if trigger mode is iteration
6161
)
6262
counter_thresholds: Optional[List[int]] = (
63-
None # count_thresholds for each feature if eviction strategy is feature score
63+
None # count_thresholds for each table if eviction strategy is feature score
6464
)
6565
ttls_in_mins: Optional[List[int]] = (
66-
None # ttls_in_mins for each feature if eviction strategy is timestamp
66+
None # ttls_in_mins for each table if eviction strategy is timestamp
6767
)
6868
counter_decay_rates: Optional[List[float]] = (
69-
None # count_decay_rates for each feature if eviction strategy is feature score
69+
None # count_decay_rates for each table if eviction strategy is feature score
7070
)
7171
l2_weight_thresholds: Optional[List[float]] = (
72-
None # l2_weight_thresholds for each feature if eviction strategy is feature l2 norm
72+
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
7373
)
7474

7575

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ def __init__(
248248
self.total_hash_size_bits: int = 0
249249
else:
250250
self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
251+
self.register_buffer(
252+
"table_hash_size_cumsum",
253+
torch.tensor(
254+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
255+
),
256+
)
251257
# The last element is to easily access # of rows of each table by
252258
self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
253259
self.total_hash_size: int = hash_size_cumsum[-1]
@@ -288,6 +294,10 @@ def __init__(
288294
"feature_dims",
289295
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
290296
)
297+
self.register_buffer(
298+
"table_dims",
299+
torch.tensor(dims, device="cpu", dtype=torch.int64),
300+
)
291301

292302
(info_B_num_bits_, info_B_mask_) = torch.ops.fbgemm.get_infos_metadata(
293303
self.D_offsets, # unused tensor
@@ -518,6 +528,7 @@ def __init__(
518528
logging.warning("dist is not initialized, treating as single gpu cases")
519529
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
520530
self.tbe_unique_id = tbe_unique_id
531+
self.l2_cache_size = l2_cache_size
521532
logging.info(f"tbe_unique_id: {tbe_unique_id}")
522533
if self.backend_type == BackendType.SSD:
523534
logging.info(
@@ -564,12 +575,12 @@ def __init__(
564575
self.res_params.table_offsets,
565576
self.res_params.table_sizes,
566577
(
567-
tensor_pad4(self.feature_dims.cpu())
578+
tensor_pad4(self.table_dims)
568579
if self.enable_optimizer_offloading
569580
else None
570581
),
571582
(
572-
self.hash_size_cumsum.cpu()
583+
self.table_hash_size_cumsum.cpu()
573584
if self.enable_optimizer_offloading
574585
else None
575586
),
@@ -607,74 +618,35 @@ def __init__(
607618
f"self.cache_row_dim={self.cache_row_dim},"
608619
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
609620
f"feature_dims={self.feature_dims},"
610-
f"hash_size_cumsum={self.hash_size_cumsum}, "
611-
f"eviction_policy={self.kv_zch_params.eviction_policy}, "
612-
)
613-
# prepare eviction policy parameters
614-
counter_eviction_threshold_tensor = None
615-
ttls_in_mins_tensor = None
616-
counter_decay_rates_tensor = None
617-
l2_weight_thresholds_tensor = None
618-
if self.kv_zch_params.eviction_policy.eviction_trigger_mode != 0:
619-
counter_eviction_threshold = [
620-
self.kv_zch_params.eviction_policy.counter_thresholds[t]
621-
for t in self.feature_table_map
622-
]
623-
counter_eviction_threshold_tensor = torch.tensor(
624-
counter_eviction_threshold,
625-
device=torch.device("cpu"),
626-
dtype=torch.uint32,
627-
)
628-
ttls_in_mins = [
629-
self.kv_zch_params.eviction_policy.ttls_in_mins[t]
630-
for t in self.feature_table_map
631-
]
632-
ttls_in_mins_tensor = torch.tensor(
633-
ttls_in_mins,
634-
device=torch.device("cpu"),
635-
dtype=torch.uint32,
636-
)
637-
counter_decay_rates = [
638-
self.kv_zch_params.eviction_policy.counter_decay_rates[t]
639-
for t in self.feature_table_map
640-
]
641-
counter_decay_rates_tensor = torch.tensor(
642-
counter_decay_rates,
643-
device=torch.device("cpu"),
644-
dtype=torch.float32,
645-
)
646-
l2_weight_thresholds = [
647-
self.kv_zch_params.eviction_policy.l2_weight_thresholds[t]
648-
for t in self.feature_table_map
649-
]
650-
l2_weight_thresholds_tensor = torch.tensor(
651-
l2_weight_thresholds,
652-
device=torch.device("cpu"),
653-
dtype=torch.float32,
654-
)
655-
621+
f"hash_size_cumsum={self.hash_size_cumsum}"
622+
)
623+
table_dims = (
624+
tensor_pad4(self.table_dims)
625+
if self.enable_optimizer_offloading
626+
else None
627+
) # table_dims
628+
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
629+
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
630+
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
631+
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
632+
self.l2_cache_size, # mem_util_threshold_in_GB if trigger mode is mem_util
633+
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
634+
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is feature score
635+
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is feature score
636+
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
637+
table_dims.tolist() if table_dims else None,
638+
)
656639
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
657640
self.cache_row_dim,
658641
ssd_uniform_init_lower,
659642
ssd_uniform_init_upper,
660-
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
661-
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
662-
l2_cache_size, # mem_util_threshold_in_GB if trigger mode is mem_util
663-
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
664-
counter_eviction_threshold_tensor, # counter_thresholds for each table if eviction strategy is feature score
665-
ttls_in_mins_tensor, # ttls_in_mins for each table if eviction strategy is timestamp
666-
counter_decay_rates_tensor, # counter_decay_rates for each table if eviction strategy is feature score
667-
l2_weight_thresholds_tensor, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
643+
eviction_config,
668644
ssd_rocksdb_shards, # num_shards
669645
ssd_rocksdb_shards, # num_threads
670646
weights_precision.bit_rate(), # row_storage_bitwidth
647+
table_dims,
671648
(
672-
tensor_pad4(self.feature_dims.cpu())
673-
if self.enable_optimizer_offloading
674-
else None
675-
), # table_dims
676-
(
677-
self.hash_size_cumsum.cpu()
649+
self.table_hash_size_cumsum.cpu()
678650
if self.enable_optimizer_offloading
679651
else None
680652
), # hash_size_cumsum
@@ -2478,6 +2450,13 @@ def _may_create_snapshot_for_state_dict(
24782450
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
24792451
)
24802452
elif self.backend_type == BackendType.DRAM:
2453+
# if there is any ongoing eviction, lets wait until eviction is finished before state_dict
2454+
# so that we can reach consistent model state before/after state_dict
2455+
evict_wait_start_time = time.time()
2456+
self.ssd_db.wait_until_eviction_done()
2457+
logging.info(
2458+
f"state_dict wait for ongoing eviction: {time.time() - evict_wait_start_time} s"
2459+
)
24812460
self.flush(force=should_flush)
24822461
return snapshot_handle, checkpoint_handle
24832462

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_base.h

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)