Skip to content

Commit 0b34b6d

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
authored andcommitted
Adding E2E unit tests for KVTensorMetaData class (#4298)
Summary: Pull Request resolved: #4298 X-link: facebookresearch/FBGEMM#1372 Context: In the Publish Component, we have aligned to not use the conventional serialization and deserialization. We need to create a KVTensorMetaData object to pass data to the publish component In this Diff: 1. We are adding a unit test for the KVTensorMetaData consistency a. Serialization of KVT data b. Construction of KVTensorMetaData object c. Creation of ReadOnlyEmbeddingKVDB object d. Narrow() Data Consistency from PMT and KVTensorMetaData Reviewed By: duduyi2013 Differential Revision: D76234751
1 parent 6041d30 commit 0b34b6d

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

fbgemm_gpu/test/tbe/ssd/kv_backend_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,3 +617,79 @@ def test_rocksdb_se_de_testing(
617617
t1 = pmt.wrapped.narrow(0, 0, Es[i])
618618
t2 = lo.wrapped.narrow(0, 0, Es[i])
619619
assert torch.equal(t1, t2)
620+
621+
@given(
622+
T=st.integers(min_value=3, max_value=3),
623+
D=st.integers(min_value=1, max_value=1),
624+
log_E=st.integers(min_value=1, max_value=1),
625+
mixed=st.booleans(),
626+
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
627+
)
628+
@settings(**default_settings)
629+
def test_rocksdb_kv_metadata_testing(
630+
self,
631+
T: int,
632+
D: int,
633+
log_E: int,
634+
mixed: bool,
635+
weights_precision: SparseType,
636+
) -> None:
637+
638+
# Generating a TBE with 3 tables, each with 1 feature and 1 embedding
639+
emb, Es, Ds = self.generate_fbgemm_kv_tbe(T, D, log_E, weights_precision, mixed)
640+
641+
total_E = sum(Es)
642+
indices = torch.as_tensor(
643+
np.random.choice(total_E, replace=False, size=(total_E,)), dtype=torch.int64
644+
)
645+
indices = torch.arange(total_E, dtype=torch.int64)
646+
647+
weights = torch.randn(
648+
total_E, emb.cache_row_dim, dtype=weights_precision.as_dtype()
649+
)
650+
651+
count = torch.as_tensor([total_E])
652+
653+
# Set the weights and indices into the TBE
654+
emb.ssd_db.set(indices, weights, count)
655+
emb.ssd_db.wait_util_filling_work_done()
656+
657+
# Flushing data from the TBE cache to the SSD
658+
emb.ssd_db.flush()
659+
660+
# Creating a hard_link_snapshot (i.e., rocksdb checkpoint)
661+
emb.ssd_db.create_rocksdb_hard_link_snapshot(0)
662+
pmts = emb.split_embedding_weights(no_snapshot=False)
663+
664+
# Iterate through the partially materialized tensors
665+
# Serialize them using pickle.dumps and then deserialize them using pickle.loads
666+
# Provides us a KVTensor backed by ReadOnlyEmbeddingKVDB that can be accessed by multiple processes
667+
# Read through the KVTensor and verify that the data is correct with the original weights
668+
for i, pmt in enumerate(pmts[0]):
669+
if type(pmt) is torch.Tensor:
670+
continue
671+
kv_metadata = pmt.generate_kvtensor_metadata
672+
673+
readonly_rdb = torch.classes.fbgemm.ReadOnlyEmbeddingKVDB(
674+
kv_metadata.checkpoint_paths,
675+
kv_metadata.tbe_uuid,
676+
kv_metadata.rdb_num_shards,
677+
kv_metadata.rdb_num_threads,
678+
kv_metadata.max_D,
679+
)
680+
681+
if kv_metadata.dtype == 5:
682+
d_type = torch.float16
683+
else:
684+
d_type = torch.float32
685+
686+
t = torch.empty(Es[i], kv_metadata.max_D, dtype=d_type)
687+
readonly_rdb.get_range_from_rdb_checkpoint(
688+
t, 0 + kv_metadata.table_offset, Es[i], 0 # offset
689+
)
690+
# reading from readonly_rdb:
691+
t1 = t.narrow(1, 0, kv_metadata.table_shape[1])
692+
# reading from pmt:
693+
t2 = pmt.wrapped.narrow(0, 0, Es[i])
694+
695+
assert torch.equal(t1, t2)

0 commit comments

Comments
 (0)