diff --git a/fast_llm/config.py b/fast_llm/config.py index cdc1dd5d..0004501b 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -379,6 +379,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ + if self._validated: + return self try: expected_class = self.get_subclass(self.type) except KeyError as e: @@ -392,15 +394,14 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: # Done during validation so we don't accidentally use default subtypes as updates. self.type = self.dynamic_type_name - if not self._validated: - try: - self._validate() - except (ValidationError, FieldTypeError) as e: - if _is_validating: - raise - else: - raise type(e)("\n".join(e.args)) from None - self._validated = True + try: + self._validate() + except (ValidationError, FieldTypeError) as e: + if _is_validating: + raise + else: + raise type(e)("\n".join(e.args)) from None + self._validated = True return self def _validate(self) -> None: diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 6681d70e..7faf599f 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -54,7 +54,7 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm})) shard_names = self.get_shard_names(config) # Make sure all shards to load are in the checkpoint. - Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards)) + Assert.leq(set(shard_names), set(loaded_metadata.shards)) Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. @@ -96,7 +96,13 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: ) path = config.path / f"rank_{rank}.safetensors" log_main_rank(f"Loading from {path}", log_fn=logger.info) - # TODO: skip shards without overlap. + + # First do a dry run to check if there is any overlap. + if not self._has_shard_overlaps(loaded_model): + # No overlap found, skip this file. + continue + + # TODO: Lazy loading? with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard if "state_shard" in f.keys(): @@ -112,22 +118,34 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names } - for shard_name, loaded_shard in loaded_shards.items(): - loaded_model.get_shard_meta(shard_name).validate(loaded_shard) - - self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names} - - counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device) - for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): - for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): - self_fsdp.copy_shard_overlaps( - loaded_fsdp, - self_fsdp_shards, - loaded_fsdp_shards, - counter, - self._model.distributed.device, - ) - - context.mark_as_loaded(counter.item()) + self._copy_shard_overlaps(loaded_model, loaded_shards, context) return loaded_metadata.metadata + + def _has_shard_overlaps(self, loaded_model) -> bool: + for _, loaded_fsdp, _ in loaded_model.split_shards_by_fsdp({}): + for _, self_fsdp, _ in self._model.split_shards_by_fsdp({}): + counter = self_fsdp.copy_shard_overlaps( + loaded_fsdp, + None, + None, + ) + if counter: + return True + return False + + def _copy_shard_overlaps(self, loaded_model, loaded_shards, context): + for shard_name, loaded_shard in loaded_shards.items(): + loaded_model.get_shard_meta(shard_name).validate(loaded_shard) + + self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in loaded_shards} + + for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): + for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): + counter = self_fsdp.copy_shard_overlaps( + loaded_fsdp, + self_fsdp_shards, + loaded_fsdp_shards, + ) + for parameter, count in counter.items(): + context.mark_as_loaded(count, parameter, True) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index e72a3a15..2e2a0188 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -5,9 +5,9 @@ from torch.distributed import all_reduce from fast_llm.core.distributed import add_ephemeral_timeout +from fast_llm.engine.multi_stage.config import ShardName from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.functional.triton.pointwise import triton_fill -from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -48,14 +48,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: self._validate() - def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None) -> None: + def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None, partial: bool = False) -> None: self._loaded += count if parameter is not None: parameter_name, shard_name = parameter if shard_name not in self._loaded_parameters: self._loaded_parameters[shard_name] = {} - Assert.not_incl(parameter_name, self._loaded_parameters[shard_name]) - self._loaded_parameters[shard_name][parameter_name] = count + if not partial and parameter_name in self._loaded_parameters[shard_name]: + raise ValueError(f"Duplicate loaded parameter ({parameter_name}, {shard_name})") + self._loaded_parameters[shard_name][parameter_name] = ( + self._loaded_parameters[shard_name].get(parameter_name, 0) + count + ) def _validate(self) -> None: errors = [] @@ -105,7 +108,7 @@ def _check_missing(self, errors: list[str]) -> None: f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}" f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" ) - missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() + missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() if fsdp._global_pad > 0 else 0 if missing_for_pad > 0: global_total += missing_for_pad local_missing_for_pad = ( @@ -127,52 +130,63 @@ def _check_missing(self, errors: list[str]) -> None: ) def _check_parameters(self, errors: list[str]) -> None: - loaded_shard_names = set(self._loaded_parameters) - shard_names = set(self._self_shards) - if loaded_shard_names != shard_names: - errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}") - for shard_name in shard_names & loaded_shard_names: - counter_per_parameter = { - parameter_name: self._loaded_parameters[shard_name].pop(parameter_name, None) - for parameter_name in self._model.parameter_names - } - for parameter_name, count in self._loaded_parameters[shard_name].items(): - errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})') - for parameter_name, counter in counter_per_parameter.items(): - if self._model.is_parameter_on_device(parameter_name): - if counter is None: - errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"') - elif counter is not None and counter > 0: - errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') - if self._distributed.world_group is not None: - counter_list = [] - for parameter_name, counter in counter_per_parameter.items(): - parameter_stage = self._model.get_parameter_stage(parameter_name) - parameter_meta = parameter_stage.get_parameter_meta(parameter_name) - if ( - counter is None - or (not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0) - or parameter_stage.is_tied_weight_copy - ): - # Ignore the counter from missing or duplicate tensors. - counter = 0 - counter_list.append(counter) - - counter_tensor = torch.tensor(counter_list, dtype=torch.int64).to(self._distributed.device) - - add_ephemeral_timeout(self._distributed.world_group, self._timeout) - all_reduce(counter_tensor, group=self._distributed.world_group) - counter_per_parameter = { - parameter_name: counter - for parameter_name, counter in zip(counter_per_parameter, counter_tensor.tolist()) - } - for parameter_name, counter in counter_per_parameter.items(): - parameter_size = ( - self._model.get_parameter_stage(parameter_name) - .get_parameter_meta(parameter_name) - .global_shape.numel() + if set(self._loaded_parameters) != set(self._self_shards): + errors.append(f"Incorrect loaded shards: {tuple(self._loaded_parameters)}!={tuple(self._self_shards)}") + + counters = [] + # Compare local counts against expected values. + for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: + for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: + counter = self._loaded_parameters[shard_name].pop(parameter_meta.tensor_name, 0) + local_size = ( + fsdp.get_parameter_size_in_shard(parameter_name, shard_name) + if self._model.is_parameter_on_device(parameter_name) + else 0 ) + if counter != local_size: + errors.append( + f'Local counter mismatch for parameter "{parameter_name}"' + f' and shard "{shard_name}": loaded {counter}, expected {local_size}' + ) + + counter_ = counter + # Accumulate in a list for global counter check. + if ( + not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0 + ) or stage.is_tied_weight_copy: + # Ignore the counter from duplicate tensors. + counter = 0 + if parameter_name == "layers.1.norm_1.weight": + logger.info( + f"Parameter {parameter_name} local {counter_} keep {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" + ) + counters.append(counter) + + # Check for unexpected parameters. + for shard_name, loaded in self._loaded_parameters.items(): + for parameter_name, count in loaded.items(): + errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})') + + # All-reduce to get global counts. + if self._distributed.world_group is not None: + counter_tensor = torch.tensor(counters, dtype=torch.int64).to(self._distributed.device) + # This may be the first distributed barrier after loading, so we need to wait for everyone to finish. + add_ephemeral_timeout(self._distributed.world_group, self._timeout) + all_reduce(counter_tensor, group=self._distributed.world_group) + counters = counter_tensor.tolist() + + # Compare global counts against expected values. + for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: + for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: + counter = counters.pop(0) + if parameter_name == "layers.1.norm_1.weight": + logger.info( + f"Parameter {parameter_name} global {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" + ) + parameter_size = parameter_meta.global_shape.numel() if counter != parameter_size: errors.append( - f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}' + f'Global counter mismatch for parameter "{parameter_name}"' + f' and shard "{shard_name}": loaded {counter}, expected {parameter_size}' ) + assert not counters diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 49ce1525..99c1bcf7 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -66,6 +66,10 @@ def parallel_dim_index(self) -> int | None: def parallel_group(self) -> "ProcessGroup|None": return None if self._parallel_dim is None else self._parallel_dim.group + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self.parallel_dim is not None + return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + class CompositeTensorDim(TensorDim): def __init__(self, name: str, dims: tuple[TensorDim, ...]): @@ -106,6 +110,12 @@ def global_expanded_shape(self) -> tuple[int, ...]: def parallel_dim_index(self) -> int | None: return self._parallel_dim_index + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self.parallel_dim_index is not None + dims = list(self.dims) + dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) + class DefaultDimNames: # Scalar diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8b689cde..7fd9fed1 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -61,7 +61,7 @@ class DistributedDim: name: str size: int rank: int - global_ranks: range | tuple[int, ...] = None + global_ranks: range | tuple[int, ...] def __post_init__(self): self._is_setup = False @@ -275,8 +275,6 @@ def _validate(self) -> None: data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - print("data_stride", data_stride) - print("pipeline_stride", pipeline_stride) self._add_distributed_dim( DistributedDim( diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 9719ff2e..fbbf9b6a 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -90,13 +90,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): global _default_pool assert _default_pool is self _default_pool = None + self.shutdown() - def __del__(self): + def shutdown(self): # Shutdown the process group backend explicitly to prevent a nccl warning. # We can't call `destroy_process_group` directly because pytorch doesn't know about it. for group in self._process_groups.values(): - if group is not None and hasattr(group, "_shutdown"): - group._shutdown() # noqa + group.shutdown() + + def __del__(self): + self.shutdown() _default_pool: ProcessGroupPool | None = None @@ -114,7 +117,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False, pool: ProcessGroupPool | None = None): + def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) assert self._config.reference_config is None self._use_cpu = use_cpu @@ -128,14 +131,13 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False, pool: Proce self.device = torch.device(self._config.local_rank) torch.cuda.set_device(self.device) - if pool is None and _default_pool is None: + self._local_pool = _default_pool is None + if self._local_pool: self._pool = ProcessGroupPool(self._config.rank, self._config.world_size, self._config.timeout) else: - if pool is None: - pool = _default_pool - Assert.eq(pool._world_size, self._config.world_size) - Assert.eq(pool._rank, self._config.rank) - self._pool = pool + self._pool = _default_pool + Assert.eq(self._pool._world_size, self._config.world_size) + Assert.eq(self._pool._rank, self._config.rank) self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) self.data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.data]) @@ -210,3 +212,7 @@ def set_step(self, step: int, phase: PhaseType) -> None: seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase] self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED) self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) + + def __del__(self): + if self._local_pool: + self._pool.shutdown() diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5cf51dd5..5b44bf14 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -1,3 +1,5 @@ +import dataclasses +import math import typing import torch @@ -8,7 +10,7 @@ from fast_llm.core.ops import gather_op from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.engine.distributed.config import DistributedDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill @@ -35,18 +37,16 @@ def __init__( self, name: str, parameter_metas: list[ParameterMeta], - fsdp_dim: DistributedDim, - training_dtype: DataType, - gradient_buffer_dtype: DataType, - optimization_dtype: DataType, + distributed_config: DistributedConfig, + full_precision_gradient_buffer: bool = False, + full_precision_shards: bool = True, + is_tied_weight_copy: bool = False, ): self._name = name self._parameter_metas = {parameter_meta.tensor_name: parameter_meta for parameter_meta in parameter_metas} - self._fsdp_dim = fsdp_dim - self._training_dtype = training_dtype - self._gradient_buffer_dtype = gradient_buffer_dtype - self._optimization_dtype = optimization_dtype - + self._distributed_config = distributed_config + self._fsdp_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.data) + self._is_tied_weight_copy = is_tied_weight_copy self._requires_grad = any(parameter_meta.requires_grad for parameter_meta in self._parameter_metas.values()) parameter_sizes = [meta.numel() for meta in self._parameter_metas.values()] @@ -75,29 +75,41 @@ def __init__( ) # TODO: Use parallel_dim property instead? - weight_shard_dim = TensorDim("weight_shard", (self._parameter_count + self._global_pad) // self._fsdp_dim.size) - grad_shard_dim = TensorDim("grad_shard", weight_shard_dim.size if self._requires_grad else 0) + weight_shard_dim = TensorDim("weight_shard", self._shard_size) + grad_shard_dim = TensorDim("grad_shard", self._shard_size if self._requires_grad else 0) self._weight_shard_meta = TensorMeta.from_dims( (weight_shard_dim,), tensor_name=f"{self._name}_weight_shard", - dtype=self._optimization_dtype.torch, + dtype=( + self._distributed_config.optimization_dtype + if full_precision_shards + else self._distributed_config.training_dtype + ).torch, ) # TODO: Distinguish grad and optimizer shard? self._grad_shard_meta = TensorMeta.from_dims( (grad_shard_dim,), tensor_name=f"{self._name}_grad_shard", - dtype=self._optimization_dtype.torch, + dtype=( + self._distributed_config.optimization_dtype + if full_precision_shards + else self._distributed_config.training_dtype + ).torch, ) self._weight_buffer_meta = TensorMeta.from_dims( (TensorDim("weight_buffer", weight_shard_dim.size * self._fsdp_dim.size),), tensor_name=f"{self._name}_weight_buffer", - dtype=self._training_dtype.torch, + dtype=self._distributed_config.training_dtype.torch, ) self._grad_buffer_meta = TensorMeta.from_dims( (TensorDim("grad_buffer", weight_shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), tensor_name=f"{self._name}_grad_buffer", - dtype=self._gradient_buffer_dtype.torch, + dtype=( + self._distributed_config.optimization_dtype + if full_precision_gradient_buffer + else self._distributed_config.training_dtype + ).torch, ) @property @@ -302,7 +314,7 @@ def import_state_tensor( """ Assert.eq(shard.shape, (self._shard_size,)) tensor_shard = self.parameter_global_to_shard(tensor, parameter_name) - begin, end = self._parameter_range_in_shard(parameter_name) + begin, end = self._get_parameter_range_in_shard(parameter_name) Assert.eq(tensor_shard.numel(), end - begin) shard[begin:end].copy_(tensor_shard) return end - begin @@ -385,20 +397,34 @@ def reduce_gradients( else: triton_copy(self._grad_buffer_local_shard, self._grad_shard) - def _parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: + def _get_parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: begin = self.index_buffer_to_shard(self.get_parameter_begin_in_buffer(parameter_name)) end = self.index_buffer_to_shard(self.get_parameter_end_in_buffer(parameter_name)) return begin, end + def get_parameter_size_in_shard(self, parameter_name: str, shard_name: str = ShardName.weights) -> int: + if not self._requires_grad and shard_name != ShardName.weights: + return 0 + begin, end = self._get_parameter_range_in_shard(parameter_name) + return end - begin + def invalidate_buffer(self) -> None: # Buffer is no longer valid (Updated weights or overwritten by other stage) assert self._mode.support_forward self._is_restored = False def parameter_global_to_shard( - self, global_param: torch.Tensor | SafeTensorSlice, parameter_name: str + self, + global_param: torch.Tensor | SafeTensorSlice, + parameter_name: str, + *, + _parameter_meta: TensorMeta | None = None, ) -> torch.Tensor: - shard_param = self.get_parameter_meta(parameter_name).global_to_local(global_param).flatten() + if _parameter_meta is None: + # Used with reduced tensor-parallel in `copy_shard_overlaps` + _parameter_meta = self._parameter_metas[parameter_name] + # This may copy the data. + shard_param = _parameter_meta.global_to_local(global_param).flatten() if self._fsdp_dim.size > 1: shard_param = shard_param[ self._index_buffer_to_param( @@ -407,13 +433,14 @@ def parameter_global_to_shard( ] return shard_param - def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, device: torch.device) -> torch.Tensor: + def _get_parameter_shard_indices_in_full_weight( + self, parameter_name: str, device: torch.device, parameter_meta: TensorMeta + ) -> torch.Tensor: """ Create an index array for the global parameter, where each entry corresponds to the index where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - parameter_meta = self.get_parameter_meta(parameter_name) # Create an empty index for the global parameter. index = torch.full( @@ -423,42 +450,194 @@ def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, devic device=device, ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard - begin, end = self._parameter_range_in_shard(parameter_name) - self.parameter_global_to_shard(index, parameter_name).copy_( - torch.arange(begin, end, dtype=torch.int64, device=device) - ) + begin, end = self._get_parameter_range_in_shard(parameter_name) + + buffer_index = parameter_meta.global_to_local(index, expand=True) + # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. + # In that case, we work with a separate tensor to be copied back into `buffer_index`. + try: + buffer_index_flat = buffer_index.view(-1) + is_view = True + except RuntimeError: + buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) + is_view = False + + # Copy the shard indices at their respective positions in the flat buffer index. + buffer_index_flat[ + self._index_buffer_to_param( + self._fsdp_dim.rank * self._shard_size, parameter_name + ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) + ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) + + # If needed, copy the flat buffer index back into the index. + if not is_view: + buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) + return index def copy_shard_overlaps( self, - loaded_fsdp: "FSDP", - shards: dict[str, torch.Tensor], - loaded_shards: dict[str, torch.Tensor], - counter: torch.Tensor, - device: torch.device, - ) -> None: + loaded_fsdp: typing.Self, + shards: dict[str, torch.Tensor] | None, + loaded_shards: dict[str, torch.Tensor] | None, + ) -> dict[tuple[str, str], int]: """ See MultiStage._load_partial. - TODO: Not intended to work with frozen weights, need to enforce. """ - Assert.eq(set(shards), set(loaded_shards)) + if shards is not None: + Assert.eq(set(shards), set(loaded_shards)) index_overlap = [name for name in loaded_fsdp._parameter_metas if name in self._parameter_metas] - for name in index_overlap: - overlap_index_map = self.parameter_global_to_shard( - loaded_fsdp._get_parameter_shard_indices_in_full_weight(name, device), name + counter = {} + + self_tensor_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + loaded_tensor_dim = loaded_fsdp._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. + if (shared_tp := math.gcd(self_tensor_dim.size, loaded_tensor_dim.size)) > 1: + self_tensor_dim, self_new_size, self_shared_rank = _reduce_tensor_parallel_size(self_tensor_dim, shared_tp) + loaded_tensor_dim, loaded_new_size, loaded_shared_rank = _reduce_tensor_parallel_size( + loaded_tensor_dim, shared_tp + ) + + if self_shared_rank != loaded_shared_rank: + # Disjoint tensor-parallel slices, no possible overlap. + # (Duplicated parameters will be loaded from the new rank 0 which prevents unnecessary file loading). + return counter + + for parameter_name in index_overlap: + self_meta = self._parameter_metas[parameter_name] + loaded_meta = loaded_fsdp._parameter_metas[parameter_name] + + if shared_tp > 1: + self_meta = self_meta.replace_tensor_parallel_dim(self_tensor_dim) + loaded_meta = loaded_meta.replace_tensor_parallel_dim(loaded_tensor_dim) + + if not loaded_meta.is_tensor_parallel and loaded_tensor_dim.rank != 0: + # Loaded parameter is tensor-parallel duplicate, ignore. + continue + + if self_meta.tensor_parallel_size == loaded_meta.tensor_parallel_size == 1: + self._copy_shard_overlaps(loaded_fsdp, shards, loaded_shards, parameter_name, counter) + else: + self._copy_tensor_parallel_shard_overlaps( + loaded_fsdp, shards, loaded_shards, parameter_name, counter, self_meta, loaded_meta + ) + + return counter + + def _copy_shard_overlaps( + self, + loaded_fsdp: typing.Self, + shards: dict[str, torch.Tensor] | None, + loaded_shards: dict[str, torch.Tensor] | None, + parameter_name: str, + counter: dict[tuple[str, str], int], + ): + # Common case: the overlap is a contiguous slice of the shards. + + # Find the slice of the parameter contained in each shard. + self_shard_begin_in_buffer = self._fsdp_dim.rank * self._shard_size + self_shard_end_in_buffer = (self._fsdp_dim.rank + 1) * self._shard_size + self_shard_begin_in_param = self._index_buffer_to_param(self_shard_begin_in_buffer, parameter_name) + self_shard_end_in_param = self._index_buffer_to_param(self_shard_end_in_buffer, parameter_name) + loaded_shard_begin_in_buffer = loaded_fsdp._fsdp_dim.rank * loaded_fsdp._shard_size + loaded_shard_end_in_buffer = (loaded_fsdp._fsdp_dim.rank + 1) * loaded_fsdp._shard_size + loaded_shard_begin_in_param = loaded_fsdp._index_buffer_to_param(loaded_shard_begin_in_buffer, parameter_name) + loaded_shard_end_in_param = loaded_fsdp._index_buffer_to_param(loaded_shard_end_in_buffer, parameter_name) + + # Calculate the overap. + overlap_begin_in_param = max(self_shard_begin_in_param, loaded_shard_begin_in_param) + overlap_end_in_param = min(self_shard_end_in_param, loaded_shard_end_in_param) + + if (overlap_size := overlap_end_in_param - overlap_begin_in_param) <= 0: + return + + # Map the overlap back to the shards. + overlap_begin_in_self_shard = ( + self._parameter_begins_in_buffer[parameter_name] + overlap_begin_in_param - self_shard_begin_in_buffer + ) + overlap_begin_in_loaded_shard = ( + loaded_fsdp._parameter_begins_in_buffer[parameter_name] + + overlap_begin_in_param + - loaded_shard_begin_in_buffer + ) + + if shards is None: + # Dry run. + counter[(parameter_name, "")] = overlap_size + return + + for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + continue + counter[(parameter_name, shard_name)] = overlap_size + + # Copy the overlap. + shard[overlap_begin_in_self_shard : overlap_begin_in_self_shard + overlap_size] = ( + loaded_shards[shard_name][overlap_begin_in_loaded_shard : overlap_begin_in_loaded_shard + overlap_size] + if loaded_shards[shard_name].numel() > 0 + else 0 ) - overlap_mask = overlap_index_map >= 0 - overlap_index_map_masked = overlap_index_map[overlap_mask] - overlap_count = overlap_mask.sum() - begin, end = self._parameter_range_in_shard(name) - - for shard_name, shard in shards.items(): - # Shards can be empty (frozen weights) - if shard.numel() == 0: - continue - if loaded_shards[shard_name].numel() == 0: - shard[begin:end][overlap_mask] = 0 - counter += overlap_count - continue - shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] - counter += overlap_count + + def _copy_tensor_parallel_shard_overlaps( + self, + loaded_fsdp: typing.Self, + shards: dict[str, torch.Tensor] | None, + loaded_shards: dict[str, torch.Tensor] | None, + parameter_name: str, + counter: dict[tuple[str, str], int], + self_meta: TensorMeta, + loaded_meta: TensorMeta, + ): + + self_begin, self_end = self._get_parameter_range_in_shard(parameter_name) + loaded_begin, loaded_end = loaded_fsdp._get_parameter_range_in_shard(parameter_name) + if self_begin >= self_end or loaded_begin >= loaded_end: + # Parameter is not present in both shards, no overlap. + return + + # Tensor-parallel case: the overlap cannot be represented as a slice. + if shards is None: + # Dry run. Since we only need to know if there can be overlap, + # we skip the slow computation and return a dummy value. + counter[(parameter_name, "")] = 1 + return + + device = next(iter(shards.values())).device + # Create an array that associates each entry in the `parameter_name` slice of `shard` + # to the index of the same parameter entry in `loaded_shard`, or -1 if not present. + overlap_index_map = self.parameter_global_to_shard( + loaded_fsdp._get_parameter_shard_indices_in_full_weight(parameter_name, device, loaded_meta), + parameter_name, + _parameter_meta=self_meta, + ) + # Create a mask to exclude the missing entries. + overlap_mask = overlap_index_map >= 0 + overlap_index_map_masked = overlap_index_map[overlap_mask] + overlap_size = overlap_mask.sum().item() + if overlap_size == 0: + return + begin, end = self._get_parameter_range_in_shard(parameter_name) + + for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + continue + counter[(parameter_name, shard_name)] = overlap_size + # Masked copy of the overlap index map. + shard[begin:end][overlap_mask] = ( + loaded_shards[shard_name][overlap_index_map_masked] if loaded_shards[shard_name].numel() > 0 else 0 + ) + + +def _reduce_tensor_parallel_size(distributed_dim: DistributedDim, shared_size: int): + new_size = distributed_dim.size // shared_size + shared_rank = distributed_dim.rank // new_size + new_dim = dataclasses.replace( + distributed_dim, + size=new_size, + rank=distributed_dim.rank % new_size, + global_ranks=distributed_dim.global_ranks[shared_size * shared_rank : shared_size * (shared_rank + 1)], + ) + return new_dim, new_size, shared_rank diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 515b977a..1f734268 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -453,6 +453,13 @@ def distributed(self) -> Distributed: assert self._is_setup return self._distributed + @property + def stages_fsdp_parameters(self) -> typing.Generator[tuple[Stage, FSDP, str, ParameterMeta], None, None]: + for stage in self._stages: + for fsdp in stage.fsdps: + for parameter_name in fsdp.parameter_names: + yield stage, fsdp, parameter_name, stage.get_parameter_meta(parameter_name) + def invalidate_buffers(self) -> None: for stage in self._stages_on_device.values(): stage.invalidate_buffer() diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 3ca28ba5..2f18f136 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -8,7 +8,7 @@ from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import ShardName, StageConfig, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP @@ -51,20 +51,13 @@ def __init__( parameter_metas, frozen_metas = self._get_parameter_metas() self._parameter_metas = parameter_metas + frozen_metas self._fsdps = [] - gradient_buffer_dtype = ( - self._distributed_config.optimization_dtype - if self._config.full_precision_gradients - else self._distributed_config.training_dtype - ) if parameter_metas: self._fsdps.append( FSDP( f"stage_{self._index}", parameter_metas, - self._distributed_config.get_distributed_dim(DistributedDimNames.data), - training_dtype=self._distributed_config.training_dtype, - gradient_buffer_dtype=gradient_buffer_dtype, - optimization_dtype=self._distributed_config.optimization_dtype, + self._distributed_config, + full_precision_gradient_buffer=self._config.full_precision_gradients, ) ) if frozen_metas: @@ -72,14 +65,9 @@ def __init__( FSDP( f"stage_{self._index}_frozen", frozen_metas, - self._distributed_config.get_distributed_dim(DistributedDimNames.data), - training_dtype=self._distributed_config.training_dtype, - gradient_buffer_dtype=gradient_buffer_dtype, - optimization_dtype=( - self._distributed_config.optimization_dtype - if self._config.store_frozen_weights_in_optimization_precision - else self._distributed_config.training_dtype.torch - ), + self._distributed_config, + full_precision_gradient_buffer=self._config.full_precision_gradients, + full_precision_shards=self._config.store_frozen_weights_in_optimization_precision, ) ) # TODO: Separate fsdp for tied weights? @@ -291,7 +279,7 @@ def get_param_groups( ) grads_norm_slices = [] for name in grad_norm_names: - begin, end = fsdp._parameter_range_in_shard(name) + begin, end = fsdp._get_parameter_range_in_shard(name) if len(grads_norm_slices) < 0 and begin == grads_norm_slices[-1].stop: grads_norm_slices[-1] = slice(grads_norm_slices[-1].start, end) else: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a3cf078d..766398d0 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -96,7 +96,7 @@ def run( done = training_progress.done completed_steps = training_progress.completed_steps - if done or self.config.enabled(completed_steps): + if (done and self.config.enabled()) or self.config.enabled(completed_steps): return self.evaluator.run(training_progress, run_index=self._config.get_run_count(completed_steps - 1)) else: return EvaluationMetrics() diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c549259..513510ec 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,9 +145,9 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1) - else: - loss = per_sample_loss.mean() + per_sample_loss = per_sample_loss * loss_mask + + loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.MEAN, group=group) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d2c01af0..d8425786 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -319,10 +319,10 @@ def _get_weight_and_bias_converters( class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), @@ -446,10 +446,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), @@ -498,10 +498,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), @@ -544,10 +544,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -568,10 +568,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk @@ -609,13 +609,13 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" modeling_file = modeling_mtp_llama.__file__ configuration_file = configuration_mtp_llama.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -697,6 +697,7 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" modeling_file = modeling_dream.__file__ configuration_file = configuration_dream.__file__ generation_utils_file = generation_utils.__file__ @@ -704,7 +705,6 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DreamModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -725,6 +725,7 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam ) format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" modeling_file = modeling_diffusion_llama.__file__ configuration_file = configuration_diffusion_llama.__file__ generation_utils_file = generation_utils.__file__ @@ -732,7 +733,6 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 84930756..d780e4d6 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import functools import math import typing @@ -86,12 +87,28 @@ def __new__( data, ) - @property - def is_tensor_parallel(self) -> bool: + @functools.cached_property + def tensor_parallel_dim_index(self) -> int | None: # TODO: Avoid hard-coded assumptions on tensor parallel. - return any( - dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor for dim in self.dims - ) + indexes = [ + i + for i, dim in enumerate(self.dims) + if dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor + ] + assert len(indexes) <= 1, indexes + return indexes[0] if indexes else None + + @functools.cached_property + def is_tensor_parallel(self) -> bool: + return self.tensor_parallel_dim_index is not None + + @functools.cached_property + def tensor_parallel_size(self) -> int: + return self.dims[self.tensor_parallel_dim_index].parallel_dim.size if self.is_tensor_parallel else 1 + + @functools.cached_property + def tensor_parallel_rank(self) -> int: + return self.dims[self.tensor_parallel_dim_index].parallel_dim.rank if self.is_tensor_parallel else 0 def __repr__(self, *, tensor_contents=()): return super().__repr__( @@ -170,6 +187,8 @@ def local_to_global( def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, + # Return an expanded tensor, avoiding `flatten` which copies the data. + expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. @@ -178,15 +197,13 @@ def global_to_local( tensor_ = tensor[:] assert not self._reductions - for i, dim in enumerate(self.dims): + for i, dim in reversed(list(enumerate(self.dims))): if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = ( - tensor_.unflatten(i, dim.global_expanded_shape) - .chunk(dim.parallel_dim.size, i + dim.parallel_dim_index)[dim.parallel_dim.rank] - .flatten(i, i + len(dim.expanded_shape) - 1) - ) + tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( + dim.parallel_dim.size, i + dim.parallel_dim_index + )[dim.parallel_dim.rank] - return tensor_.view(self.shape) + return tensor_ if expand else tensor_.reshape(self.shape) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -201,6 +218,17 @@ def memory_usage(self) -> int: def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> torch.Tensor: return validate_tensor(tensor, self, device) + def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": + # Replace the tensor-parallel `DistributedDim` in `meta`. + # Note: This will turn `ParameterMeta` into `TensorMeta` + if not self.is_tensor_parallel: + return self + dims = list(self.dims) + dims[self.tensor_parallel_dim_index] = dims[self.tensor_parallel_dim_index].replace_parallel_dim( + distributed_dim + ) + return TensorMeta(self, tensor_name=self.tensor_name, dims=tuple(dims), reductions=self._reductions) + class ParameterMeta(TensorMeta): def __init__( diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index fff0c49e..9e706ebe 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -1,4 +1,5 @@ import gc +import logging import pathlib import typing @@ -14,18 +15,23 @@ ) from fast_llm.engine.distributed.distributed import ProcessGroupPool from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.utils import header from tests.models.test_checkpoint import do_get_convert_path from tests.utils.model_configs import ModelTestingConfig from tests.utils.run_test_script import parse_run_distributed_script +logger = logging.getLogger(__name__) + def _test_load_and_save_parallel( model_testing_config: ModelTestingConfig, pretrained_path: pathlib.Path, - pretrained_format: CheckpointFormat, + pretrained_format: type[CheckpointFormat], distributed_config: dict[str, typing.Any], save_path: pathlib.Path, ): + logger.info(header(save_path.name)) + logger.info(f"Loading {pretrained_format.name} checkpoint from {pretrained_path}") model = model_testing_config.model_class.from_pretrained( CheckpointLoadConfig(path=pretrained_path, format=pretrained_format), # The world size and rank are already set through environment variable. @@ -33,6 +39,7 @@ def _test_load_and_save_parallel( mode=StageMode.inference, ) for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + logger.info(f"Loading {save_format.name} checkpoint to {save_path / save_format.name}") model.save_checkpoint(CheckpointSaveConfig(path=save_path / save_format.name, format=save_format)) del model gc.collect() @@ -70,6 +77,57 @@ def main(args: list[str] | None = None) -> None: distributed_config={}, save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_dp2", ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_tp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_stp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={"pipeline_parallel": 2}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_pp2", + ) + + dist = DistributedCheckpointFormat.name + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_dp2" / dist, + pretrained_format=DistributedCheckpointFormat, + distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, + save_path=base_path / "load_pretrained_dp2_in_stp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_stp2" / dist, + pretrained_format=DistributedCheckpointFormat, + distributed_config={}, + save_path=base_path / "load_pretrained_stp2_in_dp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_tp2" / dist, + pretrained_format=DistributedCheckpointFormat, + distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, + save_path=base_path / "load_pretrained_tp2_in_pp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_pp2" / dist, + pretrained_format=DistributedCheckpointFormat, + distributed_config={"tensor_parallel": 2}, + save_path=base_path / "load_pretrained_pp2_in_tp2", + ) if __name__ == "__main__": diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 8d5928d7..63a25747 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -356,54 +356,80 @@ def test_save_and_load_in_parallel(run_distributed_script_for_all_models, load_a ) -@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_parallel_checkpoint(model_testing_config, load_and_save_parallel_base_path, get_convert_path): - # Check the consistency of the checkpoints saved in `test_save_and_load_in_parallel` - checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) - # Compare Distributed checkpoints - for rank in range(2): - _compare_safetensor_files( - *[ - load_and_save_parallel_base_path - / f"load_pretrained_{format_.name}_in_dp2" - / DistributedCheckpointFormat.name - / f"rank_{rank}.safetensors" - for format_ in checkpoint_formats +@pytest.fixture(scope="module") +def parallel_checkpoint_names(model_testing_config): + names = [] + for format_ in (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format): + names.extend( + [ + f"load_pretrained_{format_.name}_in_dp2", + f"load_pretrained_{format_.name}_in_tp2", + f"load_pretrained_{format_.name}_in_stp2", + f"load_pretrained_{format_.name}_in_pp2", ] ) - # Compare Fast-LLM checkpoints - _compare_safetensor_files( - # Fast-LLM checkpoints are independent of the distributed configuration that saved it. - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", - *[ - load_and_save_parallel_base_path - / f"load_pretrained_{format_.name}_in_dp2" - / FastLLMCheckpointFormat.name - / f"model_0.safetensors" - for format_ in checkpoint_formats - ], + names.extend( + [ + "load_pretrained_dp2_in_stp2", + "load_pretrained_stp2_in_dp2", + "load_pretrained_tp2_in_pp2", + "load_pretrained_pp2_in_tp2", + ] ) + return names @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_parallel_checkpoint( - model_testing_config, load_and_save_parallel_base_path, get_convert_path, load_and_compare_checkpoints +def test_load_parallel_checkpoint_in_single_gpu( + load_and_save_parallel_base_path, get_convert_path, load_and_compare_checkpoints, parallel_checkpoint_names ): # Test single-gpu loading of multi-gpu distributed checkpoints. - checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ _WEIGHT_SHARD_SAVE_NAME ] - for format_ in checkpoint_formats: + for name in parallel_checkpoint_names: load_and_compare_checkpoints( DistributedCheckpointFormat, - load_and_save_parallel_base_path - / f"load_pretrained_{format_.name}_in_dp2" - / DistributedCheckpointFormat.name, + load_and_save_parallel_base_path / name / DistributedCheckpointFormat.name, None, reference_shard, ) + + +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_parallel_checkpoint_consistency(model_testing_config, load_and_save_parallel_base_path, get_convert_path): + # Check the consistency of the checkpoints saved in `test_save_and_load_in_parallel` + checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + # Compare Distributed checkpoints + for config in ("dp2", "tp2", "stp2", "pp2"): + for rank in range(2): + _compare_safetensor_files( + *[ + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_{config}" + / DistributedCheckpointFormat.name + / f"rank_{rank}.safetensors" + for format_ in checkpoint_formats + ] + ) + + +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_multi_gpu_fast_llm_checkpoint( + model_testing_config, load_and_save_parallel_base_path, get_convert_path, parallel_checkpoint_names +): + # Fast-LLM checkpoints are independent of the distributed configuration that saved it. + # TODO: Check pipeline-parallel checkpoints (two files). + _compare_safetensor_files( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", + *[ + load_and_save_parallel_base_path / name / FastLLMCheckpointFormat.name / f"model_0.safetensors" + for name in parallel_checkpoint_names + if "in_pp2" not in name + ], + ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b8dd29e8..199d5b72 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -433,12 +433,12 @@ def _update_and_add_testing_config( checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, # TODO: New base image broke mixtral groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.broken, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, - ModelTestingGroup.megatron: ModelTestingGroupAction.broken, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.normal, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, )