Skip to content

Commit 685e045

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add load_from_full_model_state_dict util (meta-pytorch#1031)
Summary: Pull Request resolved: meta-pytorch#1031 Reviewed By: diego-urgell Differential Revision: D82737624 fbshipit-source-id: bd541bd6ed3f81181716dd0663b512b92737e60a
1 parent 093952e commit 685e045

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

torchtnt/utils/checkpoint.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
from enum import Enum
1515
from functools import total_ordering
1616
from operator import xor
17-
from typing import Dict, List, Literal, Optional, Pattern, Tuple, Union
17+
from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, Union
1818

1919
import fsspec
20+
import torch
2021
import torch.distributed as dist
2122
from fsspec.core import url_to_fs
2223
from pyre_extensions import none_throws
24+
from torch import nn
25+
from torch.distributed.tensor import distribute_tensor
26+
from torch.nn.modules.module import _IncompatibleKeys
2327
from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast
2428

2529
logger: logging.Logger = logging.getLogger(__name__)
@@ -849,3 +853,50 @@ def _metadata_exists(
849853
fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str
850854
) -> bool:
851855
return fs.exists(os.path.join(dirpath, metadata_fname))
856+
857+
858+
def load_from_full_model_state_dict(
859+
model: torch.nn.Module,
860+
full_sd: Dict[str, Any],
861+
device: torch.device,
862+
strict: bool = False,
863+
cpu_offload: bool = False,
864+
release_sd: bool = True,
865+
) -> _IncompatibleKeys:
866+
"""
867+
Converting full state dict into a sharded state dict
868+
and loading it into FSDP model
869+
Args:
870+
model (Module): Model to generate fully qualified names for cpu_state_dict
871+
full_sd (dict[str, Any]): a full state dict to load into the model (mmap=True for efficient loading)
872+
device (torch.device): device used to move full state dict tensors
873+
strict (bool): flag to check if to load the model in strict mode
874+
cpu_offload (bool): flag to check if offload to CPU is enabled
875+
release_sd (bool): whether to release memory of full_sd to save ram usage
876+
Returns:
877+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
878+
* **missing_keys** is a list of str containing the missing keys
879+
* **unexpected_keys** is a list of str containing the unexpected keys
880+
"""
881+
meta_sharded_sd = model.state_dict()
882+
sharded_sd = {}
883+
for param_name, full_tensor in sorted(full_sd.items()):
884+
sharded_meta_param = meta_sharded_sd.get(param_name)
885+
assert sharded_meta_param is not None, f"{param_name} not found in model"
886+
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
887+
if not hasattr(sharded_meta_param, "device_mesh"):
888+
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
889+
sharded_tensor = full_tensor
890+
else:
891+
sharded_tensor = distribute_tensor(
892+
full_tensor,
893+
sharded_meta_param.device_mesh,
894+
sharded_meta_param.placements,
895+
)
896+
if cpu_offload:
897+
sharded_tensor = sharded_tensor.cpu()
898+
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
899+
if release_sd:
900+
full_sd[param_name] = None
901+
# choose `assign=True` since we cannot call `copy_` on meta tensor
902+
return model.load_state_dict(sharded_sd, strict=strict, assign=True)

0 commit comments

Comments
 (0)