9
9
from __future__ import annotations
10
10
11
11
import functools
12
- from typing import Optional , Union
12
+ import logging
13
+ from typing import List , Optional , Union
13
14
14
15
import torch
15
16
@@ -25,6 +26,58 @@ def decorator(func):
25
26
return decorator
26
27
27
28
29
+ class KVTensorMetadata :
30
+ """
31
+ Class that is used to represent a KVTensor as a Serialized Metadata in python
32
+ This object is used to reconstruct the KVTensor in the publish component
33
+ """
34
+
35
+ checkpoint_paths : List [str ]
36
+ tbe_uuid : str
37
+ rdb_num_shards : int
38
+ rdb_num_threads : int
39
+ max_D : int
40
+ table_offset : int
41
+ table_shape : List [int ]
42
+ dtype : int
43
+ checkpoint_uuid : str
44
+
45
+ def __init__ (
46
+ self ,
47
+ checkpoint_paths : List [str ],
48
+ tbe_uuid : str ,
49
+ rdb_num_shards : int ,
50
+ rdb_num_threads : int ,
51
+ max_D : int ,
52
+ table_offset : int ,
53
+ table_shape : List [int ],
54
+ dtype : int ,
55
+ checkpoint_uuid : str ,
56
+ ) -> None :
57
+ """
58
+ Ensure caller loads the module before creating this object.
59
+
60
+ ```
61
+ load_torch_module(
62
+ "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings"
63
+ )
64
+ ```
65
+
66
+ Args:
67
+
68
+ wrapped: torch.classes.fbgemm.KVTensorWrapper
69
+ """
70
+ self .checkpoint_paths = checkpoint_paths
71
+ self .tbe_uuid = tbe_uuid
72
+ self .rdb_num_shards = rdb_num_shards
73
+ self .rdb_num_threads = rdb_num_threads
74
+ self .max_D = max_D
75
+ self .table_offset = table_offset
76
+ self .table_shape = table_shape
77
+ self .checkpoint_uuid = checkpoint_uuid
78
+ self .dtype = dtype
79
+
80
+
28
81
class PartiallyMaterializedTensor :
29
82
"""
30
83
A tensor-like object that represents a partially materialized tensor in memory.
@@ -51,6 +104,55 @@ def __init__(self, wrapped, is_virtual: bool = False) -> None:
51
104
self ._is_virtual = is_virtual
52
105
self ._requires_grad = False
53
106
107
+ @property
108
+ def generate_kvtensor_metadata (self ) -> KVTensorMetadata :
109
+ serialized_metadata = self .wrapped .get_kvtensor_serializable_metadata ()
110
+ try :
111
+ metadata_itr = 0
112
+ num_rdb_ckpts = int (serialized_metadata [0 ])
113
+ metadata_itr += 1
114
+ checkpoint_paths : List [str ] = []
115
+ for i in range (num_rdb_ckpts ):
116
+ checkpoint_paths .append (serialized_metadata [i + metadata_itr ])
117
+ metadata_itr += num_rdb_ckpts
118
+ tbe_uuid = serialized_metadata [metadata_itr ]
119
+ metadata_itr += 1
120
+ rdb_num_shards = int (serialized_metadata [metadata_itr ])
121
+ metadata_itr += 1
122
+ rdb_num_threads = int (serialized_metadata [metadata_itr ])
123
+ metadata_itr += 1
124
+ max_D = int (serialized_metadata [metadata_itr ])
125
+ metadata_itr += 1
126
+ table_offset = int (serialized_metadata [metadata_itr ])
127
+ metadata_itr += 1
128
+ table_shape : List [int ] = []
129
+ table_shape .append (int (serialized_metadata [metadata_itr ]))
130
+ metadata_itr += 1
131
+ table_shape .append (int (serialized_metadata [metadata_itr ]))
132
+ metadata_itr += 1
133
+ dtype = int (serialized_metadata [metadata_itr ])
134
+ metadata_itr += 1
135
+ checkpoint_uuid = serialized_metadata [metadata_itr ]
136
+ metadata_itr += 1
137
+ res = KVTensorMetadata (
138
+ checkpoint_paths ,
139
+ tbe_uuid ,
140
+ rdb_num_shards ,
141
+ rdb_num_threads ,
142
+ max_D ,
143
+ table_offset ,
144
+ table_shape ,
145
+ dtype ,
146
+ checkpoint_uuid ,
147
+ )
148
+
149
+ return res
150
+ except Exception as e :
151
+ logging .error (
152
+ f"Failed to parse metadata: { e } , here is metadata: { serialized_metadata } "
153
+ )
154
+ raise e
155
+
54
156
@property
55
157
def wrapped (self ):
56
158
"""
@@ -249,6 +351,9 @@ def __eq__(self, tensor1, tensor2, **kwargs):
249
351
250
352
return torch .equal (tensor1 .full_tensor (), tensor2 .full_tensor ())
251
353
354
+ def get_kvtensor_serializable_metadata (self ) -> List [str ]:
355
+ return self ._wrapped .get_kvtensor_serializable_metadata ()
356
+
252
357
def __hash__ (self ):
253
358
return id (self )
254
359
0 commit comments