Skip to content

Commit b61968e

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
authored andcommitted
Adding KVTensorMetaData class (#4297)
Summary: Pull Request resolved: #4297 X-link: facebookresearch/FBGEMM#1373 Context: In the Publish Component, we have aligned to not use the conventional serialization and deserialization. We need to create a KVTensorMetaData object to pass data to the publish component In this Diff: 1. Adding KVTensorMetaData class 2. Adding a generate_kvtensor_metadata function to PartiallyMaterializedTensor class. We will use this to create a KVTensorMeta object given a PMT Object Reviewed By: duduyi2013 Differential Revision: D76234753
1 parent fbf7b9b commit b61968e

File tree

1 file changed

+106
-1
lines changed

1 file changed

+106
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from __future__ import annotations
1010

1111
import functools
12-
from typing import Optional, Union
12+
import logging
13+
from typing import List, Optional, Union
1314

1415
import torch
1516

@@ -25,6 +26,58 @@ def decorator(func):
2526
return decorator
2627

2728

29+
class KVTensorMetadata:
30+
"""
31+
Class that is used to represent a KVTensor as a Serialized Metadata in python
32+
This object is used to reconstruct the KVTensor in the publish component
33+
"""
34+
35+
checkpoint_paths: List[str]
36+
tbe_uuid: str
37+
rdb_num_shards: int
38+
rdb_num_threads: int
39+
max_D: int
40+
table_offset: int
41+
table_shape: List[int]
42+
dtype: int
43+
checkpoint_uuid: str
44+
45+
def __init__(
46+
self,
47+
checkpoint_paths: List[str],
48+
tbe_uuid: str,
49+
rdb_num_shards: int,
50+
rdb_num_threads: int,
51+
max_D: int,
52+
table_offset: int,
53+
table_shape: List[int],
54+
dtype: int,
55+
checkpoint_uuid: str,
56+
) -> None:
57+
"""
58+
Ensure caller loads the module before creating this object.
59+
60+
```
61+
load_torch_module(
62+
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
63+
)
64+
```
65+
66+
Args:
67+
68+
wrapped: torch.classes.fbgemm.KVTensorWrapper
69+
"""
70+
self.checkpoint_paths = checkpoint_paths
71+
self.tbe_uuid = tbe_uuid
72+
self.rdb_num_shards = rdb_num_shards
73+
self.rdb_num_threads = rdb_num_threads
74+
self.max_D = max_D
75+
self.table_offset = table_offset
76+
self.table_shape = table_shape
77+
self.checkpoint_uuid = checkpoint_uuid
78+
self.dtype = dtype
79+
80+
2881
class PartiallyMaterializedTensor:
2982
"""
3083
A tensor-like object that represents a partially materialized tensor in memory.
@@ -51,6 +104,55 @@ def __init__(self, wrapped, is_virtual: bool = False) -> None:
51104
self._is_virtual = is_virtual
52105
self._requires_grad = False
53106

107+
@property
108+
def generate_kvtensor_metadata(self) -> KVTensorMetadata:
109+
serialized_metadata = self.wrapped.get_kvtensor_serializable_metadata()
110+
try:
111+
metadata_itr = 0
112+
num_rdb_ckpts = int(serialized_metadata[0])
113+
metadata_itr += 1
114+
checkpoint_paths: List[str] = []
115+
for i in range(num_rdb_ckpts):
116+
checkpoint_paths.append(serialized_metadata[i + metadata_itr])
117+
metadata_itr += num_rdb_ckpts
118+
tbe_uuid = serialized_metadata[metadata_itr]
119+
metadata_itr += 1
120+
rdb_num_shards = int(serialized_metadata[metadata_itr])
121+
metadata_itr += 1
122+
rdb_num_threads = int(serialized_metadata[metadata_itr])
123+
metadata_itr += 1
124+
max_D = int(serialized_metadata[metadata_itr])
125+
metadata_itr += 1
126+
table_offset = int(serialized_metadata[metadata_itr])
127+
metadata_itr += 1
128+
table_shape: List[int] = []
129+
table_shape.append(int(serialized_metadata[metadata_itr]))
130+
metadata_itr += 1
131+
table_shape.append(int(serialized_metadata[metadata_itr]))
132+
metadata_itr += 1
133+
dtype = int(serialized_metadata[metadata_itr])
134+
metadata_itr += 1
135+
checkpoint_uuid = serialized_metadata[metadata_itr]
136+
metadata_itr += 1
137+
res = KVTensorMetadata(
138+
checkpoint_paths,
139+
tbe_uuid,
140+
rdb_num_shards,
141+
rdb_num_threads,
142+
max_D,
143+
table_offset,
144+
table_shape,
145+
dtype,
146+
checkpoint_uuid,
147+
)
148+
149+
return res
150+
except Exception as e:
151+
logging.error(
152+
f"Failed to parse metadata: {e}, here is metadata: {serialized_metadata}"
153+
)
154+
raise e
155+
54156
@property
55157
def wrapped(self):
56158
"""
@@ -249,6 +351,9 @@ def __eq__(self, tensor1, tensor2, **kwargs):
249351

250352
return torch.equal(tensor1.full_tensor(), tensor2.full_tensor())
251353

354+
def get_kvtensor_serializable_metadata(self) -> List[str]:
355+
return self._wrapped.get_kvtensor_serializable_metadata()
356+
252357
def __hash__(self):
253358
return id(self)
254359

0 commit comments

Comments
 (0)