Skip to content

Commit fbf7b9b

Browse files
raahul46facebook-github-bot
authored andcommitted
Temporary Commit at 6/8/2025, 3:25:18 PM
Differential Revision: D76234752
1 parent d473af0 commit fbf7b9b

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
117117

118118
void deserialize(const std::string& serialized);
119119

120+
std::vector<std::string> get_kvtensor_serializable_metadata() const;
121+
120122
friend void to_json(json& j, const KVTensorWrapper& kvt);
121123
friend void from_json(const json& j, KVTensorWrapper& kvt);
122124

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,29 @@ std::string KVTensorWrapper::serialize() const {
391391
return json_serialized.dump();
392392
}
393393

394+
std::vector<std::string> KVTensorWrapper::get_kvtensor_serializable_metadata()
395+
const {
396+
std::vector<std::string> metadata;
397+
auto* db = dynamic_cast<EmbeddingRocksDB*>(db_.get());
398+
auto checkpoint_paths = db->get_checkpoints(checkpoint_handle_->uuid);
399+
metadata.push_back(std::to_string(checkpoint_paths.size()));
400+
for (const auto& path : checkpoint_paths) {
401+
metadata.push_back(path);
402+
}
403+
metadata.push_back(db->get_tbe_uuid());
404+
metadata.push_back(std::to_string(db->num_shards()));
405+
metadata.push_back(std::to_string(db->num_threads()));
406+
metadata.push_back(std::to_string(db->get_max_D()));
407+
metadata.push_back(std::to_string(row_offset_));
408+
CHECK_EQ(shape_.size(), 2);
409+
metadata.push_back(std::to_string(shape_[0]));
410+
metadata.push_back(std::to_string(shape_[1]));
411+
metadata.push_back(
412+
std::to_string(static_cast<int64_t>(options_.dtype().toScalarType())));
413+
metadata.push_back(checkpoint_handle_->uuid);
414+
return metadata;
415+
}
416+
394417
std::string KVTensorWrapper::logs() const {
395418
std::stringstream ss;
396419
if (db_) {
@@ -871,6 +894,26 @@ static auto dram_kv_embedding_cache_wrapper =
871894
.def(
872895
"get_feature_evict_metric",
873896
&DramKVEmbeddingCacheWrapper::get_feature_evict_metric);
897+
static auto embedding_rocks_db_read_only_wrapper =
898+
torch::class_<ReadOnlyEmbeddingKVDB>("fbgemm", "ReadOnlyEmbeddingKVDB")
899+
.def(
900+
torch::init<
901+
std::vector<std::string>,
902+
std::string,
903+
int64_t,
904+
int64_t,
905+
int64_t,
906+
int64_t>(),
907+
"",
908+
{torch::arg("rdb_shard_checkpoint_paths"),
909+
torch::arg("tbe_uuid"),
910+
torch::arg("num_shards"),
911+
torch::arg("num_threads"),
912+
torch::arg("max_D"),
913+
torch::arg("cache_size") = 0})
914+
.def(
915+
"get_range_from_rdb_checkpoint",
916+
&ReadOnlyEmbeddingKVDB::get_range_from_rdb_checkpoint);
874917

875918
static auto kv_tensor_wrapper =
876919
torch::class_<KVTensorWrapper>("fbgemm", "KVTensorWrapper")
@@ -931,7 +974,10 @@ static auto kv_tensor_wrapper =
931974
[](std::string data) -> c10::intrusive_ptr<KVTensorWrapper> {
932975
return c10::make_intrusive<KVTensorWrapper>(data);
933976
})
934-
.def("logs", &KVTensorWrapper::logs, "");
977+
.def("logs", &KVTensorWrapper::logs, "")
978+
.def(
979+
"get_kvtensor_serializable_metadata",
980+
&KVTensorWrapper::get_kvtensor_serializable_metadata);
935981

936982
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
937983
m.def(

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,8 +1286,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
12861286
///
12871287
/// @brief An implementation of ReadOnlyEmbeddingKVDB for RocksDB
12881288
///
1289-
class ReadOnlyEmbeddingKVDB
1290-
: public std::enable_shared_from_this<ReadOnlyEmbeddingKVDB> {
1289+
class ReadOnlyEmbeddingKVDB : public torch::jit::CustomClassHolder {
12911290
public:
12921291
explicit ReadOnlyEmbeddingKVDB(
12931292
const std::vector<std::string>& rdb_shard_checkpoint_paths,

0 commit comments

Comments
 (0)