@@ -391,6 +391,29 @@ std::string KVTensorWrapper::serialize() const {
391
391
return json_serialized.dump ();
392
392
}
393
393
394
+ std::vector<std::string> KVTensorWrapper::get_kvtensor_serializable_metadata ()
395
+ const {
396
+ std::vector<std::string> metadata;
397
+ auto * db = dynamic_cast <EmbeddingRocksDB*>(db_.get ());
398
+ auto checkpoint_paths = db->get_checkpoints (checkpoint_handle_->uuid );
399
+ metadata.push_back (std::to_string (checkpoint_paths.size ()));
400
+ for (const auto & path : checkpoint_paths) {
401
+ metadata.push_back (path);
402
+ }
403
+ metadata.push_back (db->get_tbe_uuid ());
404
+ metadata.push_back (std::to_string (db->num_shards ()));
405
+ metadata.push_back (std::to_string (db->num_threads ()));
406
+ metadata.push_back (std::to_string (db->get_max_D ()));
407
+ metadata.push_back (std::to_string (row_offset_));
408
+ CHECK_EQ (shape_.size (), 2 );
409
+ metadata.push_back (std::to_string (shape_[0 ]));
410
+ metadata.push_back (std::to_string (shape_[1 ]));
411
+ metadata.push_back (
412
+ std::to_string (static_cast <int64_t >(options_.dtype ().toScalarType ())));
413
+ metadata.push_back (checkpoint_handle_->uuid );
414
+ return metadata;
415
+ }
416
+
394
417
std::string KVTensorWrapper::logs () const {
395
418
std::stringstream ss;
396
419
if (db_) {
@@ -871,6 +894,26 @@ static auto dram_kv_embedding_cache_wrapper =
871
894
.def(
872
895
" get_feature_evict_metric" ,
873
896
&DramKVEmbeddingCacheWrapper::get_feature_evict_metric);
897
+ static auto embedding_rocks_db_read_only_wrapper =
898
+ torch::class_<ReadOnlyEmbeddingKVDB>(" fbgemm" , " ReadOnlyEmbeddingKVDB" )
899
+ .def(
900
+ torch::init<
901
+ std::vector<std::string>,
902
+ std::string,
903
+ int64_t ,
904
+ int64_t ,
905
+ int64_t ,
906
+ int64_t >(),
907
+ " " ,
908
+ {torch::arg (" rdb_shard_checkpoint_paths" ),
909
+ torch::arg (" tbe_uuid" ),
910
+ torch::arg (" num_shards" ),
911
+ torch::arg (" num_threads" ),
912
+ torch::arg (" max_D" ),
913
+ torch::arg (" cache_size" ) = 0 })
914
+ .def(
915
+ " get_range_from_rdb_checkpoint" ,
916
+ &ReadOnlyEmbeddingKVDB::get_range_from_rdb_checkpoint);
874
917
875
918
static auto kv_tensor_wrapper =
876
919
torch::class_<KVTensorWrapper>(" fbgemm" , " KVTensorWrapper" )
@@ -931,7 +974,10 @@ static auto kv_tensor_wrapper =
931
974
[](std::string data) -> c10::intrusive_ptr<KVTensorWrapper> {
932
975
return c10::make_intrusive<KVTensorWrapper>(data);
933
976
})
934
- .def(" logs" , &KVTensorWrapper::logs, " " );
977
+ .def(" logs" , &KVTensorWrapper::logs, " " )
978
+ .def(
979
+ " get_kvtensor_serializable_metadata" ,
980
+ &KVTensorWrapper::get_kvtensor_serializable_metadata);
935
981
936
982
TORCH_LIBRARY_FRAGMENT (fbgemm, m) {
937
983
m.def (
0 commit comments