|
14 | 14 | from enum import Enum |
15 | 15 | from functools import total_ordering |
16 | 16 | from operator import xor |
17 | | -from typing import Dict, List, Literal, Optional, Pattern, Tuple, Union |
| 17 | +from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, Union |
18 | 18 |
|
19 | 19 | import fsspec |
| 20 | +import torch |
20 | 21 | import torch.distributed as dist |
21 | 22 | from fsspec.core import url_to_fs |
22 | 23 | from pyre_extensions import none_throws |
| 24 | +from torch import nn |
| 25 | +from torch.distributed.tensor import distribute_tensor |
| 26 | +from torch.nn.modules.module import _IncompatibleKeys |
23 | 27 | from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast |
24 | 28 |
|
25 | 29 | logger: logging.Logger = logging.getLogger(__name__) |
@@ -849,3 +853,50 @@ def _metadata_exists( |
849 | 853 | fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str |
850 | 854 | ) -> bool: |
851 | 855 | return fs.exists(os.path.join(dirpath, metadata_fname)) |
| 856 | + |
| 857 | + |
| 858 | +def load_from_full_model_state_dict( |
| 859 | + model: torch.nn.Module, |
| 860 | + full_sd: Dict[str, Any], |
| 861 | + device: torch.device, |
| 862 | + strict: bool = False, |
| 863 | + cpu_offload: bool = False, |
| 864 | + release_sd: bool = True, |
| 865 | +) -> _IncompatibleKeys: |
| 866 | + """ |
| 867 | + Converting full state dict into a sharded state dict |
| 868 | + and loading it into FSDP model |
| 869 | + Args: |
| 870 | + model (Module): Model to generate fully qualified names for cpu_state_dict |
| 871 | + full_sd (dict[str, Any]): a full state dict to load into the model (mmap=True for efficient loading) |
| 872 | + device (torch.device): device used to move full state dict tensors |
| 873 | + strict (bool): flag to check if to load the model in strict mode |
| 874 | + cpu_offload (bool): flag to check if offload to CPU is enabled |
| 875 | + release_sd (bool): whether to release memory of full_sd to save ram usage |
| 876 | + Returns: |
| 877 | + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: |
| 878 | + * **missing_keys** is a list of str containing the missing keys |
| 879 | + * **unexpected_keys** is a list of str containing the unexpected keys |
| 880 | + """ |
| 881 | + meta_sharded_sd = model.state_dict() |
| 882 | + sharded_sd = {} |
| 883 | + for param_name, full_tensor in sorted(full_sd.items()): |
| 884 | + sharded_meta_param = meta_sharded_sd.get(param_name) |
| 885 | + assert sharded_meta_param is not None, f"{param_name} not found in model" |
| 886 | + full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) |
| 887 | + if not hasattr(sharded_meta_param, "device_mesh"): |
| 888 | + # In cases where parts of the model aren't sharded, some parameters will be plain tensors |
| 889 | + sharded_tensor = full_tensor |
| 890 | + else: |
| 891 | + sharded_tensor = distribute_tensor( |
| 892 | + full_tensor, |
| 893 | + sharded_meta_param.device_mesh, |
| 894 | + sharded_meta_param.placements, |
| 895 | + ) |
| 896 | + if cpu_offload: |
| 897 | + sharded_tensor = sharded_tensor.cpu() |
| 898 | + sharded_sd[param_name] = nn.Parameter(sharded_tensor) |
| 899 | + if release_sd: |
| 900 | + full_sd[param_name] = None |
| 901 | + # choose `assign=True` since we cannot call `copy_` on meta tensor |
| 902 | + return model.load_state_dict(sharded_sd, strict=strict, assign=True) |
0 commit comments