Skip to content

Commit 134304a

Browse files
saumishrfacebook-github-bot
authored andcommitted
get_state_dict_for_key API with support for flat and nested state dicts for DCP checkpoints (#932)
Summary: Pull Request resolved: #932 get_state_dict_for_key API with user provided state dict Reviewed By: diego-urgell, JKSenthil Differential Revision: D64502115 fbshipit-source-id: 865a043f61f07596fb83aac40025da27aa3d214d
1 parent 1c7e9b6 commit 134304a

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

torchtnt/framework/_test_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,39 @@ def state_dict(self) -> Dict[str, Any]:
297297
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
298298
self.sum = state_dict["sum"]
299299
self.count = state_dict["count"]
300+
301+
302+
class DummyStatefulConfig:
303+
def __init__(
304+
self,
305+
storage_path: str,
306+
lazy_loading: bool,
307+
num_workers_per_gpu: int,
308+
max_batch_length: int,
309+
) -> None:
310+
self.storage_path = storage_path
311+
self.lazy_loading = lazy_loading
312+
self.num_workers_per_gpu = num_workers_per_gpu
313+
self.max_batch_length = max_batch_length
314+
315+
def state_dict(self) -> Dict[str, Any]:
316+
return {
317+
"storage_path": self.storage_path,
318+
"data": {
319+
"lazy_loading": self.lazy_loading,
320+
"train": {
321+
"num_workers_per_gpu": self.num_workers_per_gpu,
322+
"dynamic_batch_config": {
323+
"max_batch_length": self.max_batch_length,
324+
},
325+
},
326+
},
327+
}
328+
329+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
330+
self.storage_path = state_dict["storage_path"]
331+
self.lazy_loading = state_dict["data"]["lazy_loading"]
332+
self.num_workers_per_gpu = state_dict["data"]["train"]["num_workers_per_gpu"]
333+
self.max_batch_length = state_dict["data"]["train"]["dynamic_batch_config"][
334+
"max_batch_length"
335+
]

0 commit comments

Comments
 (0)