From cb86f4521b91e9f87e7fd23a4e836dfc63bbdc21 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 5 Jun 2025 15:26:02 -0400 Subject: [PATCH 01/69] Test all models --- fast_llm/layers/transformer/config.py | 8 +- tests/common.py | 470 ----------------------- tests/conftest.py | 9 +- tests/data/common.py | 2 +- tests/data/test_blending.py | 2 +- tests/data/test_concatenate.py | 2 +- tests/data/test_concatenated_memmap.py | 2 +- tests/data/test_dataset_from_file.py | 2 +- tests/data/test_fim.py | 2 +- tests/data/test_memmap.py | 2 +- tests/data/test_sampling.py | 2 +- tests/data/test_slice.py | 2 +- tests/layers/test_lm_head.py | 2 +- tests/test_checkpoint.py | 344 ++++++++--------- tests/test_config.py | 5 +- tests/test_functional.py | 2 +- tests/test_gpt_generate_and_forward.py | 69 ++-- tests/test_match_megatron.py | 156 +------- tests/test_mb.py | 68 ++-- tests/test_mb_seq_first.py | 39 +- tests/test_ms.py | 32 +- tests/test_mtp.py | 2 +- tests/test_multi_stage.py | 6 +- tests/test_seq_first.py | 39 +- tests/test_simple.py | 73 ++-- tests/test_ssms.py | 2 +- tests/test_triton_kernels.py | 2 +- tests/utils/__init__.py | 0 tests/{ => utils}/compare_tensor_logs.py | 0 tests/utils/dataset.py | 82 ++++ tests/utils/model_configs.py | 276 +++++++++++++ tests/utils/run_test_script.py | 118 ++++++ tests/utils/utils.py | 55 +++ 33 files changed, 885 insertions(+), 992 deletions(-) delete mode 100644 tests/common.py create mode 100644 tests/utils/__init__.py rename tests/{ => utils}/compare_tensor_logs.py (100%) create mode 100644 tests/utils/dataset.py create mode 100644 tests/utils/model_configs.py create mode 100644 tests/utils/run_test_script.py create mode 100644 tests/utils/utils.py diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e7ef0b15..235aa366 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -711,13 +711,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( + return self.use_flash_attention and distributed_config.training_dtype in ( DataType.float16, DataType.bfloat16, ) - - # Config parameter `window_size` only can be used with flash attention - if not use_flash_attention: - Assert.is_(self.window_size, None) - - return use_flash_attention diff --git a/tests/common.py b/tests/common.py deleted file mode 100644 index d531972e..00000000 --- a/tests/common.py +++ /dev/null @@ -1,470 +0,0 @@ -import os -import pathlib -import random -import shutil -import string -import subprocess -import sys - -import numpy as np -import pytest -import torch -import yaml - -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.gpt.config import ( - LlamaGPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, -) -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.tools.train import CliTrainingConfig -from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs - -# FIXME: figure out correct import of megatron modules without this hack -sys.path.append(os.getcwd()) - -# TODO: Use `pytest_addoption` instead? -# Keep all results in one place to allow recovering them for debugging in case of failure. -TEST_RESULTS_PATH = pathlib.Path(os.environ.get("TEST_RESULTS_PATH", "/tmp/fast_llm_tests")).resolve() -FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 -REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 -_LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) -TEST_MODEL = os.environ.get("MODEL", "llama") - -ARTIFACT_PATH = "runs/0/artifacts" - -TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" -TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_CACHE = TEST_RESULTS_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common" / "dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" - -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 - -CONFIG_BASE_FAST_LLM = [ - "training.logs.interval=1", - "run.tensor_logs.save=True", - "run.tensor_logs.show=False", - "model.base_model.transformer.num_layers=2", - "model.base_model.transformer.hidden_size=256", - "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.init_method_std=0.022", - f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", - f"model.multi_stage.debug_param_init={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", - "model.multi_stage.debug_tensor_parallel=True", - "model.distributed.reproducible_init=True", - "model.distributed.timeout=10", - "training.train_iters=2", - "training.num_workers=0", - "training.timeout=30", - "batch.batch_size=8", - "batch.sequence_length=512", - "data.datasets.training.type=slice", - "data.datasets.training.end=0.969", - "data.datasets.training.dataset.type=memmap", - f"data.datasets.training.dataset.path={DATASET_PREFIX}", - "data.datasets.validation.type=slice", - "data.datasets.validation.begin=0.969", - "data.datasets.validation.end=0.999", - "data.datasets.validation.dataset.type=memmap", - f"data.datasets.validation.dataset.path={DATASET_PREFIX}", - "data.datasets.test.type=slice", - "data.datasets.test.begin=0.999", - "data.datasets.test.end=1", - "data.datasets.test.dataset.type=memmap", - f"data.datasets.test.dataset.path={DATASET_PREFIX}", - "optimizer.learning_rate.base=0.0001", -] -CONFIG_BASE_MEGATRON = [ - "--num-layers=2", - "--hidden-size=256", - "--num-attention-heads=8", - "--log-interval=1", - "--train-iters=2", - "--eval-iters=0", - "--hidden-dropout=0", - "--attention-dropout=0", - f"--debug_param_init={_LOG_LEVEL}", - f"--debug_layer_outputs={_LOG_LEVEL}", - f"--debug_layer_gradients={_LOG_LEVEL}", - f"--debug_all_param_gradients={_LOG_LEVEL}", - "--debug_param_update=0", - "--global-batch-size=8", - "--max-position-embeddings=512", - "--seq-length=512", - "--init-method-std=0.022", - "--lr=0.0001", - "--num-workers=0", - "--valid-num-workers=0", - "--tokenizer-type=NullTokenizer", - # Megatron messes with the vocab size, so we have to subtract 1. - f"--vocab-size={TEST_VOCAB_SIZE-1}", - f"--data-path={DATASET_PREFIX}", - "--lr-decay-style=constant", - # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) - "--use-mcore-models", - # local implementation doesn't allow for RMS norm. - "--transformer-impl=transformer_engine", -] - -CONFIG_SC1_FAST_LLM = CONFIG_BASE_FAST_LLM + ["model.base_model.max_position_embeddings=512"] -CONFIG_SC1_MEGATRON = CONFIG_BASE_MEGATRON + ["--group-query-attention"] -CONFIG_SC1_COMMON = CONFIG_SC1_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_GPT2_FAST_LLM = CONFIG_SC1_FAST_LLM + ["model.base_model.transformer.head_groups=8"] -CONFIG_GPT2_MEGATRON = CONFIG_BASE_MEGATRON -CONFIG_GPT2_COMMON = CONFIG_GPT2_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_SC2_FAST_LLM = CONFIG_BASE_FAST_LLM + [ - "model.base_model.transformer.head_groups=4", - "model.base_model.transformer.rotary.type=default", -] -CONFIG_SC2_MEGATRON = CONFIG_SC1_MEGATRON + [ - "--num-query-groups=4", - "--use-rotary-position-embeddings", - "--no-position-embedding", -] -CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [ - "--swiglu", - "--disable-bias-linear", - "--normalization=RMSNorm", - "--ffn-hidden-size=1024", - "--untie-embeddings-and-output-weights", -] -CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -# Megatron does not support Llama3-style Rotary Embeddings -CONFIG_LLAMA3_MEGATRON = None -CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.rotary.type=llama3", -] -CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -# Megatron does not support per sub layer biases -CONFIG_QWEN2_MEGATRON = None -CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=only_attn_qkv", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -# Yarn-style Rotary Embeddings -CONFIG_LLAMA_YARN_MEGATRON = None -CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.rotary.type=yarn", -] -CONFIG_LLAMA_YARN_COMMON = CONFIG_LLAMA_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] - - -CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [ - "--num-experts=4", - "--moe-router-topk=4", -] -CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.num_experts=4", - "model.base_model.transformer.num_experts_per_token=4", -] -CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MIXTRAL_YARN_MEGATRON = None -CONFIG_MIXTRAL_YARN_FAST_LLM = CONFIG_MIXTRAL_FAST_LLM + [ - "model.base_model.transformer.rotary.type=yarn", -] -CONFIG_MIXTRAL_YARN_COMMON = CONFIG_MIXTRAL_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_LLAMA_MTP_MEGATRON = None -CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.prediction_heads=4", -] -CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] -CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] -CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM - -_CONFIGS = { - "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), - "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), - "starcoder2": ( - "gpt", - CONFIG_SC2_FAST_LLM, - CONFIG_SC2_MEGATRON, - CONFIG_SC2_COMMON, - Starcoder2GPTHuggingfaceCheckpointFormat, - ), - "llama": ( - "gpt", - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_LLAMA_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "llama3": ( - "gpt", - CONFIG_LLAMA3_FAST_LLM, - CONFIG_LLAMA3_MEGATRON, - CONFIG_LLAMA3_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "qwen2": ( - "gpt", - CONFIG_QWEN2_FAST_LLM, - CONFIG_QWEN2_MEGATRON, - CONFIG_QWEN2_COMMON, - Qwen2GPTHuggingfaceCheckpointFormat, - ), - "llama-yarn": ( - "gpt", - CONFIG_LLAMA_YARN_FAST_LLM, - CONFIG_LLAMA_YARN_MEGATRON, - CONFIG_LLAMA_YARN_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "mistral": ( - "gpt", - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_LLAMA_COMMON, - MistralGPTHuggingfaceCheckpointFormat, - ), - "mixtral": ( - "gpt", - CONFIG_MIXTRAL_FAST_LLM, - CONFIG_MIXTRAL_MEGATRON, - CONFIG_MIXTRAL_COMMON, - MixtralGPTHuggingfaceCheckpointFormat, - ), - "llamba": ( - "hybrid_ssm", - CONFIG_LLAMBA_FAST_LLM, - CONFIG_LLAMBA_MEGATRON, - CONFIG_LLAMBA_COMMON, - LLambaHuggingfaceCheckpointFormat, - ), - "mixtral-yarn": ( - "gpt", - CONFIG_MIXTRAL_YARN_FAST_LLM, - CONFIG_MIXTRAL_YARN_MEGATRON, - CONFIG_MIXTRAL_YARN_COMMON, - MixtralGPTHuggingfaceCheckpointFormat, - ), - "llama-mtp": ( - "gpt", - CONFIG_LLAMA_MTP_FAST_LLM, - CONFIG_LLAMA_MTP_MEGATRON, - CONFIG_LLAMA_MTP_COMMON, - MTPLlamaGPTHuggingfaceCheckpointFormat, - ), -} - -TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] - - -requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") - - -def get_test_dataset( - prefix: pathlib.Path = DATASET_PREFIX, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, -): - if not TOKENIZER_FILE.is_file(): - import transformers - - transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) - - if not ( - prefix.with_suffix(".idx").is_file() - and prefix.with_suffix(".bin").is_file() - and prefix.parent.joinpath("fast_llm_config.yaml").is_file() - ): - import transformers - - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - - samples = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts - ] - if max_spans > 0: - lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) - for sample, span in zip(samples, spans): - span = np.unique(span) - sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) - - GPTMemmapDataset.write_dataset(prefix, samples) - yaml.safe_dump( - {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") - ) - - -def get_test_concatenated_memmap_dataset( - path: pathlib.Path, - num_files: int, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - seed_shift: int = 55, -): - index_file = path / "index.txt" - if not index_file.is_file(): - for i in range(num_files): - get_test_dataset( - prefix=path / f"dataset_{i}", - seed=seed + i * seed_shift, - num_tokens=num_tokens, - characters=characters, - vocab_size=vocab_size, - ) - index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) - - -@pytest.fixture(scope="session") -def run_test_script(worker_resources): - def do_run_test_script( - name: str, - script: list[str], - num_gpus: int = 1, - *, - model_type: str = TEST_MODEL_TYPE, - is_megatron: bool = False, - compare: str | None = None, - config: CompareConfig | None = None, - prepare_fn=None, - compare_fn=None, - do_compare: bool = True, - ): - if torch.cuda.device_count() < num_gpus: - pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") - env = os.environ.copy() - if is_megatron: - # Prevent Megatron from complaining. - env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env["NVTE_FLASH_ATTN"] = "0" - path = TEST_RESULTS_PATH / name - skip = False - artifact_path = path / ARTIFACT_PATH - if path.exists(): - assert path.is_dir() - # TODO: Better way to check if the previous attempt succeeded. - if ( - REUSE_RESULTS - and artifact_path.is_dir() - and len(list((artifact_path / "0").iterdir())) >= (1 if is_megatron else 3) - ): - skip = True - elif FORCE_REUSE_RESULTS: - raise RuntimeError(artifact_path) - else: - shutil.rmtree(path) - elif FORCE_REUSE_RESULTS: - raise RuntimeError(path) - if prepare_fn is not None: - skip = prepare_fn(TEST_RESULTS_PATH / name, None if compare is None else TEST_RESULTS_PATH / compare, skip) - if is_megatron: - script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] - else: - script = [model_type, *script, f"run.experiment_dir={path}"] - header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] - command = [ - "python", - "-m", - "torch.distributed.run", - f"--nproc-per-node={num_gpus}", - f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}", - f"--master-port={worker_resources.torchrun_port}", - *header, - *script, - ] - print(" ".join(command)) - if skip: - print("Reusing existing run.") - else: - get_test_dataset() - if num_gpus == 1 and not is_megatron: - CliTrainingConfig.parse_and_run(script) - else: - completed_proc = subprocess.run(command, env=env, timeout=60) - if completed_proc.returncode: - raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare and do_compare: - if compare_fn is not None: - compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare) - compare_tensor_logs( - TEST_RESULTS_PATH / compare / ARTIFACT_PATH, - TEST_RESULTS_PATH / name / ARTIFACT_PATH, - config, - ) - - return do_run_test_script - - -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), - ssm=SSMConfig(), - hybrid_block_layout=hybrid_block_layout, - prediction_heads=prediction_heads, - default_mtp_type=default_mtp_type, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, - ) - return config diff --git a/tests/conftest.py b/tests/conftest.py index edc52e03..3d1e940b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,14 @@ from xdist.scheduler import LoadGroupScheduling # Make fixtures available globally without import -from tests.common import run_test_script # isort: skip +from tests.utils.run_test_script import ( # isort: skip + run_test_script, + run_test_script_base_path, + run_test_script_for_all_models, +) + +from tests.utils.model_configs import model_testing_config # isort: skip +from tests.utils.utils import result_path # isort: skip def pytest_addoption(parser): diff --git a/tests/data/common.py b/tests/data/common.py index cacb28e6..2d3cb905 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -23,7 +23,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.common import TEST_VOCAB_SIZE +from tests.utils.dataset import TEST_VOCAB_SIZE def get_sampling_data( diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index de97eaa2..438782df 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -5,13 +5,13 @@ from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig from fast_llm.utils import Assert, normalize_probabilities -from tests.common import DATASET_CACHE, DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) +from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, get_test_dataset _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 1142d536..e951cc2b 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -8,6 +7,7 @@ get_test_data_and_compare_samples, ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py index 09929040..0ab7c7fe 100644 --- a/tests/data/test_concatenated_memmap.py +++ b/tests/data/test_concatenated_memmap.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTConcatenatedMemmapConfig -from tests.common import DATASET_CACHE, get_test_concatenated_memmap_dataset from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -8,6 +7,7 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES +from tests.utils.dataset import DATASET_CACHE, get_test_concatenated_memmap_dataset _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 280b3413..3f7d1a13 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -1,7 +1,7 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset def test_dataset_from_file(): diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7b614d2f..7472f195 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,13 +1,13 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig from fast_llm.data.tokenizer import Tokenizer -from tests.common import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset from tests.data.common import ( compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) +from tests.utils.dataset import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index be801220..fcd7756d 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -3,8 +3,8 @@ import pytest from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from tests.common import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config +from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 38679582..32d76fa4 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -7,13 +7,13 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.utils import Assert -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 299e2054..f8eedc5b 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -8,6 +7,7 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7578a5f0..95da48e7 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -19,7 +19,7 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel from fast_llm.utils import Assert -from tests.common import requires_cuda +from tests.utils.utils import requires_cuda def _lm_head( diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 216f7828..e7929440 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -15,34 +15,18 @@ ModelConfigType, ) from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode -from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConvertConfig -from tests.common import ( - CONFIG_COMMON, - FORCE_REUSE_RESULTS, - HUGGINGFACE_CHECKPOINT_FORMAT, - REUSE_RESULTS, - TEST_MODEL, - TEST_MODEL_TYPE, - TEST_RESULTS_PATH, - requires_cuda, -) -from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor - -TEST_MODEL_CONFIG_CLS = model_registry[TEST_MODEL_TYPE] -TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_for_causal_lm_class() -TEST_MODEL_CLS = TEST_MODEL_CONFIG_CLS.get_model_class() -TEST_BASE_MODEL_CONFIG_CLS = TEST_MODEL_CONFIG_CLS.get_base_model_config_class() +from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor +from tests.utils.utils import requires_cuda -WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" +_WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" @requires_cuda -def test_checkpoint_and_eval(run_test_script): +def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). - run_test_script( - f"test_{TEST_MODEL}_checkpoint_and_eval", - CONFIG_COMMON + run_test_script_for_all_models( + model_testing_config.config_args + [ "training.checkpoint.interval=1", "training.evaluations.validation.interval=2", @@ -72,168 +56,172 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) -def test_resume(run_test_script): +def test_resume(run_test_script_for_all_models): # Resume from iteration=1 and compare outputs with the baseline run. - run_test_script( - f"test_{TEST_MODEL}_resume", - CONFIG_COMMON - + [ + run_test_script_for_all_models( + [ "training.checkpoint.interval=1", "training.evaluations.validation.interval=2", "training.evaluations.validation.iterations=1", ], - compare=f"test_{TEST_MODEL}_checkpoint_and_eval", + compare=f"test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, compare_fn=_compare_resume_fn, ) @pytest.mark.depends(on=["test_checkpoint_and_eval"]) -def test_resume_frozen(run_test_script): +def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. - run_test_script( - f"test_{TEST_MODEL}_resume_frozen", - CONFIG_COMMON - + [ + run_test_script_for_all_models( + "test_resume_frozen", + [ "training.checkpoint.interval=1", "training.evaluations.validation.interval=2", "training.evaluations.validation.iterations=1", "model.base_model.transformer.mlp_lr_scale=0.", ], - compare=f"test_{TEST_MODEL}_checkpoint_and_eval", + compare="test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, do_compare=False, ) def _run_conversion(config: ConvertConfig): - if config.output.path.is_dir() and not REUSE_RESULTS: + if config.output.path.exists(): + assert config.output.path.is_dir() shutil.rmtree(config.output.path) - if not config.output.path.is_dir(): - if FORCE_REUSE_RESULTS: - raise RuntimeError(config.output.path) - config.run() + config.run() -_CKPT_PATH = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_checkpoint_and_eval" / "checkpoint" / "2" -CONVERT_PATH = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_convert_model" +@pytest.fixture(scope="module") +def convert_paths(run_test_script_base_path): + return { + "checkpoint": run_test_script_base_path / "test_checkpoint_and_eval" / "checkpoint" / "2", + "distributed_0": run_test_script_base_path / "test_convert_model" / "distributed_0", + "distributed_1": run_test_script_base_path / "test_convert_model" / "distributed_1", + "fast_llm_0": run_test_script_base_path / "test_convert_model" / "fast_llm_0", + "fast_llm_1": run_test_script_base_path / "test_convert_model" / "fast_llm_1", + "huggingface_0": run_test_script_base_path / "test_convert_model" / "huggingface_0", + "huggingface_1": run_test_script_base_path / "test_convert_model" / "huggingface_1", + } @pytest.mark.depends(on=["test_checkpoint_and_eval"]) -def test_convert_distributed_to_fast_llm(): +def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): _run_conversion( ConvertConfig( input=CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, ), output=CheckpointSaveConfig( - path=CONVERT_PATH / "fast_llm_0", + path=convert_paths["fast_llm_0"], format=FastLLMCheckpointFormat, ), - model=TEST_MODEL_CONFIG_CLS, + model=model_testing_config.model_config_class, ) ) @pytest.mark.depends(on=["test_convert_distributed_to_fast_llm"]) -def test_convert_fast_llm_to_huggingface(): - if HUGGINGFACE_CHECKPOINT_FORMAT is None: - pytest.skip(f"Conversion not supported for {TEST_MODEL}") +def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): + if model_testing_config.checkpoint_format is None: + pytest.skip(f"Conversion not supported for {model_testing_config.name}") _run_conversion( ConvertConfig( input=CheckpointLoadConfig( - path=CONVERT_PATH / "fast_llm_0", + path=convert_paths["fast_llm_0"], format=FastLLMCheckpointFormat, ), output=CheckpointSaveConfig( - path=CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_0"], + format=model_testing_config.checkpoint_format, ), - model=TEST_MODEL_CONFIG_CLS, + model=model_testing_config.model_config_class, ) ) @pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) -def test_convert_huggingface_to_distributed(): +def test_convert_huggingface_to_distributed(model_testing_config, convert_paths): _run_conversion( ConvertConfig( input=CheckpointLoadConfig( - path=CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_0"], + format=model_testing_config.checkpoint_format, ), output=CheckpointSaveConfig( - path=CONVERT_PATH / "distributed_0", + path=convert_paths["distributed_0"], format=DistributedCheckpointFormat, ), - model=TEST_MODEL_CONFIG_CLS, + model=model_testing_config.model_config_class, ) ) @pytest.mark.depends(on=["test_checkpoint_and_eval"]) -def test_convert_distributed_to_huggingface(): - if HUGGINGFACE_CHECKPOINT_FORMAT is None: - pytest.skip(f"Conversion not supported for {TEST_MODEL}") +def test_convert_distributed_to_huggingface(model_testing_config, convert_paths): + if model_testing_config.checkpoint_format is None: + pytest.skip(f"Conversion not supported for {model_testing_config.name}") _run_conversion( ConvertConfig( input=CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, ), output=CheckpointSaveConfig( - path=CONVERT_PATH / "huggingface_1", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_1"], + format=model_testing_config.checkpoint_format, ), - model=TEST_MODEL_CONFIG_CLS, + model=model_testing_config.model_config_class, ) ) @pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) -def test_convert_huggingface_to_fast_llm(): +def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): _run_conversion( ConvertConfig( input=CheckpointLoadConfig( - path=CONVERT_PATH / "huggingface_1", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_1"], + format=model_testing_config.checkpoint_format, ), output=CheckpointSaveConfig( - path=CONVERT_PATH / "fast_llm_1", + path=convert_paths["fast_llm_1"], format=FastLLMCheckpointFormat, ), - model=TEST_MODEL_CONFIG_CLS, + model=model_testing_config.model_config_class, ) ) @pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) -def test_convert_fast_llm_to_distributed(): +def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): _run_conversion( ConvertConfig( input=CheckpointLoadConfig( - path=CONVERT_PATH / "fast_llm_1", + path=convert_paths["fast_llm_1"], format=FastLLMCheckpointFormat, ), output=CheckpointSaveConfig( - path=CONVERT_PATH / "distributed_1", + path=convert_paths["distributed_1"], format=DistributedCheckpointFormat, ), - model=TEST_MODEL_CONFIG_CLS, + model=model_testing_config.model_config_class, ) ) @pytest.mark.depends(on=["test_convert_huggingface_to_distributed", "test_convert_fast_llm_to_distributed"]) -def test_converted_distributed(): +def test_converted_distributed(convert_paths): # Compare the fast llm weights # TODO: Compare configs - w = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") - w0 = safetensors.torch.load_file(CONVERT_PATH / "distributed_0" / "rank_0.safetensors") - w1 = safetensors.torch.load_file(CONVERT_PATH / "distributed_1" / "rank_0.safetensors") - assert w.keys() >= {WEIGHT_SHARD_SAVE_NAME} - assert w0.keys() == w1.keys() == {WEIGHT_SHARD_SAVE_NAME} + w = safetensors.torch.load_file(convert_paths["checkpoint"] / "rank_0.safetensors") + w0 = safetensors.torch.load_file(convert_paths["distributed_0"] / "rank_0.safetensors") + w1 = safetensors.torch.load_file(convert_paths["distributed_1"] / "rank_0.safetensors") + assert w.keys() >= {_WEIGHT_SHARD_SAVE_NAME} + assert w0.keys() == w1.keys() == {_WEIGHT_SHARD_SAVE_NAME} for key in w0: assert w[key].shape == w0[key].shape, (key, w[key].shape, w0[key].shape) assert (w[key] == w0[key]).all(), (w[key], w0[key]) @@ -242,9 +230,9 @@ def test_converted_distributed(): @pytest.mark.depends(on=["test_convert_distributed_to_fast_llm", "test_convert_huggingface_to_fast_llm"]) -def test_converted_fast_llm(): - s0 = safetensors.torch.load_file(CONVERT_PATH / "fast_llm_0" / "model_0.safetensors") - s1 = safetensors.torch.load_file(CONVERT_PATH / "fast_llm_1" / "model_0.safetensors") +def test_converted_fast_llm(convert_paths): + s0 = safetensors.torch.load_file(convert_paths["fast_llm_0"] / "model_0.safetensors") + s1 = safetensors.torch.load_file(convert_paths["fast_llm_1"] / "model_0.safetensors") assert s0.keys() == s1.keys() for key in s0: assert s0[key].shape == s1[key].shape, (key, s0[key].shape, s1[key].shape) @@ -252,9 +240,9 @@ def test_converted_fast_llm(): @pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface", "test_convert_distributed_to_huggingface"]) -def test_converted_huggingface(): - h0 = safetensors.torch.load_file(CONVERT_PATH / "huggingface_0" / "model_0.safetensors") - h1 = safetensors.torch.load_file(CONVERT_PATH / "huggingface_1" / "model_0.safetensors") +def test_converted_huggingface(convert_paths): + h0 = safetensors.torch.load_file(convert_paths["huggingface_0"] / "model_0.safetensors") + h1 = safetensors.torch.load_file(convert_paths["huggingface_1"] / "model_0.safetensors") assert h0.keys() == h1.keys() for key in h0: assert h0[key].shape == h1[key].shape, (key, h0[key].shape, h1[key].shape) @@ -270,45 +258,45 @@ def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMM @pytest.mark.depends(on=["test_converted_distributed"]) -def test_load_pretrained_distributed_checkpoint(): - config = TEST_MODEL_CONFIG_CLS.from_dict( - yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r"))["model"], strict=False +def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_paths): + config = model_testing_config.model_config_class.from_dict( + yaml.safe_load((convert_paths["checkpoint"] / ".." / ".." / "config.yaml").open("r"))["model"], strict=False ) pretrained_config_ref = CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, optimizer_state=True, load_config=ModelConfigType.model, ) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) + model = model_testing_config.model_class.from_pretrained(pretrained_config_ref) _compare_model_configs(config, model.config) state_shards = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) + convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) ) for shard_name in model.state_shard_names: assert (state_shards[f"{shard_name}_shard"] == model.get_shard(shard_name)).all() @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) -def test_load_converted_distributed_checkpoint(): - config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( +def test_load_converted_distributed_checkpoint(model_testing_config, convert_paths): + config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) ) - model = TEST_MODEL_CLS.from_pretrained( + model = model_testing_config.model_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "distributed_0", + path=convert_paths["distributed_0"], format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) ) - config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + config_alt = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "distributed_1", + path=convert_paths["distributed_1"], format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) @@ -316,30 +304,30 @@ def test_load_converted_distributed_checkpoint(): _compare_architectures(config_ref, model.config) _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[WEIGHT_SHARD_SAVE_NAME] + convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) + )[_WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) -def test_load_converted_fast_llm_checkpoint(): - config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( +def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths): + config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) ) - model = TEST_MODEL_CLS.from_pretrained( + model = model_testing_config.model_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "fast_llm_0", + path=convert_paths["fast_llm_0"], format=FastLLMCheckpointFormat, load_config=ModelConfigType.model, ) ) - config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + config_alt = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "fast_llm_1", + path=convert_paths["fast_llm_1"], format=FastLLMCheckpointFormat, load_config=ModelConfigType.model, ) @@ -347,48 +335,48 @@ def test_load_converted_fast_llm_checkpoint(): _compare_architectures(config_ref, model.config) _compare_architectures(config_ref, config_alt) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[WEIGHT_SHARD_SAVE_NAME] + convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) + )[_WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) -def test_load_converted_huggingface_checkpoint(): - config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( +def test_load_converted_huggingface_checkpoint(model_testing_config, convert_paths): + config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) ) - model = TEST_MODEL_CLS.from_pretrained( + model = model_testing_config.model_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "huggingface_1", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_1"], + format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, ), mode=StageMode.weights, ) - config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + config_alt = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_0"], + format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[WEIGHT_SHARD_SAVE_NAME] + convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) + )[_WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_converted_huggingface_checkpoint"]) -def test_run_converted_model(): - model_ref = TEST_MODEL_HF_CLS.from_pretrained( +def test_run_converted_model(model_testing_config, convert_paths): + model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) @@ -397,18 +385,20 @@ def test_run_converted_model(): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = TEST_MODEL_HF_CLS.from_pretrained(CONVERT_PATH / "fast_llm_0") - model_from_hf = TEST_MODEL_HF_CLS.from_pretrained( + model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + convert_paths["fast_llm_0"] + ) + model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( - path=CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + path=convert_paths["huggingface_0"], + format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, ) ) errors = [] compare = CompareConfig() model_as_hf = transformers.AutoModelForCausalLM.from_pretrained( - CONVERT_PATH / "huggingface_0", trust_remote_code=HUGGINGFACE_CHECKPOINT_FORMAT.trust_remote_code + convert_paths["huggingface_0"], trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), @@ -434,14 +424,13 @@ def test_run_converted_model(): @pytest.mark.slow @pytest.mark.depends(on=["test_load_converted_distributed_checkpoint"]) -def test_load_pretrained_distributed_in_dp2(run_test_script): - run_test_script( - f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2", - CONFIG_COMMON - + [ +def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, convert_paths): + run_test_script_for_all_models( + "test_load_pretrained_distributed_in_dp2", + [ "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.path={CONVERT_PATH / 'distributed_0'}", + f"pretrained.path={convert_paths["distributed_0"]}", f"pretrained.format={DistributedCheckpointFormat.name}", "schedule.skip_step=True", ], @@ -450,14 +439,13 @@ def test_load_pretrained_distributed_in_dp2(run_test_script): @pytest.mark.depends(on=["test_load_converted_distributed_checkpoint"]) -def test_load_pretrained_distributed_with_config(run_test_script): - run_test_script( - f"test_{TEST_MODEL}_load_pretrained_distributed_with_config", - CONFIG_COMMON - + [ +def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, convert_paths): + run_test_script_for_all_models( + "test_load_pretrained_distributed_with_config", + [ "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.path={CONVERT_PATH / 'distributed_0'}", + f"pretrained.path={convert_paths["distributed_0"]}", f"pretrained.format={DistributedCheckpointFormat.name}", "schedule.skip_step=True", ], @@ -465,10 +453,10 @@ def test_load_pretrained_distributed_with_config(run_test_script): @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) -def test_load_pretrained_in_dp2_match_checkpoint(): - test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" +def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_paths, run_test_script_base_path): + test_ckpt_path = run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" pretrained_config_ref = CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, load_config=ModelConfigType.fast_llm, ) @@ -477,21 +465,21 @@ def test_load_pretrained_in_dp2_match_checkpoint(): format=DistributedCheckpointFormat, load_config=ModelConfigType.fast_llm, ) - config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) + config_ref = model_testing_config.model_config_class.from_pretrained(pretrained_config_ref) + config_test = model_testing_config.model_config_class.from_pretrained(pretrained_config_test) _compare_model_configs(config_ref, config_test) - shards_ref = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") + shards_ref = safetensors.torch.load_file(convert_paths["checkpoint"] / "rank_0.safetensors") shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] - ref_model = TEST_MODEL_CLS(config_ref) - test_model = TEST_MODEL_CLS(config_test) + ref_model = model_testing_config.model_class(config_ref) + test_model = model_testing_config.model_class(config_test) - weight_shard_ref_split = shards_ref[WEIGHT_SHARD_SAVE_NAME].split(ref_model._stage_weight_shard_sizes) + weight_shard_ref_split = shards_ref[_WEIGHT_SHARD_SAVE_NAME].split(ref_model._stage_weight_shard_sizes) weight_shards_test_split = [ - shard_test[WEIGHT_SHARD_SAVE_NAME].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test + shard_test[_WEIGHT_SHARD_SAVE_NAME].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test ] for shard_test in shards_test: for shard_name, shard in shard_test.items(): - if shard_name != WEIGHT_SHARD_SAVE_NAME: + if shard_name != _WEIGHT_SHARD_SAVE_NAME: assert (shard == 0).all() # noqa assert len(ref_model._stage_weight_shard_sizes) == len(test_model._stage_weight_shard_sizes) @@ -510,37 +498,36 @@ def test_load_pretrained_in_dp2_match_checkpoint(): @pytest.mark.slow @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) -def test_load_distributed_checkpoint_dp2(): +def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, run_test_script_base_path): # This also tests conversion which uses `FastLLMModel.from_checkpoint` pretrained_config_ref = CheckpointLoadConfig( - path=_CKPT_PATH, + path=convert_paths["checkpoint"], format=DistributedCheckpointFormat, load_config=ModelConfigType.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( - path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", + path=run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) + config = model_testing_config.model_config_class.from_pretrained(pretrained_config_ref) + model = model_testing_config.model_class.from_pretrained(pretrained_config_test, mode=StageMode.weights) _compare_model_configs(config, model.config) weight_shard = safetensors.torch.load_file( - _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) - )[WEIGHT_SHARD_SAVE_NAME] + convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) + )[_WEIGHT_SHARD_SAVE_NAME] assert (weight_shard == model.get_shard(ShardName.weights)).all() @pytest.mark.slow @pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) -def test_load_pretrained_fast_llm_in_dp2(run_test_script): +def test_load_pretrained_fast_llm_in_dp2(run_test_script, convert_paths, run_test_script_base_path): run_test_script( - f"test_{TEST_MODEL}_load_pretrained_fast_llm_in_dp2", - CONFIG_COMMON - + [ + "test_load_pretrained_fast_llm_in_dp2", + [ "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.path={CONVERT_PATH / 'fast_llm_0'}", + f"pretrained.path={convert_paths["fast_llm_0"]}", f"pretrained.format=fast_llm", "schedule.skip_step=True", ], @@ -548,15 +535,15 @@ def test_load_pretrained_fast_llm_in_dp2(run_test_script): ) for rank in range(2): ref_shard = safetensors.torch.load_file( - TEST_RESULTS_PATH - / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" + run_test_script_base_path + / f"test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" ) test_shard = safetensors.torch.load_file( - TEST_RESULTS_PATH - / f"test_{TEST_MODEL}_load_pretrained_fast_llm_in_dp2" + run_test_script_base_path + / f"test_load_pretrained_fast_llm_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" @@ -567,30 +554,31 @@ def test_load_pretrained_fast_llm_in_dp2(run_test_script): @pytest.mark.slow @pytest.mark.depends(on=["test_load_converted_huggingface_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) -def test_load_pretrained_huggingface_in_dp2(run_test_script): - run_test_script( - f"test_{TEST_MODEL}_load_pretrained_huggingface_in_dp2", - CONFIG_COMMON - + [ +def test_load_pretrained_huggingface_in_dp2( + run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths +): + run_test_script_for_all_models( + "test_load_pretrained_huggingface_in_dp2", + [ "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.path={CONVERT_PATH / 'huggingface_0'}", - f"pretrained.format={HUGGINGFACE_CHECKPOINT_FORMAT.name}", + f"pretrained.path={convert_paths["huggingface_0"]}", + f"pretrained.format={model_testing_config.checkpoint_format.name}", "schedule.skip_step=True", ], num_gpus=2, ) for rank in range(2): ref_shard = safetensors.torch.load_file( - TEST_RESULTS_PATH - / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" + run_test_script_base_path + / f"test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" ) test_shard = safetensors.torch.load_file( - TEST_RESULTS_PATH - / f"test_{TEST_MODEL}_load_pretrained_huggingface_in_dp2" + run_test_script_base_path + / f"test_load_pretrained_huggingface_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" diff --git a/tests/test_config.py b/tests/test_config.py index 80bed418..98a4c07c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,7 +14,6 @@ from fast_llm.models.auto import trainer_registry from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested -from tests.common import TEST_RESULTS_PATH def run_without_import(cmd: str): @@ -101,8 +100,8 @@ def test_serialize_default_config_updates(cls, default): @pytest.mark.parametrize("load_config", tuple(ModelConfigType)) -def test_pretrained_config(load_config: ModelConfigType): - config_path = TEST_RESULTS_PATH / "pretrained_config" +def test_pretrained_config(load_config: ModelConfigType, result_path): + config_path = result_path / "pretrained_config" pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { diff --git a/tests/test_functional.py b/tests/test_functional.py index 908a5537..03a0ae8a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,7 +8,7 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert -from tests.common import requires_cuda +from tests.utils.utils import requires_cuda def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: diff --git a/tests/test_gpt_generate_and_forward.py b/tests/test_gpt_generate_and_forward.py index a16d4c71..ca75cf3e 100644 --- a/tests/test_gpt_generate_and_forward.py +++ b/tests/test_gpt_generate_and_forward.py @@ -9,13 +9,7 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from tests.common import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT, TEST_MODEL, TEST_RESULTS_PATH, requires_cuda - - -def _prepare_checkpoint(model: str) -> str: - path = TEST_RESULTS_PATH.resolve() / "generate/model" - model_path = huggingface_hub.snapshot_download(repo_id=model, local_dir=path) - return model_path +from tests.utils.utils import requires_cuda def _prepare_data(tokenizer, use_batch_size2: bool): @@ -179,12 +173,11 @@ def _test_for_batches( @pytest.fixture(scope="module") -def model_and_tokenizer(): - model = "HuggingFaceTB/SmolLM2-135M-Instruct" - fast_llm_checkpoint_format = LlamaGPTHuggingfaceCheckpointFormat - model_path = _prepare_checkpoint(model) - tokenizer = AutoTokenizer.from_pretrained(model_path) - return model_path, tokenizer, fast_llm_checkpoint_format +def model_path(result_path): + return huggingface_hub.snapshot_download( + repo_id="HuggingFaceTB/SmolLM2-135M-Instruct", + local_dir=result_path / "generate/model", + ) def _test_generate( @@ -224,35 +217,33 @@ def _test_generate( ], ) def test_generate( - model_and_tokenizer, + model_path, use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2, ): - model_path, tokenizer, fast_llm_checkpoint_format = model_and_tokenizer _test_generate( model_path, - fast_llm_checkpoint_format, + LlamaGPTHuggingfaceCheckpointFormat, use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2, - tokenizer=tokenizer, + tokenizer=AutoTokenizer.from_pretrained(model_path), ) +@pytest.mark.slow @requires_cuda -def test_export_for_generate(run_test_script): +def test_export_for_generate(run_test_script_for_all_models, model_testing_config): # Not really testing, anything, but handles dependencies more easily than a fixture. - run_test_script( - f"test_{TEST_MODEL}_export_for_generate", - CONFIG_COMMON - + [ + run_test_script_for_all_models( + [ "training.train_iters=1", - f"training.export.format={HUGGINGFACE_CHECKPOINT_FORMAT.name}", + f"training.export.format={model_testing_config.checkpoint_format.name}", "training.export.interval=1", ], ) @@ -273,6 +264,8 @@ def test_export_for_generate(run_test_script): ], ) def test_small_generate( + model_testing_config, + run_test_script_base_path, use_flash_attention, use_bf16, max_new_tokens, @@ -280,8 +273,8 @@ def test_small_generate( min_matching_tokens_batch_size_2, ): _test_generate( - TEST_RESULTS_PATH / f"test_{TEST_MODEL}_export_for_generate/export/{HUGGINGFACE_CHECKPOINT_FORMAT.name}/1", - HUGGINGFACE_CHECKPOINT_FORMAT, + run_test_script_base_path / f"test_export_for_generate/export/{model_testing_config.checkpoint_format.name}/1", + model_testing_config.checkpoint_format, use_flash_attention, use_bf16, max_new_tokens, @@ -312,20 +305,21 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) @requires_cuda @pytest.mark.extra_slow def test_generate_from_model( - model_and_tokenizer, + model_path, ): - model_path, tokenizer, fast_llm_checkpoint_format = model_and_tokenizer - _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) + _test_generate_from_model( + model_path, AutoTokenizer.from_pretrained(model_path), LlamaGPTHuggingfaceCheckpointFormat + ) @requires_cuda @pytest.mark.slow @pytest.mark.depends(on=["test_export_for_generate"]) -def test_small_generate_from_model(): +def test_small_generate_from_model(model_testing_config, run_test_script_base_path): _test_generate_from_model( - TEST_RESULTS_PATH / f"test_{TEST_MODEL}_export_for_generate/export/{HUGGINGFACE_CHECKPOINT_FORMAT.name}/1", + run_test_script_base_path / f"test_export_for_generate/export/{model_testing_config.checkpoint_format.name}/1", None, - HUGGINGFACE_CHECKPOINT_FORMAT, + model_testing_config.checkpoint_format, ) @@ -361,16 +355,17 @@ def _test_forward_return_hidden_states( @pytest.mark.extra_slow @requires_cuda -def test_forward_return_hidden_states(model_and_tokenizer): - model_path, tokenizer, fast_llm_checkpoint_format = model_and_tokenizer - _test_forward_return_hidden_states(model_path, fast_llm_checkpoint_format, tokenizer.vocab_size) +def test_forward_return_hidden_states(model_path): + _test_forward_return_hidden_states( + model_path, LlamaGPTHuggingfaceCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size + ) @pytest.mark.slow @requires_cuda @pytest.mark.depends(on=["test_export_for_generate"]) -def test_small_forward_return_hidden_states(): +def test_small_forward_return_hidden_states(model_testing_config, run_test_script_base_path): _test_forward_return_hidden_states( - TEST_RESULTS_PATH / f"test_{TEST_MODEL}_export_for_generate/export/{HUGGINGFACE_CHECKPOINT_FORMAT.name}/1", - HUGGINGFACE_CHECKPOINT_FORMAT, + run_test_script_base_path / f"test_export_for_generate/export/{model_testing_config.checkpoint_format.name}/1", + model_testing_config.checkpoint_format, ) diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 1857f0f8..a77906ae 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -1,158 +1,32 @@ import pytest -from tests.common import ( - CONFIG_GPT2_FAST_LLM, - CONFIG_GPT2_MEGATRON, - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_MIXTRAL_FAST_LLM, - CONFIG_MIXTRAL_MEGATRON, - CONFIG_SC1_FAST_LLM, - CONFIG_SC1_MEGATRON, - CONFIG_SC2_FAST_LLM, - CONFIG_SC2_MEGATRON, - DATASET_PREFIX, -) -from tests.compare_tensor_logs import CompareConfig +from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.dataset import DATASET_PREFIX @pytest.mark.slow -@pytest.mark.skip(reason="Skipping mostly redundant test") -def test_sc1_meg(run_test_script): - # Starcoder 1 (GPT2 with MQA) with Megatron. - run_test_script("test_sc1_meg", CONFIG_SC1_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -CONFIG_MATCH_MEGATRON = [ - "data.datasets={}", - f"data.path={DATASET_PREFIX}", -] - - -@pytest.mark.depends(on=["test_sc1_meg"]) -def test_sc1_match_meg(run_test_script): - # Starcoder 1 (GPT2 with MQA) with Fast-llm. - # QKV tensors are in a different format. - run_test_script( - "test_sc1_match_meg", - CONFIG_SC1_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_sc1_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".mlp.layer_2.weight", - ] - ), - ) - - -@pytest.mark.slow -@pytest.mark.skip(reason="Skipping mostly redundant test") -@pytest.mark.depends(on=["test_sc1_match_meg"]) -def test_sc2_meg(run_test_script): - # Starcoder 2 (GPT2 with MQA and RoPE) with Megatron. - run_test_script("test_sc2_meg", CONFIG_SC2_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends(on=["test_sc2_meg"]) -def test_sc2_match_meg(run_test_script): - # Starcoder 2 (GPT2 with MQA and RoPE) with Fast-llm. - # QKV tensors are in a different format, - # dense not matching because of the way initialization is corrected for RoPE format. - run_test_script( - "test_sc2_match_meg", - CONFIG_SC2_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_sc2_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".self_attn.dense.", - ".mlp.layer_2.weight", - ] - ), - ) - - -@pytest.mark.slow -def test_gpt2_meg(run_test_script): - # GPT2 (MHA, layer norm, absolute embeddings) with Megatron. - run_test_script("test_gpt2_meg", CONFIG_GPT2_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends(on=["test_gpt2_meg"]) -def test_gpt2_match_meg(run_test_script): - # GPT2 (MHA, layer norm, absolute embeddings) with Fast-llm. - # QKV tensors are in a different format. - run_test_script( - "test_gpt2_match_meg", - CONFIG_GPT2_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_gpt2_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".mlp.layer_2.weight", - ] - ), - ) +def test_megatron(run_test_script_for_all_models, model_testing_config): + run_test_script_for_all_models(is_megatron=True) @pytest.mark.slow -def test_mistral_meg(run_test_script): - # Mistral with Megatron. - # No linear bias, swiglu activation, RMSNorm - run_test_script("test_mistral_meg", CONFIG_LLAMA_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends(on=["test_mistral_meg"]) -def test_mistral_match_meg(run_test_script): - # Mistral with Fast-LLM. - run_test_script( - "test_mistral_match_meg", - CONFIG_LLAMA_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_mistral_meg", +@pytest.mark.depends(on=["test_megatron"]) +def test_match_megatron(run_test_script_for_all_models, model_testing_config): + run_test_script_for_all_models( + [ + "model.distributed.training_dtype=fp32", + "data.datasets={}", + f"data.path={DATASET_PREFIX}", + "model.base_model.use_megatron_initialization=True", + ], + compare="test_megatron", config=CompareConfig( ignore_tensors=[ ".self_attn.query_key_value.", ".self_attn.query.", ".self_attn.key_value.", - ".self_attn.dense.", ".mlp.layer_2.weight", ] ), - ) - - -@pytest.mark.slow -def test_mixtral_meg(run_test_script): - # Mistral with Megatron. - # No linear bias, swiglu activation, RMSNorm - run_test_script("test_mixtral_meg", CONFIG_MIXTRAL_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends(on=["test_mixtral_meg"]) -def test_mixtral_match_meg(run_test_script): - # Mistral with Fast-LLM. - run_test_script( - "test_mixtral_match_meg", - CONFIG_MIXTRAL_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_mixtral_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".self_attn.dense.", - ".mlp.layer_1.weight", - ".mlp.layer_2.weight", - ".mlp.experts", - "Global layer 2 fw: Transformer layer 2 output", - ], - max_rel_tolerance=1.5e-1, - ), + use_performance_args=False, ) diff --git a/tests/test_mb.py b/tests/test_mb.py index 82ac4c25..80350df9 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -1,82 +1,84 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL -from tests.compare_tensor_logs import CompareConfig - -CONFIG_DF = CONFIG_COMMON + ["batch.depth_first_micro_batches=4"] -CONFIG_BF = CONFIG_COMMON + ["batch.breadth_first_micro_batches=4"] -CONFIG_BF_DF = CONFIG_COMMON + ["batch.depth_first_micro_batches=2", "batch.breadth_first_micro_batches=2"] +from tests.utils.compare_tensor_logs import CompareConfig # TODO: Compare grads with simple -def test_model_df4(run_test_script): +def test_model_df4(run_test_script_for_all_models): # Depth-first gradient accumulation baseline. - run_test_script(f"test_{TEST_MODEL}_df4", CONFIG_DF) + run_test_script_for_all_models("test_model_df4", ["batch.depth_first_micro_batches=4"]) @pytest.mark.slow @pytest.mark.depends(on=["test_model_df4"]) -def test_model_df4_z3(run_test_script): +def test_model_df4_z3(run_test_script_for_all_models): # Gradient accumulation with ZeRO-3. - run_test_script( - f"test_{TEST_MODEL}_df4_z3", - CONFIG_DF + ["model.multi_stage.zero_stage=3"], + run_test_script_for_all_models( + "test_model_df4_z3", + ["model.multi_stage.zero_stage=3", "batch.depth_first_micro_batches=4"], num_gpus=2, - compare=f"test_{TEST_MODEL}_df4", + compare="test_model_df4", config=CompareConfig(ignore_duplicates=["Global gradient"]), ) @pytest.mark.depends(on=["test_model_df4"], scope="session") -def test_model_bf4(run_test_script): +def test_model_bf4(run_test_script_for_all_models): # Breadth-first gradient accumulation baseline. - run_test_script(f"test_{TEST_MODEL}_bf4", CONFIG_BF, compare=f"test_{TEST_MODEL}_df4") + run_test_script_for_all_models(["batch.breadth_first_micro_batches=4"], compare="test_model_df4") @pytest.mark.depends(on=["test_model_df4", "test_model_bf4"]) -def test_model_bf2_df2(run_test_script): +def test_model_bf2_df2(run_test_script_for_all_models): # Mixed gradient accumulation baseline. - run_test_script(f"test_{TEST_MODEL}_bf2_df2", CONFIG_BF_DF, compare=f"test_{TEST_MODEL}_df4") + run_test_script_for_all_models( + ["batch.depth_first_micro_batches=2", "batch.breadth_first_micro_batches=2"], compare="test_model_df4" + ) @pytest.mark.slow @pytest.mark.depends(on=["test_model_bf4"]) -def test_model_pp2s2_bf4(run_test_script): +def test_model_pp2s2_bf4(run_test_script_for_all_models): # Pipeline-parallel without tied weights. - run_test_script( - f"test_{TEST_MODEL}_pp2s2_bf4", - CONFIG_BF + ["model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2"], + run_test_script_for_all_models( + [ + "batch.breadth_first_micro_batches=4", + "model.distributed.pipeline_parallel=2", + "model.multi_stage.layers_per_stage=2", + ], num_gpus=2, - compare=f"test_{TEST_MODEL}_df4", + compare="test_model_df4", ) @pytest.mark.slow @pytest.mark.depends(on=["test_model_bf4"]) -def test_model_pp2s1_bf4(run_test_script): +def test_model_pp2s1_bf4(run_test_script_for_all_models): # Pipeline-parallel with tied weights. - run_test_script( - f"test_{TEST_MODEL}_pp2s1_bf4", - CONFIG_BF + ["model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=1"], + run_test_script_for_all_models( + [ + "batch.breadth_first_micro_batches=4", + "model.distributed.pipeline_parallel=2", + "model.multi_stage.layers_per_stage=1", + ], num_gpus=2, - compare=f"test_{TEST_MODEL}_df4", + compare="test_model_df4", config=CompareConfig(ignore_duplicates=["layers.0.word_embeddings_weight"]), ) @pytest.mark.slow @pytest.mark.depends(on=["test_model_bf4"]) -def test_model_dp2_tp2_pp2s2_bf4(run_test_script): +def test_model_dp2_tp2_pp2s2_bf4(run_test_script_for_all_models): # Simple 3d parallelism # TODO: Test fails - run_test_script( - f"test_{TEST_MODEL}_dp2_tp2_pp2s2_bf4", - CONFIG_BF - + [ + run_test_script_for_all_models( + [ + "batch.breadth_first_micro_batches=4", "model.distributed.tensor_parallel=2", "model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=1", ], num_gpus=8, - compare=f"test_{TEST_MODEL}_df4", + compare="test_model_df4", ) diff --git a/tests/test_mb_seq_first.py b/tests/test_mb_seq_first.py index 345a7bc4..5146dc9a 100644 --- a/tests/test_mb_seq_first.py +++ b/tests/test_mb_seq_first.py @@ -1,57 +1,48 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL -from tests.compare_tensor_logs import CompareConfig - -CONFIG_DF_SF = CONFIG_COMMON + ["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"] -CONFIG_BF_SF = CONFIG_COMMON + ["batch.breadth_first_micro_batches=4", "model.base_model.sequence_first=True"] -CONFIG_BF_DF_SF = CONFIG_COMMON + [ - "batch.depth_first_micro_batches=2", - "batch.breadth_first_micro_batches=2", - "model.base_model.sequence_first=True", -] +from tests.utils.compare_tensor_logs import CompareConfig # TODO: Compare grads with simple -def test_model_df4_sf(run_test_script): +def test_model_df4_sf(run_test_script_for_all_models): # Sequence-first gradient accumulation baseline. - run_test_script(f"test_{TEST_MODEL}_df4_sf", CONFIG_DF_SF) + run_test_script_for_all_models(["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"]) @pytest.mark.slow @pytest.mark.depends(on=["test_model_df4_sf"]) -def test_model_dp2_sp2_df4(run_test_script): +def test_model_dp2_sp2_df4(run_test_script_for_all_models): # Sequence-tensor-parallel with gradient accumulation. # TODO: Compiled cross-entropy broken for this config - run_test_script( - f"test_{TEST_MODEL}_dp2_sp2_df4", - CONFIG_BF_SF - + [ + run_test_script_for_all_models( + [ + "batch.breadth_first_micro_batches=4", + "model.base_model.sequence_first=True", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", "run.torch_dynamo_enable=False", ], num_gpus=4, - compare=f"test_{TEST_MODEL}_df4_sf", + compare="test_model_df4_sf", ) @pytest.mark.slow @pytest.mark.skip(reason="Test is broken.") @pytest.mark.depends(on=["test_model_df4_sf"]) -def test_model_dp2_sp2_pp2s1(run_test_script): +def test_model_dp2_sp2_pp2s1(run_test_script_for_all_models): # 3d-parallel with sequence-tensor-parallel. # TODO: Compiled cross-entropy broken for this config - run_test_script( - f"test_{TEST_MODEL}_dp2_sp2_pp2s1", - CONFIG_BF_SF - + [ + run_test_script_for_all_models( + [ + "batch.breadth_first_micro_batches=4", + "model.base_model.sequence_first=True", "model.distributed.tensor_parallel=2", "model.distributed.pipeline_parallel=2", "model.distributed.sequence_tensor_parallel=True", "run.torch_dynamo_enable=False", ], num_gpus=8, - compare=f"test_{TEST_MODEL}_df4_sf", + compare="test_model_df4_sf", config=CompareConfig(ignore_duplicates=["layers.0.word_embeddings_weight"]), ) diff --git a/tests/test_ms.py b/tests/test_ms.py index 90d16672..256eafe3 100644 --- a/tests/test_ms.py +++ b/tests/test_ms.py @@ -1,38 +1,36 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL - -CONFIG_MS = CONFIG_COMMON + ["batch.micro_sequence_length=256"] - # TODO: Compare grads with simple -def test_model_ms256(run_test_script): +def test_model_ms256(run_test_script_for_all_models): # Micro-sequence baseline - run_test_script(f"test_{TEST_MODEL}_ms256", CONFIG_MS) + run_test_script_for_all_models(["batch.micro_sequence_length=256"]) @pytest.mark.slow @pytest.mark.depends(on=["test_model_ms256"]) -def test_model_pp2s2_ms256(run_test_script): +def test_model_pp2s2_ms256(run_test_script_for_all_models): # Sequence-pipeline-parallel - run_test_script( - f"test_{TEST_MODEL}_pp2s2_ms256", - CONFIG_MS + ["model.distributed.pipeline_parallel=2", "model.multi_stage.layers_per_stage=2"], + run_test_script_for_all_models( + [ + "batch.micro_sequence_length=256", + "model.distributed.pipeline_parallel=2", + "model.multi_stage.layers_per_stage=2", + ], num_gpus=2, - compare=f"test_{TEST_MODEL}_ms256", + compare="test_model_ms256", ) @pytest.mark.slow @pytest.mark.skip @pytest.mark.depends(on=["test_model_ms256"]) -def test_model_dp2s2_stp2_pp2s2_ms256(run_test_script): +def test_model_dp2s2_stp2_pp2s2_ms256(run_test_script_for_all_models): # TODO: Handle this case. # Sequence-3d-parallel - run_test_script( - f"test_{TEST_MODEL}_dp2s2_stp2_pp2s2_ms256", - CONFIG_MS - + [ + run_test_script_for_all_models( + [ + "batch.micro_sequence_length=256", "model.distributed.pipeline_parallel=2", "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -40,5 +38,5 @@ def test_model_dp2s2_stp2_pp2s2_ms256(run_test_script): "model.multi_stage.layers_per_stage=2", ], num_gpus=8, - compare=f"test_{TEST_MODEL}_ms256", + compare="test_model_ms256", ) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index edce4e74..5c4660b7 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -14,7 +14,7 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel from fast_llm.utils import Assert -from tests.common import get_hybrid_config, materialize_meta_tensors, requires_cuda +from tests.utils.utils import get_hybrid_config, materialize_meta_tensors, requires_cuda try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index bb468ceb..6d3861eb 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert -from tests.common import CONFIG_COMMON, requires_cuda +from tests.utils.utils import requires_cuda def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @@ -17,8 +17,8 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda -def test_frozen_weights(): - args = CONFIG_COMMON + ["run.tensor_logs.save=False"] +def test_frozen_weights(model_testing_config): + args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args)._multi_stage model_frozen = _get_trainer_from_args(args + ["model.base_model.transformer.mlp_lr_scale=[0]"])._multi_stage diff --git a/tests/test_seq_first.py b/tests/test_seq_first.py index a8f4c036..3e8b7ea1 100644 --- a/tests/test_seq_first.py +++ b/tests/test_seq_first.py @@ -1,53 +1,48 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL - -CONFIG_SF = CONFIG_COMMON + ["model.base_model.sequence_first=True"] - # TODO: Compare grads with simple -def test_model_sf(run_test_script): +def test_model_sf(run_test_script_for_all_models): # Sequence-first baseline. - run_test_script(f"test_{TEST_MODEL}_sf", CONFIG_SF) + run_test_script_for_all_models("test_model_sf", ["model.base_model.sequence_first=True"]) @pytest.mark.slow @pytest.mark.depends(on=["test_model_sf"]) -def test_model_sp2(run_test_script): +def test_model_sp2(run_test_script_for_all_models): # Sequence-tensor-parallel. - run_test_script( - f"test_{TEST_MODEL}_sp2", - CONFIG_SF + ["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True"], + run_test_script_for_all_models( + "test_model_sp2", + ["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True"], num_gpus=2, - compare=f"test_{TEST_MODEL}_sf", + compare="test_model_sf", ) @pytest.mark.slow @pytest.mark.depends(on=["test_model_sf"]) -def test_model_sdp2(run_test_script): +def test_model_sdp2(run_test_script_for_all_models): # Sequence-data-parallel - run_test_script( - f"test_{TEST_MODEL}_sdp2", - CONFIG_COMMON + ["model.distributed.sequence_data_parallel=2"], + run_test_script_for_all_models( + "test_model_sdp2", + ["model.distributed.sequence_data_parallel=2"], num_gpus=2, - compare=f"test_{TEST_MODEL}_sf", + compare="test_model_sf", ) @pytest.mark.slow @pytest.mark.depends(on=["test_model_sf"]) -def test_model_sp2_ce4(run_test_script): +def test_model_sp2_ce4(run_test_script_for_all_models): # Sequence-tensor-parallel with cross-entropy splits. - run_test_script( - f"test_{TEST_MODEL}_sp2_ce4", - CONFIG_SF - + [ + run_test_script_for_all_models( + "test_model_sp2_ce4", + [ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", "model.base_model.parallel_embeddings=False", "model.base_model.cross_entropy_splits=4", ], num_gpus=2, - compare=f"test_{TEST_MODEL}_sf", + compare="test_model_sf", ) diff --git a/tests/test_simple.py b/tests/test_simple.py index 3128626d..bc48e26b 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,14 +1,11 @@ import pytest -from tests.common import CONFIG_COMMON, CONFIG_FAST_LLM, TEST_MODEL - -def test_model_safe(run_test_script): +def test_model_safe(run_test_script_for_all_models): # The safest possible config, identical to the one in test_match_megatron except for the initialization. - run_test_script( - f"test_{TEST_MODEL}_safe", - CONFIG_FAST_LLM - + [ + run_test_script_for_all_models( + [ + "model.distributed.training_dtype=fp32", "run.torch_dynamo_enable=False", "schedule.data_overlap=False", "model.base_model.transformer.dropless_moe=False", @@ -17,29 +14,25 @@ def test_model_safe(run_test_script): @pytest.mark.depends(on=["test_model_safe"]) -def test_model(run_test_script): +def test_model(run_test_script_for_all_models): # A baseline config (single-gpu, bf16, flash-attn). # Also tests for multiple data loaders. - run_test_script( - f"test_{TEST_MODEL}", CONFIG_COMMON + ["training.num_workers=2"], compare=f"test_{TEST_MODEL}_safe" - ) + run_test_script_for_all_models(["training.num_workers=2"], compare="test_model_safe") @pytest.mark.slow @pytest.mark.depends(on=["test_model"]) -def test_model_dp2(run_test_script): +def test_model_dp2(run_test_script_for_all_models): # Simple data-parallel. - run_test_script(f"test_{TEST_MODEL}_dp2", CONFIG_COMMON, num_gpus=2, compare=f"test_{TEST_MODEL}") + run_test_script_for_all_models([], num_gpus=2, compare="test_model") @pytest.mark.slow -def test_model_dp2_timeout(run_test_script): +def test_model_dp2_timeout(run_test_script_for_all_models): # Test sampling timeout # TODO: Find a better way to test this - run_test_script( - f"test_{TEST_MODEL}_dp2_timeout", - CONFIG_COMMON - + [ + run_test_script_for_all_models( + [ # Use a short timeout "model.distributed.timeout=4", # Make a dataset that would timeout under the distributed timeout @@ -49,10 +42,10 @@ def test_model_dp2_timeout(run_test_script): # Use a bigger timeout for the dataset. "training.timeout=10", # Remove testing clutter. - f"model.multi_stage.debug_param_init=0", - f"model.multi_stage.debug_layer_outputs=0", - f"model.multi_stage.debug_layer_gradients=0", - f"model.multi_stage.debug_all_param_gradients=0", + "model.multi_stage.debug_param_init=0", + "model.multi_stage.debug_layer_outputs=0", + "model.multi_stage.debug_layer_gradients=0", + "model.multi_stage.debug_all_param_gradients=0", ], num_gpus=2, ) @@ -60,45 +53,41 @@ def test_model_dp2_timeout(run_test_script): @pytest.mark.slow @pytest.mark.depends(on=["test_model"]) -def test_model_tp2(run_test_script): +def test_model_tp2(run_test_script_for_all_models): # Simple tensor-parallel. - run_test_script( - f"test_{TEST_MODEL}_tp2", - CONFIG_COMMON + ["model.distributed.tensor_parallel=2"], + run_test_script_for_all_models( + ["model.distributed.tensor_parallel=2"], num_gpus=2, - compare=f"test_{TEST_MODEL}", + compare="test_model", ) @pytest.mark.depends(on=["test_model"]) -def test_model_ce4(run_test_script): +def test_model_ce4(run_test_script_for_all_models): # Cross-entropy splits. - run_test_script( - f"test_{TEST_MODEL}_ce4", - CONFIG_COMMON + ["model.base_model.cross_entropy_splits=4"], - compare=f"test_{TEST_MODEL}", + run_test_script_for_all_models( + ["model.base_model.cross_entropy_splits=4"], + compare="test_model", ) @pytest.mark.slow @pytest.mark.depends(on=["test_model"]) -def test_model_dp2_z2(run_test_script): +def test_model_dp2_z2(run_test_script_for_all_models): # Data-parallel with zero stage 2. - run_test_script( - f"test_{TEST_MODEL}_dp2_z2", - CONFIG_COMMON + ["model.multi_stage.zero_stage=2"], + run_test_script_for_all_models( + ["model.multi_stage.zero_stage=2"], num_gpus=2, - compare=f"test_{TEST_MODEL}", + compare="test_model", ) @pytest.mark.slow @pytest.mark.depends(on=["test_model"]) -def test_model_dp2_z3(run_test_script): +def test_model_dp2_z3(run_test_script_for_all_models): # Data-parallel with zero stage 3. - run_test_script( - f"test_{TEST_MODEL}_dp2_z3", - CONFIG_COMMON + ["model.multi_stage.zero_stage=3"], + run_test_script_for_all_models( + ["model.multi_stage.zero_stage=3"], num_gpus=2, - compare=f"test_{TEST_MODEL}", + compare="test_model", ) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index a6922a45..a1d460c2 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -16,7 +16,7 @@ from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat -from tests.common import get_hybrid_config, materialize_meta_tensors +from tests.utils.utils import get_hybrid_config, materialize_meta_tensors try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index 108a2898..9befe64f 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -31,7 +31,7 @@ from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies from fast_llm.utils import Assert, rms_diff -from tests.common import requires_cuda +from tests.utils.utils import requires_cuda @requires_cuda diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py similarity index 100% rename from tests/compare_tensor_logs.py rename to tests/utils/compare_tensor_logs.py diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py new file mode 100644 index 00000000..23c487a7 --- /dev/null +++ b/tests/utils/dataset.py @@ -0,0 +1,82 @@ +import pathlib +import random +import string + +import numpy as np +import yaml + +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample + +# TODO: Fixture +TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") +TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" +TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +DATASET_CACHE = TEST_RESULTS_PATH / "dataset" +DATASET_PREFIX = DATASET_CACHE / "common" / "dataset" +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" +TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 + + +def get_test_dataset( + prefix: pathlib.Path = DATASET_PREFIX, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, + max_spans: int = 0, +): + if not TOKENIZER_FILE.is_file(): + import transformers + + transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + + if not ( + prefix.with_suffix(".idx").is_file() + and prefix.with_suffix(".bin").is_file() + and prefix.parent.joinpath("fast_llm_config.yaml").is_file() + ): + import transformers + + texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() + tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) + + samples = [ + GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts + ] + if max_spans > 0: + lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) + for sample, span in zip(samples, spans): + span = np.unique(span) + sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + + GPTMemmapDataset.write_dataset(prefix, samples) + yaml.safe_dump( + {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + ) + + +def get_test_concatenated_memmap_dataset( + path: pathlib.Path, + num_files: int, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, + seed_shift: int = 55, +): + index_file = path / "index.txt" + if not index_file.is_file(): + for i in range(num_files): + get_test_dataset( + prefix=path / f"dataset_{i}", + seed=seed + i * seed_shift, + num_tokens=num_tokens, + characters=characters, + vocab_size=vocab_size, + ) + index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py new file mode 100644 index 00000000..963f6ae9 --- /dev/null +++ b/tests/utils/model_configs.py @@ -0,0 +1,276 @@ +import dataclasses +import functools +import os +import typing + +import pytest + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.models.auto import model_registry +from fast_llm.models.gpt.config import ( + LlamaGPTHuggingfaceCheckpointFormat, + MistralGPTHuggingfaceCheckpointFormat, + MixtralGPTHuggingfaceCheckpointFormat, + MTPLlamaGPTHuggingfaceCheckpointFormat, + Qwen2GPTHuggingfaceCheckpointFormat, + Starcoder2GPTHuggingfaceCheckpointFormat, +) +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from tests.utils.dataset import DATASET_PREFIX, TEST_VOCAB_SIZE + +_LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class ModelTestingConfig: + name: str = None + model_type: str + config_args: list[str] + megatron_args: list[str] | None + checkpoint_format: CheckpointFormat | None + + @functools.cached_property + def model_config_class(self): + return model_registry[self.model_type] + + @functools.cached_property + def huggingface_model_for_causal_lm_class(self): + return self.model_config_class.get_huggingface_model_for_causal_lm_class() + + @functools.cached_property + def model_class(self): + return self.model_config_class.get_model_class() + + @functools.cached_property + def base_model_config_class(self): + return self.model_config_class.get_base_model_config_class() + + +def _update_and_add_testing_config( + old_name: str, + new_name: str, + *, + model_type: str | None = None, + extra_args: list[str] | None = None, + megatron_args: list[str] | None = ..., + checkpoint_format: CheckpointFormat | None = ..., +): + config = _MODEL_CONFIGS[old_name] + updates: dict[str, typing.Any] = {"name": new_name} + if model_type is not None: + updates["model_type"] = model_type + if extra_args is not None: + updates["config_args"] = config.config_args + extra_args + if megatron_args is not ...: + if megatron_args is None: + updates["megatron_args"] = None + elif config.megatron_args is None: + updates["megatron_args"] = megatron_args + else: + updates["megatron_args"] = config.megatron_args + megatron_args + if checkpoint_format is not ...: + updates["checkpoint_format"] = checkpoint_format + + _MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) + + +_MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} + + +_MODEL_CONFIGS["gpt2"] = ModelTestingConfig( + name="gpt2", + model_type="gpt", + config_args=[ + "training.logs.interval=1", + "run.tensor_logs.save=True", + "run.tensor_logs.show=False", + "model.base_model.max_position_embeddings=512", + "model.base_model.transformer.num_layers=2", + "model.base_model.transformer.hidden_size=256", + "model.base_model.transformer.num_attention_heads=8", + "model.base_model.transformer.head_groups=8", + "model.base_model.transformer.init_method_std=0.022", + f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", + f"model.multi_stage.debug_param_init={_LOG_LEVEL}", + f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", + f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", + f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + "model.multi_stage.debug_tensor_parallel=True", + "model.distributed.reproducible_init=True", + "model.distributed.timeout=10", + "model.distributed.training_dtype=bf16", + "training.train_iters=2", + "training.num_workers=0", + "training.timeout=30", + "batch.batch_size=8", + "batch.sequence_length=512", + "data.datasets.training.type=slice", + "data.datasets.training.end=0.969", + "data.datasets.training.dataset.type=memmap", + f"data.datasets.training.dataset.path={DATASET_PREFIX}", + "data.datasets.validation.type=slice", + "data.datasets.validation.begin=0.969", + "data.datasets.validation.end=0.999", + "data.datasets.validation.dataset.type=memmap", + f"data.datasets.validation.dataset.path={DATASET_PREFIX}", + "data.datasets.test.type=slice", + "data.datasets.test.begin=0.999", + "data.datasets.test.end=1", + "data.datasets.test.dataset.type=memmap", + f"data.datasets.test.dataset.path={DATASET_PREFIX}", + "optimizer.learning_rate.base=0.0001", + ], + megatron_args=[ + "--num-layers=2", + "--hidden-size=256", + "--num-attention-heads=8", + "--log-interval=1", + "--train-iters=2", + "--eval-iters=0", + "--hidden-dropout=0", + "--attention-dropout=0", + f"--debug_param_init={_LOG_LEVEL}", + f"--debug_layer_outputs={_LOG_LEVEL}", + f"--debug_layer_gradients={_LOG_LEVEL}", + f"--debug_all_param_gradients={_LOG_LEVEL}", + "--debug_param_update=0", + "--global-batch-size=8", + "--micro-batch-size=8", + "--max-position-embeddings=512", + "--seq-length=512", + "--init-method-std=0.022", + "--lr=0.0001", + "--num-workers=0", + "--valid-num-workers=0", + "--tokenizer-type=NullTokenizer", + # Megatron messes with the vocab size, so we have to subtract 1. + f"--vocab-size={TEST_VOCAB_SIZE - 1}", + f"--data-path={DATASET_PREFIX}", + "--lr-decay-style=constant", + # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) + "--use-mcore-models", + # local implementation doesn't allow for RMS norm. + "--transformer-impl=transformer_engine", + ], + checkpoint_format=None, +) + +_update_and_add_testing_config( + "gpt2", + "starcoder", + extra_args=["model.base_model.transformer.head_groups=1"], + megatron_args=["--group-query-attention"], + checkpoint_format=None, +) + +_update_and_add_testing_config( + "gpt2", + "starcoder2", + extra_args=[ + "model.base_model.transformer.head_groups=4", + "model.base_model.transformer.rotary.type=default", + ], + megatron_args=[ + "--group-query-attention", + "--num-query-groups=4", + "--use-rotary-position-embeddings", + "--no-position-embedding", + ], + checkpoint_format=Starcoder2GPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + "starcoder2", + "llama", + extra_args=[ + "model.base_model.transformer.gated=True", + "model.base_model.transformer.activation_type=silu", + "model.base_model.transformer.add_linear_biases=False", + "model.base_model.transformer.normalization.type=rms_norm", + "model.base_model.transformer.ffn_hidden_size=1024", + "model.base_model.tie_word_embeddings=False", + ], + megatron_args=[ + "--swiglu", + "--disable-bias-linear", + "--normalization=RMSNorm", + "--ffn-hidden-size=1024", + "--untie-embeddings-and-output-weights", + ], + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + "llama", + "llama3", + extra_args=["model.base_model.transformer.rotary.type=llama3"], + # Megatron doesn't support Llama3-style Rotary Embeddings + megatron_args=None, + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + "llama", + "llama_yarn", + extra_args=["model.base_model.transformer.rotary.type=yarn"], + # Megatron doesn't support Yarn-style Rotary Embeddings + megatron_args=None, + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + "llama", + "llama_mtp", + extra_args=["model.base_model.prediction_heads=4"], + # Megatron doesn't support multi-token prediction. + megatron_args=None, + checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + "llama", + "qwen2", + extra_args=["model.base_model.transformer.add_linear_biases=only_attn_qkv"], + # Megatron doesn't support per sub layer biases + megatron_args=None, + checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + "llama", + "mistral", + extra_args=["model.base_model.transformer.window_size=128"], + # Megatron doesn't support sliding windows. + megatron_args=None, + checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # We ignore sliding windows to enable comparison with Megatron. + "llama", + "mixtral", + extra_args=[ + "model.base_model.transformer.num_experts=4", + "model.base_model.transformer.num_experts_per_token=4", + ], + megatron_args=[ + "--num-experts=4", + "--moe-router-topk=4", + ], + checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # We ignore sliding windows to enable comparison with Megatron. + "llama", + "llamba", + model_type="hybrid_ssm", + extra_args=["model.base_model.hybrid_block_layout=['t','m']"], + megatron_args=None, + checkpoint_format=LLambaHuggingfaceCheckpointFormat, +) + + +@pytest.fixture(scope="session", params=_MODEL_CONFIGS.keys()) +def model_testing_config(request) -> ModelTestingConfig: + return _MODEL_CONFIGS[request.param] diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py new file mode 100644 index 00000000..c11d3f3b --- /dev/null +++ b/tests/utils/run_test_script.py @@ -0,0 +1,118 @@ +import os +import pathlib +import shutil +import subprocess +import sys + +import pytest +import torch + +from fast_llm.tools.train import CliTrainingConfig +from tests.utils.compare_tensor_logs import CompareConfig, compare_tensor_logs +from tests.utils.dataset import get_test_dataset + +# FIXME: figure out correct import of megatron modules without this hack +sys.path.append(os.getcwd()) + +_ARTIFACT_PATH = "runs/0/artifacts" + + +@pytest.fixture(scope="session") +def run_test_script(worker_resources): + def do_run_test_script( + path: pathlib.Path, + args: list[str], + num_gpus: int = 1, + *, + model_type: str, + is_megatron: bool = False, + compare_path: pathlib.Path | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + ): + if torch.cuda.device_count() < num_gpus: + pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") + env = os.environ.copy() + if is_megatron: + # Prevent Megatron from complaining. + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["NVTE_FLASH_ATTN"] = "0" + skip = False + if path.exists(): + assert path.is_dir() + # TODO: Better way to check if the previous attempt succeeded. + shutil.rmtree(path) + if prepare_fn is not None: + skip = prepare_fn(path, None if compare_path is None else compare_path, skip) + if is_megatron: + args = [*args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] + else: + args = [model_type, *args, f"run.experiment_dir={path}"] + header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] + command = [ + "python", + "-m", + "torch.distributed.run", + f"--nproc-per-node={num_gpus}", + f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}", + f"--master-port={worker_resources.torchrun_port}", + *header, + *args, + ] + print(" ".join(command)) + if skip: + print("Reusing existing run.") + else: + get_test_dataset() + if num_gpus == 1 and not is_megatron: + CliTrainingConfig.parse_and_run(args) + else: + completed_proc = subprocess.run(command, env=env, timeout=60) + if completed_proc.returncode: + raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") + if compare_path is not None and do_compare: + if compare_fn is not None: + compare_fn(path, compare_path) + compare_tensor_logs( + compare_path / _ARTIFACT_PATH, + path / _ARTIFACT_PATH, + config, + ) + + return do_run_test_script + + +@pytest.fixture(scope="session") +def run_test_script_base_path(model_testing_config, result_path, request): + return result_path / "models" / model_testing_config.name + + +@pytest.fixture(scope="function") +def run_test_script_for_all_models(run_test_script, run_test_script_base_path, model_testing_config, request): + def do_run_test_script_for_all_models( + extra_args: list[str], + num_gpus: int = 1, + *, + is_megatron: bool = False, + compare: str | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + ): + run_test_script( + run_test_script_base_path / request.node.originalname, + (model_testing_config.megatron_args if is_megatron else model_testing_config.config_args) + extra_args, + num_gpus, + model_type=model_testing_config.model_type, + is_megatron=is_megatron, + compare_path=None if compare is None else run_test_script_base_path / compare, + config=config, + prepare_fn=prepare_fn, + compare_fn=compare_fn, + do_compare=do_compare, + ) + + return do_run_test_script_for_all_models diff --git a/tests/utils/utils.py b/tests/utils/utils.py new file mode 100644 index 00000000..bf2059fa --- /dev/null +++ b/tests/utils/utils.py @@ -0,0 +1,55 @@ +import pathlib + +import pytest +import torch + +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig + +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + + +@pytest.fixture(scope="session") +def result_path(): + return pathlib.Path("/tmp/fast_llm_tests") + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), + ssm=SSMConfig(), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config From f8850e4c09e677ab94ca062c51272fbe3689699c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Jun 2025 14:41:00 -0400 Subject: [PATCH 02/69] Parametrized dependencies --- tests/conftest.py | 76 ++++++---- tests/test_checkpoint.py | 84 ++++++++--- tests/test_gpt_generate_and_forward.py | 6 +- tests/test_match_megatron.py | 2 +- tests/test_mb.py | 12 +- tests/test_mb_seq_first.py | 4 +- tests/test_ms.py | 4 +- tests/test_seq_first.py | 8 +- tests/test_simple.py | 12 +- tests/utils/depends.py | 200 +++++++++++++++++++++++++ 10 files changed, 337 insertions(+), 71 deletions(-) create mode 100644 tests/utils/depends.py diff --git a/tests/conftest.py b/tests/conftest.py index 3d1e940b..4cf6158d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,13 +2,12 @@ import math import os -import networkx import pytest -import pytest_depends -import pytest_depends.main import torch from xdist.scheduler import LoadGroupScheduling +from tests.utils.depends import DependencyManager + # Make fixtures available globally without import from tests.utils.run_test_script import ( # isort: skip run_test_script, @@ -20,14 +19,24 @@ from tests.utils.utils import result_path # isort: skip +manager: DependencyManager | None = None + + def pytest_addoption(parser): - parser.addoption("--skip-slow", action="store_true") - parser.addoption( + group = parser.getgroup("fast_llm") + group.addoption("--skip-slow", action="store_true") + group.addoption( "--run-extra-slow", action="store_true", default=False, help="Run tests marked as extra_slow", ) + group.addoption( + "--show-dependencies", + action="store_true", + default=False, + help="List all dependencies of all tests as a list of nodeids + the names that could not be resolved.", + ) @dataclasses.dataclass @@ -49,6 +58,7 @@ def pytest_configure(config): config.addinivalue_line( "markers", "extra_slow: Mark test as extra slow and skip unless --run-extra-slow is given." ) + config.addinivalue_line("markers", "depends_on(name='name', on=['other_name']): marks dependencies between tests.") # TODO: Spawned processes (multi-gpu, Megatron) ignore resource allocation. is_parallel = hasattr(config, "workerinput") if is_parallel: @@ -98,6 +108,8 @@ def pytest_configure(config): @pytest.hookimpl(trylast=True) def pytest_collection_modifyitems(config, items): + global manager + if config.getoption("--skip-slow"): skip_slow = pytest.mark.skip(reason="Skipping slow tests") for item in items: @@ -109,26 +121,40 @@ def pytest_collection_modifyitems(config, items): if "extra_slow" in item.keywords: item.add_marker(skip_extra_slow) - manager: pytest_depends.DependencyManager = pytest_depends.managers[-1] - # Build the undirected graph as in `DependencyManager.sorted_items`. - dag = networkx.DiGraph() - for item in manager.items: - node_id = pytest_depends.clean_nodeid(item.nodeid) - dag.add_node(node_id) - for dependency in manager.dependencies[node_id].dependencies: - dag.add_edge(dependency, node_id) - # Mark dependency groups for xdist. - manager.groups = {} - for i, node_ids in enumerate(sorted(networkx.weakly_connected_components(dag), key=len, reverse=True)): - if len(node_ids) > 1: - for node_id in node_ids: - manager.nodeid_to_item[node_id]._nodeid = ( - f"{manager.nodeid_to_item[node_id]._nodeid}@dependency_group_{i}" - ) - - old_clean_nodeid = pytest_depends.main.clean_nodeid - # Hack into `clean_nodeid` so pytest_depends recognizes the renamed nodes. - pytest_depends.main.clean_nodeid = lambda nodeid: old_clean_nodeid(nodeid.split("@dependency_group_")[0]) + manager = DependencyManager(items) + + # Show the extra information if requested + if config.getoption("show_dependencies"): + manager.print_name_map(config.getoption("verbose") > 1) + manager.print_processed_dependencies(config.getoption("color")) + + # Reorder the items so that tests run after their dependencies + items[:] = manager.items + + # If pytest-depends is installed, it will complain about renamed nodes whether it's used or not. + try: + import pytest_depends + except ImportError: + pass + else: + old_clean_nodeid = pytest_depends.main.clean_nodeid + # Hack into `clean_nodeid` so pytest_depends recognizes the renamed nodes. + pytest_depends.main.clean_nodeid = lambda nodeid: old_clean_nodeid(nodeid.split("@dependency_group_")[0]) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item: pytest.Function, call): + outcome = yield + manager.register_result(item, outcome.get_result()) + + +def pytest_runtest_call(item: pytest.Function): + manager.handle_missing(item) + + +def pytest_unconfigure(): + global manager + manager = None @pytest.fixture(scope="session") diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index e7929440..6e6d5806 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -55,7 +55,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): shutil.copy(compare_path / path, test_path / path) -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) def test_resume(run_test_script_for_all_models): # Resume from iteration=1 and compare outputs with the baseline run. run_test_script_for_all_models( @@ -70,7 +70,7 @@ def test_resume(run_test_script_for_all_models): ) -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. run_test_script_for_all_models( @@ -107,7 +107,7 @@ def convert_paths(run_test_script_base_path): } -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -124,7 +124,7 @@ def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): ) -@pytest.mark.depends(on=["test_convert_distributed_to_fast_llm"]) +@pytest.mark.depends_on(on=["test_convert_distributed_to_fast_llm[{model_testing_config}]"]) def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): if model_testing_config.checkpoint_format is None: pytest.skip(f"Conversion not supported for {model_testing_config.name}") @@ -143,7 +143,7 @@ def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): ) -@pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) +@pytest.mark.depends_on(on=["test_convert_fast_llm_to_huggingface[{model_testing_config}]"]) def test_convert_huggingface_to_distributed(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -160,7 +160,7 @@ def test_convert_huggingface_to_distributed(model_testing_config, convert_paths) ) -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) def test_convert_distributed_to_huggingface(model_testing_config, convert_paths): if model_testing_config.checkpoint_format is None: pytest.skip(f"Conversion not supported for {model_testing_config.name}") @@ -179,7 +179,7 @@ def test_convert_distributed_to_huggingface(model_testing_config, convert_paths) ) -@pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) +@pytest.mark.depends_on(on=["test_convert_distributed_to_huggingface[{model_testing_config}]"]) def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -196,7 +196,7 @@ def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): ) -@pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) +@pytest.mark.depends_on(on=["test_convert_huggingface_to_fast_llm[{model_testing_config}]"]) def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -213,7 +213,12 @@ def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): ) -@pytest.mark.depends(on=["test_convert_huggingface_to_distributed", "test_convert_fast_llm_to_distributed"]) +@pytest.mark.depends_on( + on=[ + "test_convert_huggingface_to_distributed[{model_testing_config}]", + "test_convert_fast_llm_to_distributed[{model_testing_config}]", + ] +) def test_converted_distributed(convert_paths): # Compare the fast llm weights # TODO: Compare configs @@ -229,7 +234,12 @@ def test_converted_distributed(convert_paths): assert (w[key] == w1[key]).all(), (w[key], w1[key]) -@pytest.mark.depends(on=["test_convert_distributed_to_fast_llm", "test_convert_huggingface_to_fast_llm"]) +@pytest.mark.depends_on( + on=[ + "test_convert_distributed_to_fast_llm[{model_testing_config}]", + "test_convert_huggingface_to_fast_llm[{model_testing_config}]", + ] +) def test_converted_fast_llm(convert_paths): s0 = safetensors.torch.load_file(convert_paths["fast_llm_0"] / "model_0.safetensors") s1 = safetensors.torch.load_file(convert_paths["fast_llm_1"] / "model_0.safetensors") @@ -239,7 +249,12 @@ def test_converted_fast_llm(convert_paths): assert (s0[key] == s1[key]).all(), (key, s0, s1) -@pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface", "test_convert_distributed_to_huggingface"]) +@pytest.mark.depends_on( + on=[ + "test_convert_fast_llm_to_huggingface[{model_testing_config}]", + "test_convert_distributed_to_huggingface[{model_testing_config}]", + ] +) def test_converted_huggingface(convert_paths): h0 = safetensors.torch.load_file(convert_paths["huggingface_0"] / "model_0.safetensors") h1 = safetensors.torch.load_file(convert_paths["huggingface_1"] / "model_0.safetensors") @@ -257,7 +272,7 @@ def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMM config_ref.base_model.compare_architecture(config_test.base_model) -@pytest.mark.depends(on=["test_converted_distributed"]) +@pytest.mark.depends_on(on=["test_converted_distributed[{model_testing_config}]"]) def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_paths): config = model_testing_config.model_config_class.from_dict( yaml.safe_load((convert_paths["checkpoint"] / ".." / ".." / "config.yaml").open("r"))["model"], strict=False @@ -277,7 +292,7 @@ def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_pa assert (state_shards[f"{shard_name}_shard"] == model.get_shard(shard_name)).all() -@pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_pretrained_distributed_checkpoint[{model_testing_config}]"]) def test_load_converted_distributed_checkpoint(model_testing_config, convert_paths): config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( @@ -309,7 +324,12 @@ def test_load_converted_distributed_checkpoint(model_testing_config, convert_pat assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) +@pytest.mark.depends_on( + on=[ + "test_converted_fast_llm[{model_testing_config}]", + "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", + ] +) def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths): config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( @@ -340,7 +360,12 @@ def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths) assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) +@pytest.mark.depends_on( + on=[ + "test_converted_fast_llm[{model_testing_config}]", + "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", + ] +) def test_load_converted_huggingface_checkpoint(model_testing_config, convert_paths): config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( @@ -372,7 +397,12 @@ def test_load_converted_huggingface_checkpoint(model_testing_config, convert_pat assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_converted_huggingface_checkpoint"]) +@pytest.mark.depends_on( + on=[ + "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", + "test_load_converted_huggingface_checkpoint[{model_testing_config}]", + ] +) def test_run_converted_model(model_testing_config, convert_paths): model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( @@ -423,7 +453,7 @@ def test_run_converted_model(model_testing_config, convert_paths): @pytest.mark.slow -@pytest.mark.depends(on=["test_load_converted_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( "test_load_pretrained_distributed_in_dp2", @@ -438,7 +468,7 @@ def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, conv ) -@pytest.mark.depends(on=["test_load_converted_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( "test_load_pretrained_distributed_with_config", @@ -452,7 +482,7 @@ def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, ) -@pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) +@pytest.mark.depends_on(on=["test_load_pretrained_distributed_in_dp2[{model_testing_config}]"]) def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_paths, run_test_script_base_path): test_ckpt_path = run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" pretrained_config_ref = CheckpointLoadConfig( @@ -497,7 +527,7 @@ def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_p @pytest.mark.slow -@pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]"]) def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, run_test_script_base_path): # This also tests conversion which uses `FastLLMModel.from_checkpoint` pretrained_config_ref = CheckpointLoadConfig( @@ -520,7 +550,12 @@ def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, ru @pytest.mark.slow -@pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) +@pytest.mark.depends_on( + on=[ + "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", + "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", + ] +) def test_load_pretrained_fast_llm_in_dp2(run_test_script, convert_paths, run_test_script_base_path): run_test_script( "test_load_pretrained_fast_llm_in_dp2", @@ -553,7 +588,12 @@ def test_load_pretrained_fast_llm_in_dp2(run_test_script, convert_paths, run_tes @pytest.mark.slow -@pytest.mark.depends(on=["test_load_converted_huggingface_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) +@pytest.mark.depends_on( + on=[ + "test_load_converted_huggingface_checkpoint[{model_testing_config}]", + "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", + ] +) def test_load_pretrained_huggingface_in_dp2( run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths ): diff --git a/tests/test_gpt_generate_and_forward.py b/tests/test_gpt_generate_and_forward.py index ca75cf3e..4c920afd 100644 --- a/tests/test_gpt_generate_and_forward.py +++ b/tests/test_gpt_generate_and_forward.py @@ -251,7 +251,7 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi @pytest.mark.slow @requires_cuda -@pytest.mark.depends(on=["test_export_for_generate"]) +@pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", [ @@ -314,7 +314,7 @@ def test_generate_from_model( @requires_cuda @pytest.mark.slow -@pytest.mark.depends(on=["test_export_for_generate"]) +@pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) def test_small_generate_from_model(model_testing_config, run_test_script_base_path): _test_generate_from_model( run_test_script_base_path / f"test_export_for_generate/export/{model_testing_config.checkpoint_format.name}/1", @@ -363,7 +363,7 @@ def test_forward_return_hidden_states(model_path): @pytest.mark.slow @requires_cuda -@pytest.mark.depends(on=["test_export_for_generate"]) +@pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) def test_small_forward_return_hidden_states(model_testing_config, run_test_script_base_path): _test_forward_return_hidden_states( run_test_script_base_path / f"test_export_for_generate/export/{model_testing_config.checkpoint_format.name}/1", diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index a77906ae..5c0bbdaa 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -10,7 +10,7 @@ def test_megatron(run_test_script_for_all_models, model_testing_config): @pytest.mark.slow -@pytest.mark.depends(on=["test_megatron"]) +@pytest.mark.depends_on(on=["test_megatron[{model_testing_config}]"]) def test_match_megatron(run_test_script_for_all_models, model_testing_config): run_test_script_for_all_models( [ diff --git a/tests/test_mb.py b/tests/test_mb.py index 80350df9..e1f79fc1 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -10,7 +10,7 @@ def test_model_df4(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_df4"]) +@pytest.mark.depends_on(on=["test_model_df4[{model_testing_config}]"]) def test_model_df4_z3(run_test_script_for_all_models): # Gradient accumulation with ZeRO-3. run_test_script_for_all_models( @@ -22,13 +22,13 @@ def test_model_df4_z3(run_test_script_for_all_models): ) -@pytest.mark.depends(on=["test_model_df4"], scope="session") +@pytest.mark.depends_on(on=["test_model_df4[{model_testing_config}]"], scope="session") def test_model_bf4(run_test_script_for_all_models): # Breadth-first gradient accumulation baseline. run_test_script_for_all_models(["batch.breadth_first_micro_batches=4"], compare="test_model_df4") -@pytest.mark.depends(on=["test_model_df4", "test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_df4[{model_testing_config}]", "test_model_bf4[{model_testing_config}]"]) def test_model_bf2_df2(run_test_script_for_all_models): # Mixed gradient accumulation baseline. run_test_script_for_all_models( @@ -37,7 +37,7 @@ def test_model_bf2_df2(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_bf4[{model_testing_config}]"]) def test_model_pp2s2_bf4(run_test_script_for_all_models): # Pipeline-parallel without tied weights. run_test_script_for_all_models( @@ -52,7 +52,7 @@ def test_model_pp2s2_bf4(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_bf4[{model_testing_config}]"]) def test_model_pp2s1_bf4(run_test_script_for_all_models): # Pipeline-parallel with tied weights. run_test_script_for_all_models( @@ -68,7 +68,7 @@ def test_model_pp2s1_bf4(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_bf4[{model_testing_config}]"]) def test_model_dp2_tp2_pp2s2_bf4(run_test_script_for_all_models): # Simple 3d parallelism # TODO: Test fails diff --git a/tests/test_mb_seq_first.py b/tests/test_mb_seq_first.py index 5146dc9a..7d3cf5ad 100644 --- a/tests/test_mb_seq_first.py +++ b/tests/test_mb_seq_first.py @@ -10,7 +10,7 @@ def test_model_df4_sf(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_df4_sf"]) +@pytest.mark.depends_on(on=["test_model_df4_sf[{model_testing_config}]"]) def test_model_dp2_sp2_df4(run_test_script_for_all_models): # Sequence-tensor-parallel with gradient accumulation. # TODO: Compiled cross-entropy broken for this config @@ -29,7 +29,7 @@ def test_model_dp2_sp2_df4(run_test_script_for_all_models): @pytest.mark.slow @pytest.mark.skip(reason="Test is broken.") -@pytest.mark.depends(on=["test_model_df4_sf"]) +@pytest.mark.depends_on(on=["test_model_df4_sf[{model_testing_config}]"]) def test_model_dp2_sp2_pp2s1(run_test_script_for_all_models): # 3d-parallel with sequence-tensor-parallel. # TODO: Compiled cross-entropy broken for this config diff --git a/tests/test_ms.py b/tests/test_ms.py index 256eafe3..23ef60e6 100644 --- a/tests/test_ms.py +++ b/tests/test_ms.py @@ -8,7 +8,7 @@ def test_model_ms256(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_ms256"]) +@pytest.mark.depends_on(on=["test_model_ms256[{model_testing_config}]"]) def test_model_pp2s2_ms256(run_test_script_for_all_models): # Sequence-pipeline-parallel run_test_script_for_all_models( @@ -24,7 +24,7 @@ def test_model_pp2s2_ms256(run_test_script_for_all_models): @pytest.mark.slow @pytest.mark.skip -@pytest.mark.depends(on=["test_model_ms256"]) +@pytest.mark.depends_on(on=["test_model_ms256[{model_testing_config}]"]) def test_model_dp2s2_stp2_pp2s2_ms256(run_test_script_for_all_models): # TODO: Handle this case. # Sequence-3d-parallel diff --git a/tests/test_seq_first.py b/tests/test_seq_first.py index 3e8b7ea1..3df31bb9 100644 --- a/tests/test_seq_first.py +++ b/tests/test_seq_first.py @@ -4,11 +4,11 @@ # TODO: Compare grads with simple def test_model_sf(run_test_script_for_all_models): # Sequence-first baseline. - run_test_script_for_all_models("test_model_sf", ["model.base_model.sequence_first=True"]) + run_test_script_for_all_models("test_model_sf[{model_testing_config}]", ["model.base_model.sequence_first=True"]) @pytest.mark.slow -@pytest.mark.depends(on=["test_model_sf"]) +@pytest.mark.depends_on(on=["test_model_sf[{model_testing_config}]"]) def test_model_sp2(run_test_script_for_all_models): # Sequence-tensor-parallel. run_test_script_for_all_models( @@ -20,7 +20,7 @@ def test_model_sp2(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_sf"]) +@pytest.mark.depends_on(on=["test_model_sf[{model_testing_config}]"]) def test_model_sdp2(run_test_script_for_all_models): # Sequence-data-parallel run_test_script_for_all_models( @@ -32,7 +32,7 @@ def test_model_sdp2(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_sf"]) +@pytest.mark.depends_on(on=["test_model_sf[{model_testing_config}]"]) def test_model_sp2_ce4(run_test_script_for_all_models): # Sequence-tensor-parallel with cross-entropy splits. run_test_script_for_all_models( diff --git a/tests/test_simple.py b/tests/test_simple.py index bc48e26b..8026f012 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -13,7 +13,7 @@ def test_model_safe(run_test_script_for_all_models): ) -@pytest.mark.depends(on=["test_model_safe"]) +@pytest.mark.depends_on(on=["test_model_safe[{model_testing_config}]"]) def test_model(run_test_script_for_all_models): # A baseline config (single-gpu, bf16, flash-attn). # Also tests for multiple data loaders. @@ -21,7 +21,7 @@ def test_model(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) def test_model_dp2(run_test_script_for_all_models): # Simple data-parallel. run_test_script_for_all_models([], num_gpus=2, compare="test_model") @@ -52,7 +52,7 @@ def test_model_dp2_timeout(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) def test_model_tp2(run_test_script_for_all_models): # Simple tensor-parallel. run_test_script_for_all_models( @@ -62,7 +62,7 @@ def test_model_tp2(run_test_script_for_all_models): ) -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) def test_model_ce4(run_test_script_for_all_models): # Cross-entropy splits. run_test_script_for_all_models( @@ -72,7 +72,7 @@ def test_model_ce4(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) def test_model_dp2_z2(run_test_script_for_all_models): # Data-parallel with zero stage 2. run_test_script_for_all_models( @@ -83,7 +83,7 @@ def test_model_dp2_z2(run_test_script_for_all_models): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) def test_model_dp2_z3(run_test_script_for_all_models): # Data-parallel with zero stage 3. run_test_script_for_all_models( diff --git a/tests/utils/depends.py b/tests/utils/depends.py new file mode 100644 index 00000000..c1e2e250 --- /dev/null +++ b/tests/utils/depends.py @@ -0,0 +1,200 @@ +import re + +import colorama +import networkx +import pytest + +MARKER_NAME = "depends_on" +MARKER_KWARG_ID = "name" +MARKER_KWARG_DEPENDENCIES = "on" + +REGEX_PARAMETERS = re.compile(r"\[.+\]$") + + +def clean_nodeid(nodeid): + return nodeid.replace("::()::", "::").split("@dependency_group_")[0] + + +def get_names(item): + names = set() + + # Node id + nodeid = clean_nodeid(item.nodeid) + names.add(nodeid) + + # Node id without parameter + nodeid = REGEX_PARAMETERS.sub("", nodeid) + names.add(nodeid) + + # Node id scopes + while "::" in nodeid: + nodeid = nodeid.rsplit("::", 1)[0] + names.add(nodeid) + + # Custom name + for marker in item.iter_markers(): + if marker.name == MARKER_NAME and MARKER_KWARG_ID in marker.kwargs: + for name in as_list(marker.kwargs[MARKER_KWARG_ID]): + names.add(name) + + return names + + +def as_list(lst): + return [lst] if isinstance(lst, str) else lst + + +STEPS = ["setup", "call", "teardown"] +GOOD_OUTCOME = "passed" + + +class DependencyManager: + """Keep track of tests, their names and their dependencies.""" + + def __init__(self, items: list[pytest.Function]): + self._items = items + self._name_to_nodeids: dict[str, list[str]] = {} + self._nodeid_to_item: dict[str, pytest.Function] = {} + self._results: dict[str, dict[str, str]] = {} + self._dependencies: dict[str, set[str]] = {} + self._unresolved: dict[str, set[str]] = {} + + for item in self._items: + nodeid = clean_nodeid(item.nodeid) + # Add the mapping from nodeid to the test item + self._nodeid_to_item[nodeid] = item + # Add the mappings from all names to the node id + for name in get_names(item): + if name not in self._name_to_nodeids: + self._name_to_nodeids[name] = [] + self._name_to_nodeids[name].append(nodeid) + # Create the object that will contain the results of this test + self._results[nodeid] = {} + + for item in self._items: + # Process the dependencies of this test + # This uses the mappings created in the previous loop, and can thus not be merged into that loop + nodeid = clean_nodeid(item.nodeid) + self._dependencies[nodeid], self._unresolved[nodeid] = self._resolve_dependencies(item) + + self._items = self._sort_dependencies() + + @property + def items(self) -> list[pytest.Function]: + return self._items + + def register_result(self, item: pytest.Function, result: pytest.TestReport): + self._results[clean_nodeid(item.nodeid)][result.when] = result.outcome + + def handle_missing(self, item: pytest.Function): + nodeid = clean_nodeid(item.nodeid) + if missing := self._unresolved[nodeid]: + pytest.fail(f'{item.nodeid} depends on {", ".join(missing)}, which was not found', False) + + if failed := [ + dependency + for dependency in self._dependencies[nodeid] + if not all(self._results[dependency].get(step, None) == "passed" for step in ("setup", "call", "teardown")) + ]: + pytest.skip( + f'{item.nodeid} depends on {", ".join(failed)} ({self._dependencies[nodeid]} ;;;; { + [self._results[dependency] for dependency in self._dependencies[nodeid]]})' + ) + + def _resolve_dependencies(self, item: pytest.Function): + dependencies = set() + unresolved = set() + nodeid = clean_nodeid(item.nodeid) + + for marker in item.iter_markers(): + if marker.name == MARKER_NAME: + for dependency in as_list(marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])): + dependency = dependency.format(**item.callspec.params) + + # If the name is not known, try to make it absolute (ie file::[class::]method) + if dependency not in self._name_to_nodeids: + absolute_dependency = self._get_absolute_nodeid(dependency, nodeid) + if absolute_dependency in self._name_to_nodeids: + dependency = absolute_dependency + + # Add all items matching the name + if dependency in self._name_to_nodeids: + for nodeid in self._name_to_nodeids[dependency]: + dependencies.add(nodeid) + else: + unresolved.add(dependency) + + return dependencies, unresolved + + def _sort_dependencies(self): + # Build a directed graph for sorting + dag = networkx.DiGraph() + + for item in self.items: + nodeid = clean_nodeid(item.nodeid) + dag.add_node(nodeid) + for dependency in self._dependencies[nodeid]: + dag.add_edge(dependency, nodeid) + + for i, nodeids in enumerate(sorted(networkx.weakly_connected_components(dag), key=len, reverse=True)): + if len(nodeids) > 1: + for nodeid in nodeids: + self._nodeid_to_item[nodeid]._nodeid = ( + f"{self._nodeid_to_item[nodeid]._nodeid}@dependency_group_{i}" + ) + + return [self._nodeid_to_item[nodeid] for nodeid in networkx.topological_sort(dag)] + + @staticmethod + def _get_absolute_nodeid(nodeid: str, scope: str): + parts = nodeid.split("::") + # Completely relative (test_name), so add the full current scope (either file::class or file) + if len(parts) == 1: + base_nodeid = scope.rsplit("::", 1)[0] + nodeid = f"{base_nodeid}::{nodeid}" + # Contains some scope already (Class::test_name), so only add the current file scope + elif "." not in parts[0]: + base_nodeid = scope.split("::", 1)[0] + nodeid = f"{base_nodeid}::{nodeid}" + return clean_nodeid(nodeid) + + def print_name_map(self, verbose: bool = False): + """Print a human-readable version of the name -> test mapping.""" + print("Available dependency names:") + for name, nodeids in sorted(self._name_to_nodeids.items(), key=lambda x: x[0]): + if len(nodeids) == 1: + if name == nodeids[0]: + # This is just the base name, only print this when verbose + if verbose: + print(f" {name}") + else: + # Name refers to a single node id, so use the short format + print(f" {name} -> {nodeids[0]}") + else: + # Name refers to multiple node ids, so use the long format + print(f" {name} ->") + for nodeid in sorted(nodeids): + print(f" {nodeid}") + + def print_processed_dependencies(self, colors: bool = False): + """Print a human-readable list of the processed dependencies.""" + missing = "MISSING" + if colors: + missing = f"{colorama.Fore.RED}{missing}{colorama.Fore.RESET}" + colorama.init() + try: + print("Dependencies:") + + for nodeid in sorted(self._dependencies): + descriptions = [] + for dependency in self._dependencies[nodeid]: + descriptions.append(dependency) + for dependency in self._unresolved[nodeid]: + descriptions.append(f"{dependency} ({missing})") + if descriptions: + print(f" {nodeid} depends on") + for description in sorted(descriptions): + print(f" {description}") + finally: + if colors: + colorama.deinit() From 478ac05220d37363e8128ffec40fd17c7a3078fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Jun 2025 15:11:35 -0400 Subject: [PATCH 03/69] fixes --- tests/test_checkpoint.py | 5 ----- tests/test_mb.py | 3 +-- tests/test_seq_first.py | 4 +--- tests/utils/depends.py | 5 +---- tests/utils/model_configs.py | 2 +- 5 files changed, 4 insertions(+), 15 deletions(-) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 6e6d5806..eea3ab0e 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -74,7 +74,6 @@ def test_resume(run_test_script_for_all_models): def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. run_test_script_for_all_models( - "test_resume_frozen", [ "training.checkpoint.interval=1", "training.evaluations.validation.interval=2", @@ -456,7 +455,6 @@ def test_run_converted_model(model_testing_config, convert_paths): @pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( - "test_load_pretrained_distributed_in_dp2", [ "training.checkpoint.interval=1", "training.train_iters=1", @@ -471,7 +469,6 @@ def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, conv @pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( - "test_load_pretrained_distributed_with_config", [ "training.checkpoint.interval=1", "training.train_iters=1", @@ -558,7 +555,6 @@ def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, ru ) def test_load_pretrained_fast_llm_in_dp2(run_test_script, convert_paths, run_test_script_base_path): run_test_script( - "test_load_pretrained_fast_llm_in_dp2", [ "training.checkpoint.interval=1", "training.train_iters=1", @@ -598,7 +594,6 @@ def test_load_pretrained_huggingface_in_dp2( run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths ): run_test_script_for_all_models( - "test_load_pretrained_huggingface_in_dp2", [ "training.checkpoint.interval=1", "training.train_iters=1", diff --git a/tests/test_mb.py b/tests/test_mb.py index e1f79fc1..fb09dcec 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -6,7 +6,7 @@ # TODO: Compare grads with simple def test_model_df4(run_test_script_for_all_models): # Depth-first gradient accumulation baseline. - run_test_script_for_all_models("test_model_df4", ["batch.depth_first_micro_batches=4"]) + run_test_script_for_all_models(["batch.depth_first_micro_batches=4"]) @pytest.mark.slow @@ -14,7 +14,6 @@ def test_model_df4(run_test_script_for_all_models): def test_model_df4_z3(run_test_script_for_all_models): # Gradient accumulation with ZeRO-3. run_test_script_for_all_models( - "test_model_df4_z3", ["model.multi_stage.zero_stage=3", "batch.depth_first_micro_batches=4"], num_gpus=2, compare="test_model_df4", diff --git a/tests/test_seq_first.py b/tests/test_seq_first.py index 3df31bb9..6e1eb07a 100644 --- a/tests/test_seq_first.py +++ b/tests/test_seq_first.py @@ -4,7 +4,7 @@ # TODO: Compare grads with simple def test_model_sf(run_test_script_for_all_models): # Sequence-first baseline. - run_test_script_for_all_models("test_model_sf[{model_testing_config}]", ["model.base_model.sequence_first=True"]) + run_test_script_for_all_models(["model.base_model.sequence_first=True"]) @pytest.mark.slow @@ -12,7 +12,6 @@ def test_model_sf(run_test_script_for_all_models): def test_model_sp2(run_test_script_for_all_models): # Sequence-tensor-parallel. run_test_script_for_all_models( - "test_model_sp2", ["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True"], num_gpus=2, compare="test_model_sf", @@ -24,7 +23,6 @@ def test_model_sp2(run_test_script_for_all_models): def test_model_sdp2(run_test_script_for_all_models): # Sequence-data-parallel run_test_script_for_all_models( - "test_model_sdp2", ["model.distributed.sequence_data_parallel=2"], num_gpus=2, compare="test_model_sf", diff --git a/tests/utils/depends.py b/tests/utils/depends.py index c1e2e250..8ddb5041 100644 --- a/tests/utils/depends.py +++ b/tests/utils/depends.py @@ -96,10 +96,7 @@ def handle_missing(self, item: pytest.Function): for dependency in self._dependencies[nodeid] if not all(self._results[dependency].get(step, None) == "passed" for step in ("setup", "call", "teardown")) ]: - pytest.skip( - f'{item.nodeid} depends on {", ".join(failed)} ({self._dependencies[nodeid]} ;;;; { - [self._results[dependency] for dependency in self._dependencies[nodeid]]})' - ) + pytest.skip(f'{item.nodeid} depends on failed {", ".join(failed)}') def _resolve_dependencies(self, item: pytest.Function): dependencies = set() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 963f6ae9..d0c0d070 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -84,7 +84,7 @@ def _update_and_add_testing_config( "training.logs.interval=1", "run.tensor_logs.save=True", "run.tensor_logs.show=False", - "model.base_model.max_position_embeddings=512", + # "model.base_model.max_position_embeddings=512", "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", From d3b18a13ccd6be6c9f2d2a1b36d4deeaeebd3fc2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 9 Jun 2025 13:30:19 -0400 Subject: [PATCH 04/69] stuff --- fast_llm/layers/transformer/config.py | 5 +- tests/conftest.py | 81 +++++++++++++- tests/test_checkpoint.py | 34 ++++-- tests/test_config.py | 29 ----- tests/test_functional.py | 12 +- tests/test_gpt_generate_and_forward.py | 10 +- tests/test_match_megatron.py | 10 +- tests/test_mb.py | 12 +- tests/test_mb_seq_first.py | 6 +- tests/test_ms.py | 7 +- tests/test_multi_stage.py | 4 + tests/test_seq_first.py | 10 +- tests/test_simple.py | 16 ++- tests/utils/depends.py | 4 + tests/utils/model_configs.py | 149 ++++++++++++++++++++++++- tests/utils/run_test_script.py | 2 +- 16 files changed, 314 insertions(+), 77 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 235aa366..c0ed1472 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -711,7 +711,4 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in ( - DataType.float16, - DataType.bfloat16, - ) + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) diff --git a/tests/conftest.py b/tests/conftest.py index 4cf6158d..829e1696 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ import dataclasses +import datetime import math import os import pytest import torch -from xdist.scheduler import LoadGroupScheduling +import xdist.scheduler from tests.utils.depends import DependencyManager @@ -15,7 +16,7 @@ run_test_script_for_all_models, ) -from tests.utils.model_configs import model_testing_config # isort: skip +from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip from tests.utils.utils import result_path # isort: skip @@ -25,6 +26,8 @@ def pytest_addoption(parser): group = parser.getgroup("fast_llm") group.addoption("--skip-slow", action="store_true") + group.addoption("--show-skipped", action="store_true") + group.addoption("--models", nargs="*") group.addoption( "--run-extra-slow", action="store_true", @@ -59,6 +62,7 @@ def pytest_configure(config): "markers", "extra_slow: Mark test as extra slow and skip unless --run-extra-slow is given." ) config.addinivalue_line("markers", "depends_on(name='name', on=['other_name']): marks dependencies between tests.") + config.addinivalue_line("markers", "model_testing_group(group='group'): marks model testing group.") # TODO: Spawned processes (multi-gpu, Megatron) ignore resource allocation. is_parallel = hasattr(config, "workerinput") if is_parallel: @@ -107,8 +111,11 @@ def pytest_configure(config): @pytest.hookimpl(trylast=True) -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config, items: list[pytest.Function]): global manager + skip_slow = config.getoption("--skip-slow") + skip_extra_slow = not config.getoption("--run-extra-slow") + show_skipped = config.getoption("--show-skipped") if config.getoption("--skip-slow"): skip_slow = pytest.mark.skip(reason="Skipping slow tests") @@ -121,7 +128,23 @@ def pytest_collection_modifyitems(config, items): if "extra_slow" in item.keywords: item.add_marker(skip_extra_slow) - manager = DependencyManager(items) + new_items = [] + for item in items: + if skip_slow and "slow" in item.keywords: + if show_skipped: + item.add_marker(pytest.mark.skip(reason="Skipping slow tests")) + else: + continue + elif skip_extra_slow and "extra_slow" in item.keywords: + if show_skipped: + item.add_marker(pytest.mark.skip(reason="Skipping extra-slow tests")) + else: + continue + elif not testing_group_enabled(item, skip_slow, skip_extra_slow, show_skipped): + continue + new_items.append(item) + + manager = DependencyManager(new_items) # Show the extra information if requested if config.getoption("show_dependencies"): @@ -166,4 +189,52 @@ def worker_resources(request) -> WorkerResources: def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" - return LoadGroupScheduling(config, log) + return xdist.scheduler.LoadGroupScheduling(config, log) + + +def get_all_reports(terminalreporter): + """Reports for all stages and all outcomes""" + for reports in terminalreporter.stats.values(): + for report in reports: + if isinstance(report, pytest.TestReport): + yield report + + +def resource_usage_message(report): + """The resource usage message for a report""" + return ", ".join(content for (prefix, content) in report.get_sections(f"Captured resource {report.when}")) + + +def format_duration(seconds): + """Human-readable running time message""" + if seconds < 60: + duration_string = f"{seconds:.3f} seconds" + else: + duration_string = str(datetime.timedelta(seconds=round(seconds))) + return f"running time: {duration_string}" + + +# @pytest.hookimpl(tryfirst=True) +# def pytest_runtest_makereport(item, call): +# """Report running time of a test call""" +# if call.when == "call": +# item.add_report_section( +# call.when, "resource", format_duration(call.duration) +# ) +# +# +# @pytest.hookimpl +# def pytest_terminal_summary(terminalreporter): +# """Produce a resource usage report if any test asked for it""" +# resource_reports = [ +# (report, message) +# for report in get_all_reports(terminalreporter) +# if (message := resource_usage_message(report)) +# ] +# if not resource_reports: +# return +# terminalreporter.write_sep("=", "resource usage", bold=True) +# for report, message in resource_reports: +# terminalreporter.write_line( +# f"{report.nodeid} ({report.when}) {message}" +# ) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index eea3ab0e..06f69a96 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,12 +17,12 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.tools.convert import ConvertConfig from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor -from tests.utils.utils import requires_cuda +from tests.utils.model_configs import ModelTestingGroup _WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" -@requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). run_test_script_for_all_models( @@ -56,6 +56,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_resume(run_test_script_for_all_models): # Resume from iteration=1 and compare outputs with the baseline run. run_test_script_for_all_models( @@ -71,6 +72,7 @@ def test_resume(run_test_script_for_all_models): @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. run_test_script_for_all_models( @@ -107,6 +109,7 @@ def convert_paths(run_test_script_base_path): @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -124,6 +127,7 @@ def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): @pytest.mark.depends_on(on=["test_convert_distributed_to_fast_llm[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): if model_testing_config.checkpoint_format is None: pytest.skip(f"Conversion not supported for {model_testing_config.name}") @@ -143,6 +147,7 @@ def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): @pytest.mark.depends_on(on=["test_convert_fast_llm_to_huggingface[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_huggingface_to_distributed(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -160,6 +165,7 @@ def test_convert_huggingface_to_distributed(model_testing_config, convert_paths) @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_distributed_to_huggingface(model_testing_config, convert_paths): if model_testing_config.checkpoint_format is None: pytest.skip(f"Conversion not supported for {model_testing_config.name}") @@ -179,6 +185,7 @@ def test_convert_distributed_to_huggingface(model_testing_config, convert_paths) @pytest.mark.depends_on(on=["test_convert_distributed_to_huggingface[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -196,6 +203,7 @@ def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): @pytest.mark.depends_on(on=["test_convert_huggingface_to_fast_llm[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): _run_conversion( ConvertConfig( @@ -218,6 +226,7 @@ def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): "test_convert_fast_llm_to_distributed[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_distributed(convert_paths): # Compare the fast llm weights # TODO: Compare configs @@ -239,6 +248,7 @@ def test_converted_distributed(convert_paths): "test_convert_huggingface_to_fast_llm[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_fast_llm(convert_paths): s0 = safetensors.torch.load_file(convert_paths["fast_llm_0"] / "model_0.safetensors") s1 = safetensors.torch.load_file(convert_paths["fast_llm_1"] / "model_0.safetensors") @@ -254,6 +264,7 @@ def test_converted_fast_llm(convert_paths): "test_convert_distributed_to_huggingface[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_huggingface(convert_paths): h0 = safetensors.torch.load_file(convert_paths["huggingface_0"] / "model_0.safetensors") h1 = safetensors.torch.load_file(convert_paths["huggingface_1"] / "model_0.safetensors") @@ -272,6 +283,7 @@ def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMM @pytest.mark.depends_on(on=["test_converted_distributed[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_paths): config = model_testing_config.model_config_class.from_dict( yaml.safe_load((convert_paths["checkpoint"] / ".." / ".." / "config.yaml").open("r"))["model"], strict=False @@ -292,6 +304,7 @@ def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_pa @pytest.mark.depends_on(on=["test_load_pretrained_distributed_checkpoint[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_converted_distributed_checkpoint(model_testing_config, convert_paths): config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( @@ -329,6 +342,7 @@ def test_load_converted_distributed_checkpoint(model_testing_config, convert_pat "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths): config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( @@ -365,6 +379,7 @@ def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths) "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_converted_huggingface_checkpoint(model_testing_config, convert_paths): config_ref = model_testing_config.model_config_class.from_pretrained( CheckpointLoadConfig( @@ -402,6 +417,7 @@ def test_load_converted_huggingface_checkpoint(model_testing_config, convert_pat "test_load_converted_huggingface_checkpoint[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_run_converted_model(model_testing_config, convert_paths): model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( @@ -451,8 +467,8 @@ def test_run_converted_model(model_testing_config, convert_paths): raise ValueError(f"Comparison failed ({len(errors)} errors)") -@pytest.mark.slow @pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( [ @@ -467,6 +483,7 @@ def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, conv @pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( [ @@ -480,6 +497,7 @@ def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, @pytest.mark.depends_on(on=["test_load_pretrained_distributed_in_dp2[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_paths, run_test_script_base_path): test_ckpt_path = run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" pretrained_config_ref = CheckpointLoadConfig( @@ -523,8 +541,8 @@ def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_p assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa -@pytest.mark.slow @pytest.mark.depends_on(on=["test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, run_test_script_base_path): # This also tests conversion which uses `FastLLMModel.from_checkpoint` pretrained_config_ref = CheckpointLoadConfig( @@ -546,15 +564,15 @@ def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, ru assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.slow @pytest.mark.depends_on( on=[ "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", ] ) -def test_load_pretrained_fast_llm_in_dp2(run_test_script, convert_paths, run_test_script_base_path): - run_test_script( +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_load_pretrained_fast_llm_in_dp2(run_test_script_for_all_models, convert_paths, run_test_script_base_path): + run_test_script_for_all_models( [ "training.checkpoint.interval=1", "training.train_iters=1", @@ -583,13 +601,13 @@ def test_load_pretrained_fast_llm_in_dp2(run_test_script, convert_paths, run_tes assert (ref_shard[name] == test_shard[name]).all() -@pytest.mark.slow @pytest.mark.depends_on( on=[ "test_load_converted_huggingface_checkpoint[{model_testing_config}]", "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", ] ) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_pretrained_huggingface_in_dp2( run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths ): diff --git a/tests/test_config.py b/tests/test_config.py index 98a4c07c..ed5d9b8a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,5 @@ import pathlib import subprocess -import unittest.mock import pytest import yaml @@ -8,9 +7,7 @@ from fast_llm.config import NoAutoValidate from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.auto import trainer_registry from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested @@ -63,32 +60,6 @@ def test_validate_example_config(): trainer_registry["gpt"].from_dict(fast_llm_config_dict) -def test_do_use_flash_attention(): - # Create a mock DistributedConfig - mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) - - # Test case 1: use_flash_attention is True and training_dtype is float16 - config = TransformerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is True - - # Test case 2: use_flash_attention is False - config = TransformerConfig(use_flash_attention=False, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 - config = TransformerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float32 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 4: use_flash_attention is False and window_size is not None - config = TransformerConfig(use_flash_attention=False, window_size=512) - mock_distributed_config.training_dtype = DataType.float32 - with pytest.raises(AssertionError): - config.do_use_flash_attention(mock_distributed_config) - - @pytest.mark.parametrize( ("cls", "default"), ((GPTSamplingConfig, {}), (GPTModelConfig, {"distributed": {"world_size": 1, "rank": 0, "local_world_size": 1}})), diff --git a/tests/test_functional.py b/tests/test_functional.py index 03a0ae8a..9c01f084 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -57,9 +57,15 @@ def ref_packed_get_batch_logps( @pytest.mark.slow -@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("seq_length", [1024, 4096, 8192]) -@pytest.mark.parametrize("vocab_size", [1000, 2000, 8000]) +@pytest.mark.parametrize( + ("batch_size", "seq_length", "vocab_size"), + ( + (2, 32, 50), + (1, 32, 50), + (2, 100, 50), + (2, 32, 200), + ), +) def test_preference_logps(batch_size, seq_length, vocab_size): random.seed(0) torch.manual_seed(0) diff --git a/tests/test_gpt_generate_and_forward.py b/tests/test_gpt_generate_and_forward.py index 4c920afd..7f0b902f 100644 --- a/tests/test_gpt_generate_and_forward.py +++ b/tests/test_gpt_generate_and_forward.py @@ -9,6 +9,7 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -44,7 +45,7 @@ def _prepare_rand_data(vocab_size, use_batch_size2: bool): def _get_hf_model(model_path: str, use_flash_attention: bool, use_bf16: bool): - hf_kwargs = {} + hf_kwargs = {"trust_remote_code": True} if use_flash_attention: hf_kwargs["attn_implementation"] = "flash_attention_2" hf_kwargs["torch_dtype"] = torch.bfloat16 @@ -237,9 +238,11 @@ def test_generate( @pytest.mark.slow -@requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_export_for_generate(run_test_script_for_all_models, model_testing_config): # Not really testing, anything, but handles dependencies more easily than a fixture. + if model_testing_config.checkpoint_format is None: + pytest.skip(f"Conversion not supported for {model_testing_config.name}") run_test_script_for_all_models( [ "training.train_iters=1", @@ -263,6 +266,7 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi (True, True, 10, 10, 10), ], ) +@pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_small_generate( model_testing_config, run_test_script_base_path, @@ -315,6 +319,7 @@ def test_generate_from_model( @requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_small_generate_from_model(model_testing_config, run_test_script_base_path): _test_generate_from_model( run_test_script_base_path / f"test_export_for_generate/export/{model_testing_config.checkpoint_format.name}/1", @@ -363,6 +368,7 @@ def test_forward_return_hidden_states(model_path): @pytest.mark.slow @requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.generate) @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) def test_small_forward_return_hidden_states(model_testing_config, run_test_script_base_path): _test_forward_return_hidden_states( diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 5c0bbdaa..9b3b591b 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -2,16 +2,19 @@ from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import DATASET_PREFIX +from tests.utils.model_configs import ModelTestingGroup -@pytest.mark.slow +@pytest.mark.model_testing_group(ModelTestingGroup.megatron) def test_megatron(run_test_script_for_all_models, model_testing_config): - run_test_script_for_all_models(is_megatron=True) + run_test_script_for_all_models([], is_megatron=True) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_megatron[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.megatron) def test_match_megatron(run_test_script_for_all_models, model_testing_config): + if model_testing_config.megatron_args is None: + pytest.skip(f"Megatron does not support model {model_testing_config.name}") run_test_script_for_all_models( [ "model.distributed.training_dtype=fp32", @@ -28,5 +31,4 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config): ".mlp.layer_2.weight", ] ), - use_performance_args=False, ) diff --git a/tests/test_mb.py b/tests/test_mb.py index fb09dcec..806ccebc 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -1,16 +1,18 @@ import pytest from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.model_configs import ModelTestingGroup # TODO: Compare grads with simple +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_df4(run_test_script_for_all_models): # Depth-first gradient accumulation baseline. run_test_script_for_all_models(["batch.depth_first_micro_batches=4"]) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_df4[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_df4_z3(run_test_script_for_all_models): # Gradient accumulation with ZeRO-3. run_test_script_for_all_models( @@ -22,12 +24,14 @@ def test_model_df4_z3(run_test_script_for_all_models): @pytest.mark.depends_on(on=["test_model_df4[{model_testing_config}]"], scope="session") +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_bf4(run_test_script_for_all_models): # Breadth-first gradient accumulation baseline. run_test_script_for_all_models(["batch.breadth_first_micro_batches=4"], compare="test_model_df4") @pytest.mark.depends_on(on=["test_model_df4[{model_testing_config}]", "test_model_bf4[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_bf2_df2(run_test_script_for_all_models): # Mixed gradient accumulation baseline. run_test_script_for_all_models( @@ -35,8 +39,8 @@ def test_model_bf2_df2(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_bf4[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_pp2s2_bf4(run_test_script_for_all_models): # Pipeline-parallel without tied weights. run_test_script_for_all_models( @@ -50,8 +54,8 @@ def test_model_pp2s2_bf4(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_bf4[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_pp2s1_bf4(run_test_script_for_all_models): # Pipeline-parallel with tied weights. run_test_script_for_all_models( @@ -66,8 +70,8 @@ def test_model_pp2s1_bf4(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_bf4[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2_tp2_pp2s2_bf4(run_test_script_for_all_models): # Simple 3d parallelism # TODO: Test fails diff --git a/tests/test_mb_seq_first.py b/tests/test_mb_seq_first.py index 7d3cf5ad..5a8db0b9 100644 --- a/tests/test_mb_seq_first.py +++ b/tests/test_mb_seq_first.py @@ -1,16 +1,18 @@ import pytest from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.model_configs import ModelTestingGroup # TODO: Compare grads with simple +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_df4_sf(run_test_script_for_all_models): # Sequence-first gradient accumulation baseline. run_test_script_for_all_models(["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"]) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_df4_sf[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2_sp2_df4(run_test_script_for_all_models): # Sequence-tensor-parallel with gradient accumulation. # TODO: Compiled cross-entropy broken for this config @@ -27,9 +29,9 @@ def test_model_dp2_sp2_df4(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.skip(reason="Test is broken.") @pytest.mark.depends_on(on=["test_model_df4_sf[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2_sp2_pp2s1(run_test_script_for_all_models): # 3d-parallel with sequence-tensor-parallel. # TODO: Compiled cross-entropy broken for this config diff --git a/tests/test_ms.py b/tests/test_ms.py index 23ef60e6..b97f84e5 100644 --- a/tests/test_ms.py +++ b/tests/test_ms.py @@ -1,14 +1,17 @@ import pytest +from tests.utils.model_configs import ModelTestingGroup + # TODO: Compare grads with simple +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_ms256(run_test_script_for_all_models): # Micro-sequence baseline run_test_script_for_all_models(["batch.micro_sequence_length=256"]) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_ms256[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_pp2s2_ms256(run_test_script_for_all_models): # Sequence-pipeline-parallel run_test_script_for_all_models( @@ -22,9 +25,9 @@ def test_model_pp2s2_ms256(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.skip @pytest.mark.depends_on(on=["test_model_ms256[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2s2_stp2_pp2s2_ms256(run_test_script_for_all_models): # TODO: Handle this case. # Sequence-3d-parallel diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 6d3861eb..06eca685 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -1,9 +1,12 @@ +import pytest + from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert +from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -17,6 +20,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args)._multi_stage diff --git a/tests/test_seq_first.py b/tests/test_seq_first.py index 6e1eb07a..66b044df 100644 --- a/tests/test_seq_first.py +++ b/tests/test_seq_first.py @@ -1,14 +1,17 @@ import pytest +from tests.utils.model_configs import ModelTestingGroup + # TODO: Compare grads with simple +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_sf(run_test_script_for_all_models): # Sequence-first baseline. run_test_script_for_all_models(["model.base_model.sequence_first=True"]) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_sf[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_sp2(run_test_script_for_all_models): # Sequence-tensor-parallel. run_test_script_for_all_models( @@ -18,8 +21,8 @@ def test_model_sp2(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_sf[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_sdp2(run_test_script_for_all_models): # Sequence-data-parallel run_test_script_for_all_models( @@ -29,12 +32,11 @@ def test_model_sdp2(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model_sf[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_sp2_ce4(run_test_script_for_all_models): # Sequence-tensor-parallel with cross-entropy splits. run_test_script_for_all_models( - "test_model_sp2_ce4", [ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", diff --git a/tests/test_simple.py b/tests/test_simple.py index 8026f012..4616942c 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,6 +1,9 @@ import pytest +from tests.utils.model_configs import ModelTestingGroup + +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_safe(run_test_script_for_all_models): # The safest possible config, identical to the one in test_match_megatron except for the initialization. run_test_script_for_all_models( @@ -14,20 +17,22 @@ def test_model_safe(run_test_script_for_all_models): @pytest.mark.depends_on(on=["test_model_safe[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model(run_test_script_for_all_models): # A baseline config (single-gpu, bf16, flash-attn). # Also tests for multiple data loaders. run_test_script_for_all_models(["training.num_workers=2"], compare="test_model_safe") -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2(run_test_script_for_all_models): # Simple data-parallel. run_test_script_for_all_models([], num_gpus=2, compare="test_model") -@pytest.mark.slow +@pytest.mark.skip(reason="Flaky") +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2_timeout(run_test_script_for_all_models): # Test sampling timeout # TODO: Find a better way to test this @@ -51,8 +56,8 @@ def test_model_dp2_timeout(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_tp2(run_test_script_for_all_models): # Simple tensor-parallel. run_test_script_for_all_models( @@ -63,6 +68,7 @@ def test_model_tp2(run_test_script_for_all_models): @pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_ce4(run_test_script_for_all_models): # Cross-entropy splits. run_test_script_for_all_models( @@ -71,8 +77,8 @@ def test_model_ce4(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2_z2(run_test_script_for_all_models): # Data-parallel with zero stage 2. run_test_script_for_all_models( @@ -82,8 +88,8 @@ def test_model_dp2_z2(run_test_script_for_all_models): ) -@pytest.mark.slow @pytest.mark.depends_on(on=["test_model[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.distributed) def test_model_dp2_z3(run_test_script_for_all_models): # Data-parallel with zero stage 3. run_test_script_for_all_models( diff --git a/tests/utils/depends.py b/tests/utils/depends.py index 8ddb5041..5e6bcc71 100644 --- a/tests/utils/depends.py +++ b/tests/utils/depends.py @@ -101,6 +101,10 @@ def handle_missing(self, item: pytest.Function): def _resolve_dependencies(self, item: pytest.Function): dependencies = set() unresolved = set() + + if "skip" in item.keywords: + return dependencies, unresolved + nodeid = clean_nodeid(item.nodeid) for marker in item.iter_markers(): diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index d0c0d070..65a063b5 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -1,4 +1,5 @@ import dataclasses +import enum import functools import os import typing @@ -21,6 +22,17 @@ _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) +class ModelTestingGroup(enum.StrEnum): + basic = "basic" + megatron = "megatron" + distributed = "distributed" + convert = "convert" + generate = "generate" + + +SLOW_TESTING_GROUPS = {ModelTestingGroup.megatron, ModelTestingGroup.distributed} + + @dataclasses.dataclass(kw_only=True, frozen=True) class ModelTestingConfig: name: str = None @@ -28,6 +40,11 @@ class ModelTestingConfig: config_args: list[str] megatron_args: list[str] | None checkpoint_format: CheckpointFormat | None + # The important groups we want to test. + testing_groups: list[ModelTestingGroup] + # Other supported groups, excluded by default because they are mostly unimportant and/or redundant. + # They can be run with `--run-extra-slow`. + other_groups: list[ModelTestingGroup] @functools.cached_property def model_config_class(self): @@ -54,9 +71,15 @@ def _update_and_add_testing_config( extra_args: list[str] | None = None, megatron_args: list[str] | None = ..., checkpoint_format: CheckpointFormat | None = ..., + testing_groups: list[ModelTestingGroup], + other_groups: list[ModelTestingGroup], ): config = _MODEL_CONFIGS[old_name] - updates: dict[str, typing.Any] = {"name": new_name} + updates: dict[str, typing.Any] = { + "name": new_name, + "testing_groups": testing_groups, + "other_groups": other_groups, + } if model_type is not None: updates["model_type"] = model_type if extra_args is not None: @@ -78,6 +101,7 @@ def _update_and_add_testing_config( _MODEL_CONFIGS["gpt2"] = ModelTestingConfig( + # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). name="gpt2", model_type="gpt", config_args=[ @@ -97,7 +121,7 @@ def _update_and_add_testing_config( f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", "model.multi_stage.debug_tensor_parallel=True", "model.distributed.reproducible_init=True", - "model.distributed.timeout=10", + "model.distributed.timeout=20", "model.distributed.training_dtype=bf16", "training.train_iters=2", "training.num_workers=0", @@ -153,17 +177,32 @@ def _update_and_add_testing_config( "--transformer-impl=transformer_engine", ], checkpoint_format=None, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.megatron, + ModelTestingGroup.distributed, + ], + other_groups=[], ) _update_and_add_testing_config( + # Tests MQA. "gpt2", "starcoder", extra_args=["model.base_model.transformer.head_groups=1"], megatron_args=["--group-query-attention"], checkpoint_format=None, + testing_groups=[ + ModelTestingGroup.basic, + ], + other_groups=[ + ModelTestingGroup.megatron, + ModelTestingGroup.distributed, + ], ) _update_and_add_testing_config( + # Tests intermediate between gpt2 and llama, closest converter to gpt2. "gpt2", "starcoder2", extra_args=[ @@ -177,9 +216,19 @@ def _update_and_add_testing_config( "--no-position-embedding", ], checkpoint_format=Starcoder2GPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.convert, + ], + other_groups=[ + ModelTestingGroup.megatron, + ModelTestingGroup.distributed, + ModelTestingGroup.generate, + ], ) _update_and_add_testing_config( + # Main tested model. "starcoder2", "llama", extra_args=[ @@ -198,55 +247,108 @@ def _update_and_add_testing_config( "--untie-embeddings-and-output-weights", ], checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.megatron, + ModelTestingGroup.distributed, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], + other_groups=[], ) _update_and_add_testing_config( + # Tests llama3-style rotary embeddings. "llama", "llama3", extra_args=["model.base_model.transformer.rotary.type=llama3"], # Megatron doesn't support Llama3-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ], + other_groups=[ + ModelTestingGroup.distributed, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], ) _update_and_add_testing_config( + # Tests yarn-style rotary embeddings. "llama", "llama_yarn", extra_args=["model.base_model.transformer.rotary.type=yarn"], # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ], + other_groups=[ + ModelTestingGroup.distributed, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], ) _update_and_add_testing_config( + # Tests multi-token prediction, custom HF model and converter. "llama", "llama_mtp", extra_args=["model.base_model.prediction_heads=4"], # Megatron doesn't support multi-token prediction. megatron_args=None, checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], + other_groups=[ + ModelTestingGroup.distributed, + ], ) _update_and_add_testing_config( + # Tests partial linear biases, Qwen2 converter. "llama", "qwen2", extra_args=["model.base_model.transformer.add_linear_biases=only_attn_qkv"], # Megatron doesn't support per sub layer biases megatron_args=None, checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.convert, + ], + other_groups=[ + ModelTestingGroup.distributed, + ModelTestingGroup.generate, + ], ) _update_and_add_testing_config( + # Tests sliding window attention, mistral converter. "llama", "mistral", extra_args=["model.base_model.transformer.window_size=128"], # Megatron doesn't support sliding windows. megatron_args=None, checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], + other_groups=[ + ModelTestingGroup.distributed, + ], ) _update_and_add_testing_config( - # We ignore sliding windows to enable comparison with Megatron. + # Tests mixture of experts, mixtral converter. "llama", "mixtral", extra_args=[ @@ -258,19 +360,58 @@ def _update_and_add_testing_config( "--moe-router-topk=4", ], checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.megatron, + ModelTestingGroup.distributed, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], + other_groups=[], ) _update_and_add_testing_config( - # We ignore sliding windows to enable comparison with Megatron. + # Tests hybrid ssm, llamba converter. + # TODO: Conversion fails. "llama", "llamba", model_type="hybrid_ssm", extra_args=["model.base_model.hybrid_block_layout=['t','m']"], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, + testing_groups=[ + ModelTestingGroup.basic, + ModelTestingGroup.distributed, + ModelTestingGroup.convert, + ModelTestingGroup.generate, + ], + other_groups=[], ) @pytest.fixture(scope="session", params=_MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: return _MODEL_CONFIGS[request.param] + + +def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slow: bool, show_skipped: bool) -> bool: + if "model_testing_group" in item.keywords: + assert "model_testing_config" in item.callspec.params, item.nodeid + groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args + model_testing_config = item.callspec.params["model_testing_config"] + model_config = _MODEL_CONFIGS[model_testing_config] + for group in groups: + if group in model_config.testing_groups and not (skip_slow and group in SLOW_TESTING_GROUPS): + pass + elif group in model_config.other_groups and not skip_extra_slow: + pass + elif show_skipped: + item.add_marker( + pytest.mark.skip(reason=f"Skipping testing group {group} for model {model_testing_config}.") + ) + else: + return False + elif hasattr(item, "callspec"): + assert "model_testing_config" not in item.callspec.params, item.nodeid + + return True diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index c11d3f3b..26666df8 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -69,7 +69,7 @@ def do_run_test_script( if num_gpus == 1 and not is_megatron: CliTrainingConfig.parse_and_run(args) else: - completed_proc = subprocess.run(command, env=env, timeout=60) + completed_proc = subprocess.run(command, env=env, timeout=120) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") if compare_path is not None and do_compare: From 8c64f03e3ab657c1a857cca4743c5f6962674184 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 9 Jun 2025 14:17:29 -0400 Subject: [PATCH 05/69] fix --- tests/test_match_megatron.py | 20 ++++++++++++-------- tests/utils/model_configs.py | 4 +++- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 9b3b591b..4f82d575 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -15,6 +15,17 @@ def test_megatron(run_test_script_for_all_models, model_testing_config): def test_match_megatron(run_test_script_for_all_models, model_testing_config): if model_testing_config.megatron_args is None: pytest.skip(f"Megatron does not support model {model_testing_config.name}") + + ignore_tensors = [ + ".self_attn.query_key_value.", + ".self_attn.query.", + ".self_attn.key_value.", + ".mlp.layer_2.weight", + ".mlp.experts.", + ] + if model_testing_config.name == "mixtral": + ignore_tensors.extend([".mlp.experts.", ".mlp.layer_1.weight"]) + run_test_script_for_all_models( [ "model.distributed.training_dtype=fp32", @@ -23,12 +34,5 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config): "model.base_model.use_megatron_initialization=True", ], compare="test_megatron", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".mlp.layer_2.weight", - ] - ), + config=CompareConfig(ignore_tensors=ignore_tensors), ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 65a063b5..a444307e 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -108,7 +108,7 @@ def _update_and_add_testing_config( "training.logs.interval=1", "run.tensor_logs.save=True", "run.tensor_logs.show=False", - # "model.base_model.max_position_embeddings=512", + "model.base_model.max_position_embeddings=512", "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", @@ -208,6 +208,8 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", + # Unused, but prevents issues with conversion tests. + "model.base_model.max_position_embeddings=2048", ], megatron_args=[ "--group-query-attention", From c0f648cdbb97b902e4c9fc96636856ea17ea41c1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 10 Jun 2025 12:48:02 -0400 Subject: [PATCH 06/69] fixes --- fast_llm/layers/transformer/transformer.py | 2 +- tests/test_mb.py | 7 +++- tests/test_multi_stage.py | 37 ++++++++++++++-------- tests/utils/model_configs.py | 24 ++++++++++---- 4 files changed, 48 insertions(+), 22 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 40dd2e00..115629d6 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -20,7 +20,7 @@ class BaseBlock(Layer, abc.ABC): """ - A transformer-like decoder base block block with abstract mixer. + A transformer-like decoder base block with abstract mixer. """ _mixer_module_name = "self_attn" diff --git a/tests/test_mb.py b/tests/test_mb.py index 806ccebc..781de6e8 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -66,7 +66,12 @@ def test_model_pp2s1_bf4(run_test_script_for_all_models): ], num_gpus=2, compare="test_model_df4", - config=CompareConfig(ignore_duplicates=["layers.0.word_embeddings_weight"]), + config=CompareConfig( + ignore_duplicates=[ + "layers.0.word_embeddings_weight", + "layers.0.position_embeddings_weight", + ] + ), ) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 06eca685..8753cf48 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,6 +3,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert @@ -23,31 +24,39 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): args = model_testing_config.config_args + ["run.tensor_logs.save=False"] - model_ref = _get_trainer_from_args(args)._multi_stage - model_frozen = _get_trainer_from_args(args + ["model.base_model.transformer.mlp_lr_scale=[0]"])._multi_stage + model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage + model_frozen = _get_trainer_from_args( + args + + [ + f"model.base_model.transformer.mlp_lr_scale={[0]*model_ref.config.base_model.transformer.num_experts}", + f"model.base_model.transformer.router_lr_scale=0", + ], + model_testing_config.model_type, + )._multi_stage Assert.eq( model_ref._num_stages, model_frozen._num_stages, ) - diff_by_layer = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, TransformerLayer) else 0 + frozen_parameter_counts = [ + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 for layer in model_ref.base_model.layers ] - assert all((diff_by_layer[i] == 0) == (i in (0, len(diff_by_layer) - 1)) for i in range(len(diff_by_layer))) - total_diff = sum(diff_by_layer) - for weight_buffer_ref, weight_buffer_frozen in zip( model_ref._weight_buffers, model_frozen._weight_buffers, strict=True ): - assert weight_buffer_ref.numel() == weight_buffer_frozen.numel() + Assert.eq(weight_buffer_ref.numel() == weight_buffer_frozen.numel()) - for grad_buffer_ref, grad_buffer_frozen, diff in zip( - model_ref._grad_buffers, model_frozen._grad_buffers, diff_by_layer, strict=True + for grad_buffer_ref, grad_buffer_frozen, frozen_parameter_count in zip( + model_ref._grad_buffers, model_frozen._grad_buffers, frozen_parameter_counts, strict=True ): - Assert.eq(grad_buffer_ref.numel() - grad_buffer_frozen.numel() == diff) + Assert.eq(grad_buffer_ref.numel() - grad_buffer_frozen.numel() == frozen_parameter_count) - for shard_name, shard_diff in zip( - model_ref._shard_names, [0] + [total_diff] * (len(model_ref._all_shard_names) - 1), strict=True + for shard_name, shard_frozen_count in zip( + model_ref._shard_names, + [0] + [sum(frozen_parameter_counts)] * (len(model_ref._all_shard_names) - 1), + strict=True, ): - Assert.eq(model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), shard_diff) + Assert.eq( + model_ref.get_shard(shard_name).numel() - model_frozen.get_shard(shard_name).numel(), shard_frozen_count + ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index a444307e..3f989f58 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -222,6 +222,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic, ModelTestingGroup.convert, ], + # TODO: Bring back `generate` to `testing_groups` when stable. other_groups=[ ModelTestingGroup.megatron, ModelTestingGroup.distributed, @@ -254,9 +255,11 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron, ModelTestingGroup.distributed, ModelTestingGroup.convert, + ], + # TODO: Bring back `generate` to `testing_groups` when stable. + other_groups=[ ModelTestingGroup.generate, ], - other_groups=[], ) _update_and_add_testing_config( @@ -270,6 +273,7 @@ def _update_and_add_testing_config( testing_groups=[ ModelTestingGroup.basic, ], + # TODO: Bring back `generate` to `testing_groups` when stable. other_groups=[ ModelTestingGroup.distributed, ModelTestingGroup.convert, @@ -288,6 +292,7 @@ def _update_and_add_testing_config( testing_groups=[ ModelTestingGroup.basic, ], + # TODO: Bring back `generate` to `testing_groups` when stable. other_groups=[ ModelTestingGroup.distributed, ModelTestingGroup.convert, @@ -306,10 +311,11 @@ def _update_and_add_testing_config( testing_groups=[ ModelTestingGroup.basic, ModelTestingGroup.convert, - ModelTestingGroup.generate, ], + # TODO: Bring back `generate` to `testing_groups` when stable. other_groups=[ ModelTestingGroup.distributed, + ModelTestingGroup.generate, ], ) @@ -325,6 +331,7 @@ def _update_and_add_testing_config( ModelTestingGroup.basic, ModelTestingGroup.convert, ], + # TODO: Bring back `generate` to `testing_groups` when stable. other_groups=[ ModelTestingGroup.distributed, ModelTestingGroup.generate, @@ -342,10 +349,11 @@ def _update_and_add_testing_config( testing_groups=[ ModelTestingGroup.basic, ModelTestingGroup.convert, - ModelTestingGroup.generate, ], + # TODO: Bring back `generate` to `testing_groups` when stable. other_groups=[ ModelTestingGroup.distributed, + ModelTestingGroup.generate, ], ) @@ -367,14 +375,15 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron, ModelTestingGroup.distributed, ModelTestingGroup.convert, + ], + # TODO: Bring back `generate` to `testing_groups` when stable. + other_groups=[ ModelTestingGroup.generate, ], - other_groups=[], ) _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. - # TODO: Conversion fails. "llama", "llamba", model_type="hybrid_ssm", @@ -383,11 +392,14 @@ def _update_and_add_testing_config( checkpoint_format=LLambaHuggingfaceCheckpointFormat, testing_groups=[ ModelTestingGroup.basic, + ], + # TODO: Bring back `generate` to `testing_groups` when stable. + other_groups=[ + # TODO: Fix and bring these back to `testing_groups` ModelTestingGroup.distributed, ModelTestingGroup.convert, ModelTestingGroup.generate, ], - other_groups=[], ) From e92c311845a92d5de67aac5a5c2ab0ae9d759849 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Jun 2025 09:02:07 -0400 Subject: [PATCH 07/69] stuff --- tests/conftest.py | 111 ++++++++++++++++++++++------------------- tests/utils/depends.py | 4 +- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 829e1696..b688bb54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import dataclasses -import datetime +import gc +import json +import logging import math import os @@ -27,6 +29,7 @@ def pytest_addoption(parser): group = parser.getgroup("fast_llm") group.addoption("--skip-slow", action="store_true") group.addoption("--show-skipped", action="store_true") + group.addoption("--show-gpu-memory", type=int, default=10) group.addoption("--models", nargs="*") group.addoption( "--run-extra-slow", @@ -166,9 +169,63 @@ def pytest_collection_modifyitems(config, items: list[pytest.Function]): @pytest.hookimpl(tryfirst=True, hookwrapper=True) -def pytest_runtest_makereport(item: pytest.Function, call): +def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): outcome = yield - manager.register_result(item, outcome.get_result()) + result = outcome.get_result() + manager.register_result(item, result) + + # Measure GPU memory usage. (TODO: This excludes child processes) + if call.when == "call" and torch.cuda.is_available(): + torch._C._cuda_clearCublasWorkspaces() + gc.collect() + # This also frees memory for other processes. + torch.cuda.empty_cache() + item.add_report_section( + call.when, + "resource usage", + json.dumps( + { + "duration": call.duration, + "max_memory_reserved": torch.cuda.max_memory_reserved(), + "max_memory_allocated": torch.cuda.max_memory_allocated(), + "memory_reserved": torch.cuda.memory_reserved(), + "memory_allocated": torch.cuda.memory_allocated(), + } + ), + ) + torch.cuda.reset_peak_memory_stats() + + +@pytest.hookimpl +def pytest_terminal_summary(terminalreporter): + resource_reports = {} + for reports in terminalreporter.stats.values(): + for report in reports: + if isinstance(report, pytest.TestReport): + for _, section in report.get_sections("Captured resource usage"): + if report.nodeid in resource_reports: + logging.error(f"Duplicate resource report for {report.nodeid}") + resource_reports[report.nodeid] = json.loads(section) + + if not resource_reports: + return + + terminalreporter.write_sep("=", "Highest gpu memory usage", bold=True) + sorted_nodeids = sorted( + resource_reports.keys(), + key=lambda nodeid: resource_reports[nodeid]["max_memory_reserved"], + reverse=True, + ) + logging.error(f"sorted_nodeids {sorted_nodeids}") + for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: + terminalreporter.write_line( + f"{nodeid}:\n " + f"Max Reserved {resource_reports[nodeid]["max_memory_reserved"] / 1e6:.0f} MB | " + f"Max Allocated {resource_reports[nodeid]["max_memory_allocated"] / 1e6:.0f} MB | " + f"End Reserved {resource_reports[nodeid]["memory_reserved"] / 1e6:.0f} MB | " + f"End Allocated {resource_reports[nodeid]["memory_allocated"] / 1e6:.0f} MB | " + f"Duration {resource_reports[nodeid]["duration"]:.2f}" + ) def pytest_runtest_call(item: pytest.Function): @@ -190,51 +247,3 @@ def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" return xdist.scheduler.LoadGroupScheduling(config, log) - - -def get_all_reports(terminalreporter): - """Reports for all stages and all outcomes""" - for reports in terminalreporter.stats.values(): - for report in reports: - if isinstance(report, pytest.TestReport): - yield report - - -def resource_usage_message(report): - """The resource usage message for a report""" - return ", ".join(content for (prefix, content) in report.get_sections(f"Captured resource {report.when}")) - - -def format_duration(seconds): - """Human-readable running time message""" - if seconds < 60: - duration_string = f"{seconds:.3f} seconds" - else: - duration_string = str(datetime.timedelta(seconds=round(seconds))) - return f"running time: {duration_string}" - - -# @pytest.hookimpl(tryfirst=True) -# def pytest_runtest_makereport(item, call): -# """Report running time of a test call""" -# if call.when == "call": -# item.add_report_section( -# call.when, "resource", format_duration(call.duration) -# ) -# -# -# @pytest.hookimpl -# def pytest_terminal_summary(terminalreporter): -# """Produce a resource usage report if any test asked for it""" -# resource_reports = [ -# (report, message) -# for report in get_all_reports(terminalreporter) -# if (message := resource_usage_message(report)) -# ] -# if not resource_reports: -# return -# terminalreporter.write_sep("=", "resource usage", bold=True) -# for report, message in resource_reports: -# terminalreporter.write_line( -# f"{report.nodeid} ({report.when}) {message}" -# ) diff --git a/tests/utils/depends.py b/tests/utils/depends.py index 5e6bcc71..3fbb8f39 100644 --- a/tests/utils/depends.py +++ b/tests/utils/depends.py @@ -92,11 +92,11 @@ def handle_missing(self, item: pytest.Function): pytest.fail(f'{item.nodeid} depends on {", ".join(missing)}, which was not found', False) if failed := [ - dependency + f"{dependency} ({", ".join(f"{key}: {value}" for key, value in self._results[dependency].items()) if self._results[dependency] else "missing"})" for dependency in self._dependencies[nodeid] if not all(self._results[dependency].get(step, None) == "passed" for step in ("setup", "call", "teardown")) ]: - pytest.skip(f'{item.nodeid} depends on failed {", ".join(failed)}') + pytest.skip(f'{item.nodeid} depends on {", ".join(failed)}') def _resolve_dependencies(self, item: pytest.Function): dependencies = set() From b877fb27604be66c9ca87de11a88524e6cc5d7f9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Jun 2025 09:03:59 -0400 Subject: [PATCH 08/69] stuff --- Dockerfile | 9 ++++++++- setup.cfg | 11 +++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8c2efa85..983d785e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.7-labs -FROM nvcr.io/nvidia/pytorch:24.11-py3 +FROM nvcr.io/nvidia/pytorch:25.05-py3 # Install dependencies. RUN apt-get update \ @@ -24,6 +24,13 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to /usr/local/lib/python3.12/dist-packages \ /usr/local/lib/python3.12/dist-packages/__pycache__ +# The base image enforces versions for things like pytest for no good reason. +ENV PIP_CONSTRAINT="" +# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. +# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 +# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +RUN MAX_JOBS=4 pip install --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" + # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ diff --git a/setup.cfg b/setup.cfg index 381225bf..fac372eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,13 +17,13 @@ install_requires = # FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation CORE = # Available through the nvidia base image - torch>=2.5.0 + torch>=2.6.0 # Numpy major needs to match torch - numpy>=1.24.4,<2.0.0 + numpy>=1.26.4,<2.0.0 # Used for checkpoints safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation - flash-attn==2.7.2.post1 + flash-attn==2.7.3 mamba_ssm[causal-conv1d]==2.2.4 @@ -41,17 +41,16 @@ OPTIONAL = omegaconf>=2.3.0 # Miscellaneous requests>=2.32.3 - tqdm>=4.66.3 + tqdm>=4.67.1 DEV = # Pre-commit git hook pre-commit>=4.0.1 # Required for testing pytest>=8.3.2 - pytest-depends>=1.0.1 pytest-xdist>=3.6.1 # Somehow needed for Megatron to work with base image 24.11 - setuptools>=75.6.0 + setuptools>=78.1.1 # Required for building the documentation DOCS = From 907aef09ad944a3741ff184f36923c7cd7bb84af Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Jun 2025 09:45:29 -0400 Subject: [PATCH 09/69] attempt --- Dockerfile | 2 +- setup.cfg | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 983d785e..ae6625d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,7 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN MAX_JOBS=4 pip install --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" +RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@v2.2.4" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ diff --git a/setup.cfg b/setup.cfg index fac372eb..c0a7d57b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,14 +17,15 @@ install_requires = # FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation CORE = # Available through the nvidia base image - torch>=2.6.0 + torch>=2.7.0 # Numpy major needs to match torch numpy>=1.26.4,<2.0.0 # Used for checkpoints safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 - mamba_ssm[causal-conv1d]==2.2.4 + # mamba_ssm[causal-conv1d]=2.2.4 # Removed here because we need to compile from github. + mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@v2.2.4 # Required for some optional features and tools. @@ -48,6 +49,7 @@ DEV = pre-commit>=4.0.1 # Required for testing pytest>=8.3.2 + pytest-depends>=1.0.1 pytest-xdist>=3.6.1 # Somehow needed for Megatron to work with base image 24.11 setuptools>=78.1.1 From 1340903d5b31c8f1fc0c6afb9171b6f119f3c7a4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Jun 2025 11:56:45 -0400 Subject: [PATCH 10/69] attempt --- Dockerfile | 4 ++-- setup.cfg | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index ae6625d0..05c3870c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,10 +27,10 @@ RUN mkdir -m 777 /app/Megatron-LM /app/examples /app/fast_llm /app/tests /app/to # The base image enforces versions for things like pytest for no good reason. ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. -# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 +# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) +RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.0.post8" RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@v2.2.4" - # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ diff --git a/setup.cfg b/setup.cfg index c0a7d57b..3345ff73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,8 +24,7 @@ CORE = safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 - # mamba_ssm[causal-conv1d]=2.2.4 # Removed here because we need to compile from github. - mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@v2.2.4 + mamba_ssm[causal-conv1d]==2.2.4 # Required for some optional features and tools. From 8aed0a3e3b99edf44391f22215f69b72f640bff6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Jun 2025 18:27:06 -0400 Subject: [PATCH 11/69] Cleanup tests --- fast_llm/logging.py | 7 +- tests/conftest.py | 23 +++- tests/layers/test_lm_head.py | 86 +++++-------- tests/test_functional.py | 6 +- tests/test_mtp.py | 204 ----------------------------- tests/test_ssms.py | 241 ++--------------------------------- tests/utils/model_configs.py | 19 +++ tests/utils/utils.py | 71 +++++------ 8 files changed, 123 insertions(+), 534 deletions(-) delete mode 100644 tests/test_mtp.py diff --git a/fast_llm/logging.py b/fast_llm/logging.py index ffeb56f6..9c791ba6 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -323,16 +323,19 @@ def log_generator[ return log(f"{name} {tensor.view(dtype=torch.int64)[-8:].tolist()}", log_fn=log_fn) +_global_max_allocated = 0 _global_max_reserved = 0 def get_memory_usage_mib(reset_stats: bool = True, relative_to: dict[str, int] | None = None) -> dict[str, float]: - global _global_max_reserved + global _global_max_allocated, _global_max_reserved + max_allocated = torch.cuda.memory_allocated() / 2**20 max_reserved = torch.cuda.max_memory_reserved() / 2**20 + _global_max_allocated = max(max_allocated, _global_max_allocated) _global_max_reserved = max(max_reserved, _global_max_reserved) out = { "allocated": torch.cuda.memory_allocated() / 2**20, - "max_allocated": torch.cuda.max_memory_allocated() / 2**20, + "max_allocated": max_allocated, "reserved": torch.cuda.memory_reserved() / 2**20, "max_reserved": max_reserved, "global_max_reserved": _global_max_reserved, diff --git a/tests/conftest.py b/tests/conftest.py index b688bb54..cd4cc1d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import torch import xdist.scheduler +import fast_llm.logging from tests.utils.depends import DependencyManager # Make fixtures available globally without import @@ -176,9 +177,14 @@ def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): # Measure GPU memory usage. (TODO: This excludes child processes) if call.when == "call" and torch.cuda.is_available(): + # Free memory for more accurate reporting, and to reduce OOM risk with lots of workers. + # Cublas workspace can unnecessarily keep 100s of MBs of reserved memory. torch._C._cuda_clearCublasWorkspaces() - gc.collect() - # This also frees memory for other processes. + # Lots of tensors tend to stay allocated until the next garbage collection. + # Collect only if the remaining memory is significant enough since it's costly. + if torch.cuda.memory_allocated() > 1e7: + gc.collect() + # Actually free the memory. torch.cuda.empty_cache() item.add_report_section( call.when, @@ -186,14 +192,23 @@ def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): json.dumps( { "duration": call.duration, - "max_memory_reserved": torch.cuda.max_memory_reserved(), - "max_memory_allocated": torch.cuda.max_memory_allocated(), + # Relevant value for OOM risk. Also look at global max since fast-llm resets stats. + "max_memory_reserved": max( + torch.cuda.max_memory_reserved(), fast_llm.logging._global_max_reserved + ), + # Actual memory usage from the test. + "max_memory_allocated": max( + torch.cuda.max_memory_allocated(), fast_llm.logging._global_max_allocated + ), "memory_reserved": torch.cuda.memory_reserved(), "memory_allocated": torch.cuda.memory_allocated(), } ), ) torch.cuda.reset_peak_memory_stats() + # Reset global stats for next test. + fast_llm.logging._global_max_reserved = 0 + fast_llm.logging._global_max_allocated = 0 @pytest.hookimpl diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 95da48e7..cad95e53 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,21 +5,15 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageConfig -from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTBaseModelConfig -from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda +from tests.utils.utils import get_base_model, get_stage, requires_cuda def _lm_head( @@ -88,44 +82,41 @@ def test_lm_head( distributed_config_dict: dict[str, typing.Any], loss_masking: bool, ): - config = GPTBaseModelConfig.from_dict( + config = GPTModelConfig.from_dict( { - "transformer": { - "normalization": {"type": NormalizationType.rms_norm}, - "hidden_size": HIDDEN_SIZE, - "num_layers": 0, + "base_model": { + "transformer": { + "normalization": {"type": NormalizationType.rms_norm}, + "hidden_size": HIDDEN_SIZE, + "num_layers": 0, + }, + "vocab_size": VOCAB_SIZE, + "cross_entropy_impl": cross_entropy_impl, }, - "vocab_size": VOCAB_SIZE, - "cross_entropy_impl": cross_entropy_impl, + "distributed": distributed_config_dict, }, config_dict, update_type=UpdateType.update, ) - distributed_config = DistributedConfig.from_dict(distributed_config_dict) - distributed = Distributed(distributed_config) - tensor_space = TensorSpace(distributed_config) - config.setup_tensor_space(tensor_space) - tensor_space.setup(distributed) - model = GPTBaseModel(config, distributed_config) - model.setup(distributed) + model, distributed = get_base_model(config) - sequence_first = config.sequence_first or ( - config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 + sequence_first = config.base_model.sequence_first or ( + config.base_model.cross_entropy_splits is not None and config.base_model.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( - distributed_config.optimization_dtype.torch - if config.transformer.full_precision_residual - else distributed_config.training_dtype.torch + config.distributed.optimization_dtype.torch + if config.base_model.transformer.full_precision_residual + else config.distributed.training_dtype.torch ), device=distributed.device, requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.base_model.prediction_heads - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.base_model.prediction_heads - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -135,7 +126,7 @@ def test_lm_head( TransformerKwargs.sequence_first: sequence_first, TransformerKwargs.grad_output: 1.0, } - if config.distillation_model is None: + if config.base_model.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -148,25 +139,25 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.prediction_heads == 1 + assert config.base_model.prediction_heads == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.distillation_model}_logits"] = target + kwargs[f"{config.base_model.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.tie_word_embeddings or config.prediction_heads > 1: + if config.base_model.tie_word_embeddings or config.base_model.prediction_heads > 1: logit_weight = ( torch.empty( - VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed_config.training_dtype.torch, device=distributed.device + VOCAB_SIZE, HIDDEN_SIZE, dtype=config.distributed.training_dtype.torch, device=distributed.device ) - .normal_(config.transformer.init_method_std) + .normal_(config.base_model.transformer.init_method_std) .requires_grad_(True) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.tie_word_embeddings else OUTPUT_WEIGHTS] = logit_weight + kwargs[WORD_EMBEDDINGS_WEIGHT if config.base_model.tie_word_embeddings else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None @@ -175,18 +166,7 @@ def test_lm_head( head: LanguageModelHead = model[layer_index] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - stage = Stage( - config=StageConfig(), - base_model=[head], - distributed_config=distributed_config, - begin=0, - end=1, - index=0, - ) - stage.setup(distributed=distributed) - stage.initialize_weights() - stage.restore_parameters() - stage.reset_gradients() + stage = get_stage([head], distributed) # Get reference outputs and grads if logit_weight is None: @@ -209,8 +189,8 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.logits_scale_factor, - logit_z_loss=config.logit_z_loss, + logit_scale_factor=config.base_model.logits_scale_factor, + logit_z_loss=config.base_model.logit_z_loss, ) # Prepare LM head inputs @@ -231,10 +211,10 @@ def test_lm_head( output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed_config.training_dtype == DataType.float32 else 5e-3 + threshold = 1e-5 if config.distributed.training_dtype == DataType.float32 else 5e-3 min_threshold = ( - 1e-5 if distributed_config.training_dtype == DataType.float32 else 1e-4 - ) * config.logits_scale_factor + 1e-5 if config.distributed.training_dtype == DataType.float32 else 1e-4 + ) * config.base_model.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/test_functional.py b/tests/test_functional.py index 9c01f084..b049be85 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -226,9 +226,9 @@ def test_mlp_recomputation(gated, activation_type): def test_dropless_mlp(): num_experts = 4 experts_per_token = 4 - tokens = 1024 - hidden_size = 2048 - ffn_hidden_size = 4096 + tokens = 256 + hidden_size = 512 + ffn_hidden_size = 1024 std = 1 / 64 input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) router_weight = torch.normal(0, std, (num_experts, hidden_size), device="cuda") diff --git a/tests/test_mtp.py b/tests/test_mtp.py deleted file mode 100644 index 5c4660b7..00000000 --- a/tests/test_mtp.py +++ /dev/null @@ -1,204 +0,0 @@ -import typing - -import pytest -import torch - -from fast_llm.config import UpdateType -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.models.gpt.config import GPTBaseModelConfig -from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.utils import Assert -from tests.utils.utils import get_hybrid_config, materialize_meta_tensors, requires_cuda - -try: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel -except ImportError: - MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( - None, - None, - None, - ) - # Mamba not installed, skipping tests - - -run_hybrid_test = MambaLayer is not None and DiscreteMamba2 is not None and torch.cuda.is_available() - - -SEQUENCE_LENGTH = 200 -BATCH_SIZE = 4 -HIDDEN_SIZE = 256 -VOCAB_SIZE = 500 - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -@requires_cuda -@pytest.mark.parametrize( - "config_dict", - ( - {"prediction_heads": 1}, - {"prediction_heads": 2, "tie_word_embeddings": False}, - {"prediction_heads": 5, "tie_word_embeddings": False}, - ), -) -def test_transformer_mtp(config_dict: dict[str, typing.Any]): - config = GPTBaseModelConfig.from_dict( - { - "transformer": { - "hidden_size": HIDDEN_SIZE, - "num_layers": 2, - }, - "vocab_size": VOCAB_SIZE, - }, - config_dict, - update_type=UpdateType.update, - ) - distributed_config = DistributedConfig.from_dict({}) - distributed = Distributed(distributed_config) - model = GPTBaseModel(config, distributed_config) - model.setup(distributed) - materialize_meta_tensors(model, model._tensor_space) - model.to("cuda") - - sequence_first = config.sequence_first or ( - config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 - ) - target = torch.randint( - 0, - VOCAB_SIZE, - ( - (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) - ), - dtype=torch.int64, - device=distributed.device, - ) - input_ = torch.randint( - 0, - VOCAB_SIZE, - (SEQUENCE_LENGTH, BATCH_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH), - device=distributed.device, - ) - attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) - position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda", dtype=torch.int64) - kwargs = { - "position_ids": position_ids, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.attention_mask: attention_mask, - TransformerKwargs.attention_mask_value: -100, - TransformerKwargs.grad_output: 1.0, - LanguageModelKwargs.labels: target, - } - if config.tie_word_embeddings: - kwargs[WORD_EMBEDDINGS_WEIGHT] = model.embedding.word_embeddings_weight - else: - kwargs[OUTPUT_WEIGHTS] = model.model_head.output_weights - losses = {LanguageModelLossNames.multi_token_prediction_loss(i): [] for i in range(model._config.prediction_heads)} - _ = model(input_, kwargs, losses=losses) - for loss_name, loss_values in losses.items(): - Assert.gt(len(loss_values), 0) - loss = sum( - [ - sum(losses[LanguageModelLossNames.multi_token_prediction_loss(i)]) - for i in range(model._config.prediction_heads) - ] - ) - loss.backward() - - -@pytest.mark.skip(reason="Too slow") -@requires_cuda -@pytest.mark.skipif(not run_hybrid_test, reason="No CUDA available or Mamba not installed") -@pytest.mark.parametrize( - ("hybrid_block_layout", "prediction_heads", "default_mtp_type"), - [ - (["m", "t"], 1, None), - (["t", "m"], 2, None), - (["m", "t"], 2, None), - (["t", "m2"], 3, None), - (["t", "m2"], 3, "m"), - ], -) -def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_heads, default_mtp_type): - hybrid_config = get_hybrid_config( - hybrid_block_layout=hybrid_block_layout, prediction_heads=prediction_heads, default_mtp_type=default_mtp_type - ) - model = HybridSSMBaseModel(hybrid_config, distributed_config) - distributed = Distributed(distributed_config) - model.setup(distributed) - tensor_space = model._tensor_space - materialize_meta_tensors(model, tensor_space) - model.to("cuda") - - num_heads, num_mtp_blocks = 0, 0 - str_block_mapping = {"t": TransformerLayer, "m": MambaLayer, "m2": DiscreteMamba2} - mtp_block_type = default_mtp_type or hybrid_block_layout[-1] - for block in model.get_output_layers(): - if isinstance(block, LanguageModelHead): - num_heads += 1 - else: - block = getattr(block, "mixer", block) - Assert.custom( - lambda _: isinstance(block, str_block_mapping[mtp_block_type]), - f"Block {block} is not of type {str_block_mapping[mtp_block_type]}", - ) - num_mtp_blocks += 1 - Assert.eq(num_heads, prediction_heads) - Assert.eq(num_mtp_blocks, prediction_heads - 1) - - batch_size = 2 - seq_length = 32 - x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") - position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) - attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape - labels = torch.randint(0, 49152, (batch_size, seq_length + model._config.prediction_heads - 1), device="cuda") - losses = {LanguageModelLossNames.multi_token_prediction_loss(i): [] for i in range(model._config.prediction_heads)} - kwargs = { - "position_ids": position_ids, - TransformerKwargs.sequence_first: False, - TransformerKwargs.attention_mask: attention_mask, - TransformerKwargs.attention_mask_value: -100, - TransformerKwargs.grad_output: True, - LanguageModelKwargs.labels: labels, - } - - if model._config.tie_word_embeddings: - kwargs[WORD_EMBEDDINGS_WEIGHT] = model.embedding.word_embeddings_weight - else: - kwargs[OUTPUT_WEIGHTS] = model.model_head.output_weights - - output = model( - x, - kwargs, - losses=losses, - ) - loss = sum( - [ - sum(losses[LanguageModelLossNames.multi_token_prediction_loss(i)]) - for i in range(model._config.prediction_heads) - ] - ) - loss.backward() diff --git a/tests/test_ssms.py b/tests/test_ssms.py index a1d460c2..52b51c8a 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -1,84 +1,31 @@ import pathlib -from functools import partial import pytest import torch from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat -from tests.utils.utils import get_hybrid_config, materialize_meta_tensors - -try: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.llamba_block import LlambaBlock - from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel -except ImportError: - MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( - None, - None, - None, - None, - ) - # Mamba not installed, skipping tests +from fast_llm.models.ssm.model import HybridSSMModel try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel except ImportError: LMHeadModel = None -run_test = MambaLayer is not None and torch.cuda.is_available() - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -def get_hf_llamba_out(input_ids, path, format): - if format == LLambaHuggingfaceCheckpointFormat: - from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel - elif format == LlamaGPTHuggingfaceCheckpointFormat: - from transformers import LlamaForCausalLM as LMHeadModel - else: - raise ValueError(f"Invalid format: {format}") - - model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) - print(f"Parameter sum: {parameter_sum}") - output = model(input_ids) - del model - torch.cuda.empty_cache() - return output, parameter_sum - @pytest.mark.slow @pytest.mark.skipif( - not run_test or LMHeadModel is None, - reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", + LMHeadModel is None, + reason=f"cartesia_pytorch.Llamba not installed", ) -def test_load_from_llamba_checkpoint(distributed_config): +def test_load_from_llamba_checkpoint(): """ Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. """ @@ -90,8 +37,12 @@ def test_load_from_llamba_checkpoint(distributed_config): format = LLambaHuggingfaceCheckpointFormat x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) - hf_logits = hf_logits["logits"].cpu() + + hf_model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") + parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) + hf_logits = hf_model(x)["logits"].cpu() + del hf_model + torch.cuda.empty_cache() # Create checkpoint load config checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) @@ -109,7 +60,7 @@ def test_load_from_llamba_checkpoint(distributed_config): schedule_config = ScheduleConfig() with NoAutoValidate(): batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(distributed_config) + batch_config.setup(DistributedConfig.from_dict({})) batch_config.validate() schedule_runner = ScheduleRunner( config=schedule_config, @@ -131,173 +82,7 @@ def test_load_from_llamba_checkpoint(distributed_config): } input_data = [(x, common_kwargs)] - losses, success, metrics = schedule_runner.run_step( - iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True - ) + schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) logits = input_data[0][1]["logits"].cpu() assert torch.allclose(logits, hf_logits, atol=1e-2) - - -@pytest.mark.extra_slow -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -@pytest.mark.parametrize( - "hybrid_block_layout,LAYER_CLS", - [ - (["m", "t"], MambaLayer), - (["m2", "t"], DiscreteMamba2), - ], - ids=["mamba", "discrete_mamba2"], -) -def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): - hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) - tensor_space = TensorSpace(distributed_config=distributed_config) - hybrid_config.setup_tensor_space(tensor_space) - layer = LAYER_CLS(hybrid_config.ssm, layer_idx=0, tensor_space=tensor_space) - tensor_space.setup(distributed) - materialize_meta_tensors(layer, tensor_space) - layer.to(distributed.device) - - batch_size = 2 - seq_length = 32 - hidden_size = hybrid_config.transformer.hidden_size - x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) - - # Run forward pass - output, _ = layer(x, {}) - - loss = output.sum() - loss.backward() - # Basic shape checkss - assert output.shape == x.shape - assert not torch.isnan(output).any() - assert not torch.isinf(output).any() - - -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -def test_mamba_block(distributed_config, distributed): - hybrid_config = get_hybrid_config(hybrid_block_layout=["m", "t"]) - tensor_space = TensorSpace(distributed_config=distributed_config) - tensor_space.setup(distributed) - hybrid_config.setup_tensor_space(tensor_space) - layer_idx = 0 - - mixer_cls = partial(MambaLayer, layer_idx=layer_idx) - block = LlambaBlock( - hybrid_config.transformer, - hybrid_config.ssm, - mixer_cls=mixer_cls, - tensor_space=tensor_space, - layer_index=layer_idx, - ) - - materialize_meta_tensors(block, tensor_space) - block.to("cuda") - - batch_size = 2 - seq_length = 32 - hidden_size = hybrid_config.transformer.hidden_size - x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) - - hidden_states = block(x, {}) - loss = hidden_states.sum() - loss.backward() - - assert hidden_states.shape == x.shape - assert not torch.isnan(hidden_states).any() - assert not torch.isinf(hidden_states).any() - - -@pytest.mark.slow -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -@pytest.mark.parametrize( - ("hybrid_block_layout"), - [ - (["m", "t"]), - (["m2", "t"]), - ], - ids=["mamba", "discrete_mamba2"], -) -def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layout): - hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) - model = HybridSSMBaseModel(hybrid_config, distributed_config) - distributed = Distributed(distributed_config) - model.setup(distributed) - tensor_space = model._tensor_space - materialize_meta_tensors(model, tensor_space) - model.to("cuda") - - batch_size = 2 - seq_length = 32 - x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") - position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) - attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape - labels = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") - losses = {LanguageModelLossNames.language_model_loss: []} - output = model( - x, - { - "position_ids": position_ids, - TransformerKwargs.sequence_first: False, - TransformerKwargs.attention_mask: attention_mask, - TransformerKwargs.attention_mask_value: -100, - TransformerKwargs.grad_output: True, - LanguageModelKwargs.labels: labels, - }, - losses=losses, - ) - loss = sum(losses[LanguageModelLossNames.language_model_loss]) - loss.backward() - - -# TODO: added this when inference enabled -# No inference for now -# @dataclass -# class InferenceParams: -# max_seqlen: int -# max_batch_size: int -# sequence_len_offset: int = 0 -# key_value_memory_dict: dict = None - -# def __post_init__(self): -# if self.key_value_memory_dict is None: -# self.key_value_memory_dict = {} - - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") -# def test_hybrid_model_inference(distributed_config, hybrid_config): -# hybrid_config.ssm.use_fast_path = False -# model = HybridSSMBaseModel(hybrid_config, distributed_config) -# distributed = Distributed(distributed_config) -# model.setup(distributed) -# tensor_space = model._tensor_space -# materialize_meta_tensors(model, tensor_space) -# model.to("cuda") -# # print(model) - -# batch_size = 2 -# seq_length = 32 -# x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") -# position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) -# attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape -# labels = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") -# max_new_tokens = 10 - -# inference_params = InferenceParams( -# max_seqlen=len(x[0]) + max_new_tokens, max_batch_size=x.shape[0], sequence_len_offset=0 -# ) -# losses = {LanguageModelLossNames.language_model_loss: []} - -# output = model( -# x, -# { -# "position_ids": position_ids, -# TransformerKwargs.sequence_first: True, -# TransformerKwargs.attention_mask: attention_mask, -# TransformerKwargs.attention_mask_value: -100, -# TransformerKwargs.grad_output: True, -# LanguageModelKwargs.labels: labels, -# "inference_params": inference_params, -# }, -# losses=losses, -# ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3f989f58..1c332496 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -403,6 +403,25 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "llama", + "hybrid_mamba_2", + model_type="hybrid_ssm", + extra_args=["model.base_model.hybrid_block_layout=['t','m2']"], + megatron_args=None, + checkpoint_format=None, + testing_groups=[ + ModelTestingGroup.basic, + ], + # TODO: Bring back `generate` to `testing_groups` when stable. + other_groups=[ + # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.distributed, + ], +) + + @pytest.fixture(scope="session", params=_MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: return _MODEL_CONFIGS[request.param] diff --git a/tests/utils/utils.py b/tests/utils/utils.py index bf2059fa..ea689bcc 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -3,9 +3,11 @@ import pytest import torch -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig +from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig +from fast_llm.engine.multi_stage.stage import Stage requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") @@ -15,41 +17,30 @@ def result_path(): return pathlib.Path("/tmp/fast_llm_tests") -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), - ssm=SSMConfig(), - hybrid_block_layout=hybrid_block_layout, - prediction_heads=prediction_heads, - default_mtp_type=default_mtp_type, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, +def get_base_model(config: FastLLMModelConfig): + # Create a base model (and distributed). + # Using a full model config so we have the model type and distributed config in the same argument. + distributed = Distributed(config.distributed) + tensor_space = TensorSpace(config.distributed) + config.base_model.setup_tensor_space(tensor_space) + tensor_space.setup(distributed) + base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) + base_model.setup(distributed) + return base_model, distributed + + +def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): + # Create a fast-llm stage which allocates and initializes meta tensors correctly. + stage = Stage( + config=StageConfig(), + base_model=base_model, + distributed_config=distributed.config, + begin=0, + end=1, + index=0, ) - return config + stage.setup(distributed=distributed) + stage.initialize_weights() + stage.restore_parameters() + stage.reset_gradients() + return stage From 830a380b9d0a5835975d73f9c1fda7e2c987ce95 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Jun 2025 19:21:39 -0400 Subject: [PATCH 12/69] fixes --- tests/conftest.py | 1 - tests/layers/test_lm_head.py | 67 +++++++++++++++++++----------------- tests/utils/model_configs.py | 14 ++++++-- 3 files changed, 47 insertions(+), 35 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cd4cc1d1..bfe9f50c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -231,7 +231,6 @@ def pytest_terminal_summary(terminalreporter): key=lambda nodeid: resource_reports[nodeid]["max_memory_reserved"], reverse=True, ) - logging.error(f"sorted_nodeids {sorted_nodeids}") for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: terminalreporter.write_line( f"{nodeid}:\n " diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index cad95e53..ea09d3b5 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -11,7 +11,7 @@ from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -82,41 +82,46 @@ def test_lm_head( distributed_config_dict: dict[str, typing.Any], loss_masking: bool, ): - config = GPTModelConfig.from_dict( + config = GPTBaseModelConfig.from_dict( { - "base_model": { - "transformer": { - "normalization": {"type": NormalizationType.rms_norm}, - "hidden_size": HIDDEN_SIZE, - "num_layers": 0, - }, - "vocab_size": VOCAB_SIZE, - "cross_entropy_impl": cross_entropy_impl, + "transformer": { + "normalization": {"type": NormalizationType.rms_norm}, + "hidden_size": HIDDEN_SIZE, + "num_layers": 0, }, - "distributed": distributed_config_dict, + "vocab_size": VOCAB_SIZE, + "cross_entropy_impl": cross_entropy_impl, }, config_dict, update_type=UpdateType.update, ) - model, distributed = get_base_model(config) - sequence_first = config.base_model.sequence_first or ( - config.base_model.cross_entropy_splits is not None and config.base_model.cross_entropy_splits > 1 + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": distributed_config_dict, + }, + ) + ) + + sequence_first = config.sequence_first or ( + config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( - config.distributed.optimization_dtype.torch - if config.base_model.transformer.full_precision_residual - else config.distributed.training_dtype.torch + distributed.config.optimization_dtype.torch + if config.transformer.full_precision_residual + else distributed.config.training_dtype.torch ), device=distributed.device, requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.base_model.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.base_model.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -126,7 +131,7 @@ def test_lm_head( TransformerKwargs.sequence_first: sequence_first, TransformerKwargs.grad_output: 1.0, } - if config.base_model.distillation_model is None: + if config.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -139,25 +144,25 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.base_model.prediction_heads == 1 + assert config.prediction_heads == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.base_model.distillation_model}_logits"] = target + kwargs[f"{config.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.base_model.tie_word_embeddings or config.base_model.prediction_heads > 1: + if config.tie_word_embeddings or config.prediction_heads > 1: logit_weight = ( torch.empty( - VOCAB_SIZE, HIDDEN_SIZE, dtype=config.distributed.training_dtype.torch, device=distributed.device + VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device ) - .normal_(config.base_model.transformer.init_method_std) + .normal_(config.transformer.init_method_std) .requires_grad_(True) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.base_model.tie_word_embeddings else OUTPUT_WEIGHTS] = logit_weight + kwargs[WORD_EMBEDDINGS_WEIGHT if config.tie_word_embeddings else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None @@ -189,8 +194,8 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.base_model.logits_scale_factor, - logit_z_loss=config.base_model.logit_z_loss, + logit_scale_factor=config.logits_scale_factor, + logit_z_loss=config.logit_z_loss, ) # Prepare LM head inputs @@ -211,10 +216,10 @@ def test_lm_head( output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) - threshold = 1e-5 if config.distributed.training_dtype == DataType.float32 else 5e-3 + threshold = 1e-5 if distributed.config.training_dtype == DataType.float32 else 5e-3 min_threshold = ( - 1e-5 if config.distributed.training_dtype == DataType.float32 else 1e-4 - ) * config.base_model.logits_scale_factor + 1e-5 if distributed.config.training_dtype == DataType.float32 else 1e-4 + ) * config.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1c332496..3f334c64 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -387,7 +387,13 @@ def _update_and_add_testing_config( "llama", "llamba", model_type="hybrid_ssm", - extra_args=["model.base_model.hybrid_block_layout=['t','m']"], + extra_args=[ + "model.base_model.hybrid_block_layout=['t','m']", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.chunk_size=32", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=8", + ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, testing_groups=[ @@ -405,10 +411,12 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. - "llama", + "llamba", "hybrid_mamba_2", model_type="hybrid_ssm", - extra_args=["model.base_model.hybrid_block_layout=['t','m2']"], + extra_args=[ + "model.base_model.hybrid_block_layout=['t','m2']", + ], megatron_args=None, checkpoint_format=None, testing_groups=[ From 13e1da5c9d91658ba9941a2d03d91d21e668143b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 10:41:22 -0400 Subject: [PATCH 13/69] fix --- fast_llm/functional/triton/mlp.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ee3ba304..ab408368 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -25,9 +25,6 @@ from fast_llm.functional.triton.sparse_linear import output_sparse_matmul from fast_llm.tensor import param_get_and_unset_is_zero -# Triton requires global variables to be annotated with `constexpr`. -_TritonActivationType: tl_constexpr = ActivationType - @triton_jit() def triton_mlp_activation_forward_kernel( @@ -50,18 +47,19 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + # Triton doesn't like enums, so we use str instead of ActivationType. + if activation_type == "gelu": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) - elif activation_type == _TritonActivationType.silu: + elif activation_type == "silu": out = input_ / (1 + tl.exp(-input_)) - elif activation_type == _TritonActivationType.relu: + elif activation_type == "relu": out = tl.where(input_ > 0, input_, 0) - elif activation_type == _TritonActivationType.squared_relu: + elif activation_type == "squared_relu": relu_out = tl.where(input_ > 0, input_, 0) out = relu_out * relu_out - elif activation_type == _TritonActivationType.identity: + elif activation_type == "identity": out = input_ else: tl.static_assert(False, activation_type) @@ -100,28 +98,29 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + # Triton doesn't like enums, so we use str instead of ActivationType. + if activation_type == "gelu": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) if gated or recompute: out = input_ * 0.5 * (1.0 + tanh) - elif activation_type == _TritonActivationType.silu: + elif activation_type == "silu": exp = tl.exp(-input_) sigma = 1 / (1 + exp) grad = sigma * sigma + (1 + input_) / (2 + exp + 1 / exp) if gated or recompute: out = input_ * sigma - elif activation_type == _TritonActivationType.relu: + elif activation_type == "relu": grad = tl.where(input_ > 0, 1, 0) if gated or recompute: out = tl.where(input_ > 0, input_, 0) - elif activation_type == _TritonActivationType.squared_relu: + elif activation_type == "squared_relu": relu_out = tl.where(input_ > 0, input_, 0) grad = 2 * relu_out if gated or recompute: out = relu_out * relu_out - elif activation_type == _TritonActivationType.identity: + elif activation_type == "identity": grad = 1 if gated or recompute: out = input_ From 0dffe5c46ca31e0b8b1b13dfcbec6d0e712ab2d6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 14:27:13 -0400 Subject: [PATCH 14/69] fixes --- fast_llm/layers/ssm/discrete_mamba2.py | 41 ++++++++++++++++---------- fast_llm/layers/ssm/mamba_layer.py | 11 +++++-- setup.cfg | 29 +++++++++--------- tests/test_ssms.py | 2 +- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 85916244..ecf0b29d 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -2,7 +2,6 @@ import math import einops -import mamba_ssm.ops.triton.ssd_combined import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -13,12 +12,22 @@ logger = logging.getLogger(__name__) + try: - import causal_conv1d + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_chunk_scan_combined # noqa + + _mamba_available = True except ImportError: - # this is needed since we cannot use causal_conv1d on B200 GPUs for now - logger.warning("Note, causal_conv1d not found, will use torch.nn.functional.conv1d instead") - causal_conv1d = None + _mamba_available = False + + +try: + from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa + + _causal_conv1d_available = True +except ImportError: + _causal_conv1d_available = False + """ This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py @@ -148,6 +157,8 @@ def forward(self, hidden_states, kwargs): outputs["hidden_states"]: (B, L, D). outputs["state"]: inference cache. """ + + assert _mamba_available input_ = hidden_states outputs = {} # assert state is None @@ -201,7 +212,7 @@ def forward(self, hidden_states, kwargs): C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) # SSM forward - result = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( + result = _mamba_chunk_scan_combined( x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), dt=A_log, dt_softplus=True, @@ -234,11 +245,18 @@ def forward(self, hidden_states, kwargs): def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if causal_conv1d is None or self.activation_name not in [ + if _causal_conv1d_available and self.activation_name in ( "silu", "swish", "identity", - ]: + ): + xBC = _causal_conv1d_fn( + xBC.transpose(1, 2), + einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_bias, + activation=None if self.activation_name == "identity" else self.activation_name, + ).transpose(1, 2) + else: xBC = self.act( torch.nn.functional.conv1d( xBC.transpose(1, 2), @@ -248,11 +266,4 @@ def convolutional_forward(self, xBC, padded_len): padding=self.conv_kernel_size - 1, )[..., :padded_len].transpose(1, 2) ) - else: - xBC = causal_conv1d.causal_conv1d_fn( - xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, - ).transpose(1, 2) return xBC diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7d0ee48a..7fd43789 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -2,7 +2,6 @@ from typing import Callable import einops -import mamba_ssm.ops.selective_scan_interface import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -11,6 +10,13 @@ from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale +try: + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa + + _mamba_available = True +except ImportError: + _mamba_available = False + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -153,6 +159,7 @@ def __init__( self._return_input = return_input def forward(self, hidden_states, kwargs): + assert _mamba_available batch, seqlen, dim = hidden_states.shape # We do matmul and transpose BLH -> HBL at the same time @@ -167,7 +174,7 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s - out = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn( + out = _mamba_inner_fn( xz, self.conv1d_weight, self.conv1d_bias, diff --git a/setup.cfg b/setup.cfg index 3345ff73..bc0de459 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,10 +6,10 @@ packages = find_namespace: include_package_data = True python_requires = >=3.12 install_requires = - requests>=2.32.3 - PyYAML>=6.0.1 - pybind11>=2.5.0 - packaging>=24.1 + requests>=2.32.4 + PyYAML>=6.0.2 + pybind11>=2.13.6 + packaging>=25.0 [options.extras_require] # Required to use the main functionality of Fast-LLM @@ -21,7 +21,7 @@ CORE = # Numpy major needs to match torch numpy>=1.26.4,<2.0.0 # Used for checkpoints - safetensors>=0.4.4 + safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 mamba_ssm[causal-conv1d]==2.2.4 @@ -30,28 +30,27 @@ CORE = # Required for some optional features and tools. OPTIONAL = # Huggingface tools - transformers>=4.44.2 - hf-transfer>=0.1.8 - datasets>=3.1.0 - huggingface-hub>=0.28.1 + transformers>=4.52.4 + hf-transfer>=0.1.9 + datasets>=3.6.0 + huggingface-hub>=0.32.6 # Weights and biases - wandb>=0.17.7 + wandb>=0.20.1 # Hydra hydra-core>=1.3.2 omegaconf>=2.3.0 # Miscellaneous - requests>=2.32.3 tqdm>=4.67.1 DEV = # Pre-commit git hook - pre-commit>=4.0.1 + pre-commit>=4.2.0 # Required for testing - pytest>=8.3.2 + pytest>=8.4.0 pytest-depends>=1.0.1 - pytest-xdist>=3.6.1 + pytest-xdist>=3.7.0 # Somehow needed for Megatron to work with base image 24.11 - setuptools>=78.1.1 + setuptools>=80.9.0 # Required for building the documentation DOCS = diff --git a/tests/test_ssms.py b/tests/test_ssms.py index f3eb9261..ef5193b6 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -14,6 +14,7 @@ from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat @@ -21,7 +22,6 @@ try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel except Exception: From dcc506464d175407c3d8711e73d05ae3b88c6c41 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 14:30:29 -0400 Subject: [PATCH 15/69] fixes --- tests/test_ssms.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index ef5193b6..36c7b622 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -14,24 +14,15 @@ from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock +from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel from tests.common import get_hybrid_config, materialize_meta_tensors -try: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel -except Exception: - MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( - None, - None, - None, - None, - ) - try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel except ImportError: From 9d415bc6f29a083e326d856fcfcc949bdad3b638 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 14:37:21 -0400 Subject: [PATCH 16/69] fixes --- .github/workflows/docs.yaml | 2 +- Dockerfile | 2 +- setup.cfg | 21 ++++++++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 93191972..b755993c 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -33,7 +33,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/Dockerfile b/Dockerfile index 05c3870c..50810ed1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,DEV]" +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/setup.cfg b/setup.cfg index bc0de459..8a446064 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,16 +24,10 @@ CORE = safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 - mamba_ssm[causal-conv1d]==2.2.4 -# Required for some optional features and tools. +# Small packages required for some optional features and tools. OPTIONAL = - # Huggingface tools - transformers>=4.52.4 - hf-transfer>=0.1.9 - datasets>=3.6.0 - huggingface-hub>=0.32.6 # Weights and biases wandb>=0.20.1 # Hydra @@ -42,6 +36,19 @@ OPTIONAL = # Miscellaneous tqdm>=4.67.1 +# Huggingface tools +HUGGINGFACE = + transformers>=4.52.4 + hf-transfer>=0.1.9 + datasets>=3.6.0 + huggingface-hub>=0.32.6 + +# Required to run SSMs +# To install on cpu environment (ex. for IDE support): +# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +SSM = + mamba_ssm[causal-conv1d]==2.2.4 + DEV = # Pre-commit git hook pre-commit>=4.2.0 From 68251c29eadeb1f25d23ba1090d8f43d6665cbf4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 16:00:41 -0400 Subject: [PATCH 17/69] fixes --- fast_llm/layers/ssm/config.py | 2 +- tests/conftest.py | 14 ++++++++++++-- tests/utils/model_configs.py | 8 ++++---- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 13418254..6837507f 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -21,7 +21,7 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads -class SSMBlockType(str, enum.Enum): +class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. """ diff --git a/tests/conftest.py b/tests/conftest.py index bfe9f50c..bc3d443c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -184,8 +184,12 @@ def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): # Collect only if the remaining memory is significant enough since it's costly. if torch.cuda.memory_allocated() > 1e7: gc.collect() - # Actually free the memory. - torch.cuda.empty_cache() + try: + # Actually free the memory. + torch.cuda.empty_cache() + except RuntimeError: + # Happens if the test broke cuda. + return item.add_report_section( call.when, "resource usage", @@ -243,6 +247,12 @@ def pytest_terminal_summary(terminalreporter): def pytest_runtest_call(item: pytest.Function): + if torch.cuda.is_available(): + # Empty cache to check is cuda is still working (TODO: Is there a better way? Can we kill the worker?) + try: + torch.cuda.empty_cache() + except RuntimeError: + pytest.skip("Cuda runtime unavailable due to an error in an earlier test.") manager.handle_missing(item) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3f334c64..cf124690 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -370,14 +370,14 @@ def _update_and_add_testing_config( "--moe-router-topk=4", ], checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, - testing_groups=[ + testing_groups=[], + # TODO: New base image broke mixtral + # TODO: Bring back `generate` to `testing_groups` when stable. + other_groups=[ ModelTestingGroup.basic, ModelTestingGroup.megatron, ModelTestingGroup.distributed, ModelTestingGroup.convert, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ ModelTestingGroup.generate, ], ) From 639d6c261f8ddafae62d73631223e3f7b1cae72a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 16:54:51 -0400 Subject: [PATCH 18/69] doc --- fast_llm/layers/ssm/config.py | 1 - fast_llm/models/ssm/config.py | 8 ++++---- fast_llm/models/ssm/model.py | 6 +++--- setup.cfg | 1 - tests/utils/depends.py | 12 +++++++++++- tests/utils/model_configs.py | 2 +- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 6837507f..fd9c60ec 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -28,7 +28,6 @@ class SSMBlockType(enum.StrEnum): mamba = "m" mamba2_discrete = "m2d" - mamba2 = "m2" transformer = "t" diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index e27e5280..22f81fa1 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -30,14 +30,14 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, - desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", + desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, ) - default_mtp_type: str | None = Field( + default_mtp_type: SSMBlockType | None = Field( default=None, - desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", + desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) use_megatron_initialization: bool = Field( diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 118a195b..526d66c0 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -88,7 +88,7 @@ def get_layers(self) -> list[Layer]: # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer.value: + if block_type == SSMBlockType.transformer: # Transformer block layers.append( TransformerLayer( @@ -100,7 +100,7 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == SSMBlockType.mamba2_discrete.value: + elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -113,7 +113,7 @@ def get_layers(self) -> list[Layer]: ) layers.append(mamba_block) - elif block_type == SSMBlockType.mamba.value: + elif block_type == SSMBlockType.mamba: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, diff --git a/setup.cfg b/setup.cfg index 8a446064..24efcaf3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,6 @@ DEV = pre-commit>=4.2.0 # Required for testing pytest>=8.4.0 - pytest-depends>=1.0.1 pytest-xdist>=3.7.0 # Somehow needed for Megatron to work with base image 24.11 setuptools>=80.9.0 diff --git a/tests/utils/depends.py b/tests/utils/depends.py index 3fbb8f39..3972a066 100644 --- a/tests/utils/depends.py +++ b/tests/utils/depends.py @@ -49,7 +49,17 @@ def as_list(lst): class DependencyManager: - """Keep track of tests, their names and their dependencies.""" + """ + A simplified and improved version of pytest-depends. Main differences are the following: + * Add compatibility with pytest-xdist: group connected components of the dependency graph together, + and rename them with the `@dependency_group_{i}` suffix so they are run in the same worker, assuming + group scheduling is used. + * Improved parameterized dependencies so tests can depend on other tests with matching parametrization. + Ex. a test `test_model` with parameter `model` can depend on `test_other[{model}]`, + then `test_model[llama]` will depend on `test_other[llama]`, and so on. + * Improved description of missing/failed dependencies. + * Some option hard-coded for Fast-LLM. + """ def __init__(self, items: list[pytest.Function]): self._items = items diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index cf124690..d4889e94 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -415,7 +415,7 @@ def _update_and_add_testing_config( "hybrid_mamba_2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", ], megatron_args=None, checkpoint_format=None, From 746542847ed3045fe62819a31a12eacfa17aeb5e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 17:24:38 -0400 Subject: [PATCH 19/69] stuff --- fast_llm/layers/ssm/config.py | 3 +- fast_llm/layers/transformer/config.py | 11 +- fast_llm/layers/transformer/transformer.py | 2 +- fast_llm/logging.py | 7 +- fast_llm/models/ssm/config.py | 8 +- fast_llm/models/ssm/model.py | 6 +- setup.cfg | 1 - tests/common.py | 454 --------------------- tests/conftest.py | 183 +++++++-- tests/data/common.py | 2 +- tests/data/test_blending.py | 2 +- tests/data/test_concatenate.py | 2 +- tests/data/test_concatenated_memmap.py | 2 +- tests/data/test_dataset_from_file.py | 2 +- tests/data/test_fim.py | 2 +- tests/data/test_memmap.py | 2 +- tests/data/test_sampling.py | 2 +- tests/data/test_slice.py | 2 +- tests/layers/test_lm_head.py | 2 +- tests/test_checkpoint.py | 16 +- tests/test_config.py | 2 +- tests/test_functional.py | 2 +- tests/test_gpt_generate_and_forward.py | 4 +- tests/test_match_megatron.py | 6 +- tests/test_mb.py | 5 +- tests/test_mb_seq_first.py | 5 +- tests/test_ms.py | 3 +- tests/test_mtp.py | 2 +- tests/test_multi_stage.py | 3 +- tests/test_seq_first.py | 3 +- tests/test_simple.py | 3 +- tests/test_ssms.py | 2 +- tests/test_triton_kernels.py | 2 +- tests/utils/__init__.py | 0 tests/{ => utils}/compare_tensor_logs.py | 0 tests/utils/dataset.py | 80 ++++ tests/utils/depends.py | 211 ++++++++++ tests/utils/model_configs.py | 233 +++++++++++ tests/utils/run_test_script.py | 96 +++++ tests/utils/utils.py | 52 +++ 40 files changed, 881 insertions(+), 544 deletions(-) create mode 100644 tests/utils/__init__.py rename tests/{ => utils}/compare_tensor_logs.py (100%) create mode 100644 tests/utils/dataset.py create mode 100644 tests/utils/depends.py create mode 100644 tests/utils/model_configs.py create mode 100644 tests/utils/run_test_script.py create mode 100644 tests/utils/utils.py diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 13418254..fd9c60ec 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -21,14 +21,13 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads -class SSMBlockType(str, enum.Enum): +class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. """ mamba = "m" mamba2_discrete = "m2d" - mamba2 = "m2" transformer = "t" diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 9cc9510b..3e619eb9 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -711,13 +711,4 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( - DataType.float16, - DataType.bfloat16, - ) - - # Config parameter `window_size` only can be used with flash attention - if not use_flash_attention: - Assert.is_(self.window_size, None) - - return use_flash_attention + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index b51ba1e9..14745207 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -20,7 +20,7 @@ class BaseBlock(Layer, abc.ABC): """ - A transformer-like decoder base block block with abstract mixer. + A transformer-like decoder base block with abstract mixer. """ _mixer_module_name = "self_attn" diff --git a/fast_llm/logging.py b/fast_llm/logging.py index ffeb56f6..9c791ba6 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -323,16 +323,19 @@ def log_generator[ return log(f"{name} {tensor.view(dtype=torch.int64)[-8:].tolist()}", log_fn=log_fn) +_global_max_allocated = 0 _global_max_reserved = 0 def get_memory_usage_mib(reset_stats: bool = True, relative_to: dict[str, int] | None = None) -> dict[str, float]: - global _global_max_reserved + global _global_max_allocated, _global_max_reserved + max_allocated = torch.cuda.memory_allocated() / 2**20 max_reserved = torch.cuda.max_memory_reserved() / 2**20 + _global_max_allocated = max(max_allocated, _global_max_allocated) _global_max_reserved = max(max_reserved, _global_max_reserved) out = { "allocated": torch.cuda.memory_allocated() / 2**20, - "max_allocated": torch.cuda.max_memory_allocated() / 2**20, + "max_allocated": max_allocated, "reserved": torch.cuda.memory_reserved() / 2**20, "max_reserved": max_reserved, "global_max_reserved": _global_max_reserved, diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index e27e5280..22f81fa1 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -30,14 +30,14 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, - desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", + desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, ) - default_mtp_type: str | None = Field( + default_mtp_type: SSMBlockType | None = Field( default=None, - desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", + desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) use_megatron_initialization: bool = Field( diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 118a195b..526d66c0 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -88,7 +88,7 @@ def get_layers(self) -> list[Layer]: # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer.value: + if block_type == SSMBlockType.transformer: # Transformer block layers.append( TransformerLayer( @@ -100,7 +100,7 @@ def get_layers(self) -> list[Layer]: ), ) ) - elif block_type == SSMBlockType.mamba2_discrete.value: + elif block_type == SSMBlockType.mamba2_discrete: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -113,7 +113,7 @@ def get_layers(self) -> list[Layer]: ) layers.append(mamba_block) - elif block_type == SSMBlockType.mamba.value: + elif block_type == SSMBlockType.mamba: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, diff --git a/setup.cfg b/setup.cfg index 8a446064..24efcaf3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,6 @@ DEV = pre-commit>=4.2.0 # Required for testing pytest>=8.4.0 - pytest-depends>=1.0.1 pytest-xdist>=3.7.0 # Somehow needed for Megatron to work with base image 24.11 setuptools>=80.9.0 diff --git a/tests/common.py b/tests/common.py index d531972e..a2dba74a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,470 +1,16 @@ import os -import pathlib -import random -import shutil -import string -import subprocess import sys -import numpy as np -import pytest -import torch -import yaml - -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.gpt.config import ( - LlamaGPTHuggingfaceCheckpointFormat, - MistralGPTHuggingfaceCheckpointFormat, - MixtralGPTHuggingfaceCheckpointFormat, - MTPLlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, - Starcoder2GPTHuggingfaceCheckpointFormat, -) -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.tools.train import CliTrainingConfig -from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs - # FIXME: figure out correct import of megatron modules without this hack sys.path.append(os.getcwd()) # TODO: Use `pytest_addoption` instead? # Keep all results in one place to allow recovering them for debugging in case of failure. -TEST_RESULTS_PATH = pathlib.Path(os.environ.get("TEST_RESULTS_PATH", "/tmp/fast_llm_tests")).resolve() -FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 -REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 -_LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) -TEST_MODEL = os.environ.get("MODEL", "llama") - -ARTIFACT_PATH = "runs/0/artifacts" -TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" -TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_CACHE = TEST_RESULTS_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common" / "dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" - -TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 - -CONFIG_BASE_FAST_LLM = [ - "training.logs.interval=1", - "run.tensor_logs.save=True", - "run.tensor_logs.show=False", - "model.base_model.transformer.num_layers=2", - "model.base_model.transformer.hidden_size=256", - "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.init_method_std=0.022", - f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", - f"model.multi_stage.debug_param_init={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", - "model.multi_stage.debug_tensor_parallel=True", - "model.distributed.reproducible_init=True", - "model.distributed.timeout=10", - "training.train_iters=2", - "training.num_workers=0", - "training.timeout=30", - "batch.batch_size=8", - "batch.sequence_length=512", - "data.datasets.training.type=slice", - "data.datasets.training.end=0.969", - "data.datasets.training.dataset.type=memmap", - f"data.datasets.training.dataset.path={DATASET_PREFIX}", - "data.datasets.validation.type=slice", - "data.datasets.validation.begin=0.969", - "data.datasets.validation.end=0.999", - "data.datasets.validation.dataset.type=memmap", - f"data.datasets.validation.dataset.path={DATASET_PREFIX}", - "data.datasets.test.type=slice", - "data.datasets.test.begin=0.999", - "data.datasets.test.end=1", - "data.datasets.test.dataset.type=memmap", - f"data.datasets.test.dataset.path={DATASET_PREFIX}", - "optimizer.learning_rate.base=0.0001", -] -CONFIG_BASE_MEGATRON = [ - "--num-layers=2", - "--hidden-size=256", - "--num-attention-heads=8", - "--log-interval=1", - "--train-iters=2", - "--eval-iters=0", - "--hidden-dropout=0", - "--attention-dropout=0", - f"--debug_param_init={_LOG_LEVEL}", - f"--debug_layer_outputs={_LOG_LEVEL}", - f"--debug_layer_gradients={_LOG_LEVEL}", - f"--debug_all_param_gradients={_LOG_LEVEL}", - "--debug_param_update=0", - "--global-batch-size=8", - "--max-position-embeddings=512", - "--seq-length=512", - "--init-method-std=0.022", - "--lr=0.0001", - "--num-workers=0", - "--valid-num-workers=0", - "--tokenizer-type=NullTokenizer", - # Megatron messes with the vocab size, so we have to subtract 1. - f"--vocab-size={TEST_VOCAB_SIZE-1}", - f"--data-path={DATASET_PREFIX}", - "--lr-decay-style=constant", - # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) - "--use-mcore-models", - # local implementation doesn't allow for RMS norm. - "--transformer-impl=transformer_engine", -] - -CONFIG_SC1_FAST_LLM = CONFIG_BASE_FAST_LLM + ["model.base_model.max_position_embeddings=512"] -CONFIG_SC1_MEGATRON = CONFIG_BASE_MEGATRON + ["--group-query-attention"] -CONFIG_SC1_COMMON = CONFIG_SC1_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_GPT2_FAST_LLM = CONFIG_SC1_FAST_LLM + ["model.base_model.transformer.head_groups=8"] -CONFIG_GPT2_MEGATRON = CONFIG_BASE_MEGATRON -CONFIG_GPT2_COMMON = CONFIG_GPT2_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_SC2_FAST_LLM = CONFIG_BASE_FAST_LLM + [ - "model.base_model.transformer.head_groups=4", - "model.base_model.transformer.rotary.type=default", -] -CONFIG_SC2_MEGATRON = CONFIG_SC1_MEGATRON + [ - "--num-query-groups=4", - "--use-rotary-position-embeddings", - "--no-position-embedding", -] -CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [ - "--swiglu", - "--disable-bias-linear", - "--normalization=RMSNorm", - "--ffn-hidden-size=1024", - "--untie-embeddings-and-output-weights", -] -CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"] # Megatron does not support Llama3-style Rotary Embeddings -CONFIG_LLAMA3_MEGATRON = None -CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.rotary.type=llama3", -] -CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] # Megatron does not support per sub layer biases -CONFIG_QWEN2_MEGATRON = None -CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=only_attn_qkv", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"] # Yarn-style Rotary Embeddings -CONFIG_LLAMA_YARN_MEGATRON = None -CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.rotary.type=yarn", -] -CONFIG_LLAMA_YARN_COMMON = CONFIG_LLAMA_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] - - -CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [ - "--num-experts=4", - "--moe-router-topk=4", -] -CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.num_experts=4", - "model.base_model.transformer.num_experts_per_token=4", -] -CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MIXTRAL_YARN_MEGATRON = None -CONFIG_MIXTRAL_YARN_FAST_LLM = CONFIG_MIXTRAL_FAST_LLM + [ - "model.base_model.transformer.rotary.type=yarn", -] -CONFIG_MIXTRAL_YARN_COMMON = CONFIG_MIXTRAL_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_LLAMA_MTP_MEGATRON = None -CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.prediction_heads=4", -] -CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] - -CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] -CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] -CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM - -_CONFIGS = { - "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), - "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), - "starcoder2": ( - "gpt", - CONFIG_SC2_FAST_LLM, - CONFIG_SC2_MEGATRON, - CONFIG_SC2_COMMON, - Starcoder2GPTHuggingfaceCheckpointFormat, - ), - "llama": ( - "gpt", - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_LLAMA_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "llama3": ( - "gpt", - CONFIG_LLAMA3_FAST_LLM, - CONFIG_LLAMA3_MEGATRON, - CONFIG_LLAMA3_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "qwen2": ( - "gpt", - CONFIG_QWEN2_FAST_LLM, - CONFIG_QWEN2_MEGATRON, - CONFIG_QWEN2_COMMON, - Qwen2GPTHuggingfaceCheckpointFormat, - ), - "llama-yarn": ( - "gpt", - CONFIG_LLAMA_YARN_FAST_LLM, - CONFIG_LLAMA_YARN_MEGATRON, - CONFIG_LLAMA_YARN_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "mistral": ( - "gpt", - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_LLAMA_COMMON, - MistralGPTHuggingfaceCheckpointFormat, - ), - "mixtral": ( - "gpt", - CONFIG_MIXTRAL_FAST_LLM, - CONFIG_MIXTRAL_MEGATRON, - CONFIG_MIXTRAL_COMMON, - MixtralGPTHuggingfaceCheckpointFormat, - ), - "llamba": ( - "hybrid_ssm", - CONFIG_LLAMBA_FAST_LLM, - CONFIG_LLAMBA_MEGATRON, - CONFIG_LLAMBA_COMMON, - LLambaHuggingfaceCheckpointFormat, - ), - "mixtral-yarn": ( - "gpt", - CONFIG_MIXTRAL_YARN_FAST_LLM, - CONFIG_MIXTRAL_YARN_MEGATRON, - CONFIG_MIXTRAL_YARN_COMMON, - MixtralGPTHuggingfaceCheckpointFormat, - ), - "llama-mtp": ( - "gpt", - CONFIG_LLAMA_MTP_FAST_LLM, - CONFIG_LLAMA_MTP_MEGATRON, - CONFIG_LLAMA_MTP_COMMON, - MTPLlamaGPTHuggingfaceCheckpointFormat, - ), -} - -TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] - - -requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") - - -def get_test_dataset( - prefix: pathlib.Path = DATASET_PREFIX, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, -): - if not TOKENIZER_FILE.is_file(): - import transformers - - transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) - - if not ( - prefix.with_suffix(".idx").is_file() - and prefix.with_suffix(".bin").is_file() - and prefix.parent.joinpath("fast_llm_config.yaml").is_file() - ): - import transformers - - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - - samples = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts - ] - if max_spans > 0: - lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) - for sample, span in zip(samples, spans): - span = np.unique(span) - sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) - - GPTMemmapDataset.write_dataset(prefix, samples) - yaml.safe_dump( - {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") - ) - - -def get_test_concatenated_memmap_dataset( - path: pathlib.Path, - num_files: int, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - seed_shift: int = 55, -): - index_file = path / "index.txt" - if not index_file.is_file(): - for i in range(num_files): - get_test_dataset( - prefix=path / f"dataset_{i}", - seed=seed + i * seed_shift, - num_tokens=num_tokens, - characters=characters, - vocab_size=vocab_size, - ) - index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) - - -@pytest.fixture(scope="session") -def run_test_script(worker_resources): - def do_run_test_script( - name: str, - script: list[str], - num_gpus: int = 1, - *, - model_type: str = TEST_MODEL_TYPE, - is_megatron: bool = False, - compare: str | None = None, - config: CompareConfig | None = None, - prepare_fn=None, - compare_fn=None, - do_compare: bool = True, - ): - if torch.cuda.device_count() < num_gpus: - pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") - env = os.environ.copy() - if is_megatron: - # Prevent Megatron from complaining. - env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env["NVTE_FLASH_ATTN"] = "0" - path = TEST_RESULTS_PATH / name - skip = False - artifact_path = path / ARTIFACT_PATH - if path.exists(): - assert path.is_dir() - # TODO: Better way to check if the previous attempt succeeded. - if ( - REUSE_RESULTS - and artifact_path.is_dir() - and len(list((artifact_path / "0").iterdir())) >= (1 if is_megatron else 3) - ): - skip = True - elif FORCE_REUSE_RESULTS: - raise RuntimeError(artifact_path) - else: - shutil.rmtree(path) - elif FORCE_REUSE_RESULTS: - raise RuntimeError(path) - if prepare_fn is not None: - skip = prepare_fn(TEST_RESULTS_PATH / name, None if compare is None else TEST_RESULTS_PATH / compare, skip) - if is_megatron: - script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] - else: - script = [model_type, *script, f"run.experiment_dir={path}"] - header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] - command = [ - "python", - "-m", - "torch.distributed.run", - f"--nproc-per-node={num_gpus}", - f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}", - f"--master-port={worker_resources.torchrun_port}", - *header, - *script, - ] - print(" ".join(command)) - if skip: - print("Reusing existing run.") - else: - get_test_dataset() - if num_gpus == 1 and not is_megatron: - CliTrainingConfig.parse_and_run(script) - else: - completed_proc = subprocess.run(command, env=env, timeout=60) - if completed_proc.returncode: - raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare and do_compare: - if compare_fn is not None: - compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare) - compare_tensor_logs( - TEST_RESULTS_PATH / compare / ARTIFACT_PATH, - TEST_RESULTS_PATH / name / ARTIFACT_PATH, - config, - ) - - return do_run_test_script - - -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), - ssm=SSMConfig(), - hybrid_block_layout=hybrid_block_layout, - prediction_heads=prediction_heads, - default_mtp_type=default_mtp_type, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, - ) - return config diff --git a/tests/conftest.py b/tests/conftest.py index edc52e03..284f4140 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,26 +1,39 @@ import dataclasses +import gc +import json +import logging import math import os -import networkx import pytest -import pytest_depends -import pytest_depends.main import torch -from xdist.scheduler import LoadGroupScheduling +import xdist.scheduler + +import fast_llm.logging +from tests.utils.depends import DependencyManager # Make fixtures available globally without import -from tests.common import run_test_script # isort: skip + +manager: DependencyManager | None = None def pytest_addoption(parser): - parser.addoption("--skip-slow", action="store_true") - parser.addoption( + group = parser.getgroup("fast_llm") + group.addoption("--skip-slow", action="store_true") + group.addoption("--show-skipped", action="store_true") + group.addoption("--show-gpu-memory", type=int, default=10) + group.addoption( "--run-extra-slow", action="store_true", default=False, help="Run tests marked as extra_slow", ) + group.addoption( + "--show-dependencies", + action="store_true", + default=False, + help="List all dependencies of all tests as a list of nodeids + the names that could not be resolved.", + ) @dataclasses.dataclass @@ -42,6 +55,8 @@ def pytest_configure(config): config.addinivalue_line( "markers", "extra_slow: Mark test as extra slow and skip unless --run-extra-slow is given." ) + config.addinivalue_line("markers", "depends_on(name='name', on=['other_name']): marks dependencies between tests.") + config.addinivalue_line("markers", "model_testing_group(group='group'): marks model testing group.") # TODO: Spawned processes (multi-gpu, Megatron) ignore resource allocation. is_parallel = hasattr(config, "workerinput") if is_parallel: @@ -90,7 +105,12 @@ def pytest_configure(config): @pytest.hookimpl(trylast=True) -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config, items: list[pytest.Function]): + global manager + skip_slow = config.getoption("--skip-slow") + skip_extra_slow = not config.getoption("--run-extra-slow") + show_skipped = config.getoption("--show-skipped") + if config.getoption("--skip-slow"): skip_slow = pytest.mark.skip(reason="Skipping slow tests") for item in items: @@ -102,26 +122,131 @@ def pytest_collection_modifyitems(config, items): if "extra_slow" in item.keywords: item.add_marker(skip_extra_slow) - manager: pytest_depends.DependencyManager = pytest_depends.managers[-1] - # Build the undirected graph as in `DependencyManager.sorted_items`. - dag = networkx.DiGraph() - for item in manager.items: - node_id = pytest_depends.clean_nodeid(item.nodeid) - dag.add_node(node_id) - for dependency in manager.dependencies[node_id].dependencies: - dag.add_edge(dependency, node_id) - # Mark dependency groups for xdist. - manager.groups = {} - for i, node_ids in enumerate(sorted(networkx.weakly_connected_components(dag), key=len, reverse=True)): - if len(node_ids) > 1: - for node_id in node_ids: - manager.nodeid_to_item[node_id]._nodeid = ( - f"{manager.nodeid_to_item[node_id]._nodeid}@dependency_group_{i}" - ) - - old_clean_nodeid = pytest_depends.main.clean_nodeid - # Hack into `clean_nodeid` so pytest_depends recognizes the renamed nodes. - pytest_depends.main.clean_nodeid = lambda nodeid: old_clean_nodeid(nodeid.split("@dependency_group_")[0]) + new_items = [] + for item in items: + if skip_slow and "slow" in item.keywords: + if show_skipped: + item.add_marker(pytest.mark.skip(reason="Skipping slow tests")) + else: + continue + elif skip_extra_slow and "extra_slow" in item.keywords: + if show_skipped: + item.add_marker(pytest.mark.skip(reason="Skipping extra-slow tests")) + else: + continue + new_items.append(item) + + manager = DependencyManager(new_items) + + # Show the extra information if requested + if config.getoption("show_dependencies"): + manager.print_name_map(config.getoption("verbose") > 1) + manager.print_processed_dependencies(config.getoption("color")) + + # Reorder the items so that tests run after their dependencies + items[:] = manager.items + + # If pytest-depends is installed, it will complain about renamed nodes whether it's used or not. + try: + import pytest_depends + except ImportError: + pass + else: + old_clean_nodeid = pytest_depends.main.clean_nodeid + # Hack into `clean_nodeid` so pytest_depends recognizes the renamed nodes. + pytest_depends.main.clean_nodeid = lambda nodeid: old_clean_nodeid(nodeid.split("@dependency_group_")[0]) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): + outcome = yield + result = outcome.get_result() + manager.register_result(item, result) + + # Measure GPU memory usage. (TODO: This excludes child processes) + if call.when == "call" and torch.cuda.is_available(): + # Free memory for more accurate reporting, and to reduce OOM risk with lots of workers. + # Cublas workspace can unnecessarily keep 100s of MBs of reserved memory. + torch._C._cuda_clearCublasWorkspaces() + # Lots of tensors tend to stay allocated until the next garbage collection. + # Collect only if the remaining memory is significant enough since it's costly. + if torch.cuda.memory_allocated() > 1e7: + gc.collect() + try: + # Actually free the memory. + torch.cuda.empty_cache() + except RuntimeError: + # Happens if the test broke cuda. + return + item.add_report_section( + call.when, + "resource usage", + json.dumps( + { + "duration": call.duration, + # Relevant value for OOM risk. Also look at global max since fast-llm resets stats. + "max_memory_reserved": max( + torch.cuda.max_memory_reserved(), fast_llm.logging._global_max_reserved + ), + # Actual memory usage from the test. + "max_memory_allocated": max( + torch.cuda.max_memory_allocated(), fast_llm.logging._global_max_allocated + ), + "memory_reserved": torch.cuda.memory_reserved(), + "memory_allocated": torch.cuda.memory_allocated(), + } + ), + ) + torch.cuda.reset_peak_memory_stats() + # Reset global stats for next test. + fast_llm.logging._global_max_reserved = 0 + fast_llm.logging._global_max_allocated = 0 + + +@pytest.hookimpl +def pytest_terminal_summary(terminalreporter): + resource_reports = {} + for reports in terminalreporter.stats.values(): + for report in reports: + if isinstance(report, pytest.TestReport): + for _, section in report.get_sections("Captured resource usage"): + if report.nodeid in resource_reports: + logging.error(f"Duplicate resource report for {report.nodeid}") + resource_reports[report.nodeid] = json.loads(section) + + if not resource_reports: + return + + terminalreporter.write_sep("=", "Highest gpu memory usage", bold=True) + sorted_nodeids = sorted( + resource_reports.keys(), + key=lambda nodeid: resource_reports[nodeid]["max_memory_reserved"], + reverse=True, + ) + for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: + terminalreporter.write_line( + f"{nodeid}:\n " + f"Max Reserved {resource_reports[nodeid]["max_memory_reserved"] / 1e6:.0f} MB | " + f"Max Allocated {resource_reports[nodeid]["max_memory_allocated"] / 1e6:.0f} MB | " + f"End Reserved {resource_reports[nodeid]["memory_reserved"] / 1e6:.0f} MB | " + f"End Allocated {resource_reports[nodeid]["memory_allocated"] / 1e6:.0f} MB | " + f"Duration {resource_reports[nodeid]["duration"]:.2f}" + ) + + +def pytest_runtest_call(item: pytest.Function): + if torch.cuda.is_available(): + # Empty cache to check is cuda is still working (TODO: Is there a better way? Can we kill the worker?) + try: + torch.cuda.empty_cache() + except RuntimeError: + pytest.skip("Cuda runtime unavailable due to an error in an earlier test.") + manager.handle_missing(item) + + +def pytest_unconfigure(): + global manager + manager = None @pytest.fixture(scope="session") @@ -133,4 +258,4 @@ def worker_resources(request) -> WorkerResources: def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" - return LoadGroupScheduling(config, log) + return xdist.scheduler.LoadGroupScheduling(config, log) diff --git a/tests/data/common.py b/tests/data/common.py index cacb28e6..2d3cb905 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -23,7 +23,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.common import TEST_VOCAB_SIZE +from tests.utils.dataset import TEST_VOCAB_SIZE def get_sampling_data( diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index de97eaa2..438782df 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -5,13 +5,13 @@ from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig from fast_llm.utils import Assert, normalize_probabilities -from tests.common import DATASET_CACHE, DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) +from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, get_test_dataset _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 1142d536..e951cc2b 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -8,6 +7,7 @@ get_test_data_and_compare_samples, ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py index 09929040..0ab7c7fe 100644 --- a/tests/data/test_concatenated_memmap.py +++ b/tests/data/test_concatenated_memmap.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTConcatenatedMemmapConfig -from tests.common import DATASET_CACHE, get_test_concatenated_memmap_dataset from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -8,6 +7,7 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES +from tests.utils.dataset import DATASET_CACHE, get_test_concatenated_memmap_dataset _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 280b3413..3f7d1a13 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -1,7 +1,7 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset def test_dataset_from_file(): diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7b614d2f..7472f195 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,13 +1,13 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig from fast_llm.data.tokenizer import Tokenizer -from tests.common import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset from tests.data.common import ( compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) +from tests.utils.dataset import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index be801220..fcd7756d 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -3,8 +3,8 @@ import pytest from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from tests.common import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config +from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 38679582..32d76fa4 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -7,13 +7,13 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.utils import Assert -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 299e2054..f8eedc5b 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig -from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -8,6 +7,7 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES +from tests.utils.dataset import DATASET_PREFIX, get_test_dataset GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7578a5f0..95da48e7 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -19,7 +19,7 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel from fast_llm.utils import Assert -from tests.common import requires_cuda +from tests.utils.utils import requires_cuda def _lm_head( diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 216f7828..05a62100 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,17 +17,11 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConvertConfig -from tests.common import ( - CONFIG_COMMON, - FORCE_REUSE_RESULTS, - HUGGINGFACE_CHECKPOINT_FORMAT, - REUSE_RESULTS, - TEST_MODEL, - TEST_MODEL_TYPE, - TEST_RESULTS_PATH, - requires_cuda, -) -from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor +from tests.common import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT, TEST_MODEL_TYPE +from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor +from tests.utils.model_configs import TEST_MODEL +from tests.utils.run_test_script import FORCE_REUSE_RESULTS, REUSE_RESULTS +from tests.utils.utils import TEST_RESULTS_PATH, requires_cuda TEST_MODEL_CONFIG_CLS = model_registry[TEST_MODEL_TYPE] TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_for_causal_lm_class() diff --git a/tests/test_config.py b/tests/test_config.py index 80bed418..e050cb23 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,7 +14,7 @@ from fast_llm.models.auto import trainer_registry from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested -from tests.common import TEST_RESULTS_PATH +from tests.utils.utils import TEST_RESULTS_PATH def run_without_import(cmd: str): diff --git a/tests/test_functional.py b/tests/test_functional.py index 908a5537..03a0ae8a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,7 +8,7 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert -from tests.common import requires_cuda +from tests.utils.utils import requires_cuda def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: diff --git a/tests/test_gpt_generate_and_forward.py b/tests/test_gpt_generate_and_forward.py index a16d4c71..6e8d4360 100644 --- a/tests/test_gpt_generate_and_forward.py +++ b/tests/test_gpt_generate_and_forward.py @@ -9,7 +9,9 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from tests.common import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT, TEST_MODEL, TEST_RESULTS_PATH, requires_cuda +from tests.common import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT +from tests.utils.model_configs import TEST_MODEL +from tests.utils.utils import TEST_RESULTS_PATH, requires_cuda def _prepare_checkpoint(model: str) -> str: diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 1857f0f8..3d821086 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -1,6 +1,8 @@ import pytest -from tests.common import ( +from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.dataset import DATASET_PREFIX +from tests.utils.model_configs import ( CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_LLAMA_FAST_LLM, @@ -11,9 +13,7 @@ CONFIG_SC1_MEGATRON, CONFIG_SC2_FAST_LLM, CONFIG_SC2_MEGATRON, - DATASET_PREFIX, ) -from tests.compare_tensor_logs import CompareConfig @pytest.mark.slow diff --git a/tests/test_mb.py b/tests/test_mb.py index 82ac4c25..fd613056 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -1,7 +1,8 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL -from tests.compare_tensor_logs import CompareConfig +from tests.common import CONFIG_COMMON +from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.model_configs import TEST_MODEL CONFIG_DF = CONFIG_COMMON + ["batch.depth_first_micro_batches=4"] CONFIG_BF = CONFIG_COMMON + ["batch.breadth_first_micro_batches=4"] diff --git a/tests/test_mb_seq_first.py b/tests/test_mb_seq_first.py index 345a7bc4..dd00fd5f 100644 --- a/tests/test_mb_seq_first.py +++ b/tests/test_mb_seq_first.py @@ -1,7 +1,8 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL -from tests.compare_tensor_logs import CompareConfig +from tests.common import CONFIG_COMMON +from tests.utils.compare_tensor_logs import CompareConfig +from tests.utils.model_configs import TEST_MODEL CONFIG_DF_SF = CONFIG_COMMON + ["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"] CONFIG_BF_SF = CONFIG_COMMON + ["batch.breadth_first_micro_batches=4", "model.base_model.sequence_first=True"] diff --git a/tests/test_ms.py b/tests/test_ms.py index 90d16672..55032620 100644 --- a/tests/test_ms.py +++ b/tests/test_ms.py @@ -1,6 +1,7 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL +from tests.common import CONFIG_COMMON +from tests.utils.model_configs import TEST_MODEL CONFIG_MS = CONFIG_COMMON + ["batch.micro_sequence_length=256"] diff --git a/tests/test_mtp.py b/tests/test_mtp.py index 71c55e0f..1f01954e 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -15,7 +15,7 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel from fast_llm.utils import Assert -from tests.common import get_hybrid_config, materialize_meta_tensors, requires_cuda +from tests.utils.utils import get_hybrid_config, materialize_meta_tensors, requires_cuda try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index bb468ceb..f5f09b1b 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,8 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert -from tests.common import CONFIG_COMMON, requires_cuda +from tests.common import CONFIG_COMMON +from tests.utils.utils import requires_cuda def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: diff --git a/tests/test_seq_first.py b/tests/test_seq_first.py index a8f4c036..9ead58e8 100644 --- a/tests/test_seq_first.py +++ b/tests/test_seq_first.py @@ -1,6 +1,7 @@ import pytest -from tests.common import CONFIG_COMMON, TEST_MODEL +from tests.common import CONFIG_COMMON +from tests.utils.model_configs import TEST_MODEL CONFIG_SF = CONFIG_COMMON + ["model.base_model.sequence_first=True"] diff --git a/tests/test_simple.py b/tests/test_simple.py index 3128626d..1523750f 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,6 +1,7 @@ import pytest -from tests.common import CONFIG_COMMON, CONFIG_FAST_LLM, TEST_MODEL +from tests.common import CONFIG_COMMON, CONFIG_FAST_LLM +from tests.utils.model_configs import TEST_MODEL def test_model_safe(run_test_script): diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 36c7b622..9e748544 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -21,7 +21,7 @@ from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel -from tests.common import get_hybrid_config, materialize_meta_tensors +from tests.utils.utils import get_hybrid_config, materialize_meta_tensors try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index 108a2898..9befe64f 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -31,7 +31,7 @@ from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies from fast_llm.utils import Assert, rms_diff -from tests.common import requires_cuda +from tests.utils.utils import requires_cuda @requires_cuda diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py similarity index 100% rename from tests/compare_tensor_logs.py rename to tests/utils/compare_tensor_logs.py diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py new file mode 100644 index 00000000..72888dfd --- /dev/null +++ b/tests/utils/dataset.py @@ -0,0 +1,80 @@ +import pathlib +import random +import string + +import numpy as np +import yaml + +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample +from tests.utils.utils import TEST_RESULTS_PATH + +TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" +TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +DATASET_CACHE = TEST_RESULTS_PATH / "dataset" +DATASET_PREFIX = DATASET_CACHE / "common" / "dataset" +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" +TEST_VOCAB_SIZE = 8192 +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 + + +def get_test_dataset( + prefix: pathlib.Path = DATASET_PREFIX, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, + max_spans: int = 0, +): + if not TOKENIZER_FILE.is_file(): + import transformers + + transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + + if not ( + prefix.with_suffix(".idx").is_file() + and prefix.with_suffix(".bin").is_file() + and prefix.parent.joinpath("fast_llm_config.yaml").is_file() + ): + import transformers + + texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() + tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) + + samples = [ + GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts + ] + if max_spans > 0: + lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) + for sample, span in zip(samples, spans): + span = np.unique(span) + sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + + GPTMemmapDataset.write_dataset(prefix, samples) + yaml.safe_dump( + {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + ) + + +def get_test_concatenated_memmap_dataset( + path: pathlib.Path, + num_files: int, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, + seed_shift: int = 55, +): + index_file = path / "index.txt" + if not index_file.is_file(): + for i in range(num_files): + get_test_dataset( + prefix=path / f"dataset_{i}", + seed=seed + i * seed_shift, + num_tokens=num_tokens, + characters=characters, + vocab_size=vocab_size, + ) + index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) diff --git a/tests/utils/depends.py b/tests/utils/depends.py new file mode 100644 index 00000000..3972a066 --- /dev/null +++ b/tests/utils/depends.py @@ -0,0 +1,211 @@ +import re + +import colorama +import networkx +import pytest + +MARKER_NAME = "depends_on" +MARKER_KWARG_ID = "name" +MARKER_KWARG_DEPENDENCIES = "on" + +REGEX_PARAMETERS = re.compile(r"\[.+\]$") + + +def clean_nodeid(nodeid): + return nodeid.replace("::()::", "::").split("@dependency_group_")[0] + + +def get_names(item): + names = set() + + # Node id + nodeid = clean_nodeid(item.nodeid) + names.add(nodeid) + + # Node id without parameter + nodeid = REGEX_PARAMETERS.sub("", nodeid) + names.add(nodeid) + + # Node id scopes + while "::" in nodeid: + nodeid = nodeid.rsplit("::", 1)[0] + names.add(nodeid) + + # Custom name + for marker in item.iter_markers(): + if marker.name == MARKER_NAME and MARKER_KWARG_ID in marker.kwargs: + for name in as_list(marker.kwargs[MARKER_KWARG_ID]): + names.add(name) + + return names + + +def as_list(lst): + return [lst] if isinstance(lst, str) else lst + + +STEPS = ["setup", "call", "teardown"] +GOOD_OUTCOME = "passed" + + +class DependencyManager: + """ + A simplified and improved version of pytest-depends. Main differences are the following: + * Add compatibility with pytest-xdist: group connected components of the dependency graph together, + and rename them with the `@dependency_group_{i}` suffix so they are run in the same worker, assuming + group scheduling is used. + * Improved parameterized dependencies so tests can depend on other tests with matching parametrization. + Ex. a test `test_model` with parameter `model` can depend on `test_other[{model}]`, + then `test_model[llama]` will depend on `test_other[llama]`, and so on. + * Improved description of missing/failed dependencies. + * Some option hard-coded for Fast-LLM. + """ + + def __init__(self, items: list[pytest.Function]): + self._items = items + self._name_to_nodeids: dict[str, list[str]] = {} + self._nodeid_to_item: dict[str, pytest.Function] = {} + self._results: dict[str, dict[str, str]] = {} + self._dependencies: dict[str, set[str]] = {} + self._unresolved: dict[str, set[str]] = {} + + for item in self._items: + nodeid = clean_nodeid(item.nodeid) + # Add the mapping from nodeid to the test item + self._nodeid_to_item[nodeid] = item + # Add the mappings from all names to the node id + for name in get_names(item): + if name not in self._name_to_nodeids: + self._name_to_nodeids[name] = [] + self._name_to_nodeids[name].append(nodeid) + # Create the object that will contain the results of this test + self._results[nodeid] = {} + + for item in self._items: + # Process the dependencies of this test + # This uses the mappings created in the previous loop, and can thus not be merged into that loop + nodeid = clean_nodeid(item.nodeid) + self._dependencies[nodeid], self._unresolved[nodeid] = self._resolve_dependencies(item) + + self._items = self._sort_dependencies() + + @property + def items(self) -> list[pytest.Function]: + return self._items + + def register_result(self, item: pytest.Function, result: pytest.TestReport): + self._results[clean_nodeid(item.nodeid)][result.when] = result.outcome + + def handle_missing(self, item: pytest.Function): + nodeid = clean_nodeid(item.nodeid) + if missing := self._unresolved[nodeid]: + pytest.fail(f'{item.nodeid} depends on {", ".join(missing)}, which was not found', False) + + if failed := [ + f"{dependency} ({", ".join(f"{key}: {value}" for key, value in self._results[dependency].items()) if self._results[dependency] else "missing"})" + for dependency in self._dependencies[nodeid] + if not all(self._results[dependency].get(step, None) == "passed" for step in ("setup", "call", "teardown")) + ]: + pytest.skip(f'{item.nodeid} depends on {", ".join(failed)}') + + def _resolve_dependencies(self, item: pytest.Function): + dependencies = set() + unresolved = set() + + if "skip" in item.keywords: + return dependencies, unresolved + + nodeid = clean_nodeid(item.nodeid) + + for marker in item.iter_markers(): + if marker.name == MARKER_NAME: + for dependency in as_list(marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])): + dependency = dependency.format(**item.callspec.params) + + # If the name is not known, try to make it absolute (ie file::[class::]method) + if dependency not in self._name_to_nodeids: + absolute_dependency = self._get_absolute_nodeid(dependency, nodeid) + if absolute_dependency in self._name_to_nodeids: + dependency = absolute_dependency + + # Add all items matching the name + if dependency in self._name_to_nodeids: + for nodeid in self._name_to_nodeids[dependency]: + dependencies.add(nodeid) + else: + unresolved.add(dependency) + + return dependencies, unresolved + + def _sort_dependencies(self): + # Build a directed graph for sorting + dag = networkx.DiGraph() + + for item in self.items: + nodeid = clean_nodeid(item.nodeid) + dag.add_node(nodeid) + for dependency in self._dependencies[nodeid]: + dag.add_edge(dependency, nodeid) + + for i, nodeids in enumerate(sorted(networkx.weakly_connected_components(dag), key=len, reverse=True)): + if len(nodeids) > 1: + for nodeid in nodeids: + self._nodeid_to_item[nodeid]._nodeid = ( + f"{self._nodeid_to_item[nodeid]._nodeid}@dependency_group_{i}" + ) + + return [self._nodeid_to_item[nodeid] for nodeid in networkx.topological_sort(dag)] + + @staticmethod + def _get_absolute_nodeid(nodeid: str, scope: str): + parts = nodeid.split("::") + # Completely relative (test_name), so add the full current scope (either file::class or file) + if len(parts) == 1: + base_nodeid = scope.rsplit("::", 1)[0] + nodeid = f"{base_nodeid}::{nodeid}" + # Contains some scope already (Class::test_name), so only add the current file scope + elif "." not in parts[0]: + base_nodeid = scope.split("::", 1)[0] + nodeid = f"{base_nodeid}::{nodeid}" + return clean_nodeid(nodeid) + + def print_name_map(self, verbose: bool = False): + """Print a human-readable version of the name -> test mapping.""" + print("Available dependency names:") + for name, nodeids in sorted(self._name_to_nodeids.items(), key=lambda x: x[0]): + if len(nodeids) == 1: + if name == nodeids[0]: + # This is just the base name, only print this when verbose + if verbose: + print(f" {name}") + else: + # Name refers to a single node id, so use the short format + print(f" {name} -> {nodeids[0]}") + else: + # Name refers to multiple node ids, so use the long format + print(f" {name} ->") + for nodeid in sorted(nodeids): + print(f" {nodeid}") + + def print_processed_dependencies(self, colors: bool = False): + """Print a human-readable list of the processed dependencies.""" + missing = "MISSING" + if colors: + missing = f"{colorama.Fore.RED}{missing}{colorama.Fore.RESET}" + colorama.init() + try: + print("Dependencies:") + + for nodeid in sorted(self._dependencies): + descriptions = [] + for dependency in self._dependencies[nodeid]: + descriptions.append(dependency) + for dependency in self._unresolved[nodeid]: + descriptions.append(f"{dependency} ({missing})") + if descriptions: + print(f" {nodeid} depends on") + for description in sorted(descriptions): + print(f" {description}") + finally: + if colors: + colorama.deinit() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py new file mode 100644 index 00000000..26eebf4f --- /dev/null +++ b/tests/utils/model_configs.py @@ -0,0 +1,233 @@ +import os + +from fast_llm.models.gpt.config import ( + LlamaGPTHuggingfaceCheckpointFormat, + MistralGPTHuggingfaceCheckpointFormat, + MixtralGPTHuggingfaceCheckpointFormat, + MTPLlamaGPTHuggingfaceCheckpointFormat, + Qwen2GPTHuggingfaceCheckpointFormat, + Starcoder2GPTHuggingfaceCheckpointFormat, +) +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from tests.utils.dataset import DATASET_PREFIX, TEST_VOCAB_SIZE + +_LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) +TEST_MODEL = os.environ.get("MODEL", "llama") +CONFIG_BASE_FAST_LLM = [ + "training.logs.interval=1", + "run.tensor_logs.save=True", + "run.tensor_logs.show=False", + "model.base_model.transformer.num_layers=2", + "model.base_model.transformer.hidden_size=256", + "model.base_model.transformer.num_attention_heads=8", + "model.base_model.transformer.init_method_std=0.022", + f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", + f"model.multi_stage.debug_param_init={_LOG_LEVEL}", + f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", + f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", + f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + "model.multi_stage.debug_tensor_parallel=True", + "model.distributed.reproducible_init=True", + "model.distributed.timeout=10", + "training.train_iters=2", + "training.num_workers=0", + "training.timeout=30", + "batch.batch_size=8", + "batch.sequence_length=512", + "data.datasets.training.type=slice", + "data.datasets.training.end=0.969", + "data.datasets.training.dataset.type=memmap", + f"data.datasets.training.dataset.path={DATASET_PREFIX}", + "data.datasets.validation.type=slice", + "data.datasets.validation.begin=0.969", + "data.datasets.validation.end=0.999", + "data.datasets.validation.dataset.type=memmap", + f"data.datasets.validation.dataset.path={DATASET_PREFIX}", + "data.datasets.test.type=slice", + "data.datasets.test.begin=0.999", + "data.datasets.test.end=1", + "data.datasets.test.dataset.type=memmap", + f"data.datasets.test.dataset.path={DATASET_PREFIX}", + "optimizer.learning_rate.base=0.0001", +] +CONFIG_BASE_MEGATRON = [ + "--num-layers=2", + "--hidden-size=256", + "--num-attention-heads=8", + "--log-interval=1", + "--train-iters=2", + "--eval-iters=0", + "--hidden-dropout=0", + "--attention-dropout=0", + f"--debug_param_init={_LOG_LEVEL}", + f"--debug_layer_outputs={_LOG_LEVEL}", + f"--debug_layer_gradients={_LOG_LEVEL}", + f"--debug_all_param_gradients={_LOG_LEVEL}", + "--debug_param_update=0", + "--global-batch-size=8", + "--max-position-embeddings=512", + "--seq-length=512", + "--init-method-std=0.022", + "--lr=0.0001", + "--num-workers=0", + "--valid-num-workers=0", + "--tokenizer-type=NullTokenizer", + # Megatron messes with the vocab size, so we have to subtract 1. + f"--vocab-size={TEST_VOCAB_SIZE - 1}", + f"--data-path={DATASET_PREFIX}", + "--lr-decay-style=constant", + # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) + "--use-mcore-models", + # local implementation doesn't allow for RMS norm. + "--transformer-impl=transformer_engine", +] +CONFIG_SC1_FAST_LLM = CONFIG_BASE_FAST_LLM + ["model.base_model.max_position_embeddings=512"] +CONFIG_SC1_MEGATRON = CONFIG_BASE_MEGATRON + ["--group-query-attention"] +CONFIG_SC1_COMMON = CONFIG_SC1_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_GPT2_FAST_LLM = CONFIG_SC1_FAST_LLM + ["model.base_model.transformer.head_groups=8"] +CONFIG_GPT2_MEGATRON = CONFIG_BASE_MEGATRON +CONFIG_GPT2_COMMON = CONFIG_GPT2_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_SC2_FAST_LLM = CONFIG_BASE_FAST_LLM + [ + "model.base_model.transformer.head_groups=4", + "model.base_model.transformer.rotary.type=default", +] +CONFIG_SC2_MEGATRON = CONFIG_SC1_MEGATRON + [ + "--num-query-groups=4", + "--use-rotary-position-embeddings", + "--no-position-embedding", +] +CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [ + "--swiglu", + "--disable-bias-linear", + "--normalization=RMSNorm", + "--ffn-hidden-size=1024", + "--untie-embeddings-and-output-weights", +] +CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [ + "model.base_model.transformer.gated=True", + "model.base_model.transformer.activation_type=silu", + "model.base_model.transformer.add_linear_biases=False", + "model.base_model.transformer.normalization.type=rms_norm", + "model.base_model.transformer.ffn_hidden_size=1024", + "model.base_model.tie_word_embeddings=False", +] +CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA3_MEGATRON = None +CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.transformer.rotary.type=llama3", +] +CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_QWEN2_MEGATRON = None +CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [ + "model.base_model.transformer.gated=True", + "model.base_model.transformer.activation_type=silu", + "model.base_model.transformer.add_linear_biases=only_attn_qkv", + "model.base_model.transformer.normalization.type=rms_norm", + "model.base_model.transformer.ffn_hidden_size=1024", + "model.base_model.tie_word_embeddings=False", +] +CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA_YARN_MEGATRON = None +CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.transformer.rotary.type=yarn", +] +CONFIG_LLAMA_YARN_COMMON = CONFIG_LLAMA_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [ + "--num-experts=4", + "--moe-router-topk=4", +] +CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.transformer.num_experts=4", + "model.base_model.transformer.num_experts_per_token=4", +] +CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_MIXTRAL_YARN_MEGATRON = None +CONFIG_MIXTRAL_YARN_FAST_LLM = CONFIG_MIXTRAL_FAST_LLM + [ + "model.base_model.transformer.rotary.type=yarn", +] +CONFIG_MIXTRAL_YARN_COMMON = CONFIG_MIXTRAL_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA_MTP_MEGATRON = None +CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.prediction_heads=4", +] +CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] +CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] +CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM +_CONFIGS = { + "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), + "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), + "starcoder2": ( + "gpt", + CONFIG_SC2_FAST_LLM, + CONFIG_SC2_MEGATRON, + CONFIG_SC2_COMMON, + Starcoder2GPTHuggingfaceCheckpointFormat, + ), + "llama": ( + "gpt", + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, + CONFIG_LLAMA_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), + "llama3": ( + "gpt", + CONFIG_LLAMA3_FAST_LLM, + CONFIG_LLAMA3_MEGATRON, + CONFIG_LLAMA3_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), + "qwen2": ( + "gpt", + CONFIG_QWEN2_FAST_LLM, + CONFIG_QWEN2_MEGATRON, + CONFIG_QWEN2_COMMON, + Qwen2GPTHuggingfaceCheckpointFormat, + ), + "llama-yarn": ( + "gpt", + CONFIG_LLAMA_YARN_FAST_LLM, + CONFIG_LLAMA_YARN_MEGATRON, + CONFIG_LLAMA_YARN_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), + "mistral": ( + "gpt", + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, + CONFIG_LLAMA_COMMON, + MistralGPTHuggingfaceCheckpointFormat, + ), + "mixtral": ( + "gpt", + CONFIG_MIXTRAL_FAST_LLM, + CONFIG_MIXTRAL_MEGATRON, + CONFIG_MIXTRAL_COMMON, + MixtralGPTHuggingfaceCheckpointFormat, + ), + "llamba": ( + "hybrid_ssm", + CONFIG_LLAMBA_FAST_LLM, + CONFIG_LLAMBA_MEGATRON, + CONFIG_LLAMBA_COMMON, + LLambaHuggingfaceCheckpointFormat, + ), + "mixtral-yarn": ( + "gpt", + CONFIG_MIXTRAL_YARN_FAST_LLM, + CONFIG_MIXTRAL_YARN_MEGATRON, + CONFIG_MIXTRAL_YARN_COMMON, + MixtralGPTHuggingfaceCheckpointFormat, + ), + "llama-mtp": ( + "gpt", + CONFIG_LLAMA_MTP_FAST_LLM, + CONFIG_LLAMA_MTP_MEGATRON, + CONFIG_LLAMA_MTP_COMMON, + MTPLlamaGPTHuggingfaceCheckpointFormat, + ), +} + +TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py new file mode 100644 index 00000000..27d82869 --- /dev/null +++ b/tests/utils/run_test_script.py @@ -0,0 +1,96 @@ +import os +import shutil +import subprocess + +import pytest +import torch + +from fast_llm.tools.train import CliTrainingConfig +from tests.utils.compare_tensor_logs import CompareConfig, compare_tensor_logs +from tests.utils.dataset import get_test_dataset +from tests.utils.model_configs import TEST_MODEL_TYPE +from tests.utils.utils import TEST_RESULTS_PATH + +FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 +REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 +ARTIFACT_PATH = "runs/0/artifacts" + + +@pytest.fixture(scope="session") +def run_test_script(worker_resources): + def do_run_test_script( + name: str, + script: list[str], + num_gpus: int = 1, + *, + model_type: str = TEST_MODEL_TYPE, + is_megatron: bool = False, + compare: str | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + ): + if torch.cuda.device_count() < num_gpus: + pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") + env = os.environ.copy() + if is_megatron: + # Prevent Megatron from complaining. + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["NVTE_FLASH_ATTN"] = "0" + path = TEST_RESULTS_PATH / name + skip = False + artifact_path = path / ARTIFACT_PATH + if path.exists(): + assert path.is_dir() + # TODO: Better way to check if the previous attempt succeeded. + if ( + REUSE_RESULTS + and artifact_path.is_dir() + and len(list((artifact_path / "0").iterdir())) >= (1 if is_megatron else 3) + ): + skip = True + elif FORCE_REUSE_RESULTS: + raise RuntimeError(artifact_path) + else: + shutil.rmtree(path) + elif FORCE_REUSE_RESULTS: + raise RuntimeError(path) + if prepare_fn is not None: + skip = prepare_fn(TEST_RESULTS_PATH / name, None if compare is None else TEST_RESULTS_PATH / compare, skip) + if is_megatron: + script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] + else: + script = [model_type, *script, f"run.experiment_dir={path}"] + header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] + command = [ + "python", + "-m", + "torch.distributed.run", + f"--nproc-per-node={num_gpus}", + f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}", + f"--master-port={worker_resources.torchrun_port}", + *header, + *script, + ] + print(" ".join(command)) + if skip: + print("Reusing existing run.") + else: + get_test_dataset() + if num_gpus == 1 and not is_megatron: + CliTrainingConfig.parse_and_run(script) + else: + completed_proc = subprocess.run(command, env=env, timeout=60) + if completed_proc.returncode: + raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") + if compare and do_compare: + if compare_fn is not None: + compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare) + compare_tensor_logs( + TEST_RESULTS_PATH / compare / ARTIFACT_PATH, + TEST_RESULTS_PATH / name / ARTIFACT_PATH, + config, + ) + + return do_run_test_script diff --git a/tests/utils/utils.py b/tests/utils/utils.py new file mode 100644 index 00000000..f37c1cb2 --- /dev/null +++ b/tests/utils/utils.py @@ -0,0 +1,52 @@ +import os +import pathlib + +import pytest +import torch + +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig + +TEST_RESULTS_PATH = pathlib.Path(os.environ.get("TEST_RESULTS_PATH", "/tmp/fast_llm_tests")).resolve() +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), + ssm=SSMConfig(), + hybrid_block_layout=hybrid_block_layout, + prediction_heads=prediction_heads, + default_mtp_type=default_mtp_type, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config From ced34e08ce29bd2f4ac121609a5c49e47beefe9b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 17:41:26 -0400 Subject: [PATCH 20/69] stuff --- tests/conftest.py | 4 ++- tests/test_checkpoint.py | 49 +++++++++++++------------- tests/test_gpt_generate_and_forward.py | 9 +++-- tests/test_match_megatron.py | 14 ++++---- tests/test_mb.py | 15 ++++---- tests/test_mb_seq_first.py | 7 ++-- tests/test_ms.py | 7 ++-- tests/test_multi_stage.py | 2 +- tests/test_seq_first.py | 9 +++-- tests/test_simple.py | 15 ++++---- tests/utils/depends.py | 3 +- 11 files changed, 67 insertions(+), 67 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 284f4140..99490f1b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,8 @@ from tests.utils.depends import DependencyManager # Make fixtures available globally without import +from tests.utils.run_test_script import run_test_script # isort: skip + manager: DependencyManager | None = None @@ -148,7 +150,7 @@ def pytest_collection_modifyitems(config, items: list[pytest.Function]): # If pytest-depends is installed, it will complain about renamed nodes whether it's used or not. try: - import pytest_depends + import pytest_depends.main except ImportError: pass else: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 05a62100..55d30d3f 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,9 +17,8 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConvertConfig -from tests.common import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT, TEST_MODEL_TYPE from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT, TEST_MODEL, TEST_MODEL_TYPE from tests.utils.run_test_script import FORCE_REUSE_RESULTS, REUSE_RESULTS from tests.utils.utils import TEST_RESULTS_PATH, requires_cuda @@ -65,7 +64,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): shutil.copy(compare_path / path, test_path / path) -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval"]) def test_resume(run_test_script): # Resume from iteration=1 and compare outputs with the baseline run. run_test_script( @@ -82,7 +81,7 @@ def test_resume(run_test_script): ) -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval"]) def test_resume_frozen(run_test_script): # Resume with frozen mlp. No comparison. run_test_script( @@ -113,7 +112,7 @@ def _run_conversion(config: ConvertConfig): CONVERT_PATH = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_convert_model" -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval"]) def test_convert_distributed_to_fast_llm(): _run_conversion( ConvertConfig( @@ -130,7 +129,7 @@ def test_convert_distributed_to_fast_llm(): ) -@pytest.mark.depends(on=["test_convert_distributed_to_fast_llm"]) +@pytest.mark.depends_on(on=["test_convert_distributed_to_fast_llm"]) def test_convert_fast_llm_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") @@ -149,7 +148,7 @@ def test_convert_fast_llm_to_huggingface(): ) -@pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) +@pytest.mark.depends_on(on=["test_convert_fast_llm_to_huggingface"]) def test_convert_huggingface_to_distributed(): _run_conversion( ConvertConfig( @@ -166,7 +165,7 @@ def test_convert_huggingface_to_distributed(): ) -@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +@pytest.mark.depends_on(on=["test_checkpoint_and_eval"]) def test_convert_distributed_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") @@ -185,7 +184,7 @@ def test_convert_distributed_to_huggingface(): ) -@pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) +@pytest.mark.depends_on(on=["test_convert_distributed_to_huggingface"]) def test_convert_huggingface_to_fast_llm(): _run_conversion( ConvertConfig( @@ -202,7 +201,7 @@ def test_convert_huggingface_to_fast_llm(): ) -@pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) +@pytest.mark.depends_on(on=["test_convert_huggingface_to_fast_llm"]) def test_convert_fast_llm_to_distributed(): _run_conversion( ConvertConfig( @@ -219,7 +218,7 @@ def test_convert_fast_llm_to_distributed(): ) -@pytest.mark.depends(on=["test_convert_huggingface_to_distributed", "test_convert_fast_llm_to_distributed"]) +@pytest.mark.depends_on(on=["test_convert_huggingface_to_distributed", "test_convert_fast_llm_to_distributed"]) def test_converted_distributed(): # Compare the fast llm weights # TODO: Compare configs @@ -235,7 +234,7 @@ def test_converted_distributed(): assert (w[key] == w1[key]).all(), (w[key], w1[key]) -@pytest.mark.depends(on=["test_convert_distributed_to_fast_llm", "test_convert_huggingface_to_fast_llm"]) +@pytest.mark.depends_on(on=["test_convert_distributed_to_fast_llm", "test_convert_huggingface_to_fast_llm"]) def test_converted_fast_llm(): s0 = safetensors.torch.load_file(CONVERT_PATH / "fast_llm_0" / "model_0.safetensors") s1 = safetensors.torch.load_file(CONVERT_PATH / "fast_llm_1" / "model_0.safetensors") @@ -245,7 +244,7 @@ def test_converted_fast_llm(): assert (s0[key] == s1[key]).all(), (key, s0, s1) -@pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface", "test_convert_distributed_to_huggingface"]) +@pytest.mark.depends_on(on=["test_convert_fast_llm_to_huggingface", "test_convert_distributed_to_huggingface"]) def test_converted_huggingface(): h0 = safetensors.torch.load_file(CONVERT_PATH / "huggingface_0" / "model_0.safetensors") h1 = safetensors.torch.load_file(CONVERT_PATH / "huggingface_1" / "model_0.safetensors") @@ -263,7 +262,7 @@ def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMM config_ref.base_model.compare_architecture(config_test.base_model) -@pytest.mark.depends(on=["test_converted_distributed"]) +@pytest.mark.depends_on(on=["test_converted_distributed"]) def test_load_pretrained_distributed_checkpoint(): config = TEST_MODEL_CONFIG_CLS.from_dict( yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r"))["model"], strict=False @@ -283,7 +282,7 @@ def test_load_pretrained_distributed_checkpoint(): assert (state_shards[f"{shard_name}_shard"] == model.get_shard(shard_name)).all() -@pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( CheckpointLoadConfig( @@ -315,7 +314,7 @@ def test_load_converted_distributed_checkpoint(): assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( CheckpointLoadConfig( @@ -346,7 +345,7 @@ def test_load_converted_fast_llm_checkpoint(): assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( CheckpointLoadConfig( @@ -378,7 +377,7 @@ def test_load_converted_huggingface_checkpoint(): assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_converted_huggingface_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_converted_fast_llm_checkpoint", "test_load_converted_huggingface_checkpoint"]) def test_run_converted_model(): model_ref = TEST_MODEL_HF_CLS.from_pretrained( CheckpointLoadConfig( @@ -427,7 +426,7 @@ def test_run_converted_model(): @pytest.mark.slow -@pytest.mark.depends(on=["test_load_converted_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint"]) def test_load_pretrained_distributed_in_dp2(run_test_script): run_test_script( f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2", @@ -443,7 +442,7 @@ def test_load_pretrained_distributed_in_dp2(run_test_script): ) -@pytest.mark.depends(on=["test_load_converted_distributed_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint"]) def test_load_pretrained_distributed_with_config(run_test_script): run_test_script( f"test_{TEST_MODEL}_load_pretrained_distributed_with_config", @@ -458,7 +457,7 @@ def test_load_pretrained_distributed_with_config(run_test_script): ) -@pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) +@pytest.mark.depends_on(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" pretrained_config_ref = CheckpointLoadConfig( @@ -503,7 +502,7 @@ def test_load_pretrained_in_dp2_match_checkpoint(): @pytest.mark.slow -@pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): # This also tests conversion which uses `FastLLMModel.from_checkpoint` pretrained_config_ref = CheckpointLoadConfig( @@ -526,7 +525,7 @@ def test_load_distributed_checkpoint_dp2(): @pytest.mark.slow -@pytest.mark.depends(on=["test_load_converted_fast_llm_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) +@pytest.mark.depends_on(on=["test_load_converted_fast_llm_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_pretrained_fast_llm_in_dp2(run_test_script): run_test_script( f"test_{TEST_MODEL}_load_pretrained_fast_llm_in_dp2", @@ -560,7 +559,9 @@ def test_load_pretrained_fast_llm_in_dp2(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_load_converted_huggingface_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"]) +@pytest.mark.depends_on( + on=["test_load_converted_huggingface_checkpoint", "test_load_pretrained_in_dp2_match_checkpoint"] +) def test_load_pretrained_huggingface_in_dp2(run_test_script): run_test_script( f"test_{TEST_MODEL}_load_pretrained_huggingface_in_dp2", diff --git a/tests/test_gpt_generate_and_forward.py b/tests/test_gpt_generate_and_forward.py index 6e8d4360..06cfd803 100644 --- a/tests/test_gpt_generate_and_forward.py +++ b/tests/test_gpt_generate_and_forward.py @@ -9,8 +9,7 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from tests.common import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT, TEST_MODEL from tests.utils.utils import TEST_RESULTS_PATH, requires_cuda @@ -262,7 +261,7 @@ def test_export_for_generate(run_test_script): @pytest.mark.slow @requires_cuda -@pytest.mark.depends(on=["test_export_for_generate"]) +@pytest.mark.depends_on(on=["test_export_for_generate"]) @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", [ @@ -322,7 +321,7 @@ def test_generate_from_model( @requires_cuda @pytest.mark.slow -@pytest.mark.depends(on=["test_export_for_generate"]) +@pytest.mark.depends_on(on=["test_export_for_generate"]) def test_small_generate_from_model(): _test_generate_from_model( TEST_RESULTS_PATH / f"test_{TEST_MODEL}_export_for_generate/export/{HUGGINGFACE_CHECKPOINT_FORMAT.name}/1", @@ -370,7 +369,7 @@ def test_forward_return_hidden_states(model_and_tokenizer): @pytest.mark.slow @requires_cuda -@pytest.mark.depends(on=["test_export_for_generate"]) +@pytest.mark.depends_on(on=["test_export_for_generate"]) def test_small_forward_return_hidden_states(): _test_forward_return_hidden_states( TEST_RESULTS_PATH / f"test_{TEST_MODEL}_export_for_generate/export/{HUGGINGFACE_CHECKPOINT_FORMAT.name}/1", diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 3d821086..7d89c80a 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -29,7 +29,7 @@ def test_sc1_meg(run_test_script): ] -@pytest.mark.depends(on=["test_sc1_meg"]) +@pytest.mark.depends_on(on=["test_sc1_meg"]) def test_sc1_match_meg(run_test_script): # Starcoder 1 (GPT2 with MQA) with Fast-llm. # QKV tensors are in a different format. @@ -50,13 +50,13 @@ def test_sc1_match_meg(run_test_script): @pytest.mark.slow @pytest.mark.skip(reason="Skipping mostly redundant test") -@pytest.mark.depends(on=["test_sc1_match_meg"]) +@pytest.mark.depends_on(on=["test_sc1_match_meg"]) def test_sc2_meg(run_test_script): # Starcoder 2 (GPT2 with MQA and RoPE) with Megatron. run_test_script("test_sc2_meg", CONFIG_SC2_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) -@pytest.mark.depends(on=["test_sc2_meg"]) +@pytest.mark.depends_on(on=["test_sc2_meg"]) def test_sc2_match_meg(run_test_script): # Starcoder 2 (GPT2 with MQA and RoPE) with Fast-llm. # QKV tensors are in a different format, @@ -83,7 +83,7 @@ def test_gpt2_meg(run_test_script): run_test_script("test_gpt2_meg", CONFIG_GPT2_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) -@pytest.mark.depends(on=["test_gpt2_meg"]) +@pytest.mark.depends_on(on=["test_gpt2_meg"]) def test_gpt2_match_meg(run_test_script): # GPT2 (MHA, layer norm, absolute embeddings) with Fast-llm. # QKV tensors are in a different format. @@ -109,7 +109,7 @@ def test_mistral_meg(run_test_script): run_test_script("test_mistral_meg", CONFIG_LLAMA_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) -@pytest.mark.depends(on=["test_mistral_meg"]) +@pytest.mark.depends_on(on=["test_mistral_meg"]) def test_mistral_match_meg(run_test_script): # Mistral with Fast-LLM. run_test_script( @@ -135,9 +135,11 @@ def test_mixtral_meg(run_test_script): run_test_script("test_mixtral_meg", CONFIG_MIXTRAL_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) -@pytest.mark.depends(on=["test_mixtral_meg"]) +@pytest.mark.depends_on(on=["test_mixtral_meg"]) def test_mixtral_match_meg(run_test_script): # Mistral with Fast-LLM. + # TODO: Fix dropless MOE + pytest.fail("Test fails, aborting to avoid breaking cuda", False) run_test_script( "test_mixtral_match_meg", CONFIG_MIXTRAL_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], diff --git a/tests/test_mb.py b/tests/test_mb.py index fd613056..4df6e510 100644 --- a/tests/test_mb.py +++ b/tests/test_mb.py @@ -1,8 +1,7 @@ import pytest -from tests.common import CONFIG_COMMON from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, TEST_MODEL CONFIG_DF = CONFIG_COMMON + ["batch.depth_first_micro_batches=4"] CONFIG_BF = CONFIG_COMMON + ["batch.breadth_first_micro_batches=4"] @@ -16,7 +15,7 @@ def test_model_df4(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_df4"]) +@pytest.mark.depends_on(on=["test_model_df4"]) def test_model_df4_z3(run_test_script): # Gradient accumulation with ZeRO-3. run_test_script( @@ -28,20 +27,20 @@ def test_model_df4_z3(run_test_script): ) -@pytest.mark.depends(on=["test_model_df4"], scope="session") +@pytest.mark.depends_on(on=["test_model_df4"], scope="session") def test_model_bf4(run_test_script): # Breadth-first gradient accumulation baseline. run_test_script(f"test_{TEST_MODEL}_bf4", CONFIG_BF, compare=f"test_{TEST_MODEL}_df4") -@pytest.mark.depends(on=["test_model_df4", "test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_df4", "test_model_bf4"]) def test_model_bf2_df2(run_test_script): # Mixed gradient accumulation baseline. run_test_script(f"test_{TEST_MODEL}_bf2_df2", CONFIG_BF_DF, compare=f"test_{TEST_MODEL}_df4") @pytest.mark.slow -@pytest.mark.depends(on=["test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_bf4"]) def test_model_pp2s2_bf4(run_test_script): # Pipeline-parallel without tied weights. run_test_script( @@ -53,7 +52,7 @@ def test_model_pp2s2_bf4(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_bf4"]) def test_model_pp2s1_bf4(run_test_script): # Pipeline-parallel with tied weights. run_test_script( @@ -66,7 +65,7 @@ def test_model_pp2s1_bf4(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_bf4"]) +@pytest.mark.depends_on(on=["test_model_bf4"]) def test_model_dp2_tp2_pp2s2_bf4(run_test_script): # Simple 3d parallelism # TODO: Test fails diff --git a/tests/test_mb_seq_first.py b/tests/test_mb_seq_first.py index dd00fd5f..bb3d1e27 100644 --- a/tests/test_mb_seq_first.py +++ b/tests/test_mb_seq_first.py @@ -1,8 +1,7 @@ import pytest -from tests.common import CONFIG_COMMON from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, TEST_MODEL CONFIG_DF_SF = CONFIG_COMMON + ["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"] CONFIG_BF_SF = CONFIG_COMMON + ["batch.breadth_first_micro_batches=4", "model.base_model.sequence_first=True"] @@ -20,7 +19,7 @@ def test_model_df4_sf(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_df4_sf"]) +@pytest.mark.depends_on(on=["test_model_df4_sf"]) def test_model_dp2_sp2_df4(run_test_script): # Sequence-tensor-parallel with gradient accumulation. # TODO: Compiled cross-entropy broken for this config @@ -39,7 +38,7 @@ def test_model_dp2_sp2_df4(run_test_script): @pytest.mark.slow @pytest.mark.skip(reason="Test is broken.") -@pytest.mark.depends(on=["test_model_df4_sf"]) +@pytest.mark.depends_on(on=["test_model_df4_sf"]) def test_model_dp2_sp2_pp2s1(run_test_script): # 3d-parallel with sequence-tensor-parallel. # TODO: Compiled cross-entropy broken for this config diff --git a/tests/test_ms.py b/tests/test_ms.py index 55032620..d937f0eb 100644 --- a/tests/test_ms.py +++ b/tests/test_ms.py @@ -1,7 +1,6 @@ import pytest -from tests.common import CONFIG_COMMON -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, TEST_MODEL CONFIG_MS = CONFIG_COMMON + ["batch.micro_sequence_length=256"] @@ -13,7 +12,7 @@ def test_model_ms256(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_ms256"]) +@pytest.mark.depends_on(on=["test_model_ms256"]) def test_model_pp2s2_ms256(run_test_script): # Sequence-pipeline-parallel run_test_script( @@ -26,7 +25,7 @@ def test_model_pp2s2_ms256(run_test_script): @pytest.mark.slow @pytest.mark.skip -@pytest.mark.depends(on=["test_model_ms256"]) +@pytest.mark.depends_on(on=["test_model_ms256"]) def test_model_dp2s2_stp2_pp2s2_ms256(run_test_script): # TODO: Handle this case. # Sequence-3d-parallel diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index f5f09b1b..7424cd68 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -4,7 +4,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert -from tests.common import CONFIG_COMMON +from tests.utils.model_configs import CONFIG_COMMON from tests.utils.utils import requires_cuda diff --git a/tests/test_seq_first.py b/tests/test_seq_first.py index 9ead58e8..123d8a68 100644 --- a/tests/test_seq_first.py +++ b/tests/test_seq_first.py @@ -1,7 +1,6 @@ import pytest -from tests.common import CONFIG_COMMON -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, TEST_MODEL CONFIG_SF = CONFIG_COMMON + ["model.base_model.sequence_first=True"] @@ -13,7 +12,7 @@ def test_model_sf(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_sf"]) +@pytest.mark.depends_on(on=["test_model_sf"]) def test_model_sp2(run_test_script): # Sequence-tensor-parallel. run_test_script( @@ -25,7 +24,7 @@ def test_model_sp2(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_sf"]) +@pytest.mark.depends_on(on=["test_model_sf"]) def test_model_sdp2(run_test_script): # Sequence-data-parallel run_test_script( @@ -37,7 +36,7 @@ def test_model_sdp2(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model_sf"]) +@pytest.mark.depends_on(on=["test_model_sf"]) def test_model_sp2_ce4(run_test_script): # Sequence-tensor-parallel with cross-entropy splits. run_test_script( diff --git a/tests/test_simple.py b/tests/test_simple.py index 1523750f..36ce1424 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,7 +1,6 @@ import pytest -from tests.common import CONFIG_COMMON, CONFIG_FAST_LLM -from tests.utils.model_configs import TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, CONFIG_FAST_LLM, TEST_MODEL def test_model_safe(run_test_script): @@ -17,7 +16,7 @@ def test_model_safe(run_test_script): ) -@pytest.mark.depends(on=["test_model_safe"]) +@pytest.mark.depends_on(on=["test_model_safe"]) def test_model(run_test_script): # A baseline config (single-gpu, bf16, flash-attn). # Also tests for multiple data loaders. @@ -27,7 +26,7 @@ def test_model(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model"]) def test_model_dp2(run_test_script): # Simple data-parallel. run_test_script(f"test_{TEST_MODEL}_dp2", CONFIG_COMMON, num_gpus=2, compare=f"test_{TEST_MODEL}") @@ -60,7 +59,7 @@ def test_model_dp2_timeout(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model"]) def test_model_tp2(run_test_script): # Simple tensor-parallel. run_test_script( @@ -71,7 +70,7 @@ def test_model_tp2(run_test_script): ) -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model"]) def test_model_ce4(run_test_script): # Cross-entropy splits. run_test_script( @@ -82,7 +81,7 @@ def test_model_ce4(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model"]) def test_model_dp2_z2(run_test_script): # Data-parallel with zero stage 2. run_test_script( @@ -94,7 +93,7 @@ def test_model_dp2_z2(run_test_script): @pytest.mark.slow -@pytest.mark.depends(on=["test_model"]) +@pytest.mark.depends_on(on=["test_model"]) def test_model_dp2_z3(run_test_script): # Data-parallel with zero stage 3. run_test_script( diff --git a/tests/utils/depends.py b/tests/utils/depends.py index 3972a066..6e10eac1 100644 --- a/tests/utils/depends.py +++ b/tests/utils/depends.py @@ -120,7 +120,8 @@ def _resolve_dependencies(self, item: pytest.Function): for marker in item.iter_markers(): if marker.name == MARKER_NAME: for dependency in as_list(marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])): - dependency = dependency.format(**item.callspec.params) + if hasattr(item, "callspec"): + dependency = dependency.format(**item.callspec.params) # If the name is not known, try to make it absolute (ie file::[class::]method) if dependency not in self._name_to_nodeids: From b328f0710f5a6709e0df1c050899639379054bed Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 17:46:04 -0400 Subject: [PATCH 21/69] stuff --- tests/test_config.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index e050cb23..72eda809 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,5 @@ import pathlib import subprocess -import unittest.mock import pytest import yaml @@ -8,9 +7,7 @@ from fast_llm.config import NoAutoValidate from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.auto import trainer_registry from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested @@ -64,32 +61,6 @@ def test_validate_example_config(): trainer_registry["gpt"].from_dict(fast_llm_config_dict) -def test_do_use_flash_attention(): - # Create a mock DistributedConfig - mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) - - # Test case 1: use_flash_attention is True and training_dtype is float16 - config = TransformerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is True - - # Test case 2: use_flash_attention is False - config = TransformerConfig(use_flash_attention=False, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 - config = TransformerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float32 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 4: use_flash_attention is False and window_size is not None - config = TransformerConfig(use_flash_attention=False, window_size=512) - mock_distributed_config.training_dtype = DataType.float32 - with pytest.raises(AssertionError): - config.do_use_flash_attention(mock_distributed_config) - - @pytest.mark.parametrize( ("cls", "default"), ((GPTSamplingConfig, {}), (GPTModelConfig, {"distributed": {"world_size": 1, "rank": 0, "local_world_size": 1}})), From 7ed804b153146a58bafa3fb9f9b215eaa9b83048 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 17:48:54 -0400 Subject: [PATCH 22/69] stuff --- tests/test_functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index 03a0ae8a..0689f4d8 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -218,6 +218,8 @@ def test_mlp_recomputation(gated, activation_type): @pytest.mark.slow @requires_cuda def test_dropless_mlp(): + # TODO: Fix dropless MOE + pytest.fail("Test fails, aborting to avoid breaking cuda", False) num_experts = 4 experts_per_token = 4 tokens = 1024 From 6f000359bb2413f17552b617485f68bc1e07dfe1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 17:53:04 -0400 Subject: [PATCH 23/69] stuff --- tests/conftest.py | 2 +- tests/test_functional.py | 2 ++ tests/utils/depends.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bc3d443c..0d25fc5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,7 +160,7 @@ def pytest_collection_modifyitems(config, items: list[pytest.Function]): # If pytest-depends is installed, it will complain about renamed nodes whether it's used or not. try: - import pytest_depends + import pytest_depends.main except ImportError: pass else: diff --git a/tests/test_functional.py b/tests/test_functional.py index b049be85..9211259c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -224,6 +224,8 @@ def test_mlp_recomputation(gated, activation_type): @pytest.mark.slow @requires_cuda def test_dropless_mlp(): + # TODO: Fix dropless MOE + pytest.fail("Test fails, aborting to avoid breaking cuda", False) num_experts = 4 experts_per_token = 4 tokens = 256 diff --git a/tests/utils/depends.py b/tests/utils/depends.py index 3972a066..6e10eac1 100644 --- a/tests/utils/depends.py +++ b/tests/utils/depends.py @@ -120,7 +120,8 @@ def _resolve_dependencies(self, item: pytest.Function): for marker in item.iter_markers(): if marker.name == MARKER_NAME: for dependency in as_list(marker.kwargs.get(MARKER_KWARG_DEPENDENCIES, [])): - dependency = dependency.format(**item.callspec.params) + if hasattr(item, "callspec"): + dependency = dependency.format(**item.callspec.params) # If the name is not known, try to make it absolute (ie file::[class::]method) if dependency not in self._name_to_nodeids: From e45ff6aafacd981b5a3c21515b5e07c02b056f31 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 17:56:02 -0400 Subject: [PATCH 24/69] stuff --- tests/common.py | 16 ---------------- tests/utils/dataset.py | 1 + 2 files changed, 1 insertion(+), 16 deletions(-) delete mode 100644 tests/common.py diff --git a/tests/common.py b/tests/common.py deleted file mode 100644 index a2dba74a..00000000 --- a/tests/common.py +++ /dev/null @@ -1,16 +0,0 @@ -import os -import sys - -# FIXME: figure out correct import of megatron modules without this hack -sys.path.append(os.getcwd()) - -# TODO: Use `pytest_addoption` instead? -# Keep all results in one place to allow recovering them for debugging in case of failure. - -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% - -# Megatron does not support Llama3-style Rotary Embeddings - -# Megatron does not support per sub layer biases - -# Yarn-style Rotary Embeddings diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 72888dfd..6f40347b 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -15,6 +15,7 @@ DATASET_PREFIX = DATASET_CACHE / "common" / "dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache" TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" TEST_DATASET_TOKENS = 1000000 From 67d3c92c9420af25a6b1c70e992c2c4195357a2f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 12 Jun 2025 18:41:46 -0400 Subject: [PATCH 25/69] fix --- .github/workflows/ci.yaml | 7 ++++--- .github/workflows/docs.yaml | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 912ddaf5..0bca2dd8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,10 +27,11 @@ jobs: - name: Install dependencies run: | - pip install "torch>=2.2.2" + pip install "torch>=2.7.0" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" - + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ + MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" - name: Run tests run: pytest . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index b755993c..632fa7b9 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,7 +29,7 @@ jobs: restore-keys: | mkdocs-material- - run: | - pip install "torch>=2.2.2" + pip install "torch>=2.7.0" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ From c2ae03d830007a59745d4982a791ca32e3288f7b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 12:54:57 -0400 Subject: [PATCH 26/69] fix --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba_layer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index ecf0b29d..31e81e99 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -17,7 +17,7 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_chunk_scan_combined # noqa _mamba_available = True -except ImportError: +except (ImportError, RuntimeError): _mamba_available = False @@ -25,7 +25,7 @@ from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa _causal_conv1d_available = True -except ImportError: +except (ImportError, RuntimeError): _causal_conv1d_available = False diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7fd43789..7c824d23 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -14,7 +14,7 @@ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa _mamba_available = True -except ImportError: +except (ImportError, RuntimeError): _mamba_available = False """ From 31da2a80ff0575afa7fd6588a446a23cd3ae86c2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 15:32:48 -0400 Subject: [PATCH 27/69] misc --- fast_llm/config.py | 12 +- fast_llm/engine/multi_stage/config.py | 7 +- fast_llm/models/gpt/conversion.py | 3 +- fast_llm/models/ssm/config.py | 2 +- tests/utils/model_configs.py | 242 +++++++++++++------------- 5 files changed, 136 insertions(+), 130 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f2197224..cdc1dd5d 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -490,7 +490,7 @@ def _validate_element(cls, value, type_, name: str): elif issubclass(origin, dict): value = cls._validate_dict(value, type_, name) elif origin is type: - cls._validate_type(value, type_, name) + value = cls._validate_type(value, type_, name) else: raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): @@ -585,10 +585,13 @@ def _validate_type(cls, value, type_: type | tuple[type, ...], name): args = list(getattr(type_, "__args__", [])) if len(args) != 1: raise FieldTypeError(f"Invalid type specification `{get_type_name(type_)}` for field `{name}`") + if issubclass(args[0], Config) and isinstance(value, str): + value = args[0].get_subclass(value) if not isinstance(value, type): raise ValidationError(f"Unexpected type `{get_type_name(type(value))}`") if not issubclass(value, args[0]): raise ValidationError(f"Field value `{value} is not a subclass of `{get_type_name(type_)}`") + return value @classmethod def _validate_element_type(cls, value, type_: type | tuple[type, ...], strict: bool = True): @@ -947,6 +950,13 @@ def get_subclass(cls, name: str | None): raise KeyError(f"Unknown type {name} for base class {cls.__name__}") return cls_ + @classmethod + def __fast_llm_serialize__(cls) -> str: + # Used to serialize config type fields, which only makes sense for dynamic types. + # Deserialization implemented in _validate_type. + assert cls.dynamic_type_name is not None + return cls.dynamic_type_name + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 5aa972c2..6ac157df 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -215,7 +215,6 @@ class FastLLMModelConfig(Config): DistributedCheckpointFormat, FastLLMCheckpointFormat, ) - model_name: typing.ClassVar[str] base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( desc="Configuration for the stage breakdown of the model.", @@ -223,10 +222,6 @@ class FastLLMModelConfig(Config): ) distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) - @classmethod - def __fast_llm_serialize__(cls) -> str: - return cls.model_name - @classmethod def get_checkpoint_format(cls, format: type[CheckpointFormat] | str) -> type[CheckpointFormat]: if isinstance(format, type) and issubclass(format, CheckpointFormat): @@ -236,7 +231,7 @@ def get_checkpoint_format(cls, format: type[CheckpointFormat] | str) -> type[Che for format_ in cls.checkpoint_formats: if format_.name == format: return format_ - raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + raise ValueError(f"Checkpoint format {format} not supported for model {cls.dynamic_type_name}") @classmethod def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> type[CheckpointHandler]: diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 5c689629..93428954 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -319,7 +319,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig + fast_llm_names=(("transformer", "rotary", "type"),), + fast_llm_value=DefaultRotaryConfig.dynamic_type_name, ), ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 301aca7b..386d2f50 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -169,7 +169,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index d4889e94..8357bdbe 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -7,7 +7,7 @@ import pytest from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.models.auto import model_registry +from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -30,6 +30,19 @@ class ModelTestingGroup(enum.StrEnum): generate = "generate" +class ModelTestingGroupAction(enum.StrEnum): + # Critical test, will always run. + main = "main" + # Standard test, treated as slow + normal = "normal" + # Feature is not important enough for frequent testing (ex. mostly redundant), treated as extra-slow. + unimportant = "unimportant" + # Test is known to fail, treated as extra-slow. + broken = "broken" + # Tested feature is unsupported for this model, skip unconditionally. + not_implemented = "not_implemented" + + SLOW_TESTING_GROUPS = {ModelTestingGroup.megatron, ModelTestingGroup.distributed} @@ -40,15 +53,12 @@ class ModelTestingConfig: config_args: list[str] megatron_args: list[str] | None checkpoint_format: CheckpointFormat | None - # The important groups we want to test. - testing_groups: list[ModelTestingGroup] - # Other supported groups, excluded by default because they are mostly unimportant and/or redundant. - # They can be run with `--run-extra-slow`. - other_groups: list[ModelTestingGroup] + groups: dict[ModelTestingGroup, ModelTestingGroupAction] @functools.cached_property def model_config_class(self): - return model_registry[self.model_type] + # TODO: Ok to assume the model and trainer have the same name? + return FastLLMModelConfig.get_subclass(self.model_type) @functools.cached_property def huggingface_model_for_causal_lm_class(self): @@ -71,14 +81,12 @@ def _update_and_add_testing_config( extra_args: list[str] | None = None, megatron_args: list[str] | None = ..., checkpoint_format: CheckpointFormat | None = ..., - testing_groups: list[ModelTestingGroup], - other_groups: list[ModelTestingGroup], + groups: dict[ModelTestingGroup, ModelTestingGroupAction], ): config = _MODEL_CONFIGS[old_name] updates: dict[str, typing.Any] = { "name": new_name, - "testing_groups": testing_groups, - "other_groups": other_groups, + "groups": groups, } if model_type is not None: updates["model_type"] = model_type @@ -177,12 +185,13 @@ def _update_and_add_testing_config( "--transformer-impl=transformer_engine", ], checkpoint_format=None, - testing_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.megatron, - ModelTestingGroup.distributed, - ], - other_groups=[], + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.main, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.normal, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, ) _update_and_add_testing_config( @@ -192,13 +201,13 @@ def _update_and_add_testing_config( extra_args=["model.base_model.transformer.head_groups=1"], megatron_args=["--group-query-attention"], checkpoint_format=None, - testing_groups=[ - ModelTestingGroup.basic, - ], - other_groups=[ - ModelTestingGroup.megatron, - ModelTestingGroup.distributed, - ], + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -218,16 +227,14 @@ def _update_and_add_testing_config( "--no-position-embedding", ], checkpoint_format=Starcoder2GPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.convert, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.megatron, - ModelTestingGroup.distributed, - ModelTestingGroup.generate, - ], + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -250,16 +257,14 @@ def _update_and_add_testing_config( "--untie-embeddings-and-output-weights", ], checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.megatron, - ModelTestingGroup.distributed, - ModelTestingGroup.convert, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.generate, - ], + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.main, + ModelTestingGroup.convert: ModelTestingGroupAction.main, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.normal, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, ) _update_and_add_testing_config( @@ -270,15 +275,13 @@ def _update_and_add_testing_config( # Megatron doesn't support Llama3-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.distributed, - ModelTestingGroup.convert, - ModelTestingGroup.generate, - ], + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -289,15 +292,13 @@ def _update_and_add_testing_config( # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.distributed, - ModelTestingGroup.convert, - ModelTestingGroup.generate, - ], + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -308,15 +309,14 @@ def _update_and_add_testing_config( # Megatron doesn't support multi-token prediction. megatron_args=None, checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.convert, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.distributed, - ModelTestingGroup.generate, - ], + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -327,15 +327,14 @@ def _update_and_add_testing_config( # Megatron doesn't support per sub layer biases megatron_args=None, checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.convert, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.distributed, - ModelTestingGroup.generate, - ], + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -346,15 +345,14 @@ def _update_and_add_testing_config( # Megatron doesn't support sliding windows. megatron_args=None, checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.convert, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.distributed, - ModelTestingGroup.generate, - ], + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -370,16 +368,14 @@ def _update_and_add_testing_config( "--moe-router-topk=4", ], checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, - testing_groups=[], # TODO: New base image broke mixtral - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - ModelTestingGroup.basic, - ModelTestingGroup.megatron, - ModelTestingGroup.distributed, - ModelTestingGroup.convert, - ModelTestingGroup.generate, - ], + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + }, ) _update_and_add_testing_config( @@ -396,16 +392,16 @@ def _update_and_add_testing_config( ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, - testing_groups=[ - ModelTestingGroup.basic, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - # TODO: Fix and bring these back to `testing_groups` - ModelTestingGroup.distributed, - ModelTestingGroup.convert, - ModelTestingGroup.generate, - ], + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, + # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + }, ) @@ -419,14 +415,13 @@ def _update_and_add_testing_config( ], megatron_args=None, checkpoint_format=None, - testing_groups=[ - ModelTestingGroup.basic, - ], - # TODO: Bring back `generate` to `testing_groups` when stable. - other_groups=[ - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed, - ], + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) @@ -440,12 +435,17 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo assert "model_testing_config" in item.callspec.params, item.nodeid groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args model_testing_config = item.callspec.params["model_testing_config"] - model_config = _MODEL_CONFIGS[model_testing_config] + model_config: ModelTestingConfig = _MODEL_CONFIGS[model_testing_config] for group in groups: - if group in model_config.testing_groups and not (skip_slow and group in SLOW_TESTING_GROUPS): - pass - elif group in model_config.other_groups and not skip_extra_slow: - pass + action = model_config.groups[group] + if action == ModelTestingGroupAction.main: + return True + elif action == ModelTestingGroupAction.normal and not skip_slow: + return True + elif ( + action in (ModelTestingGroupAction.broken, ModelTestingGroupAction.unimportant) and not skip_extra_slow + ): + return True elif show_skipped: item.add_marker( pytest.mark.skip(reason=f"Skipping testing group {group} for model {model_testing_config}.") From c2ee8fee9d97dca477ed7fd700be5d440f5d6a3d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 15:37:13 -0400 Subject: [PATCH 28/69] stuff --- fast_llm/config.py | 12 +++++++++++- fast_llm/engine/multi_stage/config.py | 7 +------ fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba_layer.py | 2 +- fast_llm/models/gpt/conversion.py | 3 ++- fast_llm/models/ssm/config.py | 2 +- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f2197224..cdc1dd5d 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -490,7 +490,7 @@ def _validate_element(cls, value, type_, name: str): elif issubclass(origin, dict): value = cls._validate_dict(value, type_, name) elif origin is type: - cls._validate_type(value, type_, name) + value = cls._validate_type(value, type_, name) else: raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): @@ -585,10 +585,13 @@ def _validate_type(cls, value, type_: type | tuple[type, ...], name): args = list(getattr(type_, "__args__", [])) if len(args) != 1: raise FieldTypeError(f"Invalid type specification `{get_type_name(type_)}` for field `{name}`") + if issubclass(args[0], Config) and isinstance(value, str): + value = args[0].get_subclass(value) if not isinstance(value, type): raise ValidationError(f"Unexpected type `{get_type_name(type(value))}`") if not issubclass(value, args[0]): raise ValidationError(f"Field value `{value} is not a subclass of `{get_type_name(type_)}`") + return value @classmethod def _validate_element_type(cls, value, type_: type | tuple[type, ...], strict: bool = True): @@ -947,6 +950,13 @@ def get_subclass(cls, name: str | None): raise KeyError(f"Unknown type {name} for base class {cls.__name__}") return cls_ + @classmethod + def __fast_llm_serialize__(cls) -> str: + # Used to serialize config type fields, which only makes sense for dynamic types. + # Deserialization implemented in _validate_type. + assert cls.dynamic_type_name is not None + return cls.dynamic_type_name + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 5aa972c2..6ac157df 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -215,7 +215,6 @@ class FastLLMModelConfig(Config): DistributedCheckpointFormat, FastLLMCheckpointFormat, ) - model_name: typing.ClassVar[str] base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( desc="Configuration for the stage breakdown of the model.", @@ -223,10 +222,6 @@ class FastLLMModelConfig(Config): ) distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) - @classmethod - def __fast_llm_serialize__(cls) -> str: - return cls.model_name - @classmethod def get_checkpoint_format(cls, format: type[CheckpointFormat] | str) -> type[CheckpointFormat]: if isinstance(format, type) and issubclass(format, CheckpointFormat): @@ -236,7 +231,7 @@ def get_checkpoint_format(cls, format: type[CheckpointFormat] | str) -> type[Che for format_ in cls.checkpoint_formats: if format_.name == format: return format_ - raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + raise ValueError(f"Checkpoint format {format} not supported for model {cls.dynamic_type_name}") @classmethod def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> type[CheckpointHandler]: diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 31e81e99..ecf0b29d 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -17,7 +17,7 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_chunk_scan_combined # noqa _mamba_available = True -except (ImportError, RuntimeError): +except ImportError: _mamba_available = False @@ -25,7 +25,7 @@ from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa _causal_conv1d_available = True -except (ImportError, RuntimeError): +except ImportError: _causal_conv1d_available = False diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d23..7fd43789 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -14,7 +14,7 @@ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa _mamba_available = True -except (ImportError, RuntimeError): +except ImportError: _mamba_available = False """ diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 5c689629..93428954 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -319,7 +319,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig + fast_llm_names=(("transformer", "rotary", "type"),), + fast_llm_value=DefaultRotaryConfig.dynamic_type_name, ), ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 301aca7b..386d2f50 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -169,7 +169,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" From 6c775e47bec481569e3ab69861c52c01a7ae231f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 15:53:49 -0400 Subject: [PATCH 29/69] stuff --- tests/test_match_megatron.py | 157 ++--------- tests/test_simple.py | 5 +- tests/utils/model_configs.py | 512 ++++++++++++++++++++--------------- 3 files changed, 322 insertions(+), 352 deletions(-) diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 7d89c80a..f464dd06 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -2,25 +2,12 @@ from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import DATASET_PREFIX -from tests.utils.model_configs import ( - CONFIG_GPT2_FAST_LLM, - CONFIG_GPT2_MEGATRON, - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_MIXTRAL_FAST_LLM, - CONFIG_MIXTRAL_MEGATRON, - CONFIG_SC1_FAST_LLM, - CONFIG_SC1_MEGATRON, - CONFIG_SC2_FAST_LLM, - CONFIG_SC2_MEGATRON, -) +from tests.utils.model_configs import CONFIG_COMMON, CONFIG_MEGATRON, TEST_MODEL @pytest.mark.slow -@pytest.mark.skip(reason="Skipping mostly redundant test") -def test_sc1_meg(run_test_script): - # Starcoder 1 (GPT2 with MQA) with Megatron. - run_test_script("test_sc1_meg", CONFIG_SC1_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) +def test_megatron(run_test_script): + run_test_script(f"test_{TEST_MODEL}_megatron", CONFIG_MEGATRON, is_megatron=True) CONFIG_MATCH_MEGATRON = [ @@ -29,42 +16,31 @@ def test_sc1_meg(run_test_script): ] -@pytest.mark.depends_on(on=["test_sc1_meg"]) -def test_sc1_match_meg(run_test_script): - # Starcoder 1 (GPT2 with MQA) with Fast-llm. - # QKV tensors are in a different format. - run_test_script( - "test_sc1_match_meg", - CONFIG_SC1_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_sc1_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".mlp.layer_2.weight", - ] - ), - ) +@pytest.mark.depends_on(on=["test_megatron"]) +def test_match_megatron(run_test_script): + if CONFIG_MEGATRON is None: + pytest.skip(f"Megatron does not support model {TEST_MODEL}") + ignore_tensors = [ + ".self_attn.query_key_value.", + ".self_attn.query.", + ".self_attn.key_value.", + ".mlp.layer_2.weight", + ".mlp.experts.", + ] + if TEST_MODEL == "mixtral": + ignore_tensors.extend([".mlp.experts.", ".mlp.layer_1.weight"]) -@pytest.mark.slow -@pytest.mark.skip(reason="Skipping mostly redundant test") -@pytest.mark.depends_on(on=["test_sc1_match_meg"]) -def test_sc2_meg(run_test_script): - # Starcoder 2 (GPT2 with MQA and RoPE) with Megatron. - run_test_script("test_sc2_meg", CONFIG_SC2_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends_on(on=["test_sc2_meg"]) -def test_sc2_match_meg(run_test_script): - # Starcoder 2 (GPT2 with MQA and RoPE) with Fast-llm. - # QKV tensors are in a different format, - # dense not matching because of the way initialization is corrected for RoPE format. run_test_script( - "test_sc2_match_meg", - CONFIG_SC2_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_sc2_meg", + f"test_{TEST_MODEL}_match_megatron", + CONFIG_COMMON + + [ + "model.distributed.training_dtype=fp32", + "data.datasets={}", + f"data.path={DATASET_PREFIX}", + "model.base_model.use_megatron_initialization=True", + ], + compare=f"test_{TEST_MODEL}_megatron", config=CompareConfig( ignore_tensors=[ ".self_attn.query_key_value.", @@ -75,86 +51,3 @@ def test_sc2_match_meg(run_test_script): ] ), ) - - -@pytest.mark.slow -def test_gpt2_meg(run_test_script): - # GPT2 (MHA, layer norm, absolute embeddings) with Megatron. - run_test_script("test_gpt2_meg", CONFIG_GPT2_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends_on(on=["test_gpt2_meg"]) -def test_gpt2_match_meg(run_test_script): - # GPT2 (MHA, layer norm, absolute embeddings) with Fast-llm. - # QKV tensors are in a different format. - run_test_script( - "test_gpt2_match_meg", - CONFIG_GPT2_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_gpt2_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".mlp.layer_2.weight", - ] - ), - ) - - -@pytest.mark.slow -def test_mistral_meg(run_test_script): - # Mistral with Megatron. - # No linear bias, swiglu activation, RMSNorm - run_test_script("test_mistral_meg", CONFIG_LLAMA_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends_on(on=["test_mistral_meg"]) -def test_mistral_match_meg(run_test_script): - # Mistral with Fast-LLM. - run_test_script( - "test_mistral_match_meg", - CONFIG_LLAMA_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_mistral_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".self_attn.dense.", - ".mlp.layer_2.weight", - ] - ), - ) - - -@pytest.mark.slow -def test_mixtral_meg(run_test_script): - # Mistral with Megatron. - # No linear bias, swiglu activation, RMSNorm - run_test_script("test_mixtral_meg", CONFIG_MIXTRAL_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) - - -@pytest.mark.depends_on(on=["test_mixtral_meg"]) -def test_mixtral_match_meg(run_test_script): - # Mistral with Fast-LLM. - # TODO: Fix dropless MOE - pytest.fail("Test fails, aborting to avoid breaking cuda", False) - run_test_script( - "test_mixtral_match_meg", - CONFIG_MIXTRAL_FAST_LLM + CONFIG_MATCH_MEGATRON + ["model.base_model.use_megatron_initialization=True"], - compare="test_mixtral_meg", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".self_attn.dense.", - ".mlp.layer_1.weight", - ".mlp.layer_2.weight", - ".mlp.experts", - "Global layer 2 fw: Transformer layer 2 output", - ], - max_rel_tolerance=1.5e-1, - ), - ) diff --git a/tests/test_simple.py b/tests/test_simple.py index 36ce1424..d67d06cd 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,14 +1,15 @@ import pytest -from tests.utils.model_configs import CONFIG_COMMON, CONFIG_FAST_LLM, TEST_MODEL +from tests.utils.model_configs import CONFIG_COMMON, TEST_MODEL def test_model_safe(run_test_script): # The safest possible config, identical to the one in test_match_megatron except for the initialization. run_test_script( f"test_{TEST_MODEL}_safe", - CONFIG_FAST_LLM + CONFIG_COMMON + [ + "model.distributed.training_dtype=fp32", "run.torch_dynamo_enable=False", "schedule.data_overlap=False", "model.base_model.transformer.dropless_moe=False", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 26eebf4f..c6c412d2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -1,5 +1,10 @@ +import dataclasses +import functools import os +import typing +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -12,222 +17,293 @@ from tests.utils.dataset import DATASET_PREFIX, TEST_VOCAB_SIZE _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class ModelTestingConfig: + name: str = None + model_type: str + config_args: list[str] + megatron_args: list[str] | None + checkpoint_format: CheckpointFormat | None + + @functools.cached_property + def model_config_class(self): + # TODO: Ok to assume the model and trainer have the same name? + return FastLLMModelConfig.get_subclass(self.model_type) + + @functools.cached_property + def huggingface_model_for_causal_lm_class(self): + return self.model_config_class.get_huggingface_model_for_causal_lm_class() + + @functools.cached_property + def model_class(self): + return self.model_config_class.get_model_class() + + @functools.cached_property + def base_model_config_class(self): + return self.model_config_class.get_base_model_config_class() + + +def _update_and_add_testing_config( + old_name: str, + new_name: str, + *, + model_type: str | None = None, + extra_args: list[str] | None = None, + megatron_args: list[str] | None = ..., + checkpoint_format: CheckpointFormat | None = ..., +): + config = _MODEL_CONFIGS[old_name] + updates: dict[str, typing.Any] = {"name": new_name} + if model_type is not None: + updates["model_type"] = model_type + if extra_args is not None: + updates["config_args"] = config.config_args + extra_args + if megatron_args is not ...: + if megatron_args is None: + updates["megatron_args"] = None + elif config.megatron_args is None: + updates["megatron_args"] = megatron_args + else: + updates["megatron_args"] = config.megatron_args + megatron_args + if checkpoint_format is not ...: + updates["checkpoint_format"] = checkpoint_format + + _MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) + + +_MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} + + +_MODEL_CONFIGS["gpt2"] = ModelTestingConfig( + # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). + name="gpt2", + model_type="gpt", + config_args=[ + "training.logs.interval=1", + "run.tensor_logs.save=True", + "run.tensor_logs.show=False", + "model.base_model.max_position_embeddings=512", + "model.base_model.transformer.num_layers=2", + "model.base_model.transformer.hidden_size=256", + "model.base_model.transformer.num_attention_heads=8", + "model.base_model.transformer.head_groups=8", + "model.base_model.transformer.init_method_std=0.022", + f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", + f"model.multi_stage.debug_param_init={_LOG_LEVEL}", + f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", + f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", + f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + "model.multi_stage.debug_tensor_parallel=True", + "model.distributed.reproducible_init=True", + "model.distributed.timeout=20", + "model.distributed.training_dtype=bf16", + "training.train_iters=2", + "training.num_workers=0", + "training.timeout=30", + "batch.batch_size=8", + "batch.sequence_length=512", + "data.datasets.training.type=slice", + "data.datasets.training.end=0.969", + "data.datasets.training.dataset.type=memmap", + f"data.datasets.training.dataset.path={DATASET_PREFIX}", + "data.datasets.validation.type=slice", + "data.datasets.validation.begin=0.969", + "data.datasets.validation.end=0.999", + "data.datasets.validation.dataset.type=memmap", + f"data.datasets.validation.dataset.path={DATASET_PREFIX}", + "data.datasets.test.type=slice", + "data.datasets.test.begin=0.999", + "data.datasets.test.end=1", + "data.datasets.test.dataset.type=memmap", + f"data.datasets.test.dataset.path={DATASET_PREFIX}", + "optimizer.learning_rate.base=0.0001", + ], + megatron_args=[ + "--num-layers=2", + "--hidden-size=256", + "--num-attention-heads=8", + "--log-interval=1", + "--train-iters=2", + "--eval-iters=0", + "--hidden-dropout=0", + "--attention-dropout=0", + f"--debug_param_init={_LOG_LEVEL}", + f"--debug_layer_outputs={_LOG_LEVEL}", + f"--debug_layer_gradients={_LOG_LEVEL}", + f"--debug_all_param_gradients={_LOG_LEVEL}", + "--debug_param_update=0", + "--global-batch-size=8", + "--micro-batch-size=8", + "--max-position-embeddings=512", + "--seq-length=512", + "--init-method-std=0.022", + "--lr=0.0001", + "--num-workers=0", + "--valid-num-workers=0", + "--tokenizer-type=NullTokenizer", + # Megatron messes with the vocab size, so we have to subtract 1. + f"--vocab-size={TEST_VOCAB_SIZE - 1}", + f"--data-path={DATASET_PREFIX}", + "--lr-decay-style=constant", + # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) + "--use-mcore-models", + # local implementation doesn't allow for RMS norm. + "--transformer-impl=transformer_engine", + ], + checkpoint_format=None, +) + +_update_and_add_testing_config( + # Tests MQA. + "gpt2", + "starcoder", + extra_args=["model.base_model.transformer.head_groups=1"], + megatron_args=["--group-query-attention"], + checkpoint_format=None, +) + +_update_and_add_testing_config( + # Tests intermediate between gpt2 and llama, closest converter to gpt2. + "gpt2", + "starcoder2", + extra_args=[ + "model.base_model.transformer.head_groups=4", + "model.base_model.transformer.rotary.type=default", + # Unused, but prevents issues with conversion tests. + "model.base_model.max_position_embeddings=2048", + ], + megatron_args=[ + "--group-query-attention", + "--num-query-groups=4", + "--use-rotary-position-embeddings", + "--no-position-embedding", + ], + checkpoint_format=Starcoder2GPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Main tested model. + "starcoder2", + "llama", + extra_args=[ + "model.base_model.transformer.gated=True", + "model.base_model.transformer.activation_type=silu", + "model.base_model.transformer.add_linear_biases=False", + "model.base_model.transformer.normalization.type=rms_norm", + "model.base_model.transformer.ffn_hidden_size=1024", + "model.base_model.tie_word_embeddings=False", + ], + megatron_args=[ + "--swiglu", + "--disable-bias-linear", + "--normalization=RMSNorm", + "--ffn-hidden-size=1024", + "--untie-embeddings-and-output-weights", + ], + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests llama3-style rotary embeddings. + "llama", + "llama3", + extra_args=["model.base_model.transformer.rotary.type=llama3"], + # Megatron doesn't support Llama3-style Rotary Embeddings + megatron_args=None, + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests yarn-style rotary embeddings. + "llama", + "llama_yarn", + extra_args=["model.base_model.transformer.rotary.type=yarn"], + # Megatron doesn't support Yarn-style Rotary Embeddings + megatron_args=None, + checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests multi-token prediction, custom HF model and converter. + "llama", + "llama_mtp", + extra_args=["model.base_model.prediction_heads=4"], + # Megatron doesn't support multi-token prediction. + megatron_args=None, + checkpoint_format=MTPLlamaGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests partial linear biases, Qwen2 converter. + "llama", + "qwen2", + extra_args=["model.base_model.transformer.add_linear_biases=only_attn_qkv"], + # Megatron doesn't support per sub layer biases + megatron_args=None, + checkpoint_format=Qwen2GPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests sliding window attention, mistral converter. + "llama", + "mistral", + extra_args=["model.base_model.transformer.window_size=128"], + # Megatron doesn't support sliding windows. + megatron_args=None, + checkpoint_format=MistralGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests mixture of experts, mixtral converter. + "llama", + "mixtral", + extra_args=[ + "model.base_model.transformer.num_experts=4", + "model.base_model.transformer.num_experts_per_token=4", + ], + megatron_args=[ + "--num-experts=4", + "--moe-router-topk=4", + ], + checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, +) + +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "llama", + "llamba", + model_type="hybrid_ssm", + extra_args=[ + "model.base_model.hybrid_block_layout=['t','m']", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.chunk_size=32", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=8", + ], + megatron_args=None, + checkpoint_format=LLambaHuggingfaceCheckpointFormat, +) + + +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "llamba", + "hybrid_mamba_2", + model_type="hybrid_ssm", + extra_args=[ + "model.base_model.hybrid_block_layout=['t','m2d']", + ], + megatron_args=None, + checkpoint_format=None, +) + TEST_MODEL = os.environ.get("MODEL", "llama") -CONFIG_BASE_FAST_LLM = [ - "training.logs.interval=1", - "run.tensor_logs.save=True", - "run.tensor_logs.show=False", - "model.base_model.transformer.num_layers=2", - "model.base_model.transformer.hidden_size=256", - "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.init_method_std=0.022", - f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", - f"model.multi_stage.debug_param_init={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", - f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", - "model.multi_stage.debug_tensor_parallel=True", - "model.distributed.reproducible_init=True", - "model.distributed.timeout=10", - "training.train_iters=2", - "training.num_workers=0", - "training.timeout=30", - "batch.batch_size=8", - "batch.sequence_length=512", - "data.datasets.training.type=slice", - "data.datasets.training.end=0.969", - "data.datasets.training.dataset.type=memmap", - f"data.datasets.training.dataset.path={DATASET_PREFIX}", - "data.datasets.validation.type=slice", - "data.datasets.validation.begin=0.969", - "data.datasets.validation.end=0.999", - "data.datasets.validation.dataset.type=memmap", - f"data.datasets.validation.dataset.path={DATASET_PREFIX}", - "data.datasets.test.type=slice", - "data.datasets.test.begin=0.999", - "data.datasets.test.end=1", - "data.datasets.test.dataset.type=memmap", - f"data.datasets.test.dataset.path={DATASET_PREFIX}", - "optimizer.learning_rate.base=0.0001", -] -CONFIG_BASE_MEGATRON = [ - "--num-layers=2", - "--hidden-size=256", - "--num-attention-heads=8", - "--log-interval=1", - "--train-iters=2", - "--eval-iters=0", - "--hidden-dropout=0", - "--attention-dropout=0", - f"--debug_param_init={_LOG_LEVEL}", - f"--debug_layer_outputs={_LOG_LEVEL}", - f"--debug_layer_gradients={_LOG_LEVEL}", - f"--debug_all_param_gradients={_LOG_LEVEL}", - "--debug_param_update=0", - "--global-batch-size=8", - "--max-position-embeddings=512", - "--seq-length=512", - "--init-method-std=0.022", - "--lr=0.0001", - "--num-workers=0", - "--valid-num-workers=0", - "--tokenizer-type=NullTokenizer", - # Megatron messes with the vocab size, so we have to subtract 1. - f"--vocab-size={TEST_VOCAB_SIZE - 1}", - f"--data-path={DATASET_PREFIX}", - "--lr-decay-style=constant", - # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) - "--use-mcore-models", - # local implementation doesn't allow for RMS norm. - "--transformer-impl=transformer_engine", -] -CONFIG_SC1_FAST_LLM = CONFIG_BASE_FAST_LLM + ["model.base_model.max_position_embeddings=512"] -CONFIG_SC1_MEGATRON = CONFIG_BASE_MEGATRON + ["--group-query-attention"] -CONFIG_SC1_COMMON = CONFIG_SC1_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_GPT2_FAST_LLM = CONFIG_SC1_FAST_LLM + ["model.base_model.transformer.head_groups=8"] -CONFIG_GPT2_MEGATRON = CONFIG_BASE_MEGATRON -CONFIG_GPT2_COMMON = CONFIG_GPT2_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_SC2_FAST_LLM = CONFIG_BASE_FAST_LLM + [ - "model.base_model.transformer.head_groups=4", - "model.base_model.transformer.rotary.type=default", -] -CONFIG_SC2_MEGATRON = CONFIG_SC1_MEGATRON + [ - "--num-query-groups=4", - "--use-rotary-position-embeddings", - "--no-position-embedding", -] -CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [ - "--swiglu", - "--disable-bias-linear", - "--normalization=RMSNorm", - "--ffn-hidden-size=1024", - "--untie-embeddings-and-output-weights", -] -CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMA3_MEGATRON = None -CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.rotary.type=llama3", -] -CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_QWEN2_MEGATRON = None -CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=only_attn_qkv", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.ffn_hidden_size=1024", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMA_YARN_MEGATRON = None -CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.rotary.type=yarn", -] -CONFIG_LLAMA_YARN_COMMON = CONFIG_LLAMA_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [ - "--num-experts=4", - "--moe-router-topk=4", -] -CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.transformer.num_experts=4", - "model.base_model.transformer.num_experts_per_token=4", -] -CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MIXTRAL_YARN_MEGATRON = None -CONFIG_MIXTRAL_YARN_FAST_LLM = CONFIG_MIXTRAL_FAST_LLM + [ - "model.base_model.transformer.rotary.type=yarn", -] -CONFIG_MIXTRAL_YARN_COMMON = CONFIG_MIXTRAL_YARN_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMA_MTP_MEGATRON = None -CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ - "model.base_model.prediction_heads=4", -] -CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] -CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] -CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM -_CONFIGS = { - "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), - "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), - "starcoder2": ( - "gpt", - CONFIG_SC2_FAST_LLM, - CONFIG_SC2_MEGATRON, - CONFIG_SC2_COMMON, - Starcoder2GPTHuggingfaceCheckpointFormat, - ), - "llama": ( - "gpt", - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_LLAMA_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "llama3": ( - "gpt", - CONFIG_LLAMA3_FAST_LLM, - CONFIG_LLAMA3_MEGATRON, - CONFIG_LLAMA3_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "qwen2": ( - "gpt", - CONFIG_QWEN2_FAST_LLM, - CONFIG_QWEN2_MEGATRON, - CONFIG_QWEN2_COMMON, - Qwen2GPTHuggingfaceCheckpointFormat, - ), - "llama-yarn": ( - "gpt", - CONFIG_LLAMA_YARN_FAST_LLM, - CONFIG_LLAMA_YARN_MEGATRON, - CONFIG_LLAMA_YARN_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), - "mistral": ( - "gpt", - CONFIG_LLAMA_FAST_LLM, - CONFIG_LLAMA_MEGATRON, - CONFIG_LLAMA_COMMON, - MistralGPTHuggingfaceCheckpointFormat, - ), - "mixtral": ( - "gpt", - CONFIG_MIXTRAL_FAST_LLM, - CONFIG_MIXTRAL_MEGATRON, - CONFIG_MIXTRAL_COMMON, - MixtralGPTHuggingfaceCheckpointFormat, - ), - "llamba": ( - "hybrid_ssm", - CONFIG_LLAMBA_FAST_LLM, - CONFIG_LLAMBA_MEGATRON, - CONFIG_LLAMBA_COMMON, - LLambaHuggingfaceCheckpointFormat, - ), - "mixtral-yarn": ( - "gpt", - CONFIG_MIXTRAL_YARN_FAST_LLM, - CONFIG_MIXTRAL_YARN_MEGATRON, - CONFIG_MIXTRAL_YARN_COMMON, - MixtralGPTHuggingfaceCheckpointFormat, - ), - "llama-mtp": ( - "gpt", - CONFIG_LLAMA_MTP_FAST_LLM, - CONFIG_LLAMA_MTP_MEGATRON, - CONFIG_LLAMA_MTP_COMMON, - MTPLlamaGPTHuggingfaceCheckpointFormat, - ), -} - -TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] +_MODEL_CONFIG = _MODEL_CONFIGS[TEST_MODEL] + + +TEST_MODEL_TYPE = _MODEL_CONFIG.model_type +CONFIG_COMMON = _MODEL_CONFIG.config_args +CONFIG_MEGATRON = _MODEL_CONFIG.megatron_args +HUGGINGFACE_CHECKPOINT_FORMAT = _MODEL_CONFIG.checkpoint_format From d41e0d5a66e6b79ae9d67cec6cf086325429d3d8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 16:02:15 -0400 Subject: [PATCH 30/69] misc --- tests/test_match_megatron.py | 3 +-- tests/utils/model_configs.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 4f82d575..7645de9e 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -13,8 +13,7 @@ def test_megatron(run_test_script_for_all_models, model_testing_config): @pytest.mark.depends_on(on=["test_megatron[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.megatron) def test_match_megatron(run_test_script_for_all_models, model_testing_config): - if model_testing_config.megatron_args is None: - pytest.skip(f"Megatron does not support model {model_testing_config.name}") + assert model_testing_config.megatron_args is not None ignore_tensors = [ ".self_attn.query_key_value.", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 8357bdbe..ee9ad5cb 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -43,9 +43,6 @@ class ModelTestingGroupAction(enum.StrEnum): not_implemented = "not_implemented" -SLOW_TESTING_GROUPS = {ModelTestingGroup.megatron, ModelTestingGroup.distributed} - - @dataclasses.dataclass(kw_only=True, frozen=True) class ModelTestingConfig: name: str = None From 59582c3a639002de1a53861b6e544f5a26ca05af Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 16:02:55 -0400 Subject: [PATCH 31/69] misc --- .github/workflows/ci.yaml | 7 +++---- .github/workflows/docs.yaml | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0bca2dd8..912ddaf5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,11 +27,10 @@ jobs: - name: Install dependencies run: | - pip install "torch>=2.7.0" + pip install "torch>=2.2.2" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ - MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" + - name: Run tests run: pytest . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 632fa7b9..b755993c 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,7 +29,7 @@ jobs: restore-keys: | mkdocs-material- - run: | - pip install "torch>=2.7.0" + pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ From 8ecf81e4a6e69ebb94e1f7e02bd0c3f7d2633386 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 16:03:35 -0400 Subject: [PATCH 32/69] fix --- tests/test_match_megatron.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index f464dd06..9f861464 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -41,13 +41,5 @@ def test_match_megatron(run_test_script): "model.base_model.use_megatron_initialization=True", ], compare=f"test_{TEST_MODEL}_megatron", - config=CompareConfig( - ignore_tensors=[ - ".self_attn.query_key_value.", - ".self_attn.query.", - ".self_attn.key_value.", - ".self_attn.dense.", - ".mlp.layer_2.weight", - ] - ), + config=CompareConfig(ignore_tensors=ignore_tensors), ) From c5b29e257aa6067a96d76888878436d201d49a7e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 16:44:36 -0400 Subject: [PATCH 33/69] Revert "misc" This reverts commit 59582c3a639002de1a53861b6e544f5a26ca05af. --- .github/workflows/ci.yaml | 7 ++++--- .github/workflows/docs.yaml | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 912ddaf5..0bca2dd8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,10 +27,11 @@ jobs: - name: Install dependencies run: | - pip install "torch>=2.2.2" + pip install "torch>=2.7.0" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" - + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ + MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" - name: Run tests run: pytest . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index b755993c..632fa7b9 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,7 +29,7 @@ jobs: restore-keys: | mkdocs-material- - run: | - pip install "torch>=2.2.2" + pip install "torch>=2.7.0" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ From edced8c829bef10e8b917195e4215a0f785ee5b6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 17:06:30 -0400 Subject: [PATCH 34/69] Cleanup tests --- tests/layers/test_lm_head.py | 49 +++--- tests/test_functional.py | 18 ++- tests/test_mtp.py | 209 -------------------------- tests/test_ssms.py | 282 ++--------------------------------- tests/utils/utils.py | 67 ++++----- 5 files changed, 72 insertions(+), 553 deletions(-) delete mode 100644 tests/test_mtp.py diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ddb1521f..9d124d4d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,20 +5,14 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageConfig -from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTBaseModelConfig -from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda +from tests.utils.utils import get_base_model, get_stage, requires_cuda def _lm_head( @@ -100,13 +94,15 @@ def test_lm_head( config_dict, update_type=UpdateType.update, ) - distributed_config = DistributedConfig.from_dict(distributed_config_dict) - distributed = Distributed(distributed_config) - tensor_space = TensorSpace(distributed_config) - config.setup_tensor_space(tensor_space) - tensor_space.setup(distributed) - model = GPTBaseModel(config, distributed_config) - model.setup(distributed) + + model, distributed = get_base_model( + GPTModelConfig.from_dict( + { + "base_model": config, + "distributed": distributed_config_dict, + }, + ) + ) sequence_first = config.sequence_first or ( config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 @@ -114,9 +110,9 @@ def test_lm_head( input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( - distributed_config.optimization_dtype.torch + distributed.config.optimization_dtype.torch if config.transformer.full_precision_residual - else distributed_config.training_dtype.torch + else distributed.config.training_dtype.torch ), device=distributed.device, requires_grad=True, @@ -160,7 +156,7 @@ def test_lm_head( if config.tie_word_embeddings or config.prediction_heads > 1: logit_weight = ( torch.empty( - VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed_config.training_dtype.torch, device=distributed.device + VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.training_dtype.torch, device=distributed.device ) .normal_(config.transformer.init_method_std) .requires_grad_(True) @@ -174,18 +170,7 @@ def test_lm_head( head: LanguageModelHead = model[layer_index] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - stage = Stage( - config=StageConfig(), - base_model=[head], - distributed_config=distributed_config, - begin=0, - end=1, - index=0, - ) - stage.setup(distributed=distributed) - stage.initialize_weights() - stage.restore_parameters() - stage.reset_gradients() + stage = get_stage([head], distributed) # Get reference outputs and grads if logit_weight is None: @@ -230,9 +215,9 @@ def test_lm_head( output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed_config.training_dtype == DataType.float32 else 5e-3 + threshold = 1e-5 if distributed.config.training_dtype == DataType.float32 else 5e-3 min_threshold = ( - 1e-5 if distributed_config.training_dtype == DataType.float32 else 1e-4 + 1e-5 if distributed.config.training_dtype == DataType.float32 else 1e-4 ) * config.logits_scale_factor Assert.eq(losses.keys(), loss_keys) diff --git a/tests/test_functional.py b/tests/test_functional.py index 0689f4d8..9211259c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -57,9 +57,15 @@ def ref_packed_get_batch_logps( @pytest.mark.slow -@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("seq_length", [1024, 4096, 8192]) -@pytest.mark.parametrize("vocab_size", [1000, 2000, 8000]) +@pytest.mark.parametrize( + ("batch_size", "seq_length", "vocab_size"), + ( + (2, 32, 50), + (1, 32, 50), + (2, 100, 50), + (2, 32, 200), + ), +) def test_preference_logps(batch_size, seq_length, vocab_size): random.seed(0) torch.manual_seed(0) @@ -222,9 +228,9 @@ def test_dropless_mlp(): pytest.fail("Test fails, aborting to avoid breaking cuda", False) num_experts = 4 experts_per_token = 4 - tokens = 1024 - hidden_size = 2048 - ffn_hidden_size = 4096 + tokens = 256 + hidden_size = 512 + ffn_hidden_size = 1024 std = 1 / 64 input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) router_weight = torch.normal(0, std, (num_experts, hidden_size), device="cuda") diff --git a/tests/test_mtp.py b/tests/test_mtp.py deleted file mode 100644 index 1f01954e..00000000 --- a/tests/test_mtp.py +++ /dev/null @@ -1,209 +0,0 @@ -import typing - -import pytest -import torch - -from fast_llm.config import UpdateType -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.models.gpt.config import GPTBaseModelConfig -from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.utils import Assert -from tests.utils.utils import get_hybrid_config, materialize_meta_tensors, requires_cuda - -try: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel -except Exception: - MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( - None, - None, - None, - ) - # Mamba not installed, skipping tests - - -run_hybrid_test = MambaLayer is not None and DiscreteMamba2 is not None and torch.cuda.is_available() - - -SEQUENCE_LENGTH = 200 -BATCH_SIZE = 4 -HIDDEN_SIZE = 256 -VOCAB_SIZE = 500 - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -@requires_cuda -@pytest.mark.parametrize( - "config_dict", - ( - {"prediction_heads": 1}, - {"prediction_heads": 2, "tie_word_embeddings": False}, - {"prediction_heads": 5, "tie_word_embeddings": False}, - ), -) -def test_transformer_mtp(config_dict: dict[str, typing.Any]): - config = GPTBaseModelConfig.from_dict( - { - "transformer": { - "hidden_size": HIDDEN_SIZE, - "num_layers": 2, - }, - "vocab_size": VOCAB_SIZE, - }, - config_dict, - update_type=UpdateType.update, - ) - distributed_config = DistributedConfig.from_dict({}) - distributed = Distributed(distributed_config) - model = GPTBaseModel(config, distributed_config) - model.setup(distributed) - materialize_meta_tensors(model, model._tensor_space) - model.to("cuda") - - sequence_first = config.sequence_first or ( - config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 - ) - target = torch.randint( - 0, - VOCAB_SIZE, - ( - (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) - ), - dtype=torch.int64, - device=distributed.device, - ) - input_ = torch.randint( - 0, - VOCAB_SIZE, - (SEQUENCE_LENGTH, BATCH_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH), - device=distributed.device, - ) - attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) - position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda", dtype=torch.int64) - kwargs = { - "position_ids": position_ids, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.attention_mask: attention_mask, - TransformerKwargs.attention_mask_value: -100, - TransformerKwargs.grad_output: 1.0, - LanguageModelKwargs.labels: target, - } - if config.tie_word_embeddings: - kwargs[WORD_EMBEDDINGS_WEIGHT] = model.embedding.word_embeddings_weight - else: - kwargs[OUTPUT_WEIGHTS] = model.model_head.output_weights - losses = {LanguageModelLossNames.multi_token_prediction_loss(i): [] for i in range(model._config.prediction_heads)} - _ = model(input_, kwargs, losses=losses) - for loss_name, loss_values in losses.items(): - Assert.gt(len(loss_values), 0) - loss = sum( - [ - sum(losses[LanguageModelLossNames.multi_token_prediction_loss(i)]) - for i in range(model._config.prediction_heads) - ] - ) - loss.backward() - - -@pytest.mark.skip(reason="Too slow") -@requires_cuda -@pytest.mark.skipif(not run_hybrid_test, reason="No CUDA available or Mamba not installed") -@pytest.mark.parametrize( - ("hybrid_block_layout", "prediction_heads", "default_mtp_type"), - [ - ([SSMBlockType.mamba.value, SSMBlockType.transformer.value], 1, None), - ([SSMBlockType.transformer.value, SSMBlockType.mamba.value], 2, None), - ([SSMBlockType.mamba.value, SSMBlockType.transformer.value], 2, None), - ([SSMBlockType.transformer.value, SSMBlockType.mamba2_discrete.value], 3, None), - ([SSMBlockType.transformer.value, SSMBlockType.mamba2_discrete.value], 3, SSMBlockType.mamba.value), - ], -) -def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_heads, default_mtp_type): - hybrid_config = get_hybrid_config( - hybrid_block_layout=hybrid_block_layout, prediction_heads=prediction_heads, default_mtp_type=default_mtp_type - ) - model = HybridSSMBaseModel(hybrid_config, distributed_config) - distributed = Distributed(distributed_config) - model.setup(distributed) - tensor_space = model._tensor_space - materialize_meta_tensors(model, tensor_space) - model.to("cuda") - - num_heads, num_mtp_blocks = 0, 0 - str_block_mapping = { - SSMBlockType.transformer: TransformerLayer, - SSMBlockType.mamba: MambaLayer, - SSMBlockType.mamba2_discrete: DiscreteMamba2, - } - mtp_block_type = default_mtp_type or hybrid_block_layout[-1] - for block in model.get_output_layers(): - if isinstance(block, LanguageModelHead): - num_heads += 1 - else: - block = getattr(block, "mixer", block) - Assert.custom( - lambda _: isinstance(block, str_block_mapping[mtp_block_type]), - f"Block {block} is not of type {str_block_mapping[mtp_block_type]}", - ) - num_mtp_blocks += 1 - Assert.eq(num_heads, prediction_heads) - Assert.eq(num_mtp_blocks, prediction_heads - 1) - - batch_size = 2 - seq_length = 32 - x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") - position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) - attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape - labels = torch.randint(0, 49152, (batch_size, seq_length + model._config.prediction_heads - 1), device="cuda") - losses = {LanguageModelLossNames.multi_token_prediction_loss(i): [] for i in range(model._config.prediction_heads)} - kwargs = { - "position_ids": position_ids, - TransformerKwargs.sequence_first: False, - TransformerKwargs.attention_mask: attention_mask, - TransformerKwargs.attention_mask_value: -100, - TransformerKwargs.grad_output: True, - LanguageModelKwargs.labels: labels, - } - - if model._config.tie_word_embeddings: - kwargs[WORD_EMBEDDINGS_WEIGHT] = model.embedding.word_embeddings_weight - else: - kwargs[OUTPUT_WEIGHTS] = model.model_head.output_weights - - output = model( - x, - kwargs, - losses=losses, - ) - loss = sum( - [ - sum(losses[LanguageModelLossNames.multi_token_prediction_loss(i)]) - for i in range(model._config.prediction_heads) - ] - ) - loss.backward() diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 9e748544..52b51c8a 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -1,75 +1,31 @@ import pathlib -from functools import partial import pytest import torch from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames -from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel -from tests.utils.utils import get_hybrid_config, materialize_meta_tensors +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.model import HybridSSMModel try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel except ImportError: LMHeadModel = None -run_test = MambaLayer is not None and torch.cuda.is_available() - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -def get_hf_llamba_out(input_ids, path, format): - if format == LLambaHuggingfaceCheckpointFormat: - from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel - elif format == LlamaGPTHuggingfaceCheckpointFormat: - from transformers import LlamaForCausalLM as LMHeadModel - else: - raise ValueError(f"Invalid format: {format}") - - model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) - print(f"Parameter sum: {parameter_sum}") - output = model(input_ids) - del model - torch.cuda.empty_cache() - return output, parameter_sum - @pytest.mark.slow @pytest.mark.skipif( - not run_test or LMHeadModel is None, - reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", + LMHeadModel is None, + reason=f"cartesia_pytorch.Llamba not installed", ) -def test_load_from_llamba_checkpoint(distributed_config): +def test_load_from_llamba_checkpoint(): """ Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. """ @@ -81,8 +37,12 @@ def test_load_from_llamba_checkpoint(distributed_config): format = LLambaHuggingfaceCheckpointFormat x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) - hf_logits = hf_logits["logits"].cpu() + + hf_model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") + parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) + hf_logits = hf_model(x)["logits"].cpu() + del hf_model + torch.cuda.empty_cache() # Create checkpoint load config checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) @@ -100,7 +60,7 @@ def test_load_from_llamba_checkpoint(distributed_config): schedule_config = ScheduleConfig() with NoAutoValidate(): batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(distributed_config) + batch_config.setup(DistributedConfig.from_dict({})) batch_config.validate() schedule_runner = ScheduleRunner( config=schedule_config, @@ -122,221 +82,7 @@ def test_load_from_llamba_checkpoint(distributed_config): } input_data = [(x, common_kwargs)] - losses, success, metrics = schedule_runner.run_step( - iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True - ) + schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) logits = input_data[0][1]["logits"].cpu() assert torch.allclose(logits, hf_logits, atol=1e-2) - - -def get_hf_apriel_hybrid_out(input_ids, path, format): - from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM - - model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") - parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) - print(f"Parameter sum: {parameter_sum}") - output = model(input_ids) - del model - torch.cuda.empty_cache() - return output, parameter_sum - - -@pytest.mark.slow -@pytest.mark.skipif( - not run_test - and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), - reason=f"Skipping because no CUDA available or Mamba not installed", -) -def test_load_from_hybridssm_checkpoint(distributed_config): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") - format = AprielSSMHHybridHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) - hf_logits = hf_logits["logits"].cpu() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - -@pytest.mark.extra_slow -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -@pytest.mark.parametrize( - "hybrid_block_layout,LAYER_CLS", - [ - ([SSMBlockType.mamba, SSMBlockType.transformer], MambaLayer), - ([SSMBlockType.mamba2_discrete, SSMBlockType.transformer], DiscreteMamba2), - ], - ids=["mamba", "discrete_mamba2"], -) -def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): - hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) - tensor_space = TensorSpace(distributed_config=distributed_config) - hybrid_config.setup_tensor_space(tensor_space) - layer = LAYER_CLS(hybrid_config.ssm, layer_idx=0, tensor_space=tensor_space) - tensor_space.setup(distributed) - materialize_meta_tensors(layer, tensor_space) - layer.to(distributed.device) - - batch_size = 2 - seq_length = 32 - hidden_size = hybrid_config.transformer.hidden_size - x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) - - # Run forward pass - output, _ = layer(x, {}) - - loss = output.sum() - loss.backward() - # Basic shape checkss - assert output.shape == x.shape - assert not torch.isnan(output).any() - assert not torch.isinf(output).any() - - -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -def test_mamba_block(distributed_config, distributed): - hybrid_config = get_hybrid_config(hybrid_block_layout=["m", "t"]) - tensor_space = TensorSpace(distributed_config=distributed_config) - tensor_space.setup(distributed) - hybrid_config.setup_tensor_space(tensor_space) - layer_idx = 0 - - mixer_cls = partial(MambaLayer, layer_idx=layer_idx) - block = LlambaBlock( - hybrid_config.transformer, - hybrid_config.ssm, - mixer_cls=mixer_cls, - tensor_space=tensor_space, - layer_index=layer_idx, - ) - - materialize_meta_tensors(block, tensor_space) - block.to("cuda") - - batch_size = 2 - seq_length = 32 - hidden_size = hybrid_config.transformer.hidden_size - x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) - - hidden_states = block(x, {}) - loss = hidden_states.sum() - loss.backward() - - assert hidden_states.shape == x.shape - assert not torch.isnan(hidden_states).any() - assert not torch.isinf(hidden_states).any() - - -@pytest.mark.slow -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -@pytest.mark.parametrize( - ("hybrid_block_layout"), - [ - (["m", "t"]), - (["m2d", "t"]), - ], - ids=["mamba", "discrete_mamba2"], -) -def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layout): - hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) - model = HybridSSMBaseModel(hybrid_config, distributed_config) - distributed = Distributed(distributed_config) - model.setup(distributed) - tensor_space = model._tensor_space - materialize_meta_tensors(model, tensor_space) - model.to("cuda") - - batch_size = 2 - seq_length = 32 - x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") - position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) - attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape - labels = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") - losses = {LanguageModelLossNames.language_model_loss: []} - output = model( - x, - { - "position_ids": position_ids, - TransformerKwargs.sequence_first: False, - TransformerKwargs.attention_mask: attention_mask, - TransformerKwargs.attention_mask_value: -100, - TransformerKwargs.grad_output: True, - LanguageModelKwargs.labels: labels, - }, - losses=losses, - ) - loss = sum(losses[LanguageModelLossNames.language_model_loss]) - loss.backward() - - -# TODO: added this when inference enabled -# No inference for now -# @dataclass -# class InferenceParams: -# max_seqlen: int -# max_batch_size: int -# sequence_len_offset: int = 0 -# key_value_memory_dict: dict = None - -# def __post_init__(self): -# if self.key_value_memory_dict is None: -# self.key_value_memory_dict = {} - - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") -# def test_hybrid_model_inference(distributed_config, hybrid_config): -# hybrid_config.ssm.use_fast_path = False -# model = HybridSSMBaseModel(hybrid_config, distributed_config) -# distributed = Distributed(distributed_config) -# model.setup(distributed) -# tensor_space = model._tensor_space -# materialize_meta_tensors(model, tensor_space) -# model.to("cuda") -# # print(model) - -# batch_size = 2 -# seq_length = 32 -# x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") -# position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) -# attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape -# labels = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") -# max_new_tokens = 10 - -# inference_params = InferenceParams( -# max_seqlen=len(x[0]) + max_new_tokens, max_batch_size=x.shape[0], sequence_len_offset=0 -# ) -# losses = {LanguageModelLossNames.language_model_loss: []} - -# output = model( -# x, -# { -# "position_ids": position_ids, -# TransformerKwargs.sequence_first: True, -# TransformerKwargs.attention_mask: attention_mask, -# TransformerKwargs.attention_mask_value: -100, -# TransformerKwargs.grad_output: True, -# LanguageModelKwargs.labels: labels, -# "inference_params": inference_params, -# }, -# losses=losses, -# ) - -if __name__ == "__main__": - pytest.main(["-s", __file__]) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index f37c1cb2..11b7e403 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -4,49 +4,40 @@ import pytest import torch -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig +from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig +from fast_llm.engine.multi_stage.stage import Stage TEST_RESULTS_PATH = pathlib.Path(os.environ.get("TEST_RESULTS_PATH", "/tmp/fast_llm_tests")).resolve() requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model +def get_base_model(config: FastLLMModelConfig): + # Create a base model (and distributed). + # Using a full model config so we have the model type and distributed config in the same argument. + distributed = Distributed(config.distributed) + tensor_space = TensorSpace(config.distributed) + config.base_model.setup_tensor_space(tensor_space) + tensor_space.setup(distributed) + base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) + base_model.setup(distributed) + return base_model, distributed -def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), - ssm=SSMConfig(), - hybrid_block_layout=hybrid_block_layout, - prediction_heads=prediction_heads, - default_mtp_type=default_mtp_type, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, +def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): + # Create a fast-llm stage which allocates and initializes meta tensors correctly. + stage = Stage( + config=StageConfig(), + base_model=base_model, + distributed_config=distributed.config, + begin=0, + end=1, + index=0, ) - return config + stage.setup(distributed=distributed) + stage.initialize_weights() + stage.restore_parameters() + stage.reset_gradients() + return stage From 58677d291f37d4625307b80b323e19264b53957f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 17:31:51 -0400 Subject: [PATCH 35/69] fix --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba_layer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index ecf0b29d..31e81e99 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -17,7 +17,7 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined as _mamba_chunk_scan_combined # noqa _mamba_available = True -except ImportError: +except (ImportError, RuntimeError): _mamba_available = False @@ -25,7 +25,7 @@ from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa _causal_conv1d_available = True -except ImportError: +except (ImportError, RuntimeError): _causal_conv1d_available = False diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7fd43789..7c824d23 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -14,7 +14,7 @@ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa _mamba_available = True -except ImportError: +except (ImportError, RuntimeError): _mamba_available = False """ From e125fa9ff06f9ae148af41e14cae1c58717c88a7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 13 Jun 2025 17:37:01 -0400 Subject: [PATCH 36/69] move to directory --- tests/models/__init__.py | 0 tests/{ => models}/test_checkpoint.py | 0 .../{test_gpt_generate_and_forward.py => models/test_generate.py} | 0 tests/{ => models}/test_match_megatron.py | 0 tests/{ => models}/test_mb.py | 0 tests/{ => models}/test_mb_seq_first.py | 0 tests/{ => models}/test_ms.py | 0 tests/{ => models}/test_seq_first.py | 0 tests/{ => models}/test_simple.py | 0 9 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/models/__init__.py rename tests/{ => models}/test_checkpoint.py (100%) rename tests/{test_gpt_generate_and_forward.py => models/test_generate.py} (100%) rename tests/{ => models}/test_match_megatron.py (100%) rename tests/{ => models}/test_mb.py (100%) rename tests/{ => models}/test_mb_seq_first.py (100%) rename tests/{ => models}/test_ms.py (100%) rename tests/{ => models}/test_seq_first.py (100%) rename tests/{ => models}/test_simple.py (100%) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_checkpoint.py b/tests/models/test_checkpoint.py similarity index 100% rename from tests/test_checkpoint.py rename to tests/models/test_checkpoint.py diff --git a/tests/test_gpt_generate_and_forward.py b/tests/models/test_generate.py similarity index 100% rename from tests/test_gpt_generate_and_forward.py rename to tests/models/test_generate.py diff --git a/tests/test_match_megatron.py b/tests/models/test_match_megatron.py similarity index 100% rename from tests/test_match_megatron.py rename to tests/models/test_match_megatron.py diff --git a/tests/test_mb.py b/tests/models/test_mb.py similarity index 100% rename from tests/test_mb.py rename to tests/models/test_mb.py diff --git a/tests/test_mb_seq_first.py b/tests/models/test_mb_seq_first.py similarity index 100% rename from tests/test_mb_seq_first.py rename to tests/models/test_mb_seq_first.py diff --git a/tests/test_ms.py b/tests/models/test_ms.py similarity index 100% rename from tests/test_ms.py rename to tests/models/test_ms.py diff --git a/tests/test_seq_first.py b/tests/models/test_seq_first.py similarity index 100% rename from tests/test_seq_first.py rename to tests/models/test_seq_first.py diff --git a/tests/test_simple.py b/tests/models/test_simple.py similarity index 100% rename from tests/test_simple.py rename to tests/models/test_simple.py From d164f25718878aae5c4724985513912356310f12 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 16 Jun 2025 17:05:07 -0400 Subject: [PATCH 37/69] fixes --- setup.cfg | 2 ++ tests/test_match_megatron.py | 1 + 2 files changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index fff7503a..b3b1df03 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,6 +57,8 @@ DEV = pytest-xdist>=3.7.0 # Somehow needed for Megatron to work with base image 24.11 setuptools>=80.9.0 + # dependency manager needs it. + colorama>=0.4.6 # Required for building the documentation DOCS = diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 9f861464..5e7f3d37 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -16,6 +16,7 @@ def test_megatron(run_test_script): ] +@pytest.mark.slow @pytest.mark.depends_on(on=["test_megatron"]) def test_match_megatron(run_test_script): if CONFIG_MEGATRON is None: From 917912789f923290b0d6f9b0dec03ae86daf662e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 16 Jun 2025 17:38:08 -0400 Subject: [PATCH 38/69] fix --- tests/utils/model_configs.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 19135815..481ec611 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -301,13 +301,21 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests yarn-style rotary embeddings. + # Tests diffusion llama converter. "llama_yarn", "diffusion_llama", extra_args=[], # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=DiffusionLlamaGPTHuggingfaceCheckpointFormat, + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( @@ -347,13 +355,21 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Diffusion dream converter. + # Tests diffusion dream converter. "qwen2", "dream", extra_args=[], # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=DiffusionDreamGPTHuggingfaceCheckpointFormat, + # TODO: Add back generate as `normal` when stable. + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, ) _update_and_add_testing_config( From d97e4c10c209da96339152b446f3f1f7b9305566 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 16 Jun 2025 18:21:14 -0400 Subject: [PATCH 39/69] fix --- tests/models/test_checkpoint.py | 6 +++--- tests/utils/model_configs.py | 19 +++++++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 5132ba4f..9cf60e91 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -22,7 +22,7 @@ _WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" -@pytest.mark.model_testing_group(ModelTestingGroup.basic) +@pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). run_test_script_for_all_models( @@ -56,7 +56,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.basic) +@pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume(run_test_script_for_all_models): # Resume from iteration=1 and compare outputs with the baseline run. run_test_script_for_all_models( @@ -72,7 +72,7 @@ def test_resume(run_test_script_for_all_models): @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.basic) +@pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. run_test_script_for_all_models( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 481ec611..3d654a0f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -26,10 +26,11 @@ class ModelTestingGroup(enum.StrEnum): basic = "basic" - megatron = "megatron" - distributed = "distributed" + checkpoint = "checkpoint" convert = "convert" generate = "generate" + megatron = "megatron" + distributed = "distributed" class ModelTestingGroupAction(enum.StrEnum): @@ -186,6 +187,7 @@ def _update_and_add_testing_config( checkpoint_format=None, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, @@ -202,6 +204,7 @@ def _update_and_add_testing_config( checkpoint_format=None, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, @@ -229,6 +232,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, @@ -259,6 +263,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.main, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, ModelTestingGroup.convert: ModelTestingGroupAction.main, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, @@ -276,6 +281,7 @@ def _update_and_add_testing_config( checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -293,6 +299,7 @@ def _update_and_add_testing_config( checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -311,6 +318,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -329,6 +337,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -347,6 +356,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -365,6 +375,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -383,6 +394,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, @@ -406,6 +418,7 @@ def _update_and_add_testing_config( # TODO: New base image broke mixtral groups={ ModelTestingGroup.basic: ModelTestingGroupAction.broken, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.broken, @@ -430,6 +443,7 @@ def _update_and_add_testing_config( # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` ModelTestingGroup.generate: ModelTestingGroupAction.broken, @@ -452,6 +466,7 @@ def _update_and_add_testing_config( checkpoint_format=None, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, From c95e8ebee8f6afc61450cda6e9644fc76ad5772f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 17 Jun 2025 16:12:23 -0400 Subject: [PATCH 40/69] Fix dropless mlp --- fast_llm/functional/config.py | 2 +- fast_llm/functional/triton/sparse_copy.py | 5 + fast_llm/functional/triton/sparse_linear.py | 27 +-- setup.cfg | 4 +- tests/conftest.py | 3 + tests/functional/__init__.py | 0 tests/{ => functional}/test_functional.py | 4 +- tests/functional/test_sparse_matmul.py | 154 ++++++++++++++++++ tests/{ => functional}/test_triton_kernels.py | 0 tests/utils/model_configs.py | 12 +- 10 files changed, 190 insertions(+), 21 deletions(-) create mode 100644 tests/functional/__init__.py rename tests/{ => functional}/test_functional.py (98%) create mode 100644 tests/functional/test_sparse_matmul.py rename tests/{ => functional}/test_triton_kernels.py (100%) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 22f23174..0b7b14ab 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -15,7 +15,7 @@ class TritonConfig: MAX_BLOCK_SIZE_BYTES = 65536 -class MLPRecomputeLevel(str, enum.Enum): +class MLPRecomputeLevel(enum.StrEnum): none = "none" activation = "activation" activation_and_input = "activation_and_input" diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 258a2578..7c803689 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -11,10 +11,15 @@ @dataclasses.dataclass() class SparseMap: sparse_rows: torch.Tensor + # The end row for each expert, including padding. `expert_ends[i] = expert_begins[i] + padded_tokens_per_expert[i]` expert_ends: torch.Tensor + # The end row for each expert, excluding padding. `expert_pad_begins[i] = expert_begins[i] + unpadded_tokens_per_expert[i]` expert_pad_begins: torch.Tensor + # The number of rows un the dense tensor, i.e., the number of tokens. num_rows_dense: int + # The number of sparse rows, including padding. `num_rows = expert_ends[-1]` num_rows: int + # The number of sparse rows, excluding padding. `num_rows_unpadded = num_rows_dense * num_experts_per_token` num_rows_unpadded: int num_experts: int num_experts_per_token: int diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index 9a086494..ae46655e 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -1,10 +1,12 @@ +import os + import torch from fast_llm.functional.triton import TritonConfig, tl, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.utils import Assert, div -autotune_configs = [ +autotune_configs = ( TritonConfig( {"block_size_row": 128, "block_size_col": 256, "block_size_inner": 64, "group_size_row": 8}, num_stages=3, @@ -45,7 +47,10 @@ num_stages=5, num_warps=2, ), -] +) + +if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): + autotune_configs = (autotune_configs[2],) @triton_autotune( @@ -255,13 +260,13 @@ def output_sparse_matmul_kernel( def output_sparse_matmul( lhs: torch.Tensor, rhs: torch.Tensor, - sparse_map: SparseMap | None, + sparse_map: SparseMap | None = None, out: torch.Tensor | None = None, accumulate: bool = False, ) -> torch.Tensor: """ - Output-sparse matrix multiplication with a sparse column dimension, - i.e., with a mapping row_index -> sparse_index (obtained from expert_ends). + Output-sparse matrix multiplication with a sparse column dimension + and a mapping row_index -> sparse_index (obtained from expert_ends). Ex.: MLP layer 1 forward (Y = X x W1^T), MLP layer 2 input grad (gY = gZ x W2). Formula: out[i, js] = sum_k(lhs[i, k] * rhs[k, jd]), where jd = js + col_sparse_dim * sparse_index[i] sparse_index[i] = sum(expert_ends <= i) @@ -381,13 +386,13 @@ def input_inner_sparse_matmul_kernel( def input_inner_sparse_matmul( lhs: torch.Tensor, rhs: torch.Tensor, - sparse_map: SparseMap | None, + sparse_map: SparseMap | None = None, out: torch.Tensor | None = None, accumulate: bool = False, ) -> torch.Tensor: """ - Left-input-sparse matrix multiplication with a sparse inner dimension, - i.e., with a mapping row_index -> sparse_index (obtained from expert_ends). + Left-input-sparse matrix multiplication with a sparse inner dimension + and a mapping row_index -> sparse_index (obtained from expert_ends). Ex.: MLP layer 2 forward (Z = Y x W2^T), MLP layer 1 input grad (gX = gY x W1). Formula: out[i, j] = sum_ks(lhs[i, ks] * rhs[kd, j]), where kd = ks + inner_sparse_dim * sparse_index[i] sparse_index[i] = sum(expert_ends <= i) @@ -511,13 +516,13 @@ def input_row_sparse_matmul_kernel( def input_row_sparse_matmul( lhs: torch.Tensor, rhs: torch.Tensor, - sparse_map: SparseMap | None, + sparse_map: SparseMap | None = None, out: torch.Tensor | None = None, accumulate: bool = False, ) -> torch.Tensor: """ - Left-input-sparse matrix multiplication with a sparse row dimension, - i.e., with a mapping inner_index -> sparse_index. + Left-input-sparse matrix multiplication with a sparse row dimension + and a mapping inner_index -> sparse_index. Ex.: MLP layer 1 weight grad (gW1 = gY^T x X), MLP layer 2 weight grad (gW2^T = Y^T x gZ). Formula: out[id, j] = sum_ks(lhs[is, ks] * rhs[ks, j]), where sparse_begin[sparse_index[id]] <= ks < sparse_end[sparse_index[id]], diff --git a/setup.cfg b/setup.cfg index b3b1df03..3b79a1d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,8 @@ CORE = safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 + # Dropless MLP is broken with triton 3.3.0 and 3.3.1, probably because of a bug in triton. TODO: Fix + triton==3.2.0 # Small packages required for some optional features and tools. @@ -57,7 +59,7 @@ DEV = pytest-xdist>=3.7.0 # Somehow needed for Megatron to work with base image 24.11 setuptools>=80.9.0 - # dependency manager needs it. + # Dependency manager needs colorama to show colors. colorama>=0.4.6 # Required for building the documentation diff --git a/tests/conftest.py b/tests/conftest.py index 0d25fc5a..11757176 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,6 +113,9 @@ def pytest_configure(config): rendezvous_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id + 1, ) + # Skip slow autotune for tests. The default config has the highest block size, so this shouldn't hide any bug. + os.environ["FAST_LLM_SKIP_TRITON_AUTOTUNE"] = "TRUE" + @pytest.hookimpl(trylast=True) def pytest_collection_modifyitems(config, items: list[pytest.Function]): diff --git a/tests/functional/__init__.py b/tests/functional/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_functional.py b/tests/functional/test_functional.py similarity index 98% rename from tests/test_functional.py rename to tests/functional/test_functional.py index 9211259c..3ddd5d4f 100644 --- a/tests/test_functional.py +++ b/tests/functional/test_functional.py @@ -224,8 +224,6 @@ def test_mlp_recomputation(gated, activation_type): @pytest.mark.slow @requires_cuda def test_dropless_mlp(): - # TODO: Fix dropless MOE - pytest.fail("Test fails, aborting to avoid breaking cuda", False) num_experts = 4 experts_per_token = 4 tokens = 256 @@ -273,7 +271,7 @@ def test_dropless_mlp(): sparse_map = get_sparse_map(top_experts, num_experts) for i, recompute_level in enumerate(MLPRecomputeLevel): - print(recompute_level.value) # noqa + print("recompute_level", recompute_level) # noqa input_.grad = None scores.grad = None for param in params: diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py new file mode 100644 index 00000000..899dad96 --- /dev/null +++ b/tests/functional/test_sparse_matmul.py @@ -0,0 +1,154 @@ +import dataclasses +import functools + +import pytest +import torch + +from fast_llm.functional.triton.sparse_copy import SparseMap +from fast_llm.functional.triton.sparse_linear import ( + dense_matmul, + input_inner_sparse_matmul, + input_row_sparse_matmul, + output_sparse_matmul, +) +from fast_llm.utils import Assert +from tests.utils.utils import requires_cuda + + +@dataclasses.dataclass +class _SparseTestData: + dense_dim: int + sparse_dim: int + expert_ends: tuple[int, ...] + tokens_per_expert: tuple[int, ...] + std: float = 0.125 + + @functools.cached_property + def expert_begins(self) -> tuple[int, ...]: + return (0,) + self.expert_ends[:-1] + + @functools.cached_property + def expert_pad_begins(self) -> tuple[int, ...]: + return tuple( + expert_begin + expert_tokens + for expert_begin, expert_tokens in zip(self.expert_begins, self.tokens_per_expert, strict=True) + ) + + @functools.cached_property + def token_dim(self) -> int: + return self.expert_ends[-1] + + @property + def sparse_dim_expanded(self) -> int: + return self.sparse_dim * self.num_experts + + @functools.cached_property + def num_experts(self) -> int: + return len(self.expert_begins) + + @functools.cached_property + def sparse_map(self) -> SparseMap: + return SparseMap( + num_experts=self.num_experts, + expert_ends=torch.tensor(self.expert_ends, device="cuda"), + expert_pad_begins=torch.tensor(self.expert_pad_begins, device="cuda"), + num_rows=self.expert_ends[-1], + # Not needed + sparse_rows=None, + num_rows_dense=None, + num_rows_unpadded=None, + num_experts_per_token=None, + ) + + def normal(self, dim_0: int, dim_1: int) -> torch.Tensor: + return torch.normal(0, self.std, (dim_0, dim_1), device="cuda") + + +_SPARSE_TEST_DATAS = ( + _SparseTestData( + dense_dim=384, + sparse_dim=256, + expert_ends=(128, 384, 512), + tokens_per_expert=(78, 256, 54), + ), + _SparseTestData( + dense_dim=256, + sparse_dim=512, + expert_ends=(128, 256, 256, 384), + tokens_per_expert=(52, 125, 0, 97), + ), +) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_dense_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim) + + output = dense_matmul(lhs, rhs) + output_ref = torch.matmul(lhs, rhs) + Assert.rms_close(output, output_ref, 1e-3) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_output_sparse_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded) + + # Randomly initialize the output to ensure padded values have no effect. + out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) + output = output_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map, out) + + output_ref = torch.zeros_like(output) + for i in range(sparse_test_data.num_experts): + # Padded tokens are treated like regular ones. + output_ref[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]] = torch.matmul( + lhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]], + rhs[:, i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim], + ) + + Assert.rms_close(output, output_ref, 1e-3) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_input_inner_sparse_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) + rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim) + + output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + + output_ref = torch.zeros_like(output) + for i in range(sparse_test_data.num_experts): + # Padded tokens are treated like regular ones. + output_ref[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]] = torch.matmul( + lhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_ends[i]], + rhs[i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim], + ) + + Assert.rms_close(output, output_ref, 1e-3) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_input_row_sparse_matmul(sparse_test_data): + lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim) + rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) + + output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + + output_ref = torch.zeros_like(output) + for i in range(sparse_test_data.num_experts): + # Padded tokens are excluded from the sum. + output_ref[i * sparse_test_data.sparse_dim : (i + 1) * sparse_test_data.sparse_dim] = torch.matmul( + lhs[:, sparse_test_data.expert_begins[i] : sparse_test_data.expert_pad_begins[i]], + rhs[sparse_test_data.expert_begins[i] : sparse_test_data.expert_pad_begins[i]], + ) + + Assert.rms_close(output, output_ref, 1e-3) diff --git a/tests/test_triton_kernels.py b/tests/functional/test_triton_kernels.py similarity index 100% rename from tests/test_triton_kernels.py rename to tests/functional/test_triton_kernels.py diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3d654a0f..4c225422 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -315,11 +315,12 @@ def _update_and_add_testing_config( # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=DiffusionLlamaGPTHuggingfaceCheckpointFormat, + # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, @@ -372,11 +373,12 @@ def _update_and_add_testing_config( # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=DiffusionDreamGPTHuggingfaceCheckpointFormat, + # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, @@ -489,13 +491,13 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo for group in groups: action = model_config.groups[group] if action == ModelTestingGroupAction.main: - return True + pass elif action == ModelTestingGroupAction.normal and not skip_slow: - return True + pass elif ( action in (ModelTestingGroupAction.broken, ModelTestingGroupAction.unimportant) and not skip_extra_slow ): - return True + pass elif show_skipped: item.add_marker( pytest.mark.skip(reason=f"Skipping testing group {group} for model {model_testing_config}.") From 468ed7eb04446fdbd7ab3beb79a1f75e321a1b01 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 17 Jun 2025 16:42:46 -0400 Subject: [PATCH 41/69] fix --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3b79a1d0..b1e44e81 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,8 +24,8 @@ CORE = safetensors>=0.5.3 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 - # Dropless MLP is broken with triton 3.3.0 and 3.3.1, probably because of a bug in triton. TODO: Fix - triton==3.2.0 + # Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released. + triton==3.1.0 # Small packages required for some optional features and tools. From eb734bd5b880ee4e383fa2a9a88f6f262201f028 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 17 Jun 2025 16:55:48 -0400 Subject: [PATCH 42/69] fix --- tests/utils/model_configs.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3d654a0f..4c225422 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -315,11 +315,12 @@ def _update_and_add_testing_config( # Megatron doesn't support Yarn-style Rotary Embeddings megatron_args=None, checkpoint_format=DiffusionLlamaGPTHuggingfaceCheckpointFormat, + # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, @@ -372,11 +373,12 @@ def _update_and_add_testing_config( # Megatron doesn't support per sub layer biases. megatron_args=None, checkpoint_format=DiffusionDreamGPTHuggingfaceCheckpointFormat, + # TODO: Conversion is broken. # TODO: Add back generate as `normal` when stable. groups={ ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, @@ -489,13 +491,13 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo for group in groups: action = model_config.groups[group] if action == ModelTestingGroupAction.main: - return True + pass elif action == ModelTestingGroupAction.normal and not skip_slow: - return True + pass elif ( action in (ModelTestingGroupAction.broken, ModelTestingGroupAction.unimportant) and not skip_extra_slow ): - return True + pass elif show_skipped: item.add_marker( pytest.mark.skip(reason=f"Skipping testing group {group} for model {model_testing_config}.") From 1e16c157fa5a0b2845943bd7e03548ed66d2175c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 19 Jun 2025 15:48:12 -0400 Subject: [PATCH 43/69] stuff --- fast_llm/engine/checkpoint/distributed.py | 58 ++++++++---- fast_llm/engine/checkpoint/safe_load.py | 12 ++- fast_llm/engine/multi_stage/fsdp.py | 109 ++++++++++++++++------ fast_llm/tensor.py | 16 +++- 4 files changed, 144 insertions(+), 51 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index de1625f6..04a9461f 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -53,7 +53,7 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm})) shard_names = self.get_shard_names(config) # Make sure all shards to load are in the checkpoint. - Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards)) + Assert.leq(set(shard_names), set(loaded_metadata.shards)) Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. @@ -95,7 +95,13 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: ) path = config.path / f"rank_{rank}.safetensors" log_main_rank(f"Loading from {path}", log_fn=logger.info) - # TODO: skip shards without overlap. + + # First do a dry run to check if there is any overlap. + if not self._has_shard_overlaps(loaded_model): + # No overlap found, skip this file. + continue + + # TODO: Lazy loading? with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: # TODO: Use self_shard if "state_shard" in f.keys(): @@ -111,22 +117,36 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names } - for shard_name, loaded_shard in loaded_shards.items(): - loaded_model.get_shard_meta(shard_name).validate(loaded_shard) - - self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names} - - counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device) - for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): - for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): - self_fsdp.copy_shard_overlaps( - loaded_fsdp, - self_fsdp_shards, - loaded_fsdp_shards, - counter, - self._model.distributed.device, - ) - - context.mark_as_loaded(counter.item()) + self._copy_shard_overlaps(loaded_model, loaded_shards, context) return loaded_metadata.metadata + + def _has_shard_overlaps(self, loaded_model) -> bool: + for _, loaded_fsdp, _ in loaded_model.split_shards_by_fsdp({}): + for _, self_fsdp, _ in self._model.split_shards_by_fsdp({}): + counter = self_fsdp.copy_shard_overlaps( + loaded_fsdp, + None, + None, + self._model.distributed.device, + ) + if counter: + return True + return False + + def _copy_shard_overlaps(self, loaded_model, loaded_shards, context): + for shard_name, loaded_shard in loaded_shards.items(): + loaded_model.get_shard_meta(shard_name).validate(loaded_shard) + + self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in loaded_shards} + + for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): + for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): + counter = self_fsdp.copy_shard_overlaps( + loaded_fsdp, + self_fsdp_shards, + loaded_fsdp_shards, + self._model.distributed.device, + ) + for parameter, count in counter.items(): + context.mark_as_loaded(count, parameter, True) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index e72a3a15..8733bb0a 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -7,7 +7,6 @@ from fast_llm.core.distributed import add_ephemeral_timeout from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.functional.triton.pointwise import triton_fill -from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -48,14 +47,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: self._validate() - def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None) -> None: + def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None, partial: bool = False) -> None: self._loaded += count if parameter is not None: parameter_name, shard_name = parameter if shard_name not in self._loaded_parameters: self._loaded_parameters[shard_name] = {} - Assert.not_incl(parameter_name, self._loaded_parameters[shard_name]) - self._loaded_parameters[shard_name][parameter_name] = count + if not partial and parameter_name in self._loaded_parameters[shard_name]: + raise ValueError(f"Duplicate loaded parameter ({parameter_name}, {shard_name})") + self._loaded_parameters[shard_name][parameter_name] = ( + self._loaded_parameters[shard_name].get(parameter_name, 0) + count + ) def _validate(self) -> None: errors = [] @@ -105,7 +107,7 @@ def _check_missing(self, errors: list[str]) -> None: f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}" f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" ) - missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() + missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() if fsdp._global_pad > 0 else 0 if missing_for_pad > 0: global_total += missing_for_pad local_missing_for_pad = ( diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5cf51dd5..2c5ec212 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -1,3 +1,4 @@ +import math import typing import torch @@ -75,8 +76,8 @@ def __init__( ) # TODO: Use parallel_dim property instead? - weight_shard_dim = TensorDim("weight_shard", (self._parameter_count + self._global_pad) // self._fsdp_dim.size) - grad_shard_dim = TensorDim("grad_shard", weight_shard_dim.size if self._requires_grad else 0) + weight_shard_dim = TensorDim("weight_shard", self._shard_size) + grad_shard_dim = TensorDim("grad_shard", self._shard_size if self._requires_grad else 0) self._weight_shard_meta = TensorMeta.from_dims( (weight_shard_dim,), @@ -431,34 +432,90 @@ def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, devic def copy_shard_overlaps( self, - loaded_fsdp: "FSDP", - shards: dict[str, torch.Tensor], - loaded_shards: dict[str, torch.Tensor], - counter: torch.Tensor, + loaded_fsdp: typing.Self, + shards: dict[str, torch.Tensor] | None, + loaded_shards: dict[str, torch.Tensor] | None, + # counter: torch.Tensor, device: torch.device, - ) -> None: + ) -> dict[tuple[str, str], int]: """ See MultiStage._load_partial. - TODO: Not intended to work with frozen weights, need to enforce. """ - Assert.eq(set(shards), set(loaded_shards)) + if shards is not None: + Assert.eq(set(shards), set(loaded_shards)) index_overlap = [name for name in loaded_fsdp._parameter_metas if name in self._parameter_metas] - for name in index_overlap: - overlap_index_map = self.parameter_global_to_shard( - loaded_fsdp._get_parameter_shard_indices_in_full_weight(name, device), name - ) - overlap_mask = overlap_index_map >= 0 - overlap_index_map_masked = overlap_index_map[overlap_mask] - overlap_count = overlap_mask.sum() - begin, end = self._parameter_range_in_shard(name) - - for shard_name, shard in shards.items(): - # Shards can be empty (frozen weights) - if shard.numel() == 0: + counter = {} + for parameter_name in index_overlap: + self_meta = self._parameter_metas[parameter_name] + loaded_meta = loaded_fsdp._parameter_metas[parameter_name] + + if self_meta.is_tensor_parallel: + self_tp = self_meta.tensor_parallel_dim.size + loaded_tp = loaded_meta.tensor_parallel_dim.size + self_rank = self_meta.tensor_parallel_dim.rank + loaded_rank = loaded_meta.tensor_parallel_dim.rank + # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. + shared_tp = math.gcd(self_tp, loaded_tp) + + self_tp //= shared_tp + loaded_tp //= shared_tp + + if self_rank // self_tp != loaded_rank // loaded_tp: + # Disjoint shared rank, no possible overlap. + continue + self_rank %= self_tp + loaded_rank %= loaded_tp + else: + self_tp, loaded_tp, self_rank, loaded_rank = 1, 1, 0, 0 + + if self_tp == loaded_tp == 1: + self_shard_begin_in_buffer = self._fsdp_dim.rank * self._shard_size + self_shard_end_in_buffer = (self._fsdp_dim.rank + 1) * self._shard_size + self_shard_begin_in_param = self._index_buffer_to_param(self_shard_begin_in_buffer, parameter_name) + self_shard_end_in_param = self._index_buffer_to_param(self_shard_end_in_buffer, parameter_name) + self_param_begin_in_shard, _ = self._parameter_range_in_shard(parameter_name) + + loaded_shard_begin_in_buffer = loaded_fsdp._fsdp_dim.rank * loaded_fsdp._shard_size + loaded_shard_end_in_buffer = (loaded_fsdp._fsdp_dim.rank + 1) * loaded_fsdp._shard_size + loaded_shard_begin_in_param = loaded_fsdp._index_buffer_to_param( + loaded_shard_begin_in_buffer, parameter_name + ) + loaded_shard_end_in_param = loaded_fsdp._index_buffer_to_param( + loaded_shard_end_in_buffer, parameter_name + ) + loaded_param_begin_in_shard, _ = loaded_fsdp._parameter_range_in_shard(parameter_name) + + overlap_begin_in_param = max(self_shard_begin_in_param, loaded_shard_begin_in_param) + overlap_end_in_param = min(self_shard_end_in_param, loaded_shard_end_in_param) + + if (overlap_size := overlap_end_in_param - overlap_begin_in_param) <= 0: continue - if loaded_shards[shard_name].numel() == 0: - shard[begin:end][overlap_mask] = 0 - counter += overlap_count + + overlap_begin_in_self_shard = self_param_begin_in_shard + overlap_begin_in_param + + overlap_begin_in_loaded_shard = loaded_param_begin_in_shard + overlap_begin_in_param + + if shards is None: + # Dry run, we only want the counter. + Assert.not_incl((parameter_name, ""), counter) + counter[(parameter_name, "")] = overlap_size continue - shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] - counter += overlap_count + + for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + continue + Assert.not_incl((parameter_name, shard_name), counter) + counter[(parameter_name, shard_name)] = overlap_size + shard[overlap_begin_in_self_shard : overlap_begin_in_self_shard + overlap_size] = ( + loaded_shards[shard_name][ + overlap_begin_in_loaded_shard : overlap_begin_in_loaded_shard + overlap_size + ] + if loaded_shards[shard_name].numel() > 0 + else 0 + ) + + else: + raise NotImplementedError() + + return counter diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 84930756..b2849be8 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import functools import math import typing @@ -86,13 +87,26 @@ def __new__( data, ) - @property + @functools.cached_property def is_tensor_parallel(self) -> bool: # TODO: Avoid hard-coded assumptions on tensor parallel. return any( dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor for dim in self.dims ) + @functools.cached_property + def tensor_parallel_dim(self) -> DistributedDim | None: + # TODO: Avoid hard-coded assumptions on tensor parallel. + if not self.is_tensor_parallel: + return None + dims = [ + dim + for dim in self.dims + if dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor + ] + assert len(dims) == 1, dims + return dims[0].parallel_dim + def __repr__(self, *, tensor_contents=()): return super().__repr__( tensor_contents=", ".join((self.tensor_name, f"dims={self.dim_names}", *tensor_contents)) From c338d444e403b0147d182b480846d0db5060fd59 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 19 Jun 2025 16:36:41 -0400 Subject: [PATCH 44/69] fixes --- fast_llm/models/ssm/model.py | 5 ++--- tests/models/test_checkpoint.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 526d66c0..d6a2f7e1 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,14 +3,13 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType logger = logging.getLogger(__name__) @@ -135,7 +134,7 @@ def get_layers(self) -> list[Layer]: return layers -class HybridSSMModel[ConfigType: HybridSSMModelConfig](FastLLMModel[ConfigType]): +class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): """ A hybrid model that combines Transformer and SSM blocks. """ diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 39fd0840..aff7d991 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -30,7 +30,7 @@ def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_confi + [ "training.checkpoint.interval=1", "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluators.iterations=1", + "training.evaluators.validation.evaluator.iterations=1", ], ) @@ -63,7 +63,7 @@ def test_resume(run_test_script_for_all_models): [ "training.checkpoint.interval=1", "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluators.iterations=1", + "training.evaluators.validation.evaluator.iterations=1", ], compare=f"test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, @@ -79,7 +79,7 @@ def test_resume_frozen(run_test_script_for_all_models): [ "training.checkpoint.interval=1", "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluators.iterations=1", + "training.evaluators.validation.evaluator.iterations=1", "model.base_model.transformer.mlp_lr_scale=0.", ], compare="test_checkpoint_and_eval", @@ -442,7 +442,12 @@ def test_run_converted_model(model_testing_config, convert_paths): ) errors = [] compare = CompareConfig() - model_as_hf = transformers.AutoModel.from_pretrained( + auto_model = ( + transformers.AutoModel + if model_testing_config.name in ("diffusion_llama", "dream") + else transformers.AutoModelForCausalLM + ) + model_as_hf = auto_model.from_pretrained( convert_paths["huggingface_0"], trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() for name, model in zip( From 452397c1153ad79498102c024c9984f1fd479c55 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Jun 2025 18:17:39 -0400 Subject: [PATCH 45/69] stuff --- fast_llm/engine/checkpoint/safe_load.py | 79 +++++++------- fast_llm/engine/multi_stage/multi_stage.py | 7 ++ tests/test_gpt_loss.py | 121 --------------------- 3 files changed, 44 insertions(+), 163 deletions(-) delete mode 100644 tests/test_gpt_loss.py diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 8733bb0a..c68a4d2c 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -5,6 +5,7 @@ from torch.distributed import all_reduce from fast_llm.core.distributed import add_ephemeral_timeout +from fast_llm.engine.multi_stage.config import ShardName from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.functional.triton.pointwise import triton_fill @@ -129,52 +130,46 @@ def _check_missing(self, errors: list[str]) -> None: ) def _check_parameters(self, errors: list[str]) -> None: - loaded_shard_names = set(self._loaded_parameters) - shard_names = set(self._self_shards) - if loaded_shard_names != shard_names: - errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}") - for shard_name in shard_names & loaded_shard_names: - counter_per_parameter = { - parameter_name: self._loaded_parameters[shard_name].pop(parameter_name, None) - for parameter_name in self._model.parameter_names - } - for parameter_name, count in self._loaded_parameters[shard_name].items(): - errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})') - for parameter_name, counter in counter_per_parameter.items(): + if set(self._loaded_parameters) != set(self._self_shards): + errors.append(f"Incorrect loaded shards: {tuple(self._loaded_parameters)}!={tuple(self._self_shards)}") + + # Get a local count for each model parameter and shard. + counters = [] + for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: + for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: + counter = self._loaded_parameters[shard_name].pop(parameter_meta.tensor_name, 0) if self._model.is_parameter_on_device(parameter_name): - if counter is None: + if counter == 0: errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"') - elif counter is not None and counter > 0: + elif counter > 0: errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') - if self._distributed.world_group is not None: - counter_list = [] - for parameter_name, counter in counter_per_parameter.items(): - parameter_stage = self._model.get_parameter_stage(parameter_name) - parameter_meta = parameter_stage.get_parameter_meta(parameter_name) - if ( - counter is None - or (not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0) - or parameter_stage.is_tied_weight_copy - ): - # Ignore the counter from missing or duplicate tensors. - counter = 0 - counter_list.append(counter) - - counter_tensor = torch.tensor(counter_list, dtype=torch.int64).to(self._distributed.device) - - add_ephemeral_timeout(self._distributed.world_group, self._timeout) - all_reduce(counter_tensor, group=self._distributed.world_group) - counter_per_parameter = { - parameter_name: counter - for parameter_name, counter in zip(counter_per_parameter, counter_tensor.tolist()) - } - for parameter_name, counter in counter_per_parameter.items(): - parameter_size = ( - self._model.get_parameter_stage(parameter_name) - .get_parameter_meta(parameter_name) - .global_shape.numel() - ) + if ( + not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0 + ) or stage.is_tied_weight_copy: + # Ignore the counter from duplicate tensors. + counter = 0 + counters.append(counter) + + # Check for unexpected parameters. + for shard_name, loaded in self._loaded_parameters.items(): + for parameter_name, count in loaded.items(): + errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})') + + # All-reduce to get global counts. + if self._distributed.world_group is not None: + counter_tensor = torch.tensor(counters, dtype=torch.int64).to(self._distributed.device) + # This may be the first distributed barrier after loading, so we need to wait for everyone to finish. + add_ephemeral_timeout(self._distributed.world_group, self._timeout) + all_reduce(counter_tensor, group=self._distributed.world_group) + counters = counter_tensor.tolist() + + # Compare global counts against expected values. + for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: + for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: + counter = counters.pop(0) + parameter_size = parameter_meta.global_shape.numel() if counter != parameter_size: errors.append( f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}' ) + assert not counters diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 00570be9..71e22ed9 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -453,6 +453,13 @@ def distributed(self) -> Distributed: assert self._is_setup return self._distributed + @property + def stages_fsdp_parameters(self) -> typing.Generator[tuple[Stage, FSDP, str, ParameterMeta], None, None]: + for stage in self._stages: + for fsdp in stage.fsdps: + for parameter_name in fsdp.parameter_names: + yield stage, fsdp, parameter_name, stage.get_parameter_meta(parameter_name) + def invalidate_buffers(self) -> None: for stage in self._stages_on_device.values(): stage.invalidate_buffer() diff --git a/tests/test_gpt_loss.py b/tests/test_gpt_loss.py deleted file mode 100644 index 89262eca..00000000 --- a/tests/test_gpt_loss.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.optimizer.config import OptimizerConfig -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig -from tests.test_gpt_generate_and_forward import model_and_tokenizer # noqa: F401 -from tests.utils.utils import requires_cuda - - -def _get_model_runner_schedule( - model_path: str, - use_flash_attention: bool, - use_bf16: bool, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, - phase=PhaseType.inference, -): - assert phase == PhaseType.inference or phase == PhaseType.validation - updates = { - ("pretrained", "path"): model_path, - ("pretrained", "model_weights"): True, - ("pretrained", "format"): checkpoint_format.name, - ("model", "base_model", "cross_entropy_impl"): "fused", - ("model", "multi_stage", "zero_stage"): 2, - } - - if use_flash_attention: - updates[("model", "base_model", "transformer", "use_flash_attention")] = True - updates[("model", "distributed", "training_dtype")] = "bf16" - else: - updates[("model", "base_model", "transformer", "use_flash_attention")] = False - if use_bf16: - updates[("model", "distributed", "training_dtype")] = "bf16" - - config = PretrainedGPTModelConfig.from_dict({}, updates) - multi_stage = config.model.get_model_class()( - config.model, optimizer_state_names=OptimizerConfig.state_names() if phase == PhaseType.validation else () - ) - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=2048, batch_size=2) - batch_config.setup(config.model.distributed) - batch_config.validate() - - schedule = Schedule( - multi_stage=multi_stage, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=config.model.distributed, - phase=phase, - ) - - runner = ScheduleRunner( - config=schedule_config, - multi_stage=multi_stage, - distributed_config=config.model.distributed, - ) - - distributed = Distributed(config.model.distributed) - - with torch.no_grad(): - multi_stage.setup(distributed) - - with torch.no_grad(): - runner.setup(distributed) - - multi_stage.load_checkpoint(config.pretrained) - - return multi_stage, runner, schedule, batch_config - - -def _test_for_phase(model_path, fast_llm_checkpoint_format, phase): - model, runner, schedule, batch_config = _get_model_runner_schedule( - model_path, True, True, fast_llm_checkpoint_format, phase - ) - - inputs = GPTBatch( - torch.randint( - 1, - model.config.base_model.vocab_size, - [2, batch_config.sequence_length + 1], - dtype=torch.int64, - generator=torch.Generator().manual_seed(42), - ) - ) - - iteration = 1 - - # we need to set phase to validation here so preprocess would crate labels from input - # so it is the same process for validation and inference phases - # otherwise we can add labels manually after preprocess for inference phase - batch = model.base_model.preprocess(inputs, phase=PhaseType.validation, iteration=iteration) - ((inputs_, kwargs),) = batch - kwargs[LanguageModelKwargs.phase] = phase - iter_losses, _, _ = runner.run_step( - iter((((inputs_, kwargs),),)), schedule, iteration=iteration, preprocessed=True - ) - - return iter_losses - - -# @pytest.mark.extra_slow -@requires_cuda -def test_loss_validation_vs_inference(model_and_tokenizer): - model_path, _, fast_llm_checkpoint_format = model_and_tokenizer - - iter_losses_validation = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.validation) - - iter_losses_inference = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.inference) - - assert len(iter_losses_validation) == len(iter_losses_inference) - for key in iter_losses_validation.keys(): - assert math.isclose(iter_losses_validation[key], iter_losses_inference[key], rel_tol=1e-5) From 0329424a342d33720c53a69428760439207eb2b3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Jun 2025 18:34:45 -0400 Subject: [PATCH 46/69] fix --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index d67729d3..b583834d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,7 +30,7 @@ ENV PIP_CONSTRAINT="" # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@74729d0" +RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ From ec33c6fccb19a82f994d5924440b8e3ea8072723 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 20 Jun 2025 19:58:12 -0400 Subject: [PATCH 47/69] fixes --- fast_llm/engine/checkpoint/safe_load.py | 21 ++++++++++++------ fast_llm/engine/multi_stage/fsdp.py | 27 ++++++++++++++++------- fast_llm/engine/multi_stage/stage_base.py | 2 +- tests/utils/model_configs.py | 3 +++ 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index c68a4d2c..84a58971 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -133,16 +133,22 @@ def _check_parameters(self, errors: list[str]) -> None: if set(self._loaded_parameters) != set(self._self_shards): errors.append(f"Incorrect loaded shards: {tuple(self._loaded_parameters)}!={tuple(self._self_shards)}") - # Get a local count for each model parameter and shard. counters = [] + # Compare local counts against expected values. for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: counter = self._loaded_parameters[shard_name].pop(parameter_meta.tensor_name, 0) - if self._model.is_parameter_on_device(parameter_name): - if counter == 0: - errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"') - elif counter > 0: - errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') + local_size = ( + fsdp.get_parameter_size_in_shard(parameter_name, shard_name) + if self._model.is_parameter_on_device(parameter_name) + else 0 + ) + if counter != local_size: + errors.append( + f'Local counter mismatch for parameter "{parameter_name}"' + f' and shard "{shard_name}": loaded {counter}, expected {local_size}' + ) + # Accumulate in a list for global counter check. if ( not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0 ) or stage.is_tied_weight_copy: @@ -170,6 +176,7 @@ def _check_parameters(self, errors: list[str]) -> None: parameter_size = parameter_meta.global_shape.numel() if counter != parameter_size: errors.append( - f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}' + f'Global counter mismatch for parameter "{parameter_name}"' + f' and shard "{shard_name}": loaded {counter}, expected {parameter_size}' ) assert not counters diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 2c5ec212..ad944a84 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -303,7 +303,7 @@ def import_state_tensor( """ Assert.eq(shard.shape, (self._shard_size,)) tensor_shard = self.parameter_global_to_shard(tensor, parameter_name) - begin, end = self._parameter_range_in_shard(parameter_name) + begin, end = self._get_parameter_range_in_shard(parameter_name) Assert.eq(tensor_shard.numel(), end - begin) shard[begin:end].copy_(tensor_shard) return end - begin @@ -386,11 +386,17 @@ def reduce_gradients( else: triton_copy(self._grad_buffer_local_shard, self._grad_shard) - def _parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: + def _get_parameter_range_in_shard(self, parameter_name: str) -> tuple[int, int]: begin = self.index_buffer_to_shard(self.get_parameter_begin_in_buffer(parameter_name)) end = self.index_buffer_to_shard(self.get_parameter_end_in_buffer(parameter_name)) return begin, end + def get_parameter_size_in_shard(self, parameter_name: str, shard_name: str = ShardName.weights) -> int: + if not self._requires_grad and shard_name != ShardName.weights: + return 0 + begin, end = self._get_parameter_range_in_shard(parameter_name) + return end - begin + def invalidate_buffer(self) -> None: # Buffer is no longer valid (Updated weights or overwritten by other stage) assert self._mode.support_forward @@ -424,7 +430,7 @@ def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, devic device=device, ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard - begin, end = self._parameter_range_in_shard(parameter_name) + begin, end = self._get_parameter_range_in_shard(parameter_name) self.parameter_global_to_shard(index, parameter_name).copy_( torch.arange(begin, end, dtype=torch.int64, device=device) ) @@ -473,7 +479,6 @@ def copy_shard_overlaps( self_shard_end_in_buffer = (self._fsdp_dim.rank + 1) * self._shard_size self_shard_begin_in_param = self._index_buffer_to_param(self_shard_begin_in_buffer, parameter_name) self_shard_end_in_param = self._index_buffer_to_param(self_shard_end_in_buffer, parameter_name) - self_param_begin_in_shard, _ = self._parameter_range_in_shard(parameter_name) loaded_shard_begin_in_buffer = loaded_fsdp._fsdp_dim.rank * loaded_fsdp._shard_size loaded_shard_end_in_buffer = (loaded_fsdp._fsdp_dim.rank + 1) * loaded_fsdp._shard_size @@ -483,7 +488,6 @@ def copy_shard_overlaps( loaded_shard_end_in_param = loaded_fsdp._index_buffer_to_param( loaded_shard_end_in_buffer, parameter_name ) - loaded_param_begin_in_shard, _ = loaded_fsdp._parameter_range_in_shard(parameter_name) overlap_begin_in_param = max(self_shard_begin_in_param, loaded_shard_begin_in_param) overlap_end_in_param = min(self_shard_end_in_param, loaded_shard_end_in_param) @@ -491,9 +495,16 @@ def copy_shard_overlaps( if (overlap_size := overlap_end_in_param - overlap_begin_in_param) <= 0: continue - overlap_begin_in_self_shard = self_param_begin_in_shard + overlap_begin_in_param - - overlap_begin_in_loaded_shard = loaded_param_begin_in_shard + overlap_begin_in_param + overlap_begin_in_self_shard = ( + self._parameter_begins_in_buffer[parameter_name] + + overlap_begin_in_param + - self_shard_begin_in_buffer + ) + overlap_begin_in_loaded_shard = ( + loaded_fsdp._parameter_begins_in_buffer[parameter_name] + + overlap_begin_in_param + - loaded_shard_begin_in_buffer + ) if shards is None: # Dry run, we only want the counter. diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 3ca28ba5..b8f12de3 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -291,7 +291,7 @@ def get_param_groups( ) grads_norm_slices = [] for name in grad_norm_names: - begin, end = fsdp._parameter_range_in_shard(name) + begin, end = fsdp._get_parameter_range_in_shard(name) if len(grads_norm_slices) < 0 and begin == grads_norm_slices[-1].stop: grads_norm_slices[-1] = slice(grads_norm_slices[-1].start, end) else: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4c225422..04989a72 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -479,6 +479,9 @@ def _update_and_add_testing_config( @pytest.fixture(scope="session", params=_MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: + models = request.config.getoption("--models") + if models and request.param not in models: + pytest.skip(f"Skipping model {request.param}") return _MODEL_CONFIGS[request.param] From 2a08b144da259a42fdca5f239cbfb90984dd1533 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 23 Jun 2025 22:05:00 -0400 Subject: [PATCH 48/69] Parallel tests --- fast_llm/engine/config_utils/tensor_space.py | 5 +- fast_llm/engine/distributed/config.py | 95 +++---- fast_llm/engine/distributed/distributed.py | 171 +++++++----- fast_llm/engine/multi_stage/fsdp.py | 28 +- fast_llm/engine/multi_stage/multi_stage.py | 4 +- tests/models/distributed_test_checkpoint.py | 113 ++++++++ tests/models/test_checkpoint.py | 104 +++----- tests/utils/model_configs.py | 14 +- tests/utils/run_test_script.py | 259 ++++++++++++------- 9 files changed, 494 insertions(+), 299 deletions(-) create mode 100644 tests/models/distributed_test_checkpoint.py diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 5020bc65..49ce1525 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -147,7 +147,10 @@ def add_tensor_dim(self, dim: TensorDim) -> None: else: if dim.parallel_dim is not None: assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name - Assert.eq(dim.parallel_dim, self._distributed_config.distributed_dims[dim.parallel_dim.name]) + Assert.eq( + dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + ) self._tensor_dims[dim.name] = dim def get_tensor_dim(self, name: str) -> TensorDim: diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8e2430d5..7fade749 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -1,3 +1,4 @@ +import dataclasses import enum import logging import os @@ -50,58 +51,33 @@ def is_training(self) -> bool: return self == PhaseType.training +@dataclasses.dataclass class DistributedDim: """ A dataclass to hold all relevant information on a process group without actually creating it. """ - _is_setup: bool = False - _group: "ProcessGroup|None" + _group: "ProcessGroup|None" = dataclasses.field(init=False, repr=False) + name: str + size: int + rank: int + global_ranks: range | tuple[int, ...] = None - def __init__(self, name: str, size: int = 1, rank: int = 0, id_: str | None = None, parent: str | None = None): - self._name = name - self._size = size - self._rank = rank - self._id = id_ - self._parent = parent - - @property - def name(self) -> str: - return self._name - - @property - def size(self) -> int: - return self._size - - @property - def rank(self) -> int: - return self._rank - - @property - def id(self) -> str | None: - return self._id - - @property - def parent(self) -> str | None: - return self._parent + def __post_init__(self): + self._is_setup = False + logger.info(str(self)) @property def group(self) -> "ProcessGroup|None": - assert self._is_setup + assert hasattr(self, "_group") return self._group - def __repr__(self) -> str: - return ( - f"DistributedDim(name={self.name}, size={self.size}, rank={self.rank}, id={self.id}, parent={self.parent})" - ) - def setup(self, group: "ProcessGroup|None"): - assert not self._is_setup - self._is_setup = True + assert not hasattr(self, "_group") Assert.eq(group is None, self.size == 1) if group is not None: - Assert.eq(group.size(), self._size) - Assert.eq(group.rank(), self._rank) + Assert.eq(group.size(), self.size) + Assert.eq(group.rank(), self.rank) self._group = group @@ -296,9 +272,15 @@ def _validate(self) -> None: else: self.distributed_dims = {} + data_stride = self.tensor_parallel * (1 if self.pipeline_first else self.pipeline_parallel) + pipeline_stride = self.tensor_parallel * (self.data_parallel if self.pipeline_first else 1) + self._add_distributed_dim( DistributedDim( - name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None + name=DistributedDimNames.world, + size=self.world_size, + rank=self.rank, + global_ranks=range(self.world_size), ) ) self._add_distributed_dim( @@ -306,8 +288,7 @@ def _validate(self) -> None: name=DistributedDimNames.data, size=self.data_parallel, rank=self.data_rank, - id_=f"x_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + global_ranks=self._get_global_ranks(self.data_parallel, data_stride), ) ) self._add_distributed_dim( @@ -315,8 +296,7 @@ def _validate(self) -> None: name=DistributedDimNames.pipeline, size=self.pipeline_parallel, rank=self.pipeline_rank, - id_=f"x_{self.data_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride), ) ) self._add_distributed_dim( @@ -324,8 +304,7 @@ def _validate(self) -> None: name=DistributedDimNames.tensor, size=self.tensor_parallel, rank=self.tensor_rank, - id_=f"x_{self.data_rank}_{self.pipeline_rank}", - parent=DistributedDimNames.world, + global_ranks=self._get_global_ranks(self.tensor_parallel, 1), ) ) self._add_distributed_dim( @@ -333,8 +312,7 @@ def _validate(self) -> None: name=DistributedDimNames.sequence_data, size=self.sequence_data_parallel, rank=self.sequence_data_rank, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride), ) ) self._add_distributed_dim( @@ -342,8 +320,9 @@ def _validate(self) -> None: name=DistributedDimNames.batch_data, size=self.batch_data_parallel, rank=self.batch_data_rank, - id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + global_ranks=self._get_global_ranks( + self.batch_data_parallel, data_stride * self.sequence_data_parallel + ), ) ) self._add_distributed_dim( @@ -351,16 +330,7 @@ def _validate(self) -> None: name=DistributedDimNames.tensor_and_sequence_data, size=self.sequence_data_parallel * self.tensor_parallel, rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}", - parent=( - DistributedDimNames.tensor - if self.sequence_data_parallel == 1 - else ( - DistributedDimNames.sequence_data - if self.tensor_parallel == 1 - else DistributedDimNames.world - ) - ), + global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1), ) ) @@ -371,12 +341,15 @@ def _validate(self) -> None: Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) + def _get_global_ranks(self, size: int, stride: int) -> range: + start = self.rank // (size * stride) * size * stride + self.rank % stride + return range(start, start + size * stride, stride) + def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank) if distributed_dim.name in self.distributed_dims: Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) else: - if distributed_dim.parent is not None: - assert distributed_dim.parent in self.distributed_dims self.distributed_dims[distributed_dim.name] = distributed_dim def get_distributed_dim(self, name: str) -> DistributedDim: diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 42ec97f2..54f43b85 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -19,6 +19,90 @@ logger = logging.getLogger(__name__) +class ProcessGroupPool: + def __init__(self, rank: int | None = None, world_size: int | None = None, timeout: float = 60): + + self._rank = DistributedConfig.default_rank if rank is None else rank + self._world_size = DistributedConfig.default_world_size if world_size is None else world_size + self._timeout = timeout + + if self._world_size > 1: + if rank == 0: + logger.info("Initializing TCP store.") + # We bypass `torch.distributed.init_process_group` which makes things way more complicated for no reason. + # TODO: Allow other init methods? + self.store, _, _ = next( + torch.distributed.rendezvous( + "env://", + self._rank, + self._world_size, + timeout=datetime.timedelta(seconds=timeout), + ) + ) + self._process_groups = {} + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def get_process_group(self, global_ranks: range | tuple, rank: int) -> ProcessGroup | None: + """ + Get the requested process group from the pool, or create it if it doesn't exist. + """ + group_size = len(global_ranks) + logger.info(f"WIOUGHNIOUW {global_ranks}, {group_size}, {self._rank}") + Assert.eq(global_ranks[rank], self._rank) + if group_size == 1: + return None + + for group_ranks, group in self._process_groups.items(): + # Check if an equivalent group already exists. + if type(group_ranks) != type(global_ranks): + if group_ranks == global_ranks: + return group + elif tuple(group_ranks) == tuple(global_ranks): + return group + + prefix = ( + f"range_{global_ranks.start}_{global_ranks.start}_{global_ranks.step}" + if isinstance(global_ranks, range) + else f"ranks_{"_".join(str(rank) for rank in global_ranks)}" + ) + + group = torch.distributed.ProcessGroupNCCL( + torch.distributed.PrefixStore(prefix + "/", self.store), + global_ranks.index(rank), + group_size, + datetime.timedelta(seconds=self._timeout), + ) + self._process_groups[global_ranks] = group + return group + + def __enter__(self): + global _default_pool + assert _default_pool is None + _default_pool = self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _default_pool + assert _default_pool is self + _default_pool = None + + def __del__(self): + # Shutdown the process group backend explicitly to prevent a nccl warning. + # We can't call `destroy_process_group` directly because pytorch doesn't know about it. + for group in self._process_groups.values(): + if group is not None and hasattr(group, "_shutdown"): + group._shutdown() # noqa + + +_default_pool: ProcessGroupPool | None = None + + class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): """ A distributed instance holding pointers to the various process groups. @@ -31,7 +115,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False): + def __init__(self, config: DistributedConfig, use_cpu: bool = False, pool: ProcessGroupPool | None = None): super().__init__(config) assert self._config.reference_config is None self._use_cpu = use_cpu @@ -45,32 +129,24 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.device = torch.device(self._config.local_rank) torch.cuda.set_device(self.device) - # We bypass `torch.distributed.init_process_group` which makes things way more complicated for no reason. - - # TODO: Allow other init methods? - # TODO: Allow different timeout for the store? - if self._config.world_size > 1: - self._config.log_first_rank("Initializing TCP store.") - self.store, _, _ = next( - torch.distributed.rendezvous( - "env://", - self._config.rank, - self._config.world_size, - timeout=datetime.timedelta(seconds=self._config.timeout), - ) - ) - self._process_groups = {} - for name, distributed_dim in self._config.distributed_dims.items(): - Assert.eq(distributed_dim.name, name) - self.add_group(distributed_dim) - - self.world_group = self._process_groups[DistributedDimNames.world] - self.data_group = self._process_groups[DistributedDimNames.data] - self.pipeline_group = self._process_groups[DistributedDimNames.pipeline] - self.tensor_group = self._process_groups[DistributedDimNames.tensor] - self.sequence_data_group = self._process_groups[DistributedDimNames.sequence_data] - self.batch_data_group = self._process_groups[DistributedDimNames.batch_data] - self.tensor_and_sequence_data_group = self._process_groups[DistributedDimNames.tensor_and_sequence_data] + if pool is None and _default_pool is None: + self._pool = ProcessGroupPool(self._config.rank, self._config.world_size, self._config.timeout) + else: + if pool is None: + pool = _default_pool + Assert.eq(pool._world_size, self._config.world_size) + Assert.eq(pool._rank, self._config.rank) + self._pool = pool + + self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) + self.data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.data]) + self.pipeline_group = self.add_group(self._config.distributed_dims[DistributedDimNames.pipeline]) + self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) + self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) + self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) + self.tensor_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] + ) self._config.log_first_rank(f"Setting random seeds...") @@ -114,38 +190,10 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None: """ Add a process group from its definition. - The group name (`dim`) must be unique within a distributed instance, - - Note: the group id disambiguate between the different groups with the same name on the cluster. - (Ex.: there is one data-parallel group for each model-parallel rank.) - There should be exactly one device for each name, group_id and rank. - TODO: Make private, create all groups through distributed dims in config. """ - Assert.not_incl(distributed_dim.name, self._process_groups) - prefix = distributed_dim.name if distributed_dim.id is None else f"{distributed_dim.name}_{distributed_dim.id}" - - if distributed_dim.parent is None: - parent = None - else: - Assert.incl(distributed_dim.parent, self._process_groups) - parent = self._process_groups[distributed_dim.parent] - if distributed_dim.size == 1: - group = None - elif parent and distributed_dim.size == parent.size(): - Assert.eq(distributed_dim.rank, parent.rank()) - group = parent - else: - if parent: - Assert.lt(distributed_dim.size, parent.size()) - Assert.leq(distributed_dim.rank, parent.rank()) - self._config.log_first_rank(f"Initializing group {distributed_dim.name}, size={distributed_dim.size}...") - group = torch.distributed.ProcessGroupNCCL( - torch.distributed.PrefixStore(prefix + "/", self.store), - distributed_dim.rank, - distributed_dim.size, - datetime.timedelta(seconds=self._config.timeout), - ) - self._process_groups[distributed_dim.name] = group + self._config.log_first_rank(f"Initializing group {distributed_dim.name}, size={distributed_dim.size}...") + logger.info(f"INIT {distributed_dim}") + group = self._pool.get_process_group(distributed_dim.global_ranks, distributed_dim.rank) distributed_dim.setup(group) return group @@ -164,10 +212,3 @@ def set_step(self, step: int, phase: PhaseType) -> None: seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase] self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED) self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) - - def __del__(self): - # Shutdown the process group backend explicitly to prevent a nccl warning. - # We can't call `destroy_process_group` directly because pytorch doesn't know about it. - for group in self._process_groups.values(): - if group is not None and hasattr(group, "_shutdown"): - group._shutdown() # noqa diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index ad944a84..f991b68f 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -457,22 +457,26 @@ def copy_shard_overlaps( if self_meta.is_tensor_parallel: self_tp = self_meta.tensor_parallel_dim.size - loaded_tp = loaded_meta.tensor_parallel_dim.size self_rank = self_meta.tensor_parallel_dim.rank + else: + self_tp, self_rank = 1, 0 + if loaded_meta.is_tensor_parallel: + loaded_tp = loaded_meta.tensor_parallel_dim.size loaded_rank = loaded_meta.tensor_parallel_dim.rank - # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. - shared_tp = math.gcd(self_tp, loaded_tp) + else: + loaded_tp, loaded_rank = 1, 0 - self_tp //= shared_tp - loaded_tp //= shared_tp + # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. + shared_tp = math.gcd(self_tp, loaded_tp) - if self_rank // self_tp != loaded_rank // loaded_tp: - # Disjoint shared rank, no possible overlap. - continue - self_rank %= self_tp - loaded_rank %= loaded_tp - else: - self_tp, loaded_tp, self_rank, loaded_rank = 1, 1, 0, 0 + self_tp //= shared_tp + loaded_tp //= shared_tp + + if self_rank // self_tp != loaded_rank // loaded_tp: + # Disjoint shared rank, no possible overlap. + continue + self_rank %= self_tp + loaded_rank %= loaded_tp if self_tp == loaded_tp == 1: self_shard_begin_in_buffer = self._fsdp_dim.rank * self._shard_size diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 71e22ed9..1f734268 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -583,13 +583,13 @@ def setup(self, distributed: Distributed) -> None: # Setup the tied parameter process groups if len(self.all_ranks) > 1 and self.on_device: # TODO: Create a group def first? + pipeline_ranks = distributed.config.get_distributed_dim(DistributedDimNames.pipeline).global_ranks self.group = distributed.add_group( DistributedDim( name=self.name + "_tied_weight", size=len(self.all_ranks), rank=sorted(self.all_ranks).index(distributed.config.pipeline_rank), - id_=f"{distributed.config.data_rank}_x_{distributed.config.tensor_rank}", - parent=DistributedDimNames.pipeline, + global_ranks=tuple(pipeline_ranks[rank] for rank in self.all_ranks), ) ) else: diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py new file mode 100644 index 00000000..51045478 --- /dev/null +++ b/tests/models/distributed_test_checkpoint.py @@ -0,0 +1,113 @@ +import argparse +import pathlib + +from fast_llm.engine.checkpoint.config import DistributedCheckpointFormat, FastLLMCheckpointFormat +from fast_llm.engine.distributed.distributed import ProcessGroupPool +from tests.models.test_checkpoint import get_convert_paths +from tests.utils.model_configs import MODEL_CONFIGS +from tests.utils.run_test_script import do_run_test_script_for_all_models + + +def parse_args(args: list[str] | None = None): + parser = argparse.ArgumentParser() + parser.add_argument("rendezvous_port", type=int) + parser.add_argument("torchrun_port", type=int) + parser.add_argument("base_path", type=pathlib.Path) + parser.add_argument("model_testing_config", type=str) + parsed = parser.parse_args(args) + return parsed.rendezvous_port, parsed.torchrun_port, parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config] + + +def _test_load_and_save_parallel(fixture_args, test_name, distributed_args, pretrained_path, pretrained_format): + # TODO: Just save and load the model instead, no need for an actual run. + do_run_test_script_for_all_models( + [ + # First we load a checkpoint. + f"pretrained.path={pretrained_path}", + f"pretrained.format={pretrained_format}", + # We run for one mock iteration. + "training.train_iters=1", + "schedule.skip_step=True", + # Then we save a checkpoint (distributed format) and an export (fast_llm format). + "training.checkpoint.interval=1", + "training.export.interval=1", + "training.export.format=fast_llm", + ] + + distributed_args, + test_name=test_name, + **fixture_args, + ) + + +def main(args: list[str] | None = None) -> None: + rendezvous_port, torchrun_port, base_path, model_testing_config = parse_args(args) + convert_paths = get_convert_paths(base_path) + + fixture_args = { + "rendezvous_port": rendezvous_port, + "torchrun_port": torchrun_port, + "base_path": base_path, + "model_testing_config": model_testing_config, + "num_gpus": 2, + } + + with ProcessGroupPool(timeout=20): + for pretrained_format, pretrained_path in ( + (DistributedCheckpointFormat.name, convert_paths["distributed_0"]), + (FastLLMCheckpointFormat.name, convert_paths["fast_llm_0"]), + (model_testing_config.checkpoint_format.name, convert_paths["huggingface_0"]), + ): + _test_load_and_save_parallel( + fixture_args, + test_name=f"test_load_pretrained_{pretrained_format}_in_dp2", + distributed_args=[], + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_{pretrained_format}_in_tp2", + # distributed_args=["model.distributed.tensor_parallel=2"], + # pretrained_path=pretrained_path, + # pretrained_format=pretrained_format, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_{pretrained_format}_in_stp2", + # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], + # pretrained_path=pretrained_path, + # pretrained_format=pretrained_format, + # ) + + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_dp2_in_tp2", + # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_stp2_in_dp2", + # distributed_args=[], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_tp2_in_stp2", + # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_stp2_in_tp2", + # distributed_args=["model.distributed.tensor_parallel=2"], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_tp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index aff7d991..5b392061 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -95,19 +95,23 @@ def _run_conversion(config: ConvertConfig): config.run() -@pytest.fixture(scope="module") -def convert_paths(run_test_script_base_path): +def get_convert_paths(base_path: pathlib.Path) -> dict[str, pathlib.Path]: return { - "checkpoint": run_test_script_base_path / "test_checkpoint_and_eval" / "checkpoint" / "2", - "distributed_0": run_test_script_base_path / "test_convert_model" / "distributed_0", - "distributed_1": run_test_script_base_path / "test_convert_model" / "distributed_1", - "fast_llm_0": run_test_script_base_path / "test_convert_model" / "fast_llm_0", - "fast_llm_1": run_test_script_base_path / "test_convert_model" / "fast_llm_1", - "huggingface_0": run_test_script_base_path / "test_convert_model" / "huggingface_0", - "huggingface_1": run_test_script_base_path / "test_convert_model" / "huggingface_1", + "checkpoint": base_path / "test_checkpoint_and_eval" / "checkpoint" / "2", + "distributed_0": base_path / "test_convert_model" / "distributed_0", + "distributed_1": base_path / "test_convert_model" / "distributed_1", + "fast_llm_0": base_path / "test_convert_model" / "fast_llm_0", + "fast_llm_1": base_path / "test_convert_model" / "fast_llm_1", + "huggingface_0": base_path / "test_convert_model" / "huggingface_0", + "huggingface_1": base_path / "test_convert_model" / "huggingface_1", } +@pytest.fixture(scope="module") +def convert_paths(run_test_script_base_path: pathlib.Path) -> dict[str, pathlib.Path]: + return get_convert_paths(run_test_script_base_path) + + @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): @@ -473,8 +477,8 @@ def test_run_converted_model(model_testing_config, convert_paths): @pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, convert_paths): +@pytest.mark.model_testing_group(ModelTestingGroup.convert) +def test_load_pretrained_distributed(run_test_script_for_all_models, convert_paths): run_test_script_for_all_models( [ "training.checkpoint.interval=1", @@ -483,27 +487,27 @@ def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, conv f"pretrained.format={DistributedCheckpointFormat.name}", "schedule.skip_step=True", ], - num_gpus=2, ) -@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, convert_paths): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["distributed_0"]}", - f"pretrained.format={DistributedCheckpointFormat.name}", - "schedule.skip_step=True", - ], - ) +@pytest.mark.depends_on( + on=[ + "test_load_converted_distributed_checkpoint[{model_testing_config}]", + "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", + "test_load_converted_huggingface_checkpoint[{model_testing_config}]", + ] +) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_save_and_load_in_parallel(run_distributed_script_for_all_models): + import tests.models.distributed_test_checkpoint + + run_distributed_script_for_all_models([tests.models.distributed_test_checkpoint.__file__], 2) -@pytest.mark.depends_on(on=["test_load_pretrained_distributed_in_dp2[{model_testing_config}]"]) +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_paths, run_test_script_base_path): +def test_load_pretrained_distributed_in_dp2(model_testing_config, convert_paths, run_test_script_base_path): + # Vertify the content of the saved checkpoint. test_ckpt_path = run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" pretrained_config_ref = CheckpointLoadConfig( path=convert_paths["checkpoint"], @@ -545,24 +549,14 @@ def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_p assert (stage_shard_test[: stage_shard_ref.numel()] == stage_shard_ref).all() assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa - -@pytest.mark.depends_on(on=["test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, run_test_script_base_path): - # This also tests conversion which uses `FastLLMModel.from_checkpoint` - pretrained_config_ref = CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.fast_llm, - ) + # Test loading in single GPU. pretrained_config_test = CheckpointLoadConfig( - path=run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", + path=test_ckpt_path, format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) - config = model_testing_config.model_config_class.from_pretrained(pretrained_config_ref) model = model_testing_config.model_class.from_pretrained(pretrained_config_test, mode=StageMode.weights) - _compare_model_configs(config, model.config) + _compare_model_configs(config_ref, model.config) weight_shard = safetensors.torch.load_file( convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) )[_WEIGHT_SHARD_SAVE_NAME] @@ -571,22 +565,13 @@ def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, ru @pytest.mark.depends_on( on=[ - "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", - "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", + "test_save_and_load_in_parallel[{model_testing_config}]", + "test_load_pretrained_distributed_in_dp2[{model_testing_config}]", ] ) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_fast_llm_in_dp2(run_test_script_for_all_models, convert_paths, run_test_script_base_path): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["fast_llm_0"]}", - f"pretrained.format=fast_llm", - "schedule.skip_step=True", - ], - num_gpus=2, - ) +def test_load_pretrained_fast_llm_in_dp2(model_testing_config, convert_paths, run_test_script_base_path): + # Compare the checkpoint from test_load_pretrained_fast_llm_in_dp2 and test_load_pretrained_distributed_in_dp2 for rank in range(2): ref_shard = safetensors.torch.load_file( run_test_script_base_path @@ -609,23 +594,14 @@ def test_load_pretrained_fast_llm_in_dp2(run_test_script_for_all_models, convert @pytest.mark.depends_on( on=[ "test_load_converted_huggingface_checkpoint[{model_testing_config}]", - "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", + "test_load_pretrained_distributed_in_dp2[{model_testing_config}]", ] ) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_pretrained_huggingface_in_dp2( run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths ): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["huggingface_0"]}", - f"pretrained.format={model_testing_config.checkpoint_format.name}", - "schedule.skip_step=True", - ], - num_gpus=2, - ) + # Compare the checkpoint from test_load_pretrained_huggingface_in_dp2 and test_load_pretrained_distributed_in_dp2 for rank in range(2): ref_shard = safetensors.torch.load_file( run_test_script_base_path @@ -636,7 +612,7 @@ def test_load_pretrained_huggingface_in_dp2( ) test_shard = safetensors.torch.load_file( run_test_script_base_path - / f"test_load_pretrained_huggingface_in_dp2" + / f"test_load_pretrained_{model_testing_config.checkpoint_format.name}_in_dp2" / "checkpoint" / "1" / f"rank_{rank}.safetensors" diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 04989a72..bee43de1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -83,7 +83,7 @@ def _update_and_add_testing_config( checkpoint_format: CheckpointFormat | None = ..., groups: dict[ModelTestingGroup, ModelTestingGroupAction], ): - config = _MODEL_CONFIGS[old_name] + config = MODEL_CONFIGS[old_name] updates: dict[str, typing.Any] = { "name": new_name, "groups": groups, @@ -102,13 +102,13 @@ def _update_and_add_testing_config( if checkpoint_format is not ...: updates["checkpoint_format"] = checkpoint_format - _MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) + MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) -_MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} +MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} -_MODEL_CONFIGS["gpt2"] = ModelTestingConfig( +MODEL_CONFIGS["gpt2"] = ModelTestingConfig( # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). name="gpt2", model_type="gpt", @@ -477,12 +477,12 @@ def _update_and_add_testing_config( ) -@pytest.fixture(scope="session", params=_MODEL_CONFIGS.keys()) +@pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models") if models and request.param not in models: pytest.skip(f"Skipping model {request.param}") - return _MODEL_CONFIGS[request.param] + return MODEL_CONFIGS[request.param] def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slow: bool, show_skipped: bool) -> bool: @@ -490,7 +490,7 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo assert "model_testing_config" in item.callspec.params, item.nodeid groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args model_testing_config = item.callspec.params["model_testing_config"] - model_config: ModelTestingConfig = _MODEL_CONFIGS[model_testing_config] + model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] for group in groups: action = model_config.groups[group] if action == ModelTestingGroupAction.main: diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 69ed817a..3c8f961b 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -1,15 +1,24 @@ +import argparse +import functools import os import pathlib import shutil import subprocess import sys +import typing import pytest import torch from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig, compare_tensor_logs from tests.utils.dataset import get_test_dataset +from tests.utils.model_configs import MODEL_CONFIGS, ModelTestingConfig + +if typing.TYPE_CHECKING: + from tests.conftest import WorkerResources # FIXME: figure out correct import of megatron modules without this hack sys.path.append(os.getcwd()) @@ -17,71 +26,125 @@ _ARTIFACT_PATH = "runs/0/artifacts" -@pytest.fixture(scope="session") -def run_test_script(worker_resources): - def do_run_test_script( - path: pathlib.Path, - args: list[str], - num_gpus: int = 1, - *, - model_type: str, - is_megatron: bool = False, - compare_path: pathlib.Path | None = None, - config: CompareConfig | None = None, - prepare_fn=None, - compare_fn=None, - do_compare: bool = True, - ): - if torch.cuda.device_count() < num_gpus: - pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") - env = os.environ.copy() - if is_megatron: - # Prevent Megatron from complaining. - env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env["NVTE_FLASH_ATTN"] = "0" - skip = False - if path.exists(): - assert path.is_dir() - # TODO: Better way to check if the previous attempt succeeded. - shutil.rmtree(path) - if prepare_fn is not None: - skip = prepare_fn(path, None if compare_path is None else compare_path, skip) - if is_megatron: - args = [*args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] - else: - args = ["train", model_type, *args, f"run.experiment_dir={path}"] - header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] - command = [ - "python", - "-m", - "torch.distributed.run", - f"--nproc-per-node={num_gpus}", - f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}", - f"--master-port={worker_resources.torchrun_port}", - *header, - *args, - ] - print(" ".join(command)) - if skip: - print("Reusing existing run.") +def do_run_distributed_script( + args: list[str], + rendezvous_port: int, + torchrun_port: int, + num_gpus: int, + timeout: float = 120, +): + command = [ + "python", + "-m", + "torch.distributed.run", + f"--nproc-per-node={num_gpus}", + f"--rdzv-endpoint=localhost:{rendezvous_port}", + f"--master-port={torchrun_port}", + *args, + ] + print(" ".join(command)) + completed_proc = subprocess.run(command, timeout=timeout) + if completed_proc.returncode: + raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") + + +def do_run_test_script( + path: pathlib.Path, + args: list[str], + num_gpus: int = 1, + *, + model_type: str, + is_megatron: bool = False, + compare_path: pathlib.Path | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + rendezvous_port: int, + torchrun_port: int, +): + is_parallel = DistributedConfig.default_world_size > 1 + if is_parallel: + Assert.eq(num_gpus, DistributedConfig.default_world_size) + local_rank = DistributedConfig.default_rank + + if torch.cuda.device_count() < num_gpus: + pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") + env = os.environ.copy() + if is_megatron: + assert num_gpus == 1 + # Prevent Megatron from complaining. + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["NVTE_FLASH_ATTN"] = "0" + skip = False + if local_rank == 0 and path.exists(): + assert path.is_dir() + # TODO: Better way to check if the previous attempt succeeded. + shutil.rmtree(path) + if local_rank == 0 and prepare_fn is not None: + skip = prepare_fn(path, None if compare_path is None else compare_path, skip) + if is_megatron: + args = ["Megatron-LM/pretrain_gpt.py", *args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] + else: + args = ["--no-python", "fast-llm", "train", model_type, *args, f"run.experiment_dir={path}"] + if skip: + print("Reusing existing run.") + else: + get_test_dataset() + if (num_gpus == 1 or is_parallel) and not is_megatron: + print(" ".join(args[1:])) + RunnableConfig.parse_and_run(args[2:]) else: - get_test_dataset() - if num_gpus == 1 and not is_megatron: - RunnableConfig.parse_and_run(args) - else: - completed_proc = subprocess.run(command, env=env, timeout=120) - if completed_proc.returncode: - raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare_path is not None and do_compare: - if compare_fn is not None: - compare_fn(path, compare_path) - compare_tensor_logs( - compare_path / _ARTIFACT_PATH, - path / _ARTIFACT_PATH, - config, - ) - - return do_run_test_script + do_run_distributed_script(args, rendezvous_port, torchrun_port, num_gpus) + if local_rank == 0 and compare_path is not None and do_compare: + if compare_fn is not None: + compare_fn(path, compare_path) + compare_tensor_logs( + compare_path / _ARTIFACT_PATH, + path / _ARTIFACT_PATH, + config, + ) + + +def do_run_test_script_for_all_models( + extra_args: list[str], + num_gpus: int = 1, + *, + is_megatron: bool = False, + compare: str | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + rendezvous_port: int, + torchrun_port: int, + test_name: str, + base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, +): + do_run_test_script( + base_path / test_name, + (model_testing_config.megatron_args if is_megatron else model_testing_config.config_args) + extra_args, + num_gpus, + model_type=model_testing_config.model_type, + is_megatron=is_megatron, + compare_path=None if compare is None else base_path / compare, + config=config, + prepare_fn=prepare_fn, + compare_fn=compare_fn, + do_compare=do_compare, + rendezvous_port=rendezvous_port, + torchrun_port=torchrun_port, + ) + + +@pytest.fixture(scope="session") +def run_test_script(worker_resources: "WorkerResources"): + return functools.partial( + do_run_test_script, + rendezvous_port=worker_resources.rendezvous_port, + torchrun_port=worker_resources.torchrun_port, + ) @pytest.fixture(scope="session") @@ -90,29 +153,51 @@ def run_test_script_base_path(model_testing_config, result_path, request): @pytest.fixture(scope="function") -def run_test_script_for_all_models(run_test_script, run_test_script_base_path, model_testing_config, request): - def do_run_test_script_for_all_models( - extra_args: list[str], - num_gpus: int = 1, - *, - is_megatron: bool = False, - compare: str | None = None, - config: CompareConfig | None = None, - prepare_fn=None, - compare_fn=None, - do_compare: bool = True, - ): - run_test_script( - run_test_script_base_path / request.node.originalname, - (model_testing_config.megatron_args if is_megatron else model_testing_config.config_args) + extra_args, +def run_test_script_for_all_models( + worker_resources: "WorkerResources", + run_test_script_base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, + request: pytest.FixtureRequest, +): + return functools.partial( + do_run_test_script_for_all_models, + rendezvous_port=worker_resources.rendezvous_port, + torchrun_port=worker_resources.torchrun_port, + test_name=request.node.originalname, + base_path=run_test_script_base_path, + model_testing_config=model_testing_config, + ) + + +def parse_run_distributed_script(args: list[str] | None = None): + parser = argparse.ArgumentParser() + parser.add_argument("rendezvous_port", type=int) + parser.add_argument("torchrun_port", type=int) + parser.add_argument("base_path", type=pathlib.Path) + parser.add_argument("model_testing_config", type=str) + parsed = parser.parse_args(args) + return parsed.rendezvous_port, parsed.torchrun_port, parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config] + + +@pytest.fixture(scope="session") +def run_distributed_script_for_all_models( + worker_resources: "WorkerResources", + run_test_script_base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, + request: pytest.FixtureRequest, +): + def do_run_distributed_script_for_all_models(args: list[str], num_gpus=2): + do_run_distributed_script( + args + + [ + str(worker_resources.rendezvous_port), + str(worker_resources.torchrun_port), + str(run_test_script_base_path), + model_testing_config.name, + ], + worker_resources.rendezvous_port, + worker_resources.torchrun_port, num_gpus, - model_type=model_testing_config.model_type, - is_megatron=is_megatron, - compare_path=None if compare is None else run_test_script_base_path / compare, - config=config, - prepare_fn=prepare_fn, - compare_fn=compare_fn, - do_compare=do_compare, ) - return do_run_test_script_for_all_models + return do_run_distributed_script_for_all_models From ad6a482b62ab4232e9b3b53a0051c72031902ebc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 25 Jun 2025 20:57:41 -0400 Subject: [PATCH 49/69] misc --- fast_llm/cli.py | 13 +- fast_llm/engine/checkpoint/distributed.py | 3 +- fast_llm/engine/checkpoint/state_dict.py | 1 + fast_llm/engine/config_utils/logging.py | 5 +- fast_llm/engine/distributed/config.py | 1 - fast_llm/engine/distributed/distributed.py | 2 - fast_llm/utils.py | 1 + tests/conftest.py | 9 +- tests/models/distributed_test_checkpoint.py | 139 ++-- tests/models/test_checkpoint.py | 664 +++++++------------- tests/utils/dataset.py | 4 +- tests/utils/model_configs.py | 20 +- tests/utils/run_test_script.py | 31 +- tests/utils/utils.py | 5 +- 14 files changed, 374 insertions(+), 524 deletions(-) diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 34546120..66ce096d 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -1,3 +1,4 @@ +import contextlib import logging import sys import traceback @@ -15,12 +16,12 @@ logger = logging.getLogger(__name__) -def fast_llm_main(args: list[str] | None = None): - # TODO: Add hook to register model classes? (environment variable?) +@contextlib.contextmanager +def fast_llm_main_wrapper(): # (Pre-)configure logging configure_logging() try: - RunnableConfig.parse_and_run(args) + yield except Exception as e: if sys.gettrace(): raise @@ -31,5 +32,11 @@ def fast_llm_main(args: list[str] | None = None): sys.exit(1) +def fast_llm_main(args: list[str] | None = None): + # TODO: Add hook to register model classes? (environment variable?) + with fast_llm_main_wrapper(): + RunnableConfig.parse_and_run(args) + + if __name__ == "__main__": fast_llm_main() diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 04a9461f..903c8840 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -30,8 +30,8 @@ class DistributedCheckpointHandler(CheckpointHandler): @classmethod def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): - config.path.mkdir(parents=True, exist_ok=True) serialized_metadata = metadata.to_dict() + config.path.mkdir(parents=True, exist_ok=True) yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) @classmethod @@ -40,6 +40,7 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = metadata.to_dict() + config.path.mkdir(parents=True, exist_ok=True) if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 556e97be..7a257a5f 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -216,6 +216,7 @@ def _save_next_file(self) -> None: file_name = f"{self.base_file_name}_{self._file_count}.safetensors" if self._do_save: logger.info(f"Saving tensors to {self._config.path / file_name}") + self._config.path.mkdir(parents=True, exist_ok=True) safetensors.torch.save_file( tensors=self._tensors, filename=self._config.path / file_name, diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index ea014bd0..358674a9 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -5,6 +5,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -14,8 +15,8 @@ def configure_logging( *, log_timestamps: bool = True, enable_all_loggers: bool = False, - rank: int = 0, - world_size: int = 1, + rank: int = DistributedConfig.default_rank, + world_size: int = DistributedConfig.default_world_size, directory: pathlib.Path | str | None = None, ) -> None: rank_str = str(rank).zfill(math.ceil(math.log10(world_size))) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 7fade749..5ef7b590 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -65,7 +65,6 @@ class DistributedDim: def __post_init__(self): self._is_setup = False - logger.info(str(self)) @property def group(self) -> "ProcessGroup|None": diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 54f43b85..dfc2dd60 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -54,7 +54,6 @@ def get_process_group(self, global_ranks: range | tuple, rank: int) -> ProcessGr Get the requested process group from the pool, or create it if it doesn't exist. """ group_size = len(global_ranks) - logger.info(f"WIOUGHNIOUW {global_ranks}, {group_size}, {self._rank}") Assert.eq(global_ranks[rank], self._rank) if group_size == 1: return None @@ -192,7 +191,6 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None: Add a process group from its definition. """ self._config.log_first_rank(f"Initializing group {distributed_dim.name}, size={distributed_dim.size}...") - logger.info(f"INIT {distributed_dim}") group = self._pool.get_process_group(distributed_dim.global_ranks, distributed_dim.rank) distributed_dim.setup(group) return group diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 8c194938..7bbdd697 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -165,6 +165,7 @@ def all_equal(x, y): x = torch.as_tensor(x) y = torch.as_tensor(y) + Assert.eq(x.shape, y.shape) neq = x != y if neq.any().item(): # noqa index = None if x.numel() == 1 else torch.where(neq) # noqa diff --git a/tests/conftest.py b/tests/conftest.py index 11757176..6e9f830e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import logging import math import os +import shutil import pytest import torch @@ -14,13 +15,14 @@ # Make fixtures available globally without import from tests.utils.run_test_script import ( # isort: skip + run_distributed_script_for_all_models, run_test_script, run_test_script_base_path, run_test_script_for_all_models, ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path # isort: skip +from tests.utils.utils import result_path, TEST_RESULTS_PATH # isort: skip manager: DependencyManager | None = None @@ -76,6 +78,11 @@ def pytest_configure(config): else: worker_id = 0 + # TODO: Remove the whole `TEST_RESULTS_PATH` once `get_test_dataset` is parallel-safe. + model_result_path = TEST_RESULTS_PATH / "models" + if model_result_path.exists(): + shutil.rmtree(model_result_path) + num_gpus = torch.cuda.device_count() if num_gpus > 0 and is_parallel: # We spread workers across GPUs. diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 51045478..d27b66b7 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -1,71 +1,113 @@ -import argparse +import gc import pathlib +import typing -from fast_llm.engine.checkpoint.config import DistributedCheckpointFormat, FastLLMCheckpointFormat +import torch + +from fast_llm.cli import fast_llm_main_wrapper +from fast_llm.engine.checkpoint.config import ( + CheckpointFormat, + CheckpointLoadConfig, + CheckpointSaveConfig, + DistributedCheckpointFormat, + FastLLMCheckpointFormat, +) from fast_llm.engine.distributed.distributed import ProcessGroupPool -from tests.models.test_checkpoint import get_convert_paths -from tests.utils.model_configs import MODEL_CONFIGS -from tests.utils.run_test_script import do_run_test_script_for_all_models +from fast_llm.engine.multi_stage.config import StageMode +from tests.models.test_checkpoint import do_get_convert_path +from tests.utils.model_configs import ModelTestingConfig +from tests.utils.run_test_script import parse_run_distributed_script -def parse_args(args: list[str] | None = None): - parser = argparse.ArgumentParser() - parser.add_argument("rendezvous_port", type=int) - parser.add_argument("torchrun_port", type=int) - parser.add_argument("base_path", type=pathlib.Path) - parser.add_argument("model_testing_config", type=str) - parsed = parser.parse_args(args) - return parsed.rendezvous_port, parsed.torchrun_port, parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config] +def _test_load_and_save_parallel( + model_testing_config: ModelTestingConfig, + pretrained_path: pathlib.Path, + pretrained_format: CheckpointFormat, + distributed_config: dict[str, typing.Any], + save_path: pathlib.Path, +): + model = model_testing_config.model_class.from_pretrained( + CheckpointLoadConfig(path=pretrained_path, format=pretrained_format), + # The world size and rank are already set through environment variable. + {"distributed": distributed_config}, + mode=StageMode.inference, + ) + for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + model.save_checkpoint(CheckpointSaveConfig(path=save_path / save_format.name, format=save_format)) + del model + gc.collect() + torch.cuda.empty_cache() -def _test_load_and_save_parallel(fixture_args, test_name, distributed_args, pretrained_path, pretrained_format): - # TODO: Just save and load the model instead, no need for an actual run. - do_run_test_script_for_all_models( - [ - # First we load a checkpoint. - f"pretrained.path={pretrained_path}", - f"pretrained.format={pretrained_format}", - # We run for one mock iteration. - "training.train_iters=1", - "schedule.skip_step=True", - # Then we save a checkpoint (distributed format) and an export (fast_llm format). - "training.checkpoint.interval=1", - "training.export.interval=1", - "training.export.format=fast_llm", - ] - + distributed_args, - test_name=test_name, - **fixture_args, - ) +# def _test_load_and_save_parallel(fixture_args, test_name, distributed_args, pretrained_path, pretrained_format): +# # TODO: Just save and load the model instead, no need for an actual run. +# do_run_test_script_for_all_models( +# [ +# # First we load a checkpoint. +# f"pretrained.path={pretrained_path}", +# f"pretrained.format={pretrained_format}", +# # We run for one mock iteration. +# "training.train_iters=1", +# "schedule.skip_step=True", +# # Then we save a checkpoint (distributed format) and an export (fast_llm format). +# "training.checkpoint.interval=1", +# "training.export.interval=1", +# "training.export.format=fast_llm", +# ] +# + distributed_args, +# test_name=test_name, +# **fixture_args, +# ) def main(args: list[str] | None = None) -> None: - rendezvous_port, torchrun_port, base_path, model_testing_config = parse_args(args) - convert_paths = get_convert_paths(base_path) + base_path, model_testing_config = parse_run_distributed_script(args) - fixture_args = { - "rendezvous_port": rendezvous_port, - "torchrun_port": torchrun_port, - "base_path": base_path, - "model_testing_config": model_testing_config, - "num_gpus": 2, - } + # fixture_args = { + # "rendezvous_port": rendezvous_port, + # "torchrun_port": torchrun_port, + # "base_path": base_path, + # "model_testing_config": model_testing_config, + # "num_gpus": 2, + # } with ProcessGroupPool(timeout=20): for pretrained_format, pretrained_path in ( - (DistributedCheckpointFormat.name, convert_paths["distributed_0"]), - (FastLLMCheckpointFormat.name, convert_paths["fast_llm_0"]), - (model_testing_config.checkpoint_format.name, convert_paths["huggingface_0"]), + ( + DistributedCheckpointFormat, + do_get_convert_path( + DistributedCheckpointFormat, model_testing_config.checkpoint_format, base_path=base_path.parent + ), + ), + ( + FastLLMCheckpointFormat, + do_get_convert_path( + FastLLMCheckpointFormat, model_testing_config.checkpoint_format, base_path=base_path.parent + ), + ), + ( + model_testing_config.checkpoint_format, + do_get_convert_path( + model_testing_config.checkpoint_format, DistributedCheckpointFormat, base_path=base_path.parent + ), + ), ): _test_load_and_save_parallel( - fixture_args, - test_name=f"test_load_pretrained_{pretrained_format}_in_dp2", - distributed_args=[], + model_testing_config=model_testing_config, pretrained_path=pretrained_path, pretrained_format=pretrained_format, + distributed_config={}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_dp2", ) # _test_load_and_save_parallel( # fixture_args, + # test_name=f"test_load_pretrained_{pretrained_format}_in_dp2", + # distributed_args=[], + # pretrained_path=pretrained_path, + # pretrained_format=pretrained_format, + # ) + # _test_load_and_save_parallel( + # fixture_args, # test_name=f"test_load_pretrained_{pretrained_format}_in_tp2", # distributed_args=["model.distributed.tensor_parallel=2"], # pretrained_path=pretrained_path, @@ -110,4 +152,5 @@ def main(args: list[str] | None = None) -> None: if __name__ == "__main__": - main() + with fast_llm_main_wrapper(): + main() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 5b392061..8d5928d7 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -1,3 +1,4 @@ +import functools import pathlib import shutil @@ -8,6 +9,7 @@ import yaml from fast_llm.engine.checkpoint.config import ( + CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig, DistributedCheckpointFormat, @@ -15,37 +17,32 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName +from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor -from tests.utils.model_configs import ModelTestingGroup +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup _WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" +_CHECKPOINT_AND_EVAL_ARGS = [ + "training.checkpoint.interval=1", + "training.evaluators.validation.interval=2", + "training.evaluators.validation.evaluator.iterations=1", +] + @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). - run_test_script_for_all_models( - model_testing_config.config_args - + [ - "training.checkpoint.interval=1", - "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluator.iterations=1", - ], - ) + run_test_script_for_all_models(_CHECKPOINT_AND_EVAL_ARGS) -def _prepare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path, skip: bool) -> bool: - if skip and (test_path / "checkpoint" / "2" / "ok").is_file(): - return True - elif test_path.is_dir(): - shutil.rmtree(test_path) +def _prepare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): shutil.copytree(compare_path, test_path) shutil.rmtree(test_path / "checkpoint" / "2") assert (test_path / "checkpoint" / "1" / "ok").is_file() # TODO: Eval shutil.rmtree(test_path / "runs") - return False def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @@ -60,11 +57,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): def test_resume(run_test_script_for_all_models): # Resume from iteration=1 and compare outputs with the baseline run. run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluator.iterations=1", - ], + _CHECKPOINT_AND_EVAL_ARGS, compare=f"test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, compare_fn=_compare_resume_fn, @@ -76,206 +69,116 @@ def test_resume(run_test_script_for_all_models): def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluator.iterations=1", - "model.base_model.transformer.mlp_lr_scale=0.", - ], + _CHECKPOINT_AND_EVAL_ARGS + ["model.base_model.transformer.mlp_lr_scale=0."], compare="test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, do_compare=False, ) -def _run_conversion(config: ConvertConfig): - if config.output.path.exists(): - assert config.output.path.is_dir() - shutil.rmtree(config.output.path) - config.run() - - -def get_convert_paths(base_path: pathlib.Path) -> dict[str, pathlib.Path]: - return { - "checkpoint": base_path / "test_checkpoint_and_eval" / "checkpoint" / "2", - "distributed_0": base_path / "test_convert_model" / "distributed_0", - "distributed_1": base_path / "test_convert_model" / "distributed_1", - "fast_llm_0": base_path / "test_convert_model" / "fast_llm_0", - "fast_llm_1": base_path / "test_convert_model" / "fast_llm_1", - "huggingface_0": base_path / "test_convert_model" / "huggingface_0", - "huggingface_1": base_path / "test_convert_model" / "huggingface_1", - } +def do_get_convert_path( + to: type[CheckpointFormat] | None = None, from_: type[CheckpointFormat] | None = None, *, base_path: pathlib.Path +) -> pathlib.Path: + if to is None or from_ is None: + return base_path / "test_checkpoint_and_eval" / "checkpoint" / "2" + return base_path / "test_convert_model" / f"{to.name}_from_{from_.name}" @pytest.fixture(scope="module") -def convert_paths(run_test_script_base_path: pathlib.Path) -> dict[str, pathlib.Path]: - return get_convert_paths(run_test_script_base_path) +def get_convert_path(run_test_script_base_path): + return functools.partial(do_get_convert_path, base_path=run_test_script_base_path) -@pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): - _run_conversion( +@pytest.fixture(scope="module") +def run_conversion(model_testing_config: ModelTestingConfig, get_convert_path): + def do_run_conversion( + load_path: pathlib.Path, load_format: type[CheckpointFormat] | None, save_format: type[CheckpointFormat] | None + ): ConvertConfig( input=CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, + path=load_path, + format=load_format, ), output=CheckpointSaveConfig( - path=convert_paths["fast_llm_0"], - format=FastLLMCheckpointFormat, + path=get_convert_path(save_format, load_format), + format=save_format, ), model=model_testing_config.model_config_class, - ) - ) + ).run() - -@pytest.mark.depends_on(on=["test_convert_distributed_to_fast_llm[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): - if model_testing_config.checkpoint_format is None: - pytest.skip(f"Conversion not supported for {model_testing_config.name}") - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["fast_llm_0"], - format=FastLLMCheckpointFormat, - ), - output=CheckpointSaveConfig( - path=convert_paths["huggingface_0"], - format=model_testing_config.checkpoint_format, - ), - model=model_testing_config.model_config_class, - ) - ) - - -@pytest.mark.depends_on(on=["test_convert_fast_llm_to_huggingface[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_huggingface_to_distributed(model_testing_config, convert_paths): - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["huggingface_0"], - format=model_testing_config.checkpoint_format, - ), - output=CheckpointSaveConfig( - path=convert_paths["distributed_0"], - format=DistributedCheckpointFormat, - ), - model=model_testing_config.model_config_class, - ) - ) + return do_run_conversion @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_distributed_to_huggingface(model_testing_config, convert_paths): - if model_testing_config.checkpoint_format is None: - pytest.skip(f"Conversion not supported for {model_testing_config.name}") - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - ), - output=CheckpointSaveConfig( - path=convert_paths["huggingface_1"], - format=model_testing_config.checkpoint_format, - ), - model=model_testing_config.model_config_class, - ) +def test_conversion(model_testing_config, run_conversion, get_convert_path): + # Test that the various conversions between formats complete successfully. + run_conversion( + get_convert_path(), + DistributedCheckpointFormat, + FastLLMCheckpointFormat, ) - - -@pytest.mark.depends_on(on=["test_convert_distributed_to_huggingface[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["huggingface_1"], - format=model_testing_config.checkpoint_format, - ), - output=CheckpointSaveConfig( - path=convert_paths["fast_llm_1"], - format=FastLLMCheckpointFormat, - ), - model=model_testing_config.model_config_class, - ) + run_conversion( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + FastLLMCheckpointFormat, + model_testing_config.checkpoint_format, ) - - -@pytest.mark.depends_on(on=["test_convert_huggingface_to_fast_llm[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["fast_llm_1"], - format=FastLLMCheckpointFormat, - ), - output=CheckpointSaveConfig( - path=convert_paths["distributed_1"], - format=DistributedCheckpointFormat, - ), - model=model_testing_config.model_config_class, - ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), + model_testing_config.checkpoint_format, + DistributedCheckpointFormat, + ) + run_conversion( + get_convert_path(), + DistributedCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), + model_testing_config.checkpoint_format, + FastLLMCheckpointFormat, + ) + run_conversion( + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), + FastLLMCheckpointFormat, + DistributedCheckpointFormat, ) -@pytest.mark.depends_on( - on=[ - "test_convert_huggingface_to_distributed[{model_testing_config}]", - "test_convert_fast_llm_to_distributed[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_converted_distributed(convert_paths): - # Compare the fast llm weights - # TODO: Compare configs - w = safetensors.torch.load_file(convert_paths["checkpoint"] / "rank_0.safetensors") - w0 = safetensors.torch.load_file(convert_paths["distributed_0"] / "rank_0.safetensors") - w1 = safetensors.torch.load_file(convert_paths["distributed_1"] / "rank_0.safetensors") - assert w.keys() >= {_WEIGHT_SHARD_SAVE_NAME} - assert w0.keys() == w1.keys() == {_WEIGHT_SHARD_SAVE_NAME} - for key in w0: - assert w[key].shape == w0[key].shape, (key, w[key].shape, w0[key].shape) - assert (w[key] == w0[key]).all(), (w[key], w0[key]) - assert w[key].shape == w1[key].shape, (key, w[key].shape, w1[key].shape) - assert (w[key] == w1[key]).all(), (w[key], w1[key]) - +def _compare_safetensor_files( + reference_path: pathlib.Path, *other_paths: pathlib.Path, expected_keys: set[str] | None = None +): + reference = safetensors.torch.load_file(reference_path) + if expected_keys is None: + expected_keys = set(reference.keys()) + else: + Assert.geq(set(reference.keys()), expected_keys) -@pytest.mark.depends_on( - on=[ - "test_convert_distributed_to_fast_llm[{model_testing_config}]", - "test_convert_huggingface_to_fast_llm[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_converted_fast_llm(convert_paths): - s0 = safetensors.torch.load_file(convert_paths["fast_llm_0"] / "model_0.safetensors") - s1 = safetensors.torch.load_file(convert_paths["fast_llm_1"] / "model_0.safetensors") - assert s0.keys() == s1.keys() - for key in s0: - assert s0[key].shape == s1[key].shape, (key, s0[key].shape, s1[key].shape) - assert (s0[key] == s1[key]).all(), (key, s0, s1) + for other_path in other_paths: + other = safetensors.torch.load_file(other_path) + Assert.eq(other.keys(), expected_keys) + for key in expected_keys: + Assert.all_equal(reference[key], other[key]) -@pytest.mark.depends_on( - on=[ - "test_convert_fast_llm_to_huggingface[{model_testing_config}]", - "test_convert_distributed_to_huggingface[{model_testing_config}]", - ] -) +@pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_converted_huggingface(convert_paths): - h0 = safetensors.torch.load_file(convert_paths["huggingface_0"] / "model_0.safetensors") - h1 = safetensors.torch.load_file(convert_paths["huggingface_1"] / "model_0.safetensors") - assert h0.keys() == h1.keys() - for key in h0: - assert h0[key].shape == h1[key].shape, (key, h0[key].shape, h1[key].shape) - assert (h0[key] == h1[key]).all() +def test_converted_round_trip(model_testing_config, get_convert_path): + # Test that the various possible conversion paths yield identical results. + _compare_safetensor_files( + get_convert_path() / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) / "rank_0.safetensors", + expected_keys={_WEIGHT_SHARD_SAVE_NAME}, + ) + _compare_safetensor_files( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", + ) + _compare_safetensor_files( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) / "model_0.safetensors", + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", + ) def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): @@ -286,146 +189,101 @@ def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMM config_ref.base_model.compare_architecture(config_test.base_model) -@pytest.mark.depends_on(on=["test_converted_distributed[{model_testing_config}]"]) +@pytest.fixture(scope="module") +def load_and_compare_checkpoints(model_testing_config): + def do_load_and_compare_checkpoints( + load_format: type[CheckpointFormat], load_path: pathlib.Path, reference_config, reference_shard + ): + model = model_testing_config.model_class.from_pretrained( + CheckpointLoadConfig( + path=load_path, + format=load_format, + ) + ) + if reference_config is not None: + _compare_model_configs(reference_config, model.config) + if reference_shard is not None: + Assert.all_equal(model.get_shard(ShardName.weights), reference_shard) + + return do_load_and_compare_checkpoints + + +@pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_paths): - config = model_testing_config.model_config_class.from_dict( - yaml.safe_load((convert_paths["checkpoint"] / ".." / ".." / "config.yaml").open("r"))["model"], strict=False - ) - pretrained_config_ref = CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - optimizer_state=True, - load_config=ModelConfigType.model, +def test_load_pretrained( + model_testing_config, run_test_script_base_path, get_convert_path, load_and_compare_checkpoints +): + # Test that loadind a pretrained model from either converted checkpoint always yields the exact same model. + reference_config = model_testing_config.model_config_class.from_dict( + yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) - model = model_testing_config.model_class.from_pretrained(pretrained_config_ref) - _compare_model_configs(config, model.config) - state_shards = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) + reference_config_from_hf = model_testing_config.model_config_class.from_dict( + { + "base_model": yaml.safe_load( + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + .joinpath("metadata.yaml") + .open("r") + )["config"]["base_model"] + } ) - for shard_name in model.state_shard_names: - assert (state_shards[f"{shard_name}_shard"] == model.get_shard(shard_name)).all() + _compare_architectures(reference_config, reference_config_from_hf) + reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ + _WEIGHT_SHARD_SAVE_NAME + ] -@pytest.mark.depends_on(on=["test_load_pretrained_distributed_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_converted_distributed_checkpoint(model_testing_config, convert_paths): - config_ref = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) - ) + load_and_compare_checkpoints(DistributedCheckpointFormat, get_convert_path(), reference_config, reference_shard) - model = model_testing_config.model_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["distributed_0"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + DistributedCheckpointFormat, + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat), + reference_config_from_hf, + reference_shard, ) - config_alt = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["distributed_1"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + DistributedCheckpointFormat, + get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format), + reference_config_from_hf, + reference_shard, ) - _compare_architectures(config_ref, model.config) - _compare_model_configs(model.config, config_alt) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() - -@pytest.mark.depends_on( - on=[ - "test_converted_fast_llm[{model_testing_config}]", - "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths): - config_ref = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + FastLLMCheckpointFormat, + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + reference_config, + reference_shard, ) - model = model_testing_config.model_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["fast_llm_0"], - format=FastLLMCheckpointFormat, - load_config=ModelConfigType.model, - ) - ) - config_alt = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["fast_llm_1"], - format=FastLLMCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + FastLLMCheckpointFormat, + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), + reference_config_from_hf, + reference_shard, ) - _compare_architectures(config_ref, model.config) - _compare_architectures(config_ref, config_alt) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() - -@pytest.mark.depends_on( - on=[ - "test_converted_fast_llm[{model_testing_config}]", - "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_converted_huggingface_checkpoint(model_testing_config, convert_paths): - config_ref = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + model_testing_config.checkpoint_format, + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), + reference_config_from_hf, + reference_shard, ) - model = model_testing_config.model_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["huggingface_1"], - format=model_testing_config.checkpoint_format, - load_config=ModelConfigType.model, - ), - mode=StageMode.weights, + load_and_compare_checkpoints( + model_testing_config.checkpoint_format, + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), + reference_config_from_hf, + reference_shard, ) - config_alt = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["huggingface_0"], - format=model_testing_config.checkpoint_format, - load_config=ModelConfigType.model, - ) - ) - _compare_architectures(config_ref, model.config) - _compare_model_configs(model.config, config_alt) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends_on( - on=[ - "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", - "test_load_converted_huggingface_checkpoint[{model_testing_config}]", - ] -) +@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_run_converted_model(model_testing_config, convert_paths): +def test_huggingface_model(model_testing_config, get_convert_path): + # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Review test. Move to test_generate? + fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) + hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( - path=convert_paths["checkpoint"], + path=get_convert_path(), format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) @@ -434,12 +292,10 @@ def test_run_converted_model(model_testing_config, convert_paths): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( - convert_paths["fast_llm_0"] - ) + model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( - path=convert_paths["huggingface_0"], + path=hf_path, format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, ) @@ -452,7 +308,7 @@ def test_run_converted_model(model_testing_config, convert_paths): else transformers.AutoModelForCausalLM ) model_as_hf = auto_model.from_pretrained( - convert_paths["huggingface_0"], trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code + hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), @@ -476,146 +332,78 @@ def test_run_converted_model(model_testing_config, convert_paths): raise ValueError(f"Comparison failed ({len(errors)} errors)") -@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_pretrained_distributed(run_test_script_for_all_models, convert_paths): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["distributed_0"]}", - f"pretrained.format={DistributedCheckpointFormat.name}", - "schedule.skip_step=True", - ], - ) +@pytest.fixture(scope="module") +def load_and_save_parallel_base_path(run_test_script_base_path): + return run_test_script_base_path / "test_load_and_save_parallel" @pytest.mark.depends_on( on=[ - "test_load_converted_distributed_checkpoint[{model_testing_config}]", - "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", - "test_load_converted_huggingface_checkpoint[{model_testing_config}]", + "test_load_pretrained[{model_testing_config}]", ] ) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_save_and_load_in_parallel(run_distributed_script_for_all_models): +def test_save_and_load_in_parallel(run_distributed_script_for_all_models, load_and_save_parallel_base_path): + # Save and load checkpoints to and from various distributed configurations. + # Combined in a single test to mitigate process creation overhead. + # TODO: Test beyond 2 gpu configs? import tests.models.distributed_test_checkpoint - run_distributed_script_for_all_models([tests.models.distributed_test_checkpoint.__file__], 2) + run_distributed_script_for_all_models( + [tests.models.distributed_test_checkpoint.__file__], + base_path=load_and_save_parallel_base_path, + num_gpus=2, + ) @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_distributed_in_dp2(model_testing_config, convert_paths, run_test_script_base_path): - # Vertify the content of the saved checkpoint. - test_ckpt_path = run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" - pretrained_config_ref = CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.fast_llm, - ) - pretrained_config_test = CheckpointLoadConfig( - path=test_ckpt_path, - format=DistributedCheckpointFormat, - load_config=ModelConfigType.fast_llm, - ) - config_ref = model_testing_config.model_config_class.from_pretrained(pretrained_config_ref) - config_test = model_testing_config.model_config_class.from_pretrained(pretrained_config_test) - _compare_model_configs(config_ref, config_test) - shards_ref = safetensors.torch.load_file(convert_paths["checkpoint"] / "rank_0.safetensors") - shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] - ref_model = model_testing_config.model_class(config_ref) - test_model = model_testing_config.model_class(config_test) - - weight_shard_ref_split = shards_ref[_WEIGHT_SHARD_SAVE_NAME].split(ref_model._stage_weight_shard_sizes) - weight_shards_test_split = [ - shard_test[_WEIGHT_SHARD_SAVE_NAME].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test - ] - for shard_test in shards_test: - for shard_name, shard in shard_test.items(): - if shard_name != _WEIGHT_SHARD_SAVE_NAME: - assert (shard == 0).all() # noqa - - assert len(ref_model._stage_weight_shard_sizes) == len(test_model._stage_weight_shard_sizes) - for i, stage_shard_ref in enumerate(weight_shard_ref_split): - assert ( - test_model._stage_weight_shard_sizes[i] - == ref_model._stage_weight_shard_sizes[i] // 2 + (-ref_model._stage_weight_shard_sizes[i] // 2) % 32 +def test_parallel_checkpoint(model_testing_config, load_and_save_parallel_base_path, get_convert_path): + # Check the consistency of the checkpoints saved in `test_save_and_load_in_parallel` + checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + # Compare Distributed checkpoints + for rank in range(2): + _compare_safetensor_files( + *[ + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_dp2" + / DistributedCheckpointFormat.name + / f"rank_{rank}.safetensors" + for format_ in checkpoint_formats + ] ) - stage_shard_test = torch.concatenate( - [weight_shard_test_split[i] for weight_shard_test_split in weight_shards_test_split] - ) - assert (stage_shard_test[: stage_shard_ref.numel()] == stage_shard_ref).all() - assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa - - # Test loading in single GPU. - pretrained_config_test = CheckpointLoadConfig( - path=test_ckpt_path, - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, + # Compare Fast-LLM checkpoints + _compare_safetensor_files( + # Fast-LLM checkpoints are independent of the distributed configuration that saved it. + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", + *[ + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_dp2" + / FastLLMCheckpointFormat.name + / f"model_0.safetensors" + for format_ in checkpoint_formats + ], ) - model = model_testing_config.model_class.from_pretrained(pretrained_config_test, mode=StageMode.weights) - _compare_model_configs(config_ref, model.config) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() - - -@pytest.mark.depends_on( - on=[ - "test_save_and_load_in_parallel[{model_testing_config}]", - "test_load_pretrained_distributed_in_dp2[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_fast_llm_in_dp2(model_testing_config, convert_paths, run_test_script_base_path): - # Compare the checkpoint from test_load_pretrained_fast_llm_in_dp2 and test_load_pretrained_distributed_in_dp2 - for rank in range(2): - ref_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_distributed_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" - ) - test_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_fast_llm_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" - ) - for name in set(ref_shard) | set(test_shard): - assert (ref_shard[name] == test_shard[name]).all() -@pytest.mark.depends_on( - on=[ - "test_load_converted_huggingface_checkpoint[{model_testing_config}]", - "test_load_pretrained_distributed_in_dp2[{model_testing_config}]", - ] -) +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_huggingface_in_dp2( - run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths +def test_load_parallel_checkpoint( + model_testing_config, load_and_save_parallel_base_path, get_convert_path, load_and_compare_checkpoints ): - # Compare the checkpoint from test_load_pretrained_huggingface_in_dp2 and test_load_pretrained_distributed_in_dp2 - for rank in range(2): - ref_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_distributed_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" - ) - test_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_{model_testing_config.checkpoint_format.name}_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" + # Test single-gpu loading of multi-gpu distributed checkpoints. + checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ + _WEIGHT_SHARD_SAVE_NAME + ] + + for format_ in checkpoint_formats: + load_and_compare_checkpoints( + DistributedCheckpointFormat, + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_dp2" + / DistributedCheckpointFormat.name, + None, + reference_shard, ) - for name in set(ref_shard) | set(test_shard): - assert (ref_shard[name] == test_shard[name]).all() diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 23c487a7..2a12c4f7 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -7,9 +7,9 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from tests.utils.utils import TEST_RESULTS_PATH -# TODO: Fixture -TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") +# TODO: Fixtures TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = TEST_RESULTS_PATH / "dataset" diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index bee43de1..b8dd29e8 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -8,6 +8,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -52,14 +53,27 @@ class ModelTestingConfig: model_type: str config_args: list[str] megatron_args: list[str] | None - checkpoint_format: CheckpointFormat | None + checkpoint_format: type[CheckpointFormat] | None groups: dict[ModelTestingGroup, ModelTestingGroupAction] @functools.cached_property - def model_config_class(self): + def trainer_config_class(self) -> type[TrainerConfig]: + return TrainerConfig.get_subclass(self.model_type) + + @functools.cached_property + def trainer_config(self) -> TrainerConfig: + # See `RunnableConfig._from_parsed_args` + return self.trainer_config_class.from_dict(self.trainer_config_class._parse_updates(self.config_args)) + + @functools.cached_property + def model_config_class(self) -> type[FastLLMModelConfig]: # TODO: Ok to assume the model and trainer have the same name? return FastLLMModelConfig.get_subclass(self.model_type) + @functools.cached_property + def model_config(self) -> FastLLMModelConfig: + return self.trainer_config.model + @functools.cached_property def huggingface_model_for_causal_lm_class(self): return self.model_config_class.get_huggingface_model_for_causal_lm_class() @@ -487,7 +501,7 @@ def model_testing_config(request) -> ModelTestingConfig: def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slow: bool, show_skipped: bool) -> bool: if "model_testing_group" in item.keywords: - assert "model_testing_config" in item.callspec.params, item.nodeid + assert hasattr(item, "callspec") and "model_testing_config" in item.callspec.params, item.nodeid groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args model_testing_config = item.callspec.params["model_testing_config"] model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 3c8f961b..263484db 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -2,7 +2,6 @@ import functools import os import pathlib -import shutil import subprocess import sys import typing @@ -76,26 +75,18 @@ def do_run_test_script( # Prevent Megatron from complaining. env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env["NVTE_FLASH_ATTN"] = "0" - skip = False - if local_rank == 0 and path.exists(): - assert path.is_dir() - # TODO: Better way to check if the previous attempt succeeded. - shutil.rmtree(path) if local_rank == 0 and prepare_fn is not None: - skip = prepare_fn(path, None if compare_path is None else compare_path, skip) + prepare_fn(path, None if compare_path is None else compare_path) if is_megatron: args = ["Megatron-LM/pretrain_gpt.py", *args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] else: args = ["--no-python", "fast-llm", "train", model_type, *args, f"run.experiment_dir={path}"] - if skip: - print("Reusing existing run.") + get_test_dataset() + if (num_gpus == 1 or is_parallel) and not is_megatron: + print(" ".join(args[1:])) + RunnableConfig.parse_and_run(args[2:]) else: - get_test_dataset() - if (num_gpus == 1 or is_parallel) and not is_megatron: - print(" ".join(args[1:])) - RunnableConfig.parse_and_run(args[2:]) - else: - do_run_distributed_script(args, rendezvous_port, torchrun_port, num_gpus) + do_run_distributed_script(args, rendezvous_port, torchrun_port, num_gpus) if local_rank == 0 and compare_path is not None and do_compare: if compare_fn is not None: compare_fn(path, compare_path) @@ -171,12 +162,10 @@ def run_test_script_for_all_models( def parse_run_distributed_script(args: list[str] | None = None): parser = argparse.ArgumentParser() - parser.add_argument("rendezvous_port", type=int) - parser.add_argument("torchrun_port", type=int) parser.add_argument("base_path", type=pathlib.Path) parser.add_argument("model_testing_config", type=str) parsed = parser.parse_args(args) - return parsed.rendezvous_port, parsed.torchrun_port, parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config] + return parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config] @pytest.fixture(scope="session") @@ -186,13 +175,11 @@ def run_distributed_script_for_all_models( model_testing_config: ModelTestingConfig, request: pytest.FixtureRequest, ): - def do_run_distributed_script_for_all_models(args: list[str], num_gpus=2): + def do_run_distributed_script_for_all_models(args: list[str], num_gpus=2, base_path: pathlib.Path | None = None): do_run_distributed_script( args + [ - str(worker_resources.rendezvous_port), - str(worker_resources.torchrun_port), - str(run_test_script_base_path), + str(run_test_script_base_path if base_path is None else base_path), model_testing_config.name, ], worker_resources.rendezvous_port, diff --git a/tests/utils/utils.py b/tests/utils/utils.py index ea689bcc..1ea7717f 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -12,9 +12,12 @@ requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") + + @pytest.fixture(scope="session") def result_path(): - return pathlib.Path("/tmp/fast_llm_tests") + return TEST_RESULTS_PATH def get_base_model(config: FastLLMModelConfig): From 7dee1086b4380519ec9704f836f1b29f419ffeda Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 25 Jun 2025 21:04:54 -0400 Subject: [PATCH 50/69] stuff --- Dockerfile | 2 +- fast_llm/cli.py | 13 +- fast_llm/engine/checkpoint/distributed.py | 3 +- fast_llm/engine/checkpoint/state_dict.py | 1 + fast_llm/engine/config_utils/logging.py | 5 +- fast_llm/engine/config_utils/tensor_space.py | 5 +- fast_llm/engine/distributed/config.py | 94 +-- fast_llm/engine/distributed/distributed.py | 169 +++-- fast_llm/engine/multi_stage/multi_stage.py | 4 +- fast_llm/utils.py | 1 + tests/conftest.py | 9 +- tests/models/distributed_test_checkpoint.py | 156 ++++ tests/models/test_checkpoint.py | 710 +++++++------------ tests/test_gpt_loss.py | 121 ---- tests/utils/dataset.py | 4 +- tests/utils/model_configs.py | 37 +- tests/utils/run_test_script.py | 248 ++++--- tests/utils/utils.py | 5 +- 18 files changed, 755 insertions(+), 832 deletions(-) create mode 100644 tests/models/distributed_test_checkpoint.py delete mode 100644 tests/test_gpt_loss.py diff --git a/Dockerfile b/Dockerfile index d67729d3..b583834d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,7 +30,7 @@ ENV PIP_CONSTRAINT="" # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@74729d0" +RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 34546120..66ce096d 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -1,3 +1,4 @@ +import contextlib import logging import sys import traceback @@ -15,12 +16,12 @@ logger = logging.getLogger(__name__) -def fast_llm_main(args: list[str] | None = None): - # TODO: Add hook to register model classes? (environment variable?) +@contextlib.contextmanager +def fast_llm_main_wrapper(): # (Pre-)configure logging configure_logging() try: - RunnableConfig.parse_and_run(args) + yield except Exception as e: if sys.gettrace(): raise @@ -31,5 +32,11 @@ def fast_llm_main(args: list[str] | None = None): sys.exit(1) +def fast_llm_main(args: list[str] | None = None): + # TODO: Add hook to register model classes? (environment variable?) + with fast_llm_main_wrapper(): + RunnableConfig.parse_and_run(args) + + if __name__ == "__main__": fast_llm_main() diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index de1625f6..6681d70e 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -30,8 +30,8 @@ class DistributedCheckpointHandler(CheckpointHandler): @classmethod def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): - config.path.mkdir(parents=True, exist_ok=True) serialized_metadata = metadata.to_dict() + config.path.mkdir(parents=True, exist_ok=True) yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) @classmethod @@ -40,6 +40,7 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = metadata.to_dict() + config.path.mkdir(parents=True, exist_ok=True) if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 556e97be..7a257a5f 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -216,6 +216,7 @@ def _save_next_file(self) -> None: file_name = f"{self.base_file_name}_{self._file_count}.safetensors" if self._do_save: logger.info(f"Saving tensors to {self._config.path / file_name}") + self._config.path.mkdir(parents=True, exist_ok=True) safetensors.torch.save_file( tensors=self._tensors, filename=self._config.path / file_name, diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index ea014bd0..358674a9 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -5,6 +5,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -14,8 +15,8 @@ def configure_logging( *, log_timestamps: bool = True, enable_all_loggers: bool = False, - rank: int = 0, - world_size: int = 1, + rank: int = DistributedConfig.default_rank, + world_size: int = DistributedConfig.default_world_size, directory: pathlib.Path | str | None = None, ) -> None: rank_str = str(rank).zfill(math.ceil(math.log10(world_size))) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 5020bc65..49ce1525 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -147,7 +147,10 @@ def add_tensor_dim(self, dim: TensorDim) -> None: else: if dim.parallel_dim is not None: assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name - Assert.eq(dim.parallel_dim, self._distributed_config.distributed_dims[dim.parallel_dim.name]) + Assert.eq( + dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + ) self._tensor_dims[dim.name] = dim def get_tensor_dim(self, name: str) -> TensorDim: diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8e2430d5..5ef7b590 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -1,3 +1,4 @@ +import dataclasses import enum import logging import os @@ -50,58 +51,32 @@ def is_training(self) -> bool: return self == PhaseType.training +@dataclasses.dataclass class DistributedDim: """ A dataclass to hold all relevant information on a process group without actually creating it. """ - _is_setup: bool = False - _group: "ProcessGroup|None" + _group: "ProcessGroup|None" = dataclasses.field(init=False, repr=False) + name: str + size: int + rank: int + global_ranks: range | tuple[int, ...] = None - def __init__(self, name: str, size: int = 1, rank: int = 0, id_: str | None = None, parent: str | None = None): - self._name = name - self._size = size - self._rank = rank - self._id = id_ - self._parent = parent - - @property - def name(self) -> str: - return self._name - - @property - def size(self) -> int: - return self._size - - @property - def rank(self) -> int: - return self._rank - - @property - def id(self) -> str | None: - return self._id - - @property - def parent(self) -> str | None: - return self._parent + def __post_init__(self): + self._is_setup = False @property def group(self) -> "ProcessGroup|None": - assert self._is_setup + assert hasattr(self, "_group") return self._group - def __repr__(self) -> str: - return ( - f"DistributedDim(name={self.name}, size={self.size}, rank={self.rank}, id={self.id}, parent={self.parent})" - ) - def setup(self, group: "ProcessGroup|None"): - assert not self._is_setup - self._is_setup = True + assert not hasattr(self, "_group") Assert.eq(group is None, self.size == 1) if group is not None: - Assert.eq(group.size(), self._size) - Assert.eq(group.rank(), self._rank) + Assert.eq(group.size(), self.size) + Assert.eq(group.rank(), self.rank) self._group = group @@ -296,9 +271,15 @@ def _validate(self) -> None: else: self.distributed_dims = {} + data_stride = self.tensor_parallel * (1 if self.pipeline_first else self.pipeline_parallel) + pipeline_stride = self.tensor_parallel * (self.data_parallel if self.pipeline_first else 1) + self._add_distributed_dim( DistributedDim( - name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None + name=DistributedDimNames.world, + size=self.world_size, + rank=self.rank, + global_ranks=range(self.world_size), ) ) self._add_distributed_dim( @@ -306,8 +287,7 @@ def _validate(self) -> None: name=DistributedDimNames.data, size=self.data_parallel, rank=self.data_rank, - id_=f"x_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + global_ranks=self._get_global_ranks(self.data_parallel, data_stride), ) ) self._add_distributed_dim( @@ -315,8 +295,7 @@ def _validate(self) -> None: name=DistributedDimNames.pipeline, size=self.pipeline_parallel, rank=self.pipeline_rank, - id_=f"x_{self.data_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride), ) ) self._add_distributed_dim( @@ -324,8 +303,7 @@ def _validate(self) -> None: name=DistributedDimNames.tensor, size=self.tensor_parallel, rank=self.tensor_rank, - id_=f"x_{self.data_rank}_{self.pipeline_rank}", - parent=DistributedDimNames.world, + global_ranks=self._get_global_ranks(self.tensor_parallel, 1), ) ) self._add_distributed_dim( @@ -333,8 +311,7 @@ def _validate(self) -> None: name=DistributedDimNames.sequence_data, size=self.sequence_data_parallel, rank=self.sequence_data_rank, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride), ) ) self._add_distributed_dim( @@ -342,8 +319,9 @@ def _validate(self) -> None: name=DistributedDimNames.batch_data, size=self.batch_data_parallel, rank=self.batch_data_rank, - id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + global_ranks=self._get_global_ranks( + self.batch_data_parallel, data_stride * self.sequence_data_parallel + ), ) ) self._add_distributed_dim( @@ -351,16 +329,7 @@ def _validate(self) -> None: name=DistributedDimNames.tensor_and_sequence_data, size=self.sequence_data_parallel * self.tensor_parallel, rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}", - parent=( - DistributedDimNames.tensor - if self.sequence_data_parallel == 1 - else ( - DistributedDimNames.sequence_data - if self.tensor_parallel == 1 - else DistributedDimNames.world - ) - ), + global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1), ) ) @@ -371,12 +340,15 @@ def _validate(self) -> None: Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) + def _get_global_ranks(self, size: int, stride: int) -> range: + start = self.rank // (size * stride) * size * stride + self.rank % stride + return range(start, start + size * stride, stride) + def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank) if distributed_dim.name in self.distributed_dims: Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) else: - if distributed_dim.parent is not None: - assert distributed_dim.parent in self.distributed_dims self.distributed_dims[distributed_dim.name] = distributed_dim def get_distributed_dim(self, name: str) -> DistributedDim: diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 42ec97f2..dfc2dd60 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -19,6 +19,89 @@ logger = logging.getLogger(__name__) +class ProcessGroupPool: + def __init__(self, rank: int | None = None, world_size: int | None = None, timeout: float = 60): + + self._rank = DistributedConfig.default_rank if rank is None else rank + self._world_size = DistributedConfig.default_world_size if world_size is None else world_size + self._timeout = timeout + + if self._world_size > 1: + if rank == 0: + logger.info("Initializing TCP store.") + # We bypass `torch.distributed.init_process_group` which makes things way more complicated for no reason. + # TODO: Allow other init methods? + self.store, _, _ = next( + torch.distributed.rendezvous( + "env://", + self._rank, + self._world_size, + timeout=datetime.timedelta(seconds=timeout), + ) + ) + self._process_groups = {} + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def get_process_group(self, global_ranks: range | tuple, rank: int) -> ProcessGroup | None: + """ + Get the requested process group from the pool, or create it if it doesn't exist. + """ + group_size = len(global_ranks) + Assert.eq(global_ranks[rank], self._rank) + if group_size == 1: + return None + + for group_ranks, group in self._process_groups.items(): + # Check if an equivalent group already exists. + if type(group_ranks) != type(global_ranks): + if group_ranks == global_ranks: + return group + elif tuple(group_ranks) == tuple(global_ranks): + return group + + prefix = ( + f"range_{global_ranks.start}_{global_ranks.start}_{global_ranks.step}" + if isinstance(global_ranks, range) + else f"ranks_{"_".join(str(rank) for rank in global_ranks)}" + ) + + group = torch.distributed.ProcessGroupNCCL( + torch.distributed.PrefixStore(prefix + "/", self.store), + global_ranks.index(rank), + group_size, + datetime.timedelta(seconds=self._timeout), + ) + self._process_groups[global_ranks] = group + return group + + def __enter__(self): + global _default_pool + assert _default_pool is None + _default_pool = self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _default_pool + assert _default_pool is self + _default_pool = None + + def __del__(self): + # Shutdown the process group backend explicitly to prevent a nccl warning. + # We can't call `destroy_process_group` directly because pytorch doesn't know about it. + for group in self._process_groups.values(): + if group is not None and hasattr(group, "_shutdown"): + group._shutdown() # noqa + + +_default_pool: ProcessGroupPool | None = None + + class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): """ A distributed instance holding pointers to the various process groups. @@ -31,7 +114,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False): + def __init__(self, config: DistributedConfig, use_cpu: bool = False, pool: ProcessGroupPool | None = None): super().__init__(config) assert self._config.reference_config is None self._use_cpu = use_cpu @@ -45,32 +128,24 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.device = torch.device(self._config.local_rank) torch.cuda.set_device(self.device) - # We bypass `torch.distributed.init_process_group` which makes things way more complicated for no reason. - - # TODO: Allow other init methods? - # TODO: Allow different timeout for the store? - if self._config.world_size > 1: - self._config.log_first_rank("Initializing TCP store.") - self.store, _, _ = next( - torch.distributed.rendezvous( - "env://", - self._config.rank, - self._config.world_size, - timeout=datetime.timedelta(seconds=self._config.timeout), - ) - ) - self._process_groups = {} - for name, distributed_dim in self._config.distributed_dims.items(): - Assert.eq(distributed_dim.name, name) - self.add_group(distributed_dim) - - self.world_group = self._process_groups[DistributedDimNames.world] - self.data_group = self._process_groups[DistributedDimNames.data] - self.pipeline_group = self._process_groups[DistributedDimNames.pipeline] - self.tensor_group = self._process_groups[DistributedDimNames.tensor] - self.sequence_data_group = self._process_groups[DistributedDimNames.sequence_data] - self.batch_data_group = self._process_groups[DistributedDimNames.batch_data] - self.tensor_and_sequence_data_group = self._process_groups[DistributedDimNames.tensor_and_sequence_data] + if pool is None and _default_pool is None: + self._pool = ProcessGroupPool(self._config.rank, self._config.world_size, self._config.timeout) + else: + if pool is None: + pool = _default_pool + Assert.eq(pool._world_size, self._config.world_size) + Assert.eq(pool._rank, self._config.rank) + self._pool = pool + + self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) + self.data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.data]) + self.pipeline_group = self.add_group(self._config.distributed_dims[DistributedDimNames.pipeline]) + self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) + self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) + self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) + self.tensor_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] + ) self._config.log_first_rank(f"Setting random seeds...") @@ -114,38 +189,9 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None: """ Add a process group from its definition. - The group name (`dim`) must be unique within a distributed instance, - - Note: the group id disambiguate between the different groups with the same name on the cluster. - (Ex.: there is one data-parallel group for each model-parallel rank.) - There should be exactly one device for each name, group_id and rank. - TODO: Make private, create all groups through distributed dims in config. """ - Assert.not_incl(distributed_dim.name, self._process_groups) - prefix = distributed_dim.name if distributed_dim.id is None else f"{distributed_dim.name}_{distributed_dim.id}" - - if distributed_dim.parent is None: - parent = None - else: - Assert.incl(distributed_dim.parent, self._process_groups) - parent = self._process_groups[distributed_dim.parent] - if distributed_dim.size == 1: - group = None - elif parent and distributed_dim.size == parent.size(): - Assert.eq(distributed_dim.rank, parent.rank()) - group = parent - else: - if parent: - Assert.lt(distributed_dim.size, parent.size()) - Assert.leq(distributed_dim.rank, parent.rank()) - self._config.log_first_rank(f"Initializing group {distributed_dim.name}, size={distributed_dim.size}...") - group = torch.distributed.ProcessGroupNCCL( - torch.distributed.PrefixStore(prefix + "/", self.store), - distributed_dim.rank, - distributed_dim.size, - datetime.timedelta(seconds=self._config.timeout), - ) - self._process_groups[distributed_dim.name] = group + self._config.log_first_rank(f"Initializing group {distributed_dim.name}, size={distributed_dim.size}...") + group = self._pool.get_process_group(distributed_dim.global_ranks, distributed_dim.rank) distributed_dim.setup(group) return group @@ -164,10 +210,3 @@ def set_step(self, step: int, phase: PhaseType) -> None: seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase] self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED) self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) - - def __del__(self): - # Shutdown the process group backend explicitly to prevent a nccl warning. - # We can't call `destroy_process_group` directly because pytorch doesn't know about it. - for group in self._process_groups.values(): - if group is not None and hasattr(group, "_shutdown"): - group._shutdown() # noqa diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 00570be9..515b977a 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -576,13 +576,13 @@ def setup(self, distributed: Distributed) -> None: # Setup the tied parameter process groups if len(self.all_ranks) > 1 and self.on_device: # TODO: Create a group def first? + pipeline_ranks = distributed.config.get_distributed_dim(DistributedDimNames.pipeline).global_ranks self.group = distributed.add_group( DistributedDim( name=self.name + "_tied_weight", size=len(self.all_ranks), rank=sorted(self.all_ranks).index(distributed.config.pipeline_rank), - id_=f"{distributed.config.data_rank}_x_{distributed.config.tensor_rank}", - parent=DistributedDimNames.pipeline, + global_ranks=tuple(pipeline_ranks[rank] for rank in self.all_ranks), ) ) else: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 8c194938..7bbdd697 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -165,6 +165,7 @@ def all_equal(x, y): x = torch.as_tensor(x) y = torch.as_tensor(y) + Assert.eq(x.shape, y.shape) neq = x != y if neq.any().item(): # noqa index = None if x.numel() == 1 else torch.where(neq) # noqa diff --git a/tests/conftest.py b/tests/conftest.py index 11757176..6e9f830e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import logging import math import os +import shutil import pytest import torch @@ -14,13 +15,14 @@ # Make fixtures available globally without import from tests.utils.run_test_script import ( # isort: skip + run_distributed_script_for_all_models, run_test_script, run_test_script_base_path, run_test_script_for_all_models, ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path # isort: skip +from tests.utils.utils import result_path, TEST_RESULTS_PATH # isort: skip manager: DependencyManager | None = None @@ -76,6 +78,11 @@ def pytest_configure(config): else: worker_id = 0 + # TODO: Remove the whole `TEST_RESULTS_PATH` once `get_test_dataset` is parallel-safe. + model_result_path = TEST_RESULTS_PATH / "models" + if model_result_path.exists(): + shutil.rmtree(model_result_path) + num_gpus = torch.cuda.device_count() if num_gpus > 0 and is_parallel: # We spread workers across GPUs. diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py new file mode 100644 index 00000000..d27b66b7 --- /dev/null +++ b/tests/models/distributed_test_checkpoint.py @@ -0,0 +1,156 @@ +import gc +import pathlib +import typing + +import torch + +from fast_llm.cli import fast_llm_main_wrapper +from fast_llm.engine.checkpoint.config import ( + CheckpointFormat, + CheckpointLoadConfig, + CheckpointSaveConfig, + DistributedCheckpointFormat, + FastLLMCheckpointFormat, +) +from fast_llm.engine.distributed.distributed import ProcessGroupPool +from fast_llm.engine.multi_stage.config import StageMode +from tests.models.test_checkpoint import do_get_convert_path +from tests.utils.model_configs import ModelTestingConfig +from tests.utils.run_test_script import parse_run_distributed_script + + +def _test_load_and_save_parallel( + model_testing_config: ModelTestingConfig, + pretrained_path: pathlib.Path, + pretrained_format: CheckpointFormat, + distributed_config: dict[str, typing.Any], + save_path: pathlib.Path, +): + model = model_testing_config.model_class.from_pretrained( + CheckpointLoadConfig(path=pretrained_path, format=pretrained_format), + # The world size and rank are already set through environment variable. + {"distributed": distributed_config}, + mode=StageMode.inference, + ) + for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + model.save_checkpoint(CheckpointSaveConfig(path=save_path / save_format.name, format=save_format)) + del model + gc.collect() + torch.cuda.empty_cache() + + +# def _test_load_and_save_parallel(fixture_args, test_name, distributed_args, pretrained_path, pretrained_format): +# # TODO: Just save and load the model instead, no need for an actual run. +# do_run_test_script_for_all_models( +# [ +# # First we load a checkpoint. +# f"pretrained.path={pretrained_path}", +# f"pretrained.format={pretrained_format}", +# # We run for one mock iteration. +# "training.train_iters=1", +# "schedule.skip_step=True", +# # Then we save a checkpoint (distributed format) and an export (fast_llm format). +# "training.checkpoint.interval=1", +# "training.export.interval=1", +# "training.export.format=fast_llm", +# ] +# + distributed_args, +# test_name=test_name, +# **fixture_args, +# ) + + +def main(args: list[str] | None = None) -> None: + base_path, model_testing_config = parse_run_distributed_script(args) + + # fixture_args = { + # "rendezvous_port": rendezvous_port, + # "torchrun_port": torchrun_port, + # "base_path": base_path, + # "model_testing_config": model_testing_config, + # "num_gpus": 2, + # } + + with ProcessGroupPool(timeout=20): + for pretrained_format, pretrained_path in ( + ( + DistributedCheckpointFormat, + do_get_convert_path( + DistributedCheckpointFormat, model_testing_config.checkpoint_format, base_path=base_path.parent + ), + ), + ( + FastLLMCheckpointFormat, + do_get_convert_path( + FastLLMCheckpointFormat, model_testing_config.checkpoint_format, base_path=base_path.parent + ), + ), + ( + model_testing_config.checkpoint_format, + do_get_convert_path( + model_testing_config.checkpoint_format, DistributedCheckpointFormat, base_path=base_path.parent + ), + ), + ): + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_dp2", + ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_{pretrained_format}_in_dp2", + # distributed_args=[], + # pretrained_path=pretrained_path, + # pretrained_format=pretrained_format, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_{pretrained_format}_in_tp2", + # distributed_args=["model.distributed.tensor_parallel=2"], + # pretrained_path=pretrained_path, + # pretrained_format=pretrained_format, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_{pretrained_format}_in_stp2", + # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], + # pretrained_path=pretrained_path, + # pretrained_format=pretrained_format, + # ) + + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_dp2_in_tp2", + # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_stp2_in_dp2", + # distributed_args=[], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_tp2_in_stp2", + # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + # _test_load_and_save_parallel( + # fixture_args, + # test_name=f"test_load_pretrained_stp2_in_tp2", + # distributed_args=["model.distributed.tensor_parallel=2"], + # pretrained_path=base_path / "test_load_pretrained_distributed_in_tp2" / "checkpoint" / "1", + # pretrained_format=DistributedCheckpointFormat.name, + # ) + + +if __name__ == "__main__": + with fast_llm_main_wrapper(): + main() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index aff7d991..8d5928d7 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -1,3 +1,4 @@ +import functools import pathlib import shutil @@ -8,6 +9,7 @@ import yaml from fast_llm.engine.checkpoint.config import ( + CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig, DistributedCheckpointFormat, @@ -15,37 +17,32 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName +from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig, compare_logged_tensor -from tests.utils.model_configs import ModelTestingGroup +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup _WEIGHT_SHARD_SAVE_NAME = f"{ShardName.weights}_shard" +_CHECKPOINT_AND_EVAL_ARGS = [ + "training.checkpoint.interval=1", + "training.evaluators.validation.interval=2", + "training.evaluators.validation.evaluator.iterations=1", +] + @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). - run_test_script_for_all_models( - model_testing_config.config_args - + [ - "training.checkpoint.interval=1", - "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluator.iterations=1", - ], - ) + run_test_script_for_all_models(_CHECKPOINT_AND_EVAL_ARGS) -def _prepare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path, skip: bool) -> bool: - if skip and (test_path / "checkpoint" / "2" / "ok").is_file(): - return True - elif test_path.is_dir(): - shutil.rmtree(test_path) +def _prepare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): shutil.copytree(compare_path, test_path) shutil.rmtree(test_path / "checkpoint" / "2") assert (test_path / "checkpoint" / "1" / "ok").is_file() # TODO: Eval shutil.rmtree(test_path / "runs") - return False def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @@ -60,11 +57,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): def test_resume(run_test_script_for_all_models): # Resume from iteration=1 and compare outputs with the baseline run. run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluator.iterations=1", - ], + _CHECKPOINT_AND_EVAL_ARGS, compare=f"test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, compare_fn=_compare_resume_fn, @@ -76,202 +69,116 @@ def test_resume(run_test_script_for_all_models): def test_resume_frozen(run_test_script_for_all_models): # Resume with frozen mlp. No comparison. run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.evaluators.validation.interval=2", - "training.evaluators.validation.evaluator.iterations=1", - "model.base_model.transformer.mlp_lr_scale=0.", - ], + _CHECKPOINT_AND_EVAL_ARGS + ["model.base_model.transformer.mlp_lr_scale=0."], compare="test_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, do_compare=False, ) -def _run_conversion(config: ConvertConfig): - if config.output.path.exists(): - assert config.output.path.is_dir() - shutil.rmtree(config.output.path) - config.run() +def do_get_convert_path( + to: type[CheckpointFormat] | None = None, from_: type[CheckpointFormat] | None = None, *, base_path: pathlib.Path +) -> pathlib.Path: + if to is None or from_ is None: + return base_path / "test_checkpoint_and_eval" / "checkpoint" / "2" + return base_path / "test_convert_model" / f"{to.name}_from_{from_.name}" @pytest.fixture(scope="module") -def convert_paths(run_test_script_base_path): - return { - "checkpoint": run_test_script_base_path / "test_checkpoint_and_eval" / "checkpoint" / "2", - "distributed_0": run_test_script_base_path / "test_convert_model" / "distributed_0", - "distributed_1": run_test_script_base_path / "test_convert_model" / "distributed_1", - "fast_llm_0": run_test_script_base_path / "test_convert_model" / "fast_llm_0", - "fast_llm_1": run_test_script_base_path / "test_convert_model" / "fast_llm_1", - "huggingface_0": run_test_script_base_path / "test_convert_model" / "huggingface_0", - "huggingface_1": run_test_script_base_path / "test_convert_model" / "huggingface_1", - } +def get_convert_path(run_test_script_base_path): + return functools.partial(do_get_convert_path, base_path=run_test_script_base_path) -@pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_distributed_to_fast_llm(model_testing_config, convert_paths): - _run_conversion( +@pytest.fixture(scope="module") +def run_conversion(model_testing_config: ModelTestingConfig, get_convert_path): + def do_run_conversion( + load_path: pathlib.Path, load_format: type[CheckpointFormat] | None, save_format: type[CheckpointFormat] | None + ): ConvertConfig( input=CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, + path=load_path, + format=load_format, ), output=CheckpointSaveConfig( - path=convert_paths["fast_llm_0"], - format=FastLLMCheckpointFormat, + path=get_convert_path(save_format, load_format), + format=save_format, ), model=model_testing_config.model_config_class, - ) - ) + ).run() - -@pytest.mark.depends_on(on=["test_convert_distributed_to_fast_llm[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_fast_llm_to_huggingface(model_testing_config, convert_paths): - if model_testing_config.checkpoint_format is None: - pytest.skip(f"Conversion not supported for {model_testing_config.name}") - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["fast_llm_0"], - format=FastLLMCheckpointFormat, - ), - output=CheckpointSaveConfig( - path=convert_paths["huggingface_0"], - format=model_testing_config.checkpoint_format, - ), - model=model_testing_config.model_config_class, - ) - ) - - -@pytest.mark.depends_on(on=["test_convert_fast_llm_to_huggingface[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_huggingface_to_distributed(model_testing_config, convert_paths): - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["huggingface_0"], - format=model_testing_config.checkpoint_format, - ), - output=CheckpointSaveConfig( - path=convert_paths["distributed_0"], - format=DistributedCheckpointFormat, - ), - model=model_testing_config.model_config_class, - ) - ) + return do_run_conversion @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_distributed_to_huggingface(model_testing_config, convert_paths): - if model_testing_config.checkpoint_format is None: - pytest.skip(f"Conversion not supported for {model_testing_config.name}") - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - ), - output=CheckpointSaveConfig( - path=convert_paths["huggingface_1"], - format=model_testing_config.checkpoint_format, - ), - model=model_testing_config.model_config_class, - ) - ) +def test_conversion(model_testing_config, run_conversion, get_convert_path): + # Test that the various conversions between formats complete successfully. + run_conversion( + get_convert_path(), + DistributedCheckpointFormat, + FastLLMCheckpointFormat, + ) + run_conversion( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + FastLLMCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), + model_testing_config.checkpoint_format, + DistributedCheckpointFormat, + ) + run_conversion( + get_convert_path(), + DistributedCheckpointFormat, + model_testing_config.checkpoint_format, + ) + run_conversion( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), + model_testing_config.checkpoint_format, + FastLLMCheckpointFormat, + ) + run_conversion( + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), + FastLLMCheckpointFormat, + DistributedCheckpointFormat, + ) + + +def _compare_safetensor_files( + reference_path: pathlib.Path, *other_paths: pathlib.Path, expected_keys: set[str] | None = None +): + reference = safetensors.torch.load_file(reference_path) + if expected_keys is None: + expected_keys = set(reference.keys()) + else: + Assert.geq(set(reference.keys()), expected_keys) - -@pytest.mark.depends_on(on=["test_convert_distributed_to_huggingface[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_huggingface_to_fast_llm(model_testing_config, convert_paths): - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["huggingface_1"], - format=model_testing_config.checkpoint_format, - ), - output=CheckpointSaveConfig( - path=convert_paths["fast_llm_1"], - format=FastLLMCheckpointFormat, - ), - model=model_testing_config.model_config_class, - ) - ) + for other_path in other_paths: + other = safetensors.torch.load_file(other_path) + Assert.eq(other.keys(), expected_keys) + for key in expected_keys: + Assert.all_equal(reference[key], other[key]) -@pytest.mark.depends_on(on=["test_convert_huggingface_to_fast_llm[{model_testing_config}]"]) +@pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_convert_fast_llm_to_distributed(model_testing_config, convert_paths): - _run_conversion( - ConvertConfig( - input=CheckpointLoadConfig( - path=convert_paths["fast_llm_1"], - format=FastLLMCheckpointFormat, - ), - output=CheckpointSaveConfig( - path=convert_paths["distributed_1"], - format=DistributedCheckpointFormat, - ), - model=model_testing_config.model_config_class, - ) +def test_converted_round_trip(model_testing_config, get_convert_path): + # Test that the various possible conversion paths yield identical results. + _compare_safetensor_files( + get_convert_path() / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", + get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) / "rank_0.safetensors", + expected_keys={_WEIGHT_SHARD_SAVE_NAME}, + ) + _compare_safetensor_files( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", + ) + _compare_safetensor_files( + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) / "model_0.safetensors", + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", ) - - -@pytest.mark.depends_on( - on=[ - "test_convert_huggingface_to_distributed[{model_testing_config}]", - "test_convert_fast_llm_to_distributed[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_converted_distributed(convert_paths): - # Compare the fast llm weights - # TODO: Compare configs - w = safetensors.torch.load_file(convert_paths["checkpoint"] / "rank_0.safetensors") - w0 = safetensors.torch.load_file(convert_paths["distributed_0"] / "rank_0.safetensors") - w1 = safetensors.torch.load_file(convert_paths["distributed_1"] / "rank_0.safetensors") - assert w.keys() >= {_WEIGHT_SHARD_SAVE_NAME} - assert w0.keys() == w1.keys() == {_WEIGHT_SHARD_SAVE_NAME} - for key in w0: - assert w[key].shape == w0[key].shape, (key, w[key].shape, w0[key].shape) - assert (w[key] == w0[key]).all(), (w[key], w0[key]) - assert w[key].shape == w1[key].shape, (key, w[key].shape, w1[key].shape) - assert (w[key] == w1[key]).all(), (w[key], w1[key]) - - -@pytest.mark.depends_on( - on=[ - "test_convert_distributed_to_fast_llm[{model_testing_config}]", - "test_convert_huggingface_to_fast_llm[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_converted_fast_llm(convert_paths): - s0 = safetensors.torch.load_file(convert_paths["fast_llm_0"] / "model_0.safetensors") - s1 = safetensors.torch.load_file(convert_paths["fast_llm_1"] / "model_0.safetensors") - assert s0.keys() == s1.keys() - for key in s0: - assert s0[key].shape == s1[key].shape, (key, s0[key].shape, s1[key].shape) - assert (s0[key] == s1[key]).all(), (key, s0, s1) - - -@pytest.mark.depends_on( - on=[ - "test_convert_fast_llm_to_huggingface[{model_testing_config}]", - "test_convert_distributed_to_huggingface[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_converted_huggingface(convert_paths): - h0 = safetensors.torch.load_file(convert_paths["huggingface_0"] / "model_0.safetensors") - h1 = safetensors.torch.load_file(convert_paths["huggingface_1"] / "model_0.safetensors") - assert h0.keys() == h1.keys() - for key in h0: - assert h0[key].shape == h1[key].shape, (key, h0[key].shape, h1[key].shape) - assert (h0[key] == h1[key]).all() def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): @@ -282,146 +189,101 @@ def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMM config_ref.base_model.compare_architecture(config_test.base_model) -@pytest.mark.depends_on(on=["test_converted_distributed[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_pretrained_distributed_checkpoint(model_testing_config, convert_paths): - config = model_testing_config.model_config_class.from_dict( - yaml.safe_load((convert_paths["checkpoint"] / ".." / ".." / "config.yaml").open("r"))["model"], strict=False - ) - pretrained_config_ref = CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - optimizer_state=True, - load_config=ModelConfigType.model, - ) - model = model_testing_config.model_class.from_pretrained(pretrained_config_ref) - _compare_model_configs(config, model.config) - state_shards = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - ) - for shard_name in model.state_shard_names: - assert (state_shards[f"{shard_name}_shard"] == model.get_shard(shard_name)).all() +@pytest.fixture(scope="module") +def load_and_compare_checkpoints(model_testing_config): + def do_load_and_compare_checkpoints( + load_format: type[CheckpointFormat], load_path: pathlib.Path, reference_config, reference_shard + ): + model = model_testing_config.model_class.from_pretrained( + CheckpointLoadConfig( + path=load_path, + format=load_format, + ) + ) + if reference_config is not None: + _compare_model_configs(reference_config, model.config) + if reference_shard is not None: + Assert.all_equal(model.get_shard(ShardName.weights), reference_shard) + + return do_load_and_compare_checkpoints -@pytest.mark.depends_on(on=["test_load_pretrained_distributed_checkpoint[{model_testing_config}]"]) +@pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_converted_distributed_checkpoint(model_testing_config, convert_paths): - config_ref = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) - ) +def test_load_pretrained( + model_testing_config, run_test_script_base_path, get_convert_path, load_and_compare_checkpoints +): + # Test that loadind a pretrained model from either converted checkpoint always yields the exact same model. + reference_config = model_testing_config.model_config_class.from_dict( + yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] + ) + reference_config_from_hf = model_testing_config.model_config_class.from_dict( + { + "base_model": yaml.safe_load( + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + .joinpath("metadata.yaml") + .open("r") + )["config"]["base_model"] + } + ) + _compare_architectures(reference_config, reference_config_from_hf) + + reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ + _WEIGHT_SHARD_SAVE_NAME + ] - model = model_testing_config.model_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["distributed_0"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints(DistributedCheckpointFormat, get_convert_path(), reference_config, reference_shard) + + load_and_compare_checkpoints( + DistributedCheckpointFormat, + get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat), + reference_config_from_hf, + reference_shard, ) - config_alt = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["distributed_1"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + DistributedCheckpointFormat, + get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format), + reference_config_from_hf, + reference_shard, ) - _compare_architectures(config_ref, model.config) - _compare_model_configs(model.config, config_alt) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() - -@pytest.mark.depends_on( - on=[ - "test_converted_fast_llm[{model_testing_config}]", - "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_converted_fast_llm_checkpoint(model_testing_config, convert_paths): - config_ref = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) - ) - model = model_testing_config.model_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["fast_llm_0"], - format=FastLLMCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + FastLLMCheckpointFormat, + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), + reference_config, + reference_shard, ) - config_alt = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["fast_llm_1"], - format=FastLLMCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + FastLLMCheckpointFormat, + get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format), + reference_config_from_hf, + reference_shard, ) - _compare_architectures(config_ref, model.config) - _compare_architectures(config_ref, config_alt) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() - -@pytest.mark.depends_on( - on=[ - "test_converted_fast_llm[{model_testing_config}]", - "test_load_pretrained_distributed_checkpoint[{model_testing_config}]", - ] -) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_converted_huggingface_checkpoint(model_testing_config, convert_paths): - config_ref = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + model_testing_config.checkpoint_format, + get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat), + reference_config_from_hf, + reference_shard, ) - model = model_testing_config.model_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["huggingface_1"], - format=model_testing_config.checkpoint_format, - load_config=ModelConfigType.model, - ), - mode=StageMode.weights, - ) - config_alt = model_testing_config.model_config_class.from_pretrained( - CheckpointLoadConfig( - path=convert_paths["huggingface_0"], - format=model_testing_config.checkpoint_format, - load_config=ModelConfigType.model, - ) + load_and_compare_checkpoints( + model_testing_config.checkpoint_format, + get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat), + reference_config_from_hf, + reference_shard, ) - _compare_architectures(config_ref, model.config) - _compare_model_configs(model.config, config_alt) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() -@pytest.mark.depends_on( - on=[ - "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", - "test_load_converted_huggingface_checkpoint[{model_testing_config}]", - ] -) +@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_run_converted_model(model_testing_config, convert_paths): +def test_huggingface_model(model_testing_config, get_convert_path): + # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Review test. Move to test_generate? + fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) + hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( - path=convert_paths["checkpoint"], + path=get_convert_path(), format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) @@ -430,12 +292,10 @@ def test_run_converted_model(model_testing_config, convert_paths): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( - convert_paths["fast_llm_0"] - ) + model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( CheckpointLoadConfig( - path=convert_paths["huggingface_0"], + path=hf_path, format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, ) @@ -448,7 +308,7 @@ def test_run_converted_model(model_testing_config, convert_paths): else transformers.AutoModelForCausalLM ) model_as_hf = auto_model.from_pretrained( - convert_paths["huggingface_0"], trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code + hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), @@ -472,174 +332,78 @@ def test_run_converted_model(model_testing_config, convert_paths): raise ValueError(f"Comparison failed ({len(errors)} errors)") -@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_distributed_in_dp2(run_test_script_for_all_models, convert_paths): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["distributed_0"]}", - f"pretrained.format={DistributedCheckpointFormat.name}", - "schedule.skip_step=True", - ], - num_gpus=2, - ) - - -@pytest.mark.depends_on(on=["test_load_converted_distributed_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_load_pretrained_distributed_with_config(run_test_script_for_all_models, convert_paths): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["distributed_0"]}", - f"pretrained.format={DistributedCheckpointFormat.name}", - "schedule.skip_step=True", - ], - ) - - -@pytest.mark.depends_on(on=["test_load_pretrained_distributed_in_dp2[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_in_dp2_match_checkpoint(model_testing_config, convert_paths, run_test_script_base_path): - test_ckpt_path = run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" - pretrained_config_ref = CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.fast_llm, - ) - pretrained_config_test = CheckpointLoadConfig( - path=test_ckpt_path, - format=DistributedCheckpointFormat, - load_config=ModelConfigType.fast_llm, - ) - config_ref = model_testing_config.model_config_class.from_pretrained(pretrained_config_ref) - config_test = model_testing_config.model_config_class.from_pretrained(pretrained_config_test) - _compare_model_configs(config_ref, config_test) - shards_ref = safetensors.torch.load_file(convert_paths["checkpoint"] / "rank_0.safetensors") - shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] - ref_model = model_testing_config.model_class(config_ref) - test_model = model_testing_config.model_class(config_test) - - weight_shard_ref_split = shards_ref[_WEIGHT_SHARD_SAVE_NAME].split(ref_model._stage_weight_shard_sizes) - weight_shards_test_split = [ - shard_test[_WEIGHT_SHARD_SAVE_NAME].split(test_model._stage_weight_shard_sizes) for shard_test in shards_test - ] - for shard_test in shards_test: - for shard_name, shard in shard_test.items(): - if shard_name != _WEIGHT_SHARD_SAVE_NAME: - assert (shard == 0).all() # noqa - - assert len(ref_model._stage_weight_shard_sizes) == len(test_model._stage_weight_shard_sizes) - for i, stage_shard_ref in enumerate(weight_shard_ref_split): - assert ( - test_model._stage_weight_shard_sizes[i] - == ref_model._stage_weight_shard_sizes[i] // 2 + (-ref_model._stage_weight_shard_sizes[i] // 2) % 32 - ) - - stage_shard_test = torch.concatenate( - [weight_shard_test_split[i] for weight_shard_test_split in weight_shards_test_split] - ) - assert (stage_shard_test[: stage_shard_ref.numel()] == stage_shard_ref).all() - assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa - - -@pytest.mark.depends_on(on=["test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_distributed_checkpoint_dp2(model_testing_config, convert_paths, run_test_script_base_path): - # This also tests conversion which uses `FastLLMModel.from_checkpoint` - pretrained_config_ref = CheckpointLoadConfig( - path=convert_paths["checkpoint"], - format=DistributedCheckpointFormat, - load_config=ModelConfigType.fast_llm, - ) - pretrained_config_test = CheckpointLoadConfig( - path=run_test_script_base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", - format=DistributedCheckpointFormat, - load_config=ModelConfigType.model, - ) - config = model_testing_config.model_config_class.from_pretrained(pretrained_config_ref) - model = model_testing_config.model_class.from_pretrained(pretrained_config_test, mode=StageMode.weights) - _compare_model_configs(config, model.config) - weight_shard = safetensors.torch.load_file( - convert_paths["checkpoint"] / "rank_0.safetensors", device=str(model._distributed.device) - )[_WEIGHT_SHARD_SAVE_NAME] - assert (weight_shard == model.get_shard(ShardName.weights)).all() +@pytest.fixture(scope="module") +def load_and_save_parallel_base_path(run_test_script_base_path): + return run_test_script_base_path / "test_load_and_save_parallel" @pytest.mark.depends_on( on=[ - "test_load_converted_fast_llm_checkpoint[{model_testing_config}]", - "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", + "test_load_pretrained[{model_testing_config}]", ] ) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_fast_llm_in_dp2(run_test_script_for_all_models, convert_paths, run_test_script_base_path): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["fast_llm_0"]}", - f"pretrained.format=fast_llm", - "schedule.skip_step=True", - ], +def test_save_and_load_in_parallel(run_distributed_script_for_all_models, load_and_save_parallel_base_path): + # Save and load checkpoints to and from various distributed configurations. + # Combined in a single test to mitigate process creation overhead. + # TODO: Test beyond 2 gpu configs? + import tests.models.distributed_test_checkpoint + + run_distributed_script_for_all_models( + [tests.models.distributed_test_checkpoint.__file__], + base_path=load_and_save_parallel_base_path, num_gpus=2, ) + + +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_parallel_checkpoint(model_testing_config, load_and_save_parallel_base_path, get_convert_path): + # Check the consistency of the checkpoints saved in `test_save_and_load_in_parallel` + checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + # Compare Distributed checkpoints for rank in range(2): - ref_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_distributed_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" - ) - test_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_fast_llm_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" + _compare_safetensor_files( + *[ + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_dp2" + / DistributedCheckpointFormat.name + / f"rank_{rank}.safetensors" + for format_ in checkpoint_formats + ] ) - for name in set(ref_shard) | set(test_shard): - assert (ref_shard[name] == test_shard[name]).all() + # Compare Fast-LLM checkpoints + _compare_safetensor_files( + # Fast-LLM checkpoints are independent of the distributed configuration that saved it. + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", + *[ + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_dp2" + / FastLLMCheckpointFormat.name + / f"model_0.safetensors" + for format_ in checkpoint_formats + ], + ) -@pytest.mark.depends_on( - on=[ - "test_load_converted_huggingface_checkpoint[{model_testing_config}]", - "test_load_pretrained_in_dp2_match_checkpoint[{model_testing_config}]", - ] -) + +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_pretrained_huggingface_in_dp2( - run_test_script_for_all_models, model_testing_config, run_test_script_base_path, convert_paths +def test_load_parallel_checkpoint( + model_testing_config, load_and_save_parallel_base_path, get_convert_path, load_and_compare_checkpoints ): - run_test_script_for_all_models( - [ - "training.checkpoint.interval=1", - "training.train_iters=1", - f"pretrained.path={convert_paths["huggingface_0"]}", - f"pretrained.format={model_testing_config.checkpoint_format.name}", - "schedule.skip_step=True", - ], - num_gpus=2, - ) - for rank in range(2): - ref_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_distributed_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" - ) - test_shard = safetensors.torch.load_file( - run_test_script_base_path - / f"test_load_pretrained_huggingface_in_dp2" - / "checkpoint" - / "1" - / f"rank_{rank}.safetensors" + # Test single-gpu loading of multi-gpu distributed checkpoints. + checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ + _WEIGHT_SHARD_SAVE_NAME + ] + + for format_ in checkpoint_formats: + load_and_compare_checkpoints( + DistributedCheckpointFormat, + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_dp2" + / DistributedCheckpointFormat.name, + None, + reference_shard, ) - for name in set(ref_shard) | set(test_shard): - assert (ref_shard[name] == test_shard[name]).all() diff --git a/tests/test_gpt_loss.py b/tests/test_gpt_loss.py deleted file mode 100644 index 89262eca..00000000 --- a/tests/test_gpt_loss.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.optimizer.config import OptimizerConfig -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig -from tests.test_gpt_generate_and_forward import model_and_tokenizer # noqa: F401 -from tests.utils.utils import requires_cuda - - -def _get_model_runner_schedule( - model_path: str, - use_flash_attention: bool, - use_bf16: bool, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, - phase=PhaseType.inference, -): - assert phase == PhaseType.inference or phase == PhaseType.validation - updates = { - ("pretrained", "path"): model_path, - ("pretrained", "model_weights"): True, - ("pretrained", "format"): checkpoint_format.name, - ("model", "base_model", "cross_entropy_impl"): "fused", - ("model", "multi_stage", "zero_stage"): 2, - } - - if use_flash_attention: - updates[("model", "base_model", "transformer", "use_flash_attention")] = True - updates[("model", "distributed", "training_dtype")] = "bf16" - else: - updates[("model", "base_model", "transformer", "use_flash_attention")] = False - if use_bf16: - updates[("model", "distributed", "training_dtype")] = "bf16" - - config = PretrainedGPTModelConfig.from_dict({}, updates) - multi_stage = config.model.get_model_class()( - config.model, optimizer_state_names=OptimizerConfig.state_names() if phase == PhaseType.validation else () - ) - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=2048, batch_size=2) - batch_config.setup(config.model.distributed) - batch_config.validate() - - schedule = Schedule( - multi_stage=multi_stage, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=config.model.distributed, - phase=phase, - ) - - runner = ScheduleRunner( - config=schedule_config, - multi_stage=multi_stage, - distributed_config=config.model.distributed, - ) - - distributed = Distributed(config.model.distributed) - - with torch.no_grad(): - multi_stage.setup(distributed) - - with torch.no_grad(): - runner.setup(distributed) - - multi_stage.load_checkpoint(config.pretrained) - - return multi_stage, runner, schedule, batch_config - - -def _test_for_phase(model_path, fast_llm_checkpoint_format, phase): - model, runner, schedule, batch_config = _get_model_runner_schedule( - model_path, True, True, fast_llm_checkpoint_format, phase - ) - - inputs = GPTBatch( - torch.randint( - 1, - model.config.base_model.vocab_size, - [2, batch_config.sequence_length + 1], - dtype=torch.int64, - generator=torch.Generator().manual_seed(42), - ) - ) - - iteration = 1 - - # we need to set phase to validation here so preprocess would crate labels from input - # so it is the same process for validation and inference phases - # otherwise we can add labels manually after preprocess for inference phase - batch = model.base_model.preprocess(inputs, phase=PhaseType.validation, iteration=iteration) - ((inputs_, kwargs),) = batch - kwargs[LanguageModelKwargs.phase] = phase - iter_losses, _, _ = runner.run_step( - iter((((inputs_, kwargs),),)), schedule, iteration=iteration, preprocessed=True - ) - - return iter_losses - - -# @pytest.mark.extra_slow -@requires_cuda -def test_loss_validation_vs_inference(model_and_tokenizer): - model_path, _, fast_llm_checkpoint_format = model_and_tokenizer - - iter_losses_validation = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.validation) - - iter_losses_inference = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.inference) - - assert len(iter_losses_validation) == len(iter_losses_inference) - for key in iter_losses_validation.keys(): - assert math.isclose(iter_losses_validation[key], iter_losses_inference[key], rel_tol=1e-5) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 23c487a7..2a12c4f7 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -7,9 +7,9 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from tests.utils.utils import TEST_RESULTS_PATH -# TODO: Fixture -TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") +# TODO: Fixtures TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = TEST_RESULTS_PATH / "dataset" diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4c225422..b8dd29e8 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -8,6 +8,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -52,14 +53,27 @@ class ModelTestingConfig: model_type: str config_args: list[str] megatron_args: list[str] | None - checkpoint_format: CheckpointFormat | None + checkpoint_format: type[CheckpointFormat] | None groups: dict[ModelTestingGroup, ModelTestingGroupAction] @functools.cached_property - def model_config_class(self): + def trainer_config_class(self) -> type[TrainerConfig]: + return TrainerConfig.get_subclass(self.model_type) + + @functools.cached_property + def trainer_config(self) -> TrainerConfig: + # See `RunnableConfig._from_parsed_args` + return self.trainer_config_class.from_dict(self.trainer_config_class._parse_updates(self.config_args)) + + @functools.cached_property + def model_config_class(self) -> type[FastLLMModelConfig]: # TODO: Ok to assume the model and trainer have the same name? return FastLLMModelConfig.get_subclass(self.model_type) + @functools.cached_property + def model_config(self) -> FastLLMModelConfig: + return self.trainer_config.model + @functools.cached_property def huggingface_model_for_causal_lm_class(self): return self.model_config_class.get_huggingface_model_for_causal_lm_class() @@ -83,7 +97,7 @@ def _update_and_add_testing_config( checkpoint_format: CheckpointFormat | None = ..., groups: dict[ModelTestingGroup, ModelTestingGroupAction], ): - config = _MODEL_CONFIGS[old_name] + config = MODEL_CONFIGS[old_name] updates: dict[str, typing.Any] = { "name": new_name, "groups": groups, @@ -102,13 +116,13 @@ def _update_and_add_testing_config( if checkpoint_format is not ...: updates["checkpoint_format"] = checkpoint_format - _MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) + MODEL_CONFIGS[new_name] = dataclasses.replace(config, **updates) -_MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} +MODEL_CONFIGS: dict[str, ModelTestingConfig] = {} -_MODEL_CONFIGS["gpt2"] = ModelTestingConfig( +MODEL_CONFIGS["gpt2"] = ModelTestingConfig( # Tests gpt2 features (absolute embeddings, layer norm, relu activation, tied embeddings, MHA, linear biases). name="gpt2", model_type="gpt", @@ -477,17 +491,20 @@ def _update_and_add_testing_config( ) -@pytest.fixture(scope="session", params=_MODEL_CONFIGS.keys()) +@pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: - return _MODEL_CONFIGS[request.param] + models = request.config.getoption("--models") + if models and request.param not in models: + pytest.skip(f"Skipping model {request.param}") + return MODEL_CONFIGS[request.param] def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slow: bool, show_skipped: bool) -> bool: if "model_testing_group" in item.keywords: - assert "model_testing_config" in item.callspec.params, item.nodeid + assert hasattr(item, "callspec") and "model_testing_config" in item.callspec.params, item.nodeid groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args model_testing_config = item.callspec.params["model_testing_config"] - model_config: ModelTestingConfig = _MODEL_CONFIGS[model_testing_config] + model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] for group in groups: action = model_config.groups[group] if action == ModelTestingGroupAction.main: diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 69ed817a..263484db 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -1,15 +1,23 @@ +import argparse +import functools import os import pathlib -import shutil import subprocess import sys +import typing import pytest import torch from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig, compare_tensor_logs from tests.utils.dataset import get_test_dataset +from tests.utils.model_configs import MODEL_CONFIGS, ModelTestingConfig + +if typing.TYPE_CHECKING: + from tests.conftest import WorkerResources # FIXME: figure out correct import of megatron modules without this hack sys.path.append(os.getcwd()) @@ -17,71 +25,117 @@ _ARTIFACT_PATH = "runs/0/artifacts" +def do_run_distributed_script( + args: list[str], + rendezvous_port: int, + torchrun_port: int, + num_gpus: int, + timeout: float = 120, +): + command = [ + "python", + "-m", + "torch.distributed.run", + f"--nproc-per-node={num_gpus}", + f"--rdzv-endpoint=localhost:{rendezvous_port}", + f"--master-port={torchrun_port}", + *args, + ] + print(" ".join(command)) + completed_proc = subprocess.run(command, timeout=timeout) + if completed_proc.returncode: + raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") + + +def do_run_test_script( + path: pathlib.Path, + args: list[str], + num_gpus: int = 1, + *, + model_type: str, + is_megatron: bool = False, + compare_path: pathlib.Path | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + rendezvous_port: int, + torchrun_port: int, +): + is_parallel = DistributedConfig.default_world_size > 1 + if is_parallel: + Assert.eq(num_gpus, DistributedConfig.default_world_size) + local_rank = DistributedConfig.default_rank + + if torch.cuda.device_count() < num_gpus: + pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") + env = os.environ.copy() + if is_megatron: + assert num_gpus == 1 + # Prevent Megatron from complaining. + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["NVTE_FLASH_ATTN"] = "0" + if local_rank == 0 and prepare_fn is not None: + prepare_fn(path, None if compare_path is None else compare_path) + if is_megatron: + args = ["Megatron-LM/pretrain_gpt.py", *args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] + else: + args = ["--no-python", "fast-llm", "train", model_type, *args, f"run.experiment_dir={path}"] + get_test_dataset() + if (num_gpus == 1 or is_parallel) and not is_megatron: + print(" ".join(args[1:])) + RunnableConfig.parse_and_run(args[2:]) + else: + do_run_distributed_script(args, rendezvous_port, torchrun_port, num_gpus) + if local_rank == 0 and compare_path is not None and do_compare: + if compare_fn is not None: + compare_fn(path, compare_path) + compare_tensor_logs( + compare_path / _ARTIFACT_PATH, + path / _ARTIFACT_PATH, + config, + ) + + +def do_run_test_script_for_all_models( + extra_args: list[str], + num_gpus: int = 1, + *, + is_megatron: bool = False, + compare: str | None = None, + config: CompareConfig | None = None, + prepare_fn=None, + compare_fn=None, + do_compare: bool = True, + rendezvous_port: int, + torchrun_port: int, + test_name: str, + base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, +): + do_run_test_script( + base_path / test_name, + (model_testing_config.megatron_args if is_megatron else model_testing_config.config_args) + extra_args, + num_gpus, + model_type=model_testing_config.model_type, + is_megatron=is_megatron, + compare_path=None if compare is None else base_path / compare, + config=config, + prepare_fn=prepare_fn, + compare_fn=compare_fn, + do_compare=do_compare, + rendezvous_port=rendezvous_port, + torchrun_port=torchrun_port, + ) + + @pytest.fixture(scope="session") -def run_test_script(worker_resources): - def do_run_test_script( - path: pathlib.Path, - args: list[str], - num_gpus: int = 1, - *, - model_type: str, - is_megatron: bool = False, - compare_path: pathlib.Path | None = None, - config: CompareConfig | None = None, - prepare_fn=None, - compare_fn=None, - do_compare: bool = True, - ): - if torch.cuda.device_count() < num_gpus: - pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") - env = os.environ.copy() - if is_megatron: - # Prevent Megatron from complaining. - env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - env["NVTE_FLASH_ATTN"] = "0" - skip = False - if path.exists(): - assert path.is_dir() - # TODO: Better way to check if the previous attempt succeeded. - shutil.rmtree(path) - if prepare_fn is not None: - skip = prepare_fn(path, None if compare_path is None else compare_path, skip) - if is_megatron: - args = [*args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] - else: - args = ["train", model_type, *args, f"run.experiment_dir={path}"] - header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] - command = [ - "python", - "-m", - "torch.distributed.run", - f"--nproc-per-node={num_gpus}", - f"--rdzv-endpoint=localhost:{worker_resources.rendezvous_port}", - f"--master-port={worker_resources.torchrun_port}", - *header, - *args, - ] - print(" ".join(command)) - if skip: - print("Reusing existing run.") - else: - get_test_dataset() - if num_gpus == 1 and not is_megatron: - RunnableConfig.parse_and_run(args) - else: - completed_proc = subprocess.run(command, env=env, timeout=120) - if completed_proc.returncode: - raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare_path is not None and do_compare: - if compare_fn is not None: - compare_fn(path, compare_path) - compare_tensor_logs( - compare_path / _ARTIFACT_PATH, - path / _ARTIFACT_PATH, - config, - ) - - return do_run_test_script +def run_test_script(worker_resources: "WorkerResources"): + return functools.partial( + do_run_test_script, + rendezvous_port=worker_resources.rendezvous_port, + torchrun_port=worker_resources.torchrun_port, + ) @pytest.fixture(scope="session") @@ -90,29 +144,47 @@ def run_test_script_base_path(model_testing_config, result_path, request): @pytest.fixture(scope="function") -def run_test_script_for_all_models(run_test_script, run_test_script_base_path, model_testing_config, request): - def do_run_test_script_for_all_models( - extra_args: list[str], - num_gpus: int = 1, - *, - is_megatron: bool = False, - compare: str | None = None, - config: CompareConfig | None = None, - prepare_fn=None, - compare_fn=None, - do_compare: bool = True, - ): - run_test_script( - run_test_script_base_path / request.node.originalname, - (model_testing_config.megatron_args if is_megatron else model_testing_config.config_args) + extra_args, +def run_test_script_for_all_models( + worker_resources: "WorkerResources", + run_test_script_base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, + request: pytest.FixtureRequest, +): + return functools.partial( + do_run_test_script_for_all_models, + rendezvous_port=worker_resources.rendezvous_port, + torchrun_port=worker_resources.torchrun_port, + test_name=request.node.originalname, + base_path=run_test_script_base_path, + model_testing_config=model_testing_config, + ) + + +def parse_run_distributed_script(args: list[str] | None = None): + parser = argparse.ArgumentParser() + parser.add_argument("base_path", type=pathlib.Path) + parser.add_argument("model_testing_config", type=str) + parsed = parser.parse_args(args) + return parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config] + + +@pytest.fixture(scope="session") +def run_distributed_script_for_all_models( + worker_resources: "WorkerResources", + run_test_script_base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, + request: pytest.FixtureRequest, +): + def do_run_distributed_script_for_all_models(args: list[str], num_gpus=2, base_path: pathlib.Path | None = None): + do_run_distributed_script( + args + + [ + str(run_test_script_base_path if base_path is None else base_path), + model_testing_config.name, + ], + worker_resources.rendezvous_port, + worker_resources.torchrun_port, num_gpus, - model_type=model_testing_config.model_type, - is_megatron=is_megatron, - compare_path=None if compare is None else run_test_script_base_path / compare, - config=config, - prepare_fn=prepare_fn, - compare_fn=compare_fn, - do_compare=do_compare, ) - return do_run_test_script_for_all_models + return do_run_distributed_script_for_all_models diff --git a/tests/utils/utils.py b/tests/utils/utils.py index ea689bcc..1ea7717f 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -12,9 +12,12 @@ requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") + + @pytest.fixture(scope="session") def result_path(): - return pathlib.Path("/tmp/fast_llm_tests") + return TEST_RESULTS_PATH def get_base_model(config: FastLLMModelConfig): From 3e818952a5cdb5065174f00606a8818a2f3f46d3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 25 Jun 2025 21:30:53 -0400 Subject: [PATCH 51/69] cleanup --- tests/models/distributed_test_checkpoint.py | 79 --------------------- 1 file changed, 79 deletions(-) diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index d27b66b7..fff0c49e 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -39,38 +39,9 @@ def _test_load_and_save_parallel( torch.cuda.empty_cache() -# def _test_load_and_save_parallel(fixture_args, test_name, distributed_args, pretrained_path, pretrained_format): -# # TODO: Just save and load the model instead, no need for an actual run. -# do_run_test_script_for_all_models( -# [ -# # First we load a checkpoint. -# f"pretrained.path={pretrained_path}", -# f"pretrained.format={pretrained_format}", -# # We run for one mock iteration. -# "training.train_iters=1", -# "schedule.skip_step=True", -# # Then we save a checkpoint (distributed format) and an export (fast_llm format). -# "training.checkpoint.interval=1", -# "training.export.interval=1", -# "training.export.format=fast_llm", -# ] -# + distributed_args, -# test_name=test_name, -# **fixture_args, -# ) - - def main(args: list[str] | None = None) -> None: base_path, model_testing_config = parse_run_distributed_script(args) - # fixture_args = { - # "rendezvous_port": rendezvous_port, - # "torchrun_port": torchrun_port, - # "base_path": base_path, - # "model_testing_config": model_testing_config, - # "num_gpus": 2, - # } - with ProcessGroupPool(timeout=20): for pretrained_format, pretrained_path in ( ( @@ -99,56 +70,6 @@ def main(args: list[str] | None = None) -> None: distributed_config={}, save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_dp2", ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_{pretrained_format}_in_dp2", - # distributed_args=[], - # pretrained_path=pretrained_path, - # pretrained_format=pretrained_format, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_{pretrained_format}_in_tp2", - # distributed_args=["model.distributed.tensor_parallel=2"], - # pretrained_path=pretrained_path, - # pretrained_format=pretrained_format, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_{pretrained_format}_in_stp2", - # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], - # pretrained_path=pretrained_path, - # pretrained_format=pretrained_format, - # ) - - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_dp2_in_tp2", - # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_stp2_in_dp2", - # distributed_args=[], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_tp2_in_stp2", - # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_stp2_in_tp2", - # distributed_args=["model.distributed.tensor_parallel=2"], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_tp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) if __name__ == "__main__": From 0bbdbe11ca2758494e5c24858fa81a64ef9c8c84 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 25 Jun 2025 21:51:51 -0400 Subject: [PATCH 52/69] stuff --- Dockerfile | 2 +- setup.cfg | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index b583834d..96a86063 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/setup.cfg b/setup.cfg index b1e44e81..2f69b8e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,7 +25,8 @@ CORE = # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 # Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released. - triton==3.1.0 + # TODO: Removed because it breaks cpu-only installs and pip dependency resolution. + # triton==3.1.0 # Small packages required for some optional features and tools. From 878ed823ba8509c97fab78e3fc589e04782b080c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 12:16:48 -0400 Subject: [PATCH 53/69] test --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 96a86063..e98223de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,8 +29,8 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" +RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ From cfd3d778a536da3ff01bbe76281acc28380d8b3f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 13:42:31 -0400 Subject: [PATCH 54/69] fix --- fast_llm/engine/evaluation/evaluator.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index f07a8c48..78aad230 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -124,9 +124,13 @@ def setup( self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - return EvaluatorSamplingParameters( - (self._name if self._config.dataset_name is None else self._config.dataset_name), - self._config.iterations * self._batch_config.batch_size, + return ( + None + if self._config.iterations is None + else EvaluatorSamplingParameters( + (self._name if self._config.dataset_name is None else self._config.dataset_name), + self._config.iterations * self._batch_config.batch_size, + ) ) def run( @@ -139,7 +143,6 @@ def run( run_index = 0 metrics = {} - formatted_metrics = None if self._evaluation_iterator is None: self._evaluation_iterator = self._get_data_iterator(self._get_completed_evaluation_steps(run_index)) From 9309abc6cc5506680ca39b8b0fbe367cebf7993e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 14:42:20 -0400 Subject: [PATCH 55/69] fix --- fast_llm/layers/language_model/head.py | 2 +- fast_llm/logging.py | 2 +- tests/conftest.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 52637869..88b0612b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -135,7 +135,7 @@ def forward( # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss) + losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. diff --git a/fast_llm/logging.py b/fast_llm/logging.py index a70cacce..f574aa38 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -331,7 +331,7 @@ def log_generator[ def get_memory_usage_mib(reset_stats: bool = True, relative_to: dict[str, int] | None = None) -> dict[str, float]: global _global_max_allocated, _global_max_reserved - max_allocated = torch.cuda.memory_allocated() / 2**20 + max_allocated = torch.cuda.max_memory_allocated() / 2**20 max_reserved = torch.cuda.max_memory_reserved() / 2**20 _global_max_allocated = max(max_allocated, _global_max_allocated) _global_max_reserved = max(max_reserved, _global_max_reserved) diff --git a/tests/conftest.py b/tests/conftest.py index 6e9f830e..27ea5f63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -208,14 +208,14 @@ def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): "duration": call.duration, # Relevant value for OOM risk. Also look at global max since fast-llm resets stats. "max_memory_reserved": max( - torch.cuda.max_memory_reserved(), fast_llm.logging._global_max_reserved + torch.cuda.max_memory_reserved() / 2**20, fast_llm.logging._global_max_reserved ), # Actual memory usage from the test. "max_memory_allocated": max( - torch.cuda.max_memory_allocated(), fast_llm.logging._global_max_allocated + torch.cuda.max_memory_allocated() / 2**20, fast_llm.logging._global_max_allocated ), - "memory_reserved": torch.cuda.memory_reserved(), - "memory_allocated": torch.cuda.memory_allocated(), + "memory_reserved": torch.cuda.memory_reserved() / 2**20, + "memory_allocated": torch.cuda.memory_allocated() / 2**20, } ), ) @@ -248,10 +248,10 @@ def pytest_terminal_summary(terminalreporter): for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: terminalreporter.write_line( f"{nodeid}:\n " - f"Max Reserved {resource_reports[nodeid]["max_memory_reserved"] / 1e6:.0f} MB | " - f"Max Allocated {resource_reports[nodeid]["max_memory_allocated"] / 1e6:.0f} MB | " - f"End Reserved {resource_reports[nodeid]["memory_reserved"] / 1e6:.0f} MB | " - f"End Allocated {resource_reports[nodeid]["memory_allocated"] / 1e6:.0f} MB | " + f"Max Reserved {resource_reports[nodeid]["max_memory_reserved"]:.0f} MiB | " + f"Max Allocated {resource_reports[nodeid]["max_memory_allocated"]:.0f} MiB | " + f"End Reserved {resource_reports[nodeid]["memory_reserved"]:.0f} MiB | " + f"End Allocated {resource_reports[nodeid]["memory_allocated"]:.0f} MiB | " f"Duration {resource_reports[nodeid]["duration"]:.2f}" ) From 3c438c7260a4e6ffd9f5ea818007b9a69448d7af Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 14:48:31 -0400 Subject: [PATCH 56/69] Misc fixes --- Dockerfile | 6 +++--- fast_llm/engine/checkpoint/distributed.py | 3 ++- fast_llm/engine/checkpoint/state_dict.py | 1 + fast_llm/layers/language_model/head.py | 2 +- fast_llm/logging.py | 2 +- setup.cfg | 3 ++- tests/conftest.py | 16 ++++++++-------- 7 files changed, 18 insertions(+), 15 deletions(-) diff --git a/Dockerfile b/Dockerfile index d67729d3..e98223de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,15 +29,15 @@ ENV PIP_CONSTRAINT="" # There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds. # We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d) # We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?) -RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" -RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@74729d0" +RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1" +RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2" # Copy dependency files with universal write permissions for all users. COPY --chmod=777 setup.py setup.cfg pyproject.toml ./ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index de1625f6..6681d70e 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -30,8 +30,8 @@ class DistributedCheckpointHandler(CheckpointHandler): @classmethod def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): - config.path.mkdir(parents=True, exist_ok=True) serialized_metadata = metadata.to_dict() + config.path.mkdir(parents=True, exist_ok=True) yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) @classmethod @@ -40,6 +40,7 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = metadata.to_dict() + config.path.mkdir(parents=True, exist_ok=True) if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 556e97be..7a257a5f 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -216,6 +216,7 @@ def _save_next_file(self) -> None: file_name = f"{self.base_file_name}_{self._file_count}.safetensors" if self._do_save: logger.info(f"Saving tensors to {self._config.path / file_name}") + self._config.path.mkdir(parents=True, exist_ok=True) safetensors.torch.save_file( tensors=self._tensors, filename=self._config.path / file_name, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 52637869..88b0612b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -135,7 +135,7 @@ def forward( # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss) + losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. diff --git a/fast_llm/logging.py b/fast_llm/logging.py index a70cacce..f574aa38 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -331,7 +331,7 @@ def log_generator[ def get_memory_usage_mib(reset_stats: bool = True, relative_to: dict[str, int] | None = None) -> dict[str, float]: global _global_max_allocated, _global_max_reserved - max_allocated = torch.cuda.memory_allocated() / 2**20 + max_allocated = torch.cuda.max_memory_allocated() / 2**20 max_reserved = torch.cuda.max_memory_reserved() / 2**20 _global_max_allocated = max(max_allocated, _global_max_allocated) _global_max_reserved = max(max_reserved, _global_max_reserved) diff --git a/setup.cfg b/setup.cfg index b1e44e81..2f69b8e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,7 +25,8 @@ CORE = # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.3 # Dropless MLP is broken with triton 3.2.0, 3.3.0 and 3.3.1. TODO: Remove once a working triton version is released. - triton==3.1.0 + # TODO: Removed because it breaks cpu-only installs and pip dependency resolution. + # triton==3.1.0 # Small packages required for some optional features and tools. diff --git a/tests/conftest.py b/tests/conftest.py index 11757176..b3798365 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -201,14 +201,14 @@ def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): "duration": call.duration, # Relevant value for OOM risk. Also look at global max since fast-llm resets stats. "max_memory_reserved": max( - torch.cuda.max_memory_reserved(), fast_llm.logging._global_max_reserved + torch.cuda.max_memory_reserved() / 2**20, fast_llm.logging._global_max_reserved ), # Actual memory usage from the test. "max_memory_allocated": max( - torch.cuda.max_memory_allocated(), fast_llm.logging._global_max_allocated + torch.cuda.max_memory_allocated() / 2**20, fast_llm.logging._global_max_allocated ), - "memory_reserved": torch.cuda.memory_reserved(), - "memory_allocated": torch.cuda.memory_allocated(), + "memory_reserved": torch.cuda.memory_reserved() / 2**20, + "memory_allocated": torch.cuda.memory_allocated() / 2**20, } ), ) @@ -241,10 +241,10 @@ def pytest_terminal_summary(terminalreporter): for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: terminalreporter.write_line( f"{nodeid}:\n " - f"Max Reserved {resource_reports[nodeid]["max_memory_reserved"] / 1e6:.0f} MB | " - f"Max Allocated {resource_reports[nodeid]["max_memory_allocated"] / 1e6:.0f} MB | " - f"End Reserved {resource_reports[nodeid]["memory_reserved"] / 1e6:.0f} MB | " - f"End Allocated {resource_reports[nodeid]["memory_allocated"] / 1e6:.0f} MB | " + f"Max Reserved {resource_reports[nodeid]["max_memory_reserved"]:.0f} MiB | " + f"Max Allocated {resource_reports[nodeid]["max_memory_allocated"]:.0f} MiB | " + f"End Reserved {resource_reports[nodeid]["memory_reserved"]:.0f} MiB | " + f"End Allocated {resource_reports[nodeid]["memory_allocated"]:.0f} MiB | " f"Duration {resource_reports[nodeid]["duration"]:.2f}" ) From 45ae71785b2c7d8c704cd3ab3d965077aa5cc67b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 14:55:11 -0400 Subject: [PATCH 57/69] Misc fixes --- fast_llm/engine/evaluation/evaluator.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index f07a8c48..78aad230 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -124,9 +124,13 @@ def setup( self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - return EvaluatorSamplingParameters( - (self._name if self._config.dataset_name is None else self._config.dataset_name), - self._config.iterations * self._batch_config.batch_size, + return ( + None + if self._config.iterations is None + else EvaluatorSamplingParameters( + (self._name if self._config.dataset_name is None else self._config.dataset_name), + self._config.iterations * self._batch_config.batch_size, + ) ) def run( @@ -139,7 +143,6 @@ def run( run_index = 0 metrics = {} - formatted_metrics = None if self._evaluation_iterator is None: self._evaluation_iterator = self._get_data_iterator(self._get_completed_evaluation_steps(run_index)) From a478ebeb44fb6fba9d8fc89ed8098194872e8cbf Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 15:00:08 -0400 Subject: [PATCH 58/69] Misc fixes --- tests/test_gpt_loss.py | 121 ----------------------------------------- 1 file changed, 121 deletions(-) delete mode 100644 tests/test_gpt_loss.py diff --git a/tests/test_gpt_loss.py b/tests/test_gpt_loss.py deleted file mode 100644 index 89262eca..00000000 --- a/tests/test_gpt_loss.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.optimizer.config import OptimizerConfig -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig -from tests.test_gpt_generate_and_forward import model_and_tokenizer # noqa: F401 -from tests.utils.utils import requires_cuda - - -def _get_model_runner_schedule( - model_path: str, - use_flash_attention: bool, - use_bf16: bool, - checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat, - phase=PhaseType.inference, -): - assert phase == PhaseType.inference or phase == PhaseType.validation - updates = { - ("pretrained", "path"): model_path, - ("pretrained", "model_weights"): True, - ("pretrained", "format"): checkpoint_format.name, - ("model", "base_model", "cross_entropy_impl"): "fused", - ("model", "multi_stage", "zero_stage"): 2, - } - - if use_flash_attention: - updates[("model", "base_model", "transformer", "use_flash_attention")] = True - updates[("model", "distributed", "training_dtype")] = "bf16" - else: - updates[("model", "base_model", "transformer", "use_flash_attention")] = False - if use_bf16: - updates[("model", "distributed", "training_dtype")] = "bf16" - - config = PretrainedGPTModelConfig.from_dict({}, updates) - multi_stage = config.model.get_model_class()( - config.model, optimizer_state_names=OptimizerConfig.state_names() if phase == PhaseType.validation else () - ) - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=2048, batch_size=2) - batch_config.setup(config.model.distributed) - batch_config.validate() - - schedule = Schedule( - multi_stage=multi_stage, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=config.model.distributed, - phase=phase, - ) - - runner = ScheduleRunner( - config=schedule_config, - multi_stage=multi_stage, - distributed_config=config.model.distributed, - ) - - distributed = Distributed(config.model.distributed) - - with torch.no_grad(): - multi_stage.setup(distributed) - - with torch.no_grad(): - runner.setup(distributed) - - multi_stage.load_checkpoint(config.pretrained) - - return multi_stage, runner, schedule, batch_config - - -def _test_for_phase(model_path, fast_llm_checkpoint_format, phase): - model, runner, schedule, batch_config = _get_model_runner_schedule( - model_path, True, True, fast_llm_checkpoint_format, phase - ) - - inputs = GPTBatch( - torch.randint( - 1, - model.config.base_model.vocab_size, - [2, batch_config.sequence_length + 1], - dtype=torch.int64, - generator=torch.Generator().manual_seed(42), - ) - ) - - iteration = 1 - - # we need to set phase to validation here so preprocess would crate labels from input - # so it is the same process for validation and inference phases - # otherwise we can add labels manually after preprocess for inference phase - batch = model.base_model.preprocess(inputs, phase=PhaseType.validation, iteration=iteration) - ((inputs_, kwargs),) = batch - kwargs[LanguageModelKwargs.phase] = phase - iter_losses, _, _ = runner.run_step( - iter((((inputs_, kwargs),),)), schedule, iteration=iteration, preprocessed=True - ) - - return iter_losses - - -# @pytest.mark.extra_slow -@requires_cuda -def test_loss_validation_vs_inference(model_and_tokenizer): - model_path, _, fast_llm_checkpoint_format = model_and_tokenizer - - iter_losses_validation = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.validation) - - iter_losses_inference = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.inference) - - assert len(iter_losses_validation) == len(iter_losses_inference) - for key in iter_losses_validation.keys(): - assert math.isclose(iter_losses_validation[key], iter_losses_inference[key], rel_tol=1e-5) From 672158564ceba76ff6142cf3dae2d995d3c5ebc6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 15:31:53 -0400 Subject: [PATCH 59/69] fix --- tests/utils/run_test_script.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 263484db..ab08ad73 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -31,6 +31,7 @@ def do_run_distributed_script( torchrun_port: int, num_gpus: int, timeout: float = 120, + env: dict[str, str | None] = None, ): command = [ "python", @@ -42,7 +43,7 @@ def do_run_distributed_script( *args, ] print(" ".join(command)) - completed_proc = subprocess.run(command, timeout=timeout) + completed_proc = subprocess.run(command, timeout=timeout, env=env) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") @@ -75,6 +76,8 @@ def do_run_test_script( # Prevent Megatron from complaining. env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env["NVTE_FLASH_ATTN"] = "0" + else: + env = None if local_rank == 0 and prepare_fn is not None: prepare_fn(path, None if compare_path is None else compare_path) if is_megatron: @@ -86,7 +89,9 @@ def do_run_test_script( print(" ".join(args[1:])) RunnableConfig.parse_and_run(args[2:]) else: - do_run_distributed_script(args, rendezvous_port, torchrun_port, num_gpus) + do_run_distributed_script( + args, rendezvous_port=rendezvous_port, torchrun_port=torchrun_port, num_gpus=num_gpus, env=env + ) if local_rank == 0 and compare_path is not None and do_compare: if compare_fn is not None: compare_fn(path, compare_path) From 47c59e8d1fdf03a3624623831a03e9c72f0a3be1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 15:49:01 -0400 Subject: [PATCH 60/69] fix --- fast_llm/engine/distributed/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index dfc2dd60..9719ff2e 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -49,12 +49,12 @@ def rank(self): def world_size(self): return self._world_size - def get_process_group(self, global_ranks: range | tuple, rank: int) -> ProcessGroup | None: + def get_process_group(self, global_ranks: range | tuple, group_rank: int) -> ProcessGroup | None: """ Get the requested process group from the pool, or create it if it doesn't exist. """ group_size = len(global_ranks) - Assert.eq(global_ranks[rank], self._rank) + Assert.eq(global_ranks[group_rank], self._rank) if group_size == 1: return None @@ -74,7 +74,7 @@ def get_process_group(self, global_ranks: range | tuple, rank: int) -> ProcessGr group = torch.distributed.ProcessGroupNCCL( torch.distributed.PrefixStore(prefix + "/", self.store), - global_ranks.index(rank), + group_rank, group_size, datetime.timedelta(seconds=self._timeout), ) From 35712f82f449f31dc738a673a3e2e74ead32eaab Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 17:52:15 -0400 Subject: [PATCH 61/69] Test, fixes --- fast_llm/engine/distributed/config.py | 10 ++- tests/test_config.py | 122 +++++++++++++++++++++++++- 2 files changed, 128 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 5ef7b590..8b689cde 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -247,6 +247,8 @@ def _validate(self) -> None: Assert.multiple(self.local_world_size, self.tensor_parallel) if self.pipeline_first: + # Case is useless and would cause too many complications. + Assert.eq(self.sequence_data_parallel, 1) # Smaller models can be more demanding on pipeline parallel. self.data_rank = (self.rank // self.tensor_parallel) // self.pipeline_parallel self.pipeline_rank = (self.rank // self.tensor_parallel) % self.pipeline_parallel @@ -271,8 +273,10 @@ def _validate(self) -> None: else: self.distributed_dims = {} - data_stride = self.tensor_parallel * (1 if self.pipeline_first else self.pipeline_parallel) - pipeline_stride = self.tensor_parallel * (self.data_parallel if self.pipeline_first else 1) + data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) + print("data_stride", data_stride) + print("pipeline_stride", pipeline_stride) self._add_distributed_dim( DistributedDim( @@ -345,7 +349,7 @@ def _get_global_ranks(self, size: int, stride: int) -> range: return range(start, start + size * stride, stride) def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: - Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank) + Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) if distributed_dim.name in self.distributed_dims: Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) else: diff --git a/tests/test_config.py b/tests/test_config.py index 6aacee10..b6a9a985 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import collections import pathlib import subprocess @@ -7,7 +8,7 @@ from fast_llm.config import NoAutoValidate from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested @@ -148,3 +149,122 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) + + +def _check_dim(dim: DistributedDim, name: str, rank: int, size: int, global_rank: int): + Assert.eq(dim.name, name) + Assert.eq(dim.size, size) + Assert.eq(dim.rank, rank) + # Already checked in distributed config, we repeat for extra safety. + Assert.eq(dim.global_ranks[rank], global_rank) + Assert.eq(len(dim.global_ranks), size) + + +@pytest.mark.parametrize( + ("bdp", "sdp", "tp", "pp", "pipeline_first"), + ( + (1, 1, 1, 1, False), + (4, 1, 1, 1, False), + (1, 4, 1, 1, False), + (1, 1, 4, 1, False), + (1, 1, 1, 4, False), + (1, 4, 1, 3, False), + (1, 1, 3, 2, False), + (1, 1, 3, 2, True), + (3, 1, 1, 2, False), + (3, 1, 1, 2, True), + (2, 2, 2, 3, False), + ), +) +def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline_first: bool): + world_size = bdp * sdp * tp * pp + dp = sdp * bdp + config_dict = { + "sequence_data_parallel": sdp, + "tensor_parallel": tp, + "pipeline_parallel": pp, + "pipeline_first": pipeline_first, + "world_size": world_size, + "local_world_size": world_size, + } + + all_global_ranks = collections.defaultdict(set) + rank_breakdowns = set() + for rank in range(world_size): + # Independent computation of the group ranks. + tp_rank = rank % tp + rank_ = rank // tp + if pipeline_first: + pp_rank = rank_ % pp + dp_rank = rank_ // pp + else: + dp_rank = rank_ % dp + pp_rank = rank_ // dp + + config = DistributedConfig.from_dict(config_dict, {"rank": rank}) + # Check that each group has the right size and rank. + _check_dim( + world_dim := config.get_distributed_dim(DistributedDimNames.world), + DistributedDimNames.world, + rank, + world_size, + rank, + ) + _check_dim( + tp_dim := config.get_distributed_dim(DistributedDimNames.tensor), + DistributedDimNames.tensor, + tp_rank, + tp, + rank, + ) + _check_dim( + tp_sdp_dim := config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), + DistributedDimNames.tensor_and_sequence_data, + dp_rank % sdp * tp + tp_rank, + tp * sdp, + rank, + ) + _check_dim( + sdp_dim := config.get_distributed_dim(DistributedDimNames.sequence_data), + DistributedDimNames.sequence_data, + dp_rank % sdp, + sdp, + rank, + ) + _check_dim( + bdp_dim := config.get_distributed_dim(DistributedDimNames.batch_data), + DistributedDimNames.batch_data, + dp_rank // sdp, + bdp, + rank, + ) + _check_dim( + dp_dim := config.get_distributed_dim(DistributedDimNames.data), + DistributedDimNames.data, + dp_rank, + bdp * sdp, + rank, + ) + _check_dim( + pp_dim := config.get_distributed_dim(DistributedDimNames.pipeline), + DistributedDimNames.pipeline, + pp_rank, + pp, + rank, + ) + all_global_ranks["world"].add(tuple(world_dim.global_ranks)) + all_global_ranks["tp"].add(tuple(tp_dim.global_ranks)) + all_global_ranks["tp_sdp"].add(tuple(tp_sdp_dim.global_ranks)) + all_global_ranks["sdp"].add(tuple(sdp_dim.global_ranks)) + all_global_ranks["bdp"].add(tuple(bdp_dim.global_ranks)) + all_global_ranks["dp"].add(tuple(dp_dim.global_ranks)) + all_global_ranks["pp"].add(tuple(pp_dim.global_ranks)) + rank_breakdowns.add((tp_rank, dp_rank // sdp, dp_rank % sdp, pp_rank)) + + for name, global_ranks_set in all_global_ranks.items(): + # Check that the global ranks are partitioned into disjoint groups for each distributed dimension, + # and indirectly that `DistributedDim.global_ranks` is consistent between ranks. + Assert.eq(sum(len(global_ranks) for global_ranks in global_ranks_set), world_size) + Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks})) + + Assert.eq(len(rank_breakdowns), world_size) From 30685cc0142d7ec6629f7529c10955b15a918c6a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 20:49:37 -0400 Subject: [PATCH 62/69] Fix overlap index --- fast_llm/engine/distributed/config.py | 2 +- fast_llm/engine/multi_stage/fsdp.py | 219 +++++++++++++++++--------- fast_llm/functional/linear.py | 2 +- fast_llm/tensor.py | 45 +++--- fast_llm/utils.py | 7 + 5 files changed, 176 insertions(+), 99 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8b689cde..ff5569f4 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -61,7 +61,7 @@ class DistributedDim: name: str size: int rank: int - global_ranks: range | tuple[int, ...] = None + global_ranks: range | tuple[int, ...] def __post_init__(self): self._is_setup = False diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index f991b68f..f329c12c 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -1,3 +1,4 @@ +import dataclasses import math import typing @@ -403,9 +404,17 @@ def invalidate_buffer(self) -> None: self._is_restored = False def parameter_global_to_shard( - self, global_param: torch.Tensor | SafeTensorSlice, parameter_name: str + self, + global_param: torch.Tensor | SafeTensorSlice, + parameter_name: str, + *, + _parameter_meta: TensorMeta | None = None, ) -> torch.Tensor: - shard_param = self.get_parameter_meta(parameter_name).global_to_local(global_param).flatten() + if _parameter_meta is None: + # Used with reduced tensor-parallel in `copy_shard_overlaps` + _parameter_meta = self._parameter_metas[parameter_name] + # This may copy the data. + shard_param = _parameter_meta.global_to_local(global_param).flatten() if self._fsdp_dim.size > 1: shard_param = shard_param[ self._index_buffer_to_param( @@ -414,13 +423,14 @@ def parameter_global_to_shard( ] return shard_param - def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, device: torch.device) -> torch.Tensor: + def _get_parameter_shard_indices_in_full_weight( + self, parameter_name: str, device: torch.device, parameter_meta: TensorMeta + ) -> torch.Tensor: """ Create an index array for the global parameter, where each entry corresponds to the index where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - parameter_meta = self.get_parameter_meta(parameter_name) # Create an empty index for the global parameter. index = torch.full( @@ -431,9 +441,23 @@ def _get_parameter_shard_indices_in_full_weight(self, parameter_name: str, devic ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - self.parameter_global_to_shard(index, parameter_name).copy_( - torch.arange(begin, end, dtype=torch.int64, device=device) - ) + + buffer_index = parameter_meta.global_to_local(index) + buffer_flat_index = buffer_index.flatten() + + # Copy the shard indices at their respective positions in the flat buffer index. + shard_index = buffer_flat_index[ + self._index_buffer_to_param( + self._fsdp_dim.rank * self._shard_size, parameter_name + ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) + ] + shard_index.copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) + + # `buffer_flat_index` may be a copy of `buffer_index`. + # If this is the case, we need to copy the result back into `buffer_index`, which itself is a view of `index`. + if buffer_flat_index.is_contiguous() and not buffer_index.is_contiguous(): + buffer_index.copy_(buffer_flat_index.view_as(buffer_index)) + return index def copy_shard_overlaps( @@ -441,8 +465,6 @@ def copy_shard_overlaps( loaded_fsdp: typing.Self, shards: dict[str, torch.Tensor] | None, loaded_shards: dict[str, torch.Tensor] | None, - # counter: torch.Tensor, - device: torch.device, ) -> dict[tuple[str, str], int]: """ See MultiStage._load_partial. @@ -455,82 +477,125 @@ def copy_shard_overlaps( self_meta = self._parameter_metas[parameter_name] loaded_meta = loaded_fsdp._parameter_metas[parameter_name] - if self_meta.is_tensor_parallel: - self_tp = self_meta.tensor_parallel_dim.size - self_rank = self_meta.tensor_parallel_dim.rank - else: - self_tp, self_rank = 1, 0 - if loaded_meta.is_tensor_parallel: - loaded_tp = loaded_meta.tensor_parallel_dim.size - loaded_rank = loaded_meta.tensor_parallel_dim.rank + # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. + if (shared_tp := math.gcd(self_meta.tensor_parallel_size, loaded_meta.tensor_parallel_size)) > 1: + self_meta, self_shared_rank = _reduce_tensor_parallelism_in_meta(self_meta, shared_tp) + loaded_meta, loaded_shared_rank = _reduce_tensor_parallelism_in_meta(loaded_meta, shared_tp) + if self_shared_rank != loaded_shared_rank: + # Disjoint tensor-parallel slices, no possible overlap. + continue + + if self_meta.tensor_parallel_size == loaded_meta.tensor_parallel_size == 1: + self._copy_shard_overlaps(loaded_fsdp, shards, loaded_shards, parameter_name, counter) else: - loaded_tp, loaded_rank = 1, 0 + raise NotImplementedError() - # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. - shared_tp = math.gcd(self_tp, loaded_tp) + return counter - self_tp //= shared_tp - loaded_tp //= shared_tp + def _copy_shard_overlaps( + self, + loaded_fsdp: typing.Self, + shards: dict[str, torch.Tensor] | None, + loaded_shards: dict[str, torch.Tensor] | None, + parameter_name: str, + counter: dict[tuple[str, str], int], + ): + self_shard_begin_in_buffer = self._fsdp_dim.rank * self._shard_size + self_shard_end_in_buffer = (self._fsdp_dim.rank + 1) * self._shard_size + self_shard_begin_in_param = self._index_buffer_to_param(self_shard_begin_in_buffer, parameter_name) + self_shard_end_in_param = self._index_buffer_to_param(self_shard_end_in_buffer, parameter_name) - if self_rank // self_tp != loaded_rank // loaded_tp: - # Disjoint shared rank, no possible overlap. - continue - self_rank %= self_tp - loaded_rank %= loaded_tp - - if self_tp == loaded_tp == 1: - self_shard_begin_in_buffer = self._fsdp_dim.rank * self._shard_size - self_shard_end_in_buffer = (self._fsdp_dim.rank + 1) * self._shard_size - self_shard_begin_in_param = self._index_buffer_to_param(self_shard_begin_in_buffer, parameter_name) - self_shard_end_in_param = self._index_buffer_to_param(self_shard_end_in_buffer, parameter_name) - - loaded_shard_begin_in_buffer = loaded_fsdp._fsdp_dim.rank * loaded_fsdp._shard_size - loaded_shard_end_in_buffer = (loaded_fsdp._fsdp_dim.rank + 1) * loaded_fsdp._shard_size - loaded_shard_begin_in_param = loaded_fsdp._index_buffer_to_param( - loaded_shard_begin_in_buffer, parameter_name - ) - loaded_shard_end_in_param = loaded_fsdp._index_buffer_to_param( - loaded_shard_end_in_buffer, parameter_name - ) + loaded_shard_begin_in_buffer = loaded_fsdp._fsdp_dim.rank * loaded_fsdp._shard_size + loaded_shard_end_in_buffer = (loaded_fsdp._fsdp_dim.rank + 1) * loaded_fsdp._shard_size + loaded_shard_begin_in_param = loaded_fsdp._index_buffer_to_param(loaded_shard_begin_in_buffer, parameter_name) + loaded_shard_end_in_param = loaded_fsdp._index_buffer_to_param(loaded_shard_end_in_buffer, parameter_name) - overlap_begin_in_param = max(self_shard_begin_in_param, loaded_shard_begin_in_param) - overlap_end_in_param = min(self_shard_end_in_param, loaded_shard_end_in_param) + overlap_begin_in_param = max(self_shard_begin_in_param, loaded_shard_begin_in_param) + overlap_end_in_param = min(self_shard_end_in_param, loaded_shard_end_in_param) - if (overlap_size := overlap_end_in_param - overlap_begin_in_param) <= 0: - continue + if (overlap_size := overlap_end_in_param - overlap_begin_in_param) <= 0: + return - overlap_begin_in_self_shard = ( - self._parameter_begins_in_buffer[parameter_name] - + overlap_begin_in_param - - self_shard_begin_in_buffer - ) - overlap_begin_in_loaded_shard = ( - loaded_fsdp._parameter_begins_in_buffer[parameter_name] - + overlap_begin_in_param - - loaded_shard_begin_in_buffer - ) + overlap_begin_in_self_shard = ( + self._parameter_begins_in_buffer[parameter_name] + overlap_begin_in_param - self_shard_begin_in_buffer + ) + overlap_begin_in_loaded_shard = ( + loaded_fsdp._parameter_begins_in_buffer[parameter_name] + + overlap_begin_in_param + - loaded_shard_begin_in_buffer + ) - if shards is None: - # Dry run, we only want the counter. - Assert.not_incl((parameter_name, ""), counter) - counter[(parameter_name, "")] = overlap_size - continue + if shards is None: + # Dry run. + counter[(parameter_name, "")] = overlap_size + return - for shard_name, shard in shards.items(): - # Shards can be empty (frozen weights) - if shard.numel() == 0: - continue - Assert.not_incl((parameter_name, shard_name), counter) - counter[(parameter_name, shard_name)] = overlap_size - shard[overlap_begin_in_self_shard : overlap_begin_in_self_shard + overlap_size] = ( - loaded_shards[shard_name][ - overlap_begin_in_loaded_shard : overlap_begin_in_loaded_shard + overlap_size - ] - if loaded_shards[shard_name].numel() > 0 - else 0 - ) + for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + continue + counter[(parameter_name, shard_name)] = overlap_size + shard[overlap_begin_in_self_shard : overlap_begin_in_self_shard + overlap_size] = ( + loaded_shards[shard_name][overlap_begin_in_loaded_shard : overlap_begin_in_loaded_shard + overlap_size] + if loaded_shards[shard_name].numel() > 0 + else 0 + ) - else: - raise NotImplementedError() + def _copy_tensor_parallel_shard_overlaps( + self, + loaded_fsdp: typing.Self, + shards: dict[str, torch.Tensor] | None, + loaded_shards: dict[str, torch.Tensor] | None, + parameter_name: str, + counter: dict[tuple[str, str], int], + self_meta: TensorMeta, + loaded_meta: TensorMeta, + ): + if shards is None: + # Dry run. Since we only need to know if there can be overlap, + # we skip the slow computation and return a dummy value. + counter[(parameter_name, "")] = 1 + return - return counter + device = next(iter(shards.values())).device + overlap_index_map = self.parameter_global_to_shard( + loaded_fsdp._get_parameter_shard_indices_in_full_weight(parameter_name, device, loaded_meta), + parameter_name, + _parameter_meta=self_meta, + ) + overlap_mask = overlap_index_map >= 0 + overlap_index_map_masked = overlap_index_map[overlap_mask] + overlap_count = overlap_mask.sum().item() + if overlap_count == 0: + return + begin, end = self._get_parameter_range_in_shard(parameter_name) + + for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + continue + if loaded_shards[shard_name].numel() == 0: + shard[begin:end][overlap_mask] = 0 + counter += overlap_count + continue + shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] + counter += overlap_count + + +def _reduce_tensor_parallelism_in_meta(meta: TensorMeta, shared_tp: int) -> tuple[TensorMeta, int]: + # Make a `TensorMeta` look like it has less tensor parallelism. + dims = list(meta.dims) + dim = dims[meta.tensor_parallel_dim_index] + new_size = meta.tensor_parallel_size // shared_tp + shared_rank = meta.tensor_parallel_rank + dims[meta.tensor_parallel_dim_index] = TensorDim( + dim.name, + dim.global_size // shared_tp, + dataclasses.replace( + dim.parallel_dim, + size=new_size, + rank=meta.tensor_parallel_rank % new_size, + global_rank=dim.parallel_dim.global_ranks[shared_tp * shared_rank : shared_tp * (shared_rank + 1)], + ), + ) + return TensorMeta(meta, tensor_name=meta.tensor_name, dims=tuple(dims)), shared_rank diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index dbc05184..60c2554d 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -38,7 +38,7 @@ def update_linear_gradients( Which one is best? (and can we fuse everything?) """ - grad_output = grad_output.flatten(0, -2) + grad_output = grad_output(0, -2) input_ = input_.flatten(0, -2) lhs, rhs = (input_.t(), grad_output) if transposed_weight else (grad_output.t(), input_) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b2849be8..b3c4e3fe 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -10,7 +10,7 @@ from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy -from fast_llm.utils import Assert +from fast_llm.utils import Assert, flatten_without_copy class _SafeTensorSliceMeta(type): @@ -88,24 +88,27 @@ def __new__( ) @functools.cached_property - def is_tensor_parallel(self) -> bool: - # TODO: Avoid hard-coded assumptions on tensor parallel. - return any( - dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor for dim in self.dims - ) - - @functools.cached_property - def tensor_parallel_dim(self) -> DistributedDim | None: + def tensor_parallel_dim_index(self) -> int | None: # TODO: Avoid hard-coded assumptions on tensor parallel. - if not self.is_tensor_parallel: - return None - dims = [ - dim - for dim in self.dims + indexes = [ + i + for i, dim in enumerate(self.dims) if dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor ] - assert len(dims) == 1, dims - return dims[0].parallel_dim + assert len(indexes) == 1, indexes + return indexes[0] + + @functools.cached_property + def is_tensor_parallel(self) -> bool: + return self.tensor_parallel_dim_index is not None + + @functools.cached_property + def tensor_parallel_size(self) -> int: + return self.dims[self.tensor_parallel_dim_index].parallel_dim.size if self.is_tensor_parallel else 1 + + @functools.cached_property + def tensor_parallel_rank(self) -> int: + return self.dims[self.tensor_parallel_dim_index].parallel_dim.rank if self.is_tensor_parallel else 0 def __repr__(self, *, tensor_contents=()): return super().__repr__( @@ -194,10 +197,12 @@ def global_to_local( for i, dim in enumerate(self.dims): if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = ( - tensor_.unflatten(i, dim.global_expanded_shape) - .chunk(dim.parallel_dim.size, i + dim.parallel_dim_index)[dim.parallel_dim.rank] - .flatten(i, i + len(dim.expanded_shape) - 1) + tensor_ = flatten_without_copy( + tensor_.unflatten(i, dim.global_expanded_shape).chunk( + dim.parallel_dim.size, i + dim.parallel_dim_index + )[dim.parallel_dim.rank], + i, + i + len(dim.expanded_shape) - 1, ) return tensor_.view(self.shape) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 7bbdd697..ce41fcac 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -71,6 +71,13 @@ def rms_diff(x: "torch.Tensor", y: "torch.Tensor") -> "torch.Tensor": return torch.norm(x - y, 2, dtype=torch.float32) / x.numel() ** 0.5 # noqa +def flatten_without_copy(x: "torch.Tensor", start_dim: int = 0, end_dim: int = -1) -> "torch.Tensor": + # Similar to torch.flatten, but never returns a copy. + return x.view( + *x.shape[:start_dim], x.shape[start_dim:end_dim].numel(), *([] if end_dim == -1 else x.shape[end_dim + 1 :]) + ) + + class Tag: __slots__ = ("value",) From 5dbd101d26d230c905a318f214f2737d25805705 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 26 Jun 2025 21:18:27 -0400 Subject: [PATCH 63/69] stuff --- fast_llm/engine/checkpoint/distributed.py | 2 - fast_llm/functional/linear.py | 2 +- fast_llm/tensor.py | 4 +- tests/models/distributed_test_checkpoint.py | 128 ++++++++------------ 4 files changed, 53 insertions(+), 83 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 903c8840..7faf599f 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -129,7 +129,6 @@ def _has_shard_overlaps(self, loaded_model) -> bool: loaded_fsdp, None, None, - self._model.distributed.device, ) if counter: return True @@ -147,7 +146,6 @@ def _copy_shard_overlaps(self, loaded_model, loaded_shards, context): loaded_fsdp, self_fsdp_shards, loaded_fsdp_shards, - self._model.distributed.device, ) for parameter, count in counter.items(): context.mark_as_loaded(count, parameter, True) diff --git a/fast_llm/functional/linear.py b/fast_llm/functional/linear.py index 60c2554d..dbc05184 100644 --- a/fast_llm/functional/linear.py +++ b/fast_llm/functional/linear.py @@ -38,7 +38,7 @@ def update_linear_gradients( Which one is best? (and can we fuse everything?) """ - grad_output = grad_output(0, -2) + grad_output = grad_output.flatten(0, -2) input_ = input_.flatten(0, -2) lhs, rhs = (input_.t(), grad_output) if transposed_weight else (grad_output.t(), input_) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b3c4e3fe..1e3dfc75 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -95,8 +95,8 @@ def tensor_parallel_dim_index(self) -> int | None: for i, dim in enumerate(self.dims) if dim.parallel_dim is not None and dim.parallel_dim.name == DistributedDimNames.tensor ] - assert len(indexes) == 1, indexes - return indexes[0] + assert len(indexes) <= 1, indexes + return indexes[0] if indexes else None @functools.cached_property def is_tensor_parallel(self) -> bool: diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index d27b66b7..e247c33b 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -39,38 +39,9 @@ def _test_load_and_save_parallel( torch.cuda.empty_cache() -# def _test_load_and_save_parallel(fixture_args, test_name, distributed_args, pretrained_path, pretrained_format): -# # TODO: Just save and load the model instead, no need for an actual run. -# do_run_test_script_for_all_models( -# [ -# # First we load a checkpoint. -# f"pretrained.path={pretrained_path}", -# f"pretrained.format={pretrained_format}", -# # We run for one mock iteration. -# "training.train_iters=1", -# "schedule.skip_step=True", -# # Then we save a checkpoint (distributed format) and an export (fast_llm format). -# "training.checkpoint.interval=1", -# "training.export.interval=1", -# "training.export.format=fast_llm", -# ] -# + distributed_args, -# test_name=test_name, -# **fixture_args, -# ) - - def main(args: list[str] | None = None) -> None: base_path, model_testing_config = parse_run_distributed_script(args) - # fixture_args = { - # "rendezvous_port": rendezvous_port, - # "torchrun_port": torchrun_port, - # "base_path": base_path, - # "model_testing_config": model_testing_config, - # "num_gpus": 2, - # } - with ProcessGroupPool(timeout=20): for pretrained_format, pretrained_path in ( ( @@ -99,56 +70,57 @@ def main(args: list[str] | None = None) -> None: distributed_config={}, save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_dp2", ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_{pretrained_format}_in_dp2", - # distributed_args=[], - # pretrained_path=pretrained_path, - # pretrained_format=pretrained_format, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_{pretrained_format}_in_tp2", - # distributed_args=["model.distributed.tensor_parallel=2"], - # pretrained_path=pretrained_path, - # pretrained_format=pretrained_format, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_{pretrained_format}_in_stp2", - # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], - # pretrained_path=pretrained_path, - # pretrained_format=pretrained_format, - # ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_tp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_stp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=pretrained_path, + pretrained_format=pretrained_format, + distributed_config={"pipeline_parallel": 2}, + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_ppp2", + ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_dp2_in_tp2", - # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_stp2_in_dp2", - # distributed_args=[], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_tp2_in_stp2", - # distributed_args=["model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=true"], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_stp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) - # _test_load_and_save_parallel( - # fixture_args, - # test_name=f"test_load_pretrained_stp2_in_tp2", - # distributed_args=["model.distributed.tensor_parallel=2"], - # pretrained_path=base_path / "test_load_pretrained_distributed_in_tp2" / "checkpoint" / "1", - # pretrained_format=DistributedCheckpointFormat.name, - # ) + dist = DistributedCheckpointFormat.name + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_dp2" / dist, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, + save_path=base_path / f"load_pretrained_dp2_in_stp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_stp2" / dist, + pretrained_format=pretrained_format, + distributed_config={}, + save_path=base_path / f"load_pretrained_stp2_in_dp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_tp2" / dist, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, + save_path=base_path / f"load_pretrained_tp2_in_stp2", + ) + _test_load_and_save_parallel( + model_testing_config=model_testing_config, + pretrained_path=base_path / f"load_pretrained_{dist}_in_pp2" / dist, + pretrained_format=pretrained_format, + distributed_config={"tensor_parallel": 2}, + save_path=base_path / f"load_pretrained_pp2_in_tp2", + ) if __name__ == "__main__": From 5f368ef73ae6b16d4c0222914053934b3bebd9f5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Jun 2025 19:16:21 -0400 Subject: [PATCH 64/69] fixes --- fast_llm/engine/checkpoint/safe_load.py | 10 ++ fast_llm/engine/config_utils/tensor_space.py | 10 ++ fast_llm/engine/distributed/config.py | 2 - fast_llm/engine/multi_stage/fsdp.py | 158 ++++++++++++------- fast_llm/engine/multi_stage/stage_base.py | 24 +-- fast_llm/models/gpt/conversion.py | 16 +- fast_llm/tensor.py | 29 ++-- fast_llm/utils.py | 7 - tests/models/distributed_test_checkpoint.py | 27 ++-- tests/models/test_checkpoint.py | 90 +++++++---- 10 files changed, 228 insertions(+), 145 deletions(-) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 84a58971..2e2a0188 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -148,12 +148,18 @@ def _check_parameters(self, errors: list[str]) -> None: f'Local counter mismatch for parameter "{parameter_name}"' f' and shard "{shard_name}": loaded {counter}, expected {local_size}' ) + + counter_ = counter # Accumulate in a list for global counter check. if ( not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0 ) or stage.is_tied_weight_copy: # Ignore the counter from duplicate tensors. counter = 0 + if parameter_name == "layers.1.norm_1.weight": + logger.info( + f"Parameter {parameter_name} local {counter_} keep {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" + ) counters.append(counter) # Check for unexpected parameters. @@ -173,6 +179,10 @@ def _check_parameters(self, errors: list[str]) -> None: for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters: for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]: counter = counters.pop(0) + if parameter_name == "layers.1.norm_1.weight": + logger.info( + f"Parameter {parameter_name} global {counter} (size {parameter_meta.numel()} / {parameter_meta.global_shape.numel()})" + ) parameter_size = parameter_meta.global_shape.numel() if counter != parameter_size: errors.append( diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 49ce1525..99c1bcf7 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -66,6 +66,10 @@ def parallel_dim_index(self) -> int | None: def parallel_group(self) -> "ProcessGroup|None": return None if self._parallel_dim is None else self._parallel_dim.group + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self.parallel_dim is not None + return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + class CompositeTensorDim(TensorDim): def __init__(self, name: str, dims: tuple[TensorDim, ...]): @@ -106,6 +110,12 @@ def global_expanded_shape(self) -> tuple[int, ...]: def parallel_dim_index(self) -> int | None: return self._parallel_dim_index + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self.parallel_dim_index is not None + dims = list(self.dims) + dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) + class DefaultDimNames: # Scalar diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index ff5569f4..7fd9fed1 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -275,8 +275,6 @@ def _validate(self) -> None: data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - print("data_stride", data_stride) - print("pipeline_stride", pipeline_stride) self._add_distributed_dim( DistributedDim( diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index f329c12c..5b44bf14 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -10,7 +10,7 @@ from fast_llm.core.ops import gather_op from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.engine.distributed.config import DistributedDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill @@ -37,18 +37,16 @@ def __init__( self, name: str, parameter_metas: list[ParameterMeta], - fsdp_dim: DistributedDim, - training_dtype: DataType, - gradient_buffer_dtype: DataType, - optimization_dtype: DataType, + distributed_config: DistributedConfig, + full_precision_gradient_buffer: bool = False, + full_precision_shards: bool = True, + is_tied_weight_copy: bool = False, ): self._name = name self._parameter_metas = {parameter_meta.tensor_name: parameter_meta for parameter_meta in parameter_metas} - self._fsdp_dim = fsdp_dim - self._training_dtype = training_dtype - self._gradient_buffer_dtype = gradient_buffer_dtype - self._optimization_dtype = optimization_dtype - + self._distributed_config = distributed_config + self._fsdp_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.data) + self._is_tied_weight_copy = is_tied_weight_copy self._requires_grad = any(parameter_meta.requires_grad for parameter_meta in self._parameter_metas.values()) parameter_sizes = [meta.numel() for meta in self._parameter_metas.values()] @@ -83,23 +81,35 @@ def __init__( self._weight_shard_meta = TensorMeta.from_dims( (weight_shard_dim,), tensor_name=f"{self._name}_weight_shard", - dtype=self._optimization_dtype.torch, + dtype=( + self._distributed_config.optimization_dtype + if full_precision_shards + else self._distributed_config.training_dtype + ).torch, ) # TODO: Distinguish grad and optimizer shard? self._grad_shard_meta = TensorMeta.from_dims( (grad_shard_dim,), tensor_name=f"{self._name}_grad_shard", - dtype=self._optimization_dtype.torch, + dtype=( + self._distributed_config.optimization_dtype + if full_precision_shards + else self._distributed_config.training_dtype + ).torch, ) self._weight_buffer_meta = TensorMeta.from_dims( (TensorDim("weight_buffer", weight_shard_dim.size * self._fsdp_dim.size),), tensor_name=f"{self._name}_weight_buffer", - dtype=self._training_dtype.torch, + dtype=self._distributed_config.training_dtype.torch, ) self._grad_buffer_meta = TensorMeta.from_dims( (TensorDim("grad_buffer", weight_shard_dim.size * self._fsdp_dim.size if self._requires_grad else 0),), tensor_name=f"{self._name}_grad_buffer", - dtype=self._gradient_buffer_dtype.torch, + dtype=( + self._distributed_config.optimization_dtype + if full_precision_gradient_buffer + else self._distributed_config.training_dtype + ).torch, ) @property @@ -442,21 +452,26 @@ def _get_parameter_shard_indices_in_full_weight( # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index) - buffer_flat_index = buffer_index.flatten() + buffer_index = parameter_meta.global_to_local(index, expand=True) + # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. + # In that case, we work with a separate tensor to be copied back into `buffer_index`. + try: + buffer_index_flat = buffer_index.view(-1) + is_view = True + except RuntimeError: + buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) + is_view = False # Copy the shard indices at their respective positions in the flat buffer index. - shard_index = buffer_flat_index[ + buffer_index_flat[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) - ] - shard_index.copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) + ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # `buffer_flat_index` may be a copy of `buffer_index`. - # If this is the case, we need to copy the result back into `buffer_index`, which itself is a view of `index`. - if buffer_flat_index.is_contiguous() and not buffer_index.is_contiguous(): - buffer_index.copy_(buffer_flat_index.view_as(buffer_index)) + # If needed, copy the flat buffer index back into the index. + if not is_view: + buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) return index @@ -473,22 +488,40 @@ def copy_shard_overlaps( Assert.eq(set(shards), set(loaded_shards)) index_overlap = [name for name in loaded_fsdp._parameter_metas if name in self._parameter_metas] counter = {} + + self_tensor_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + loaded_tensor_dim = loaded_fsdp._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. + if (shared_tp := math.gcd(self_tensor_dim.size, loaded_tensor_dim.size)) > 1: + self_tensor_dim, self_new_size, self_shared_rank = _reduce_tensor_parallel_size(self_tensor_dim, shared_tp) + loaded_tensor_dim, loaded_new_size, loaded_shared_rank = _reduce_tensor_parallel_size( + loaded_tensor_dim, shared_tp + ) + + if self_shared_rank != loaded_shared_rank: + # Disjoint tensor-parallel slices, no possible overlap. + # (Duplicated parameters will be loaded from the new rank 0 which prevents unnecessary file loading). + return counter + for parameter_name in index_overlap: self_meta = self._parameter_metas[parameter_name] loaded_meta = loaded_fsdp._parameter_metas[parameter_name] - # The shared tensor-parallel part (usually the smallest of the two) can be safely ignored. - if (shared_tp := math.gcd(self_meta.tensor_parallel_size, loaded_meta.tensor_parallel_size)) > 1: - self_meta, self_shared_rank = _reduce_tensor_parallelism_in_meta(self_meta, shared_tp) - loaded_meta, loaded_shared_rank = _reduce_tensor_parallelism_in_meta(loaded_meta, shared_tp) - if self_shared_rank != loaded_shared_rank: - # Disjoint tensor-parallel slices, no possible overlap. - continue + if shared_tp > 1: + self_meta = self_meta.replace_tensor_parallel_dim(self_tensor_dim) + loaded_meta = loaded_meta.replace_tensor_parallel_dim(loaded_tensor_dim) + + if not loaded_meta.is_tensor_parallel and loaded_tensor_dim.rank != 0: + # Loaded parameter is tensor-parallel duplicate, ignore. + continue if self_meta.tensor_parallel_size == loaded_meta.tensor_parallel_size == 1: self._copy_shard_overlaps(loaded_fsdp, shards, loaded_shards, parameter_name, counter) else: - raise NotImplementedError() + self._copy_tensor_parallel_shard_overlaps( + loaded_fsdp, shards, loaded_shards, parameter_name, counter, self_meta, loaded_meta + ) return counter @@ -500,22 +533,26 @@ def _copy_shard_overlaps( parameter_name: str, counter: dict[tuple[str, str], int], ): + # Common case: the overlap is a contiguous slice of the shards. + + # Find the slice of the parameter contained in each shard. self_shard_begin_in_buffer = self._fsdp_dim.rank * self._shard_size self_shard_end_in_buffer = (self._fsdp_dim.rank + 1) * self._shard_size self_shard_begin_in_param = self._index_buffer_to_param(self_shard_begin_in_buffer, parameter_name) self_shard_end_in_param = self._index_buffer_to_param(self_shard_end_in_buffer, parameter_name) - loaded_shard_begin_in_buffer = loaded_fsdp._fsdp_dim.rank * loaded_fsdp._shard_size loaded_shard_end_in_buffer = (loaded_fsdp._fsdp_dim.rank + 1) * loaded_fsdp._shard_size loaded_shard_begin_in_param = loaded_fsdp._index_buffer_to_param(loaded_shard_begin_in_buffer, parameter_name) loaded_shard_end_in_param = loaded_fsdp._index_buffer_to_param(loaded_shard_end_in_buffer, parameter_name) + # Calculate the overap. overlap_begin_in_param = max(self_shard_begin_in_param, loaded_shard_begin_in_param) overlap_end_in_param = min(self_shard_end_in_param, loaded_shard_end_in_param) if (overlap_size := overlap_end_in_param - overlap_begin_in_param) <= 0: return + # Map the overlap back to the shards. overlap_begin_in_self_shard = ( self._parameter_begins_in_buffer[parameter_name] + overlap_begin_in_param - self_shard_begin_in_buffer ) @@ -535,6 +572,8 @@ def _copy_shard_overlaps( if shard.numel() == 0: continue counter[(parameter_name, shard_name)] = overlap_size + + # Copy the overlap. shard[overlap_begin_in_self_shard : overlap_begin_in_self_shard + overlap_size] = ( loaded_shards[shard_name][overlap_begin_in_loaded_shard : overlap_begin_in_loaded_shard + overlap_size] if loaded_shards[shard_name].numel() > 0 @@ -551,6 +590,14 @@ def _copy_tensor_parallel_shard_overlaps( self_meta: TensorMeta, loaded_meta: TensorMeta, ): + + self_begin, self_end = self._get_parameter_range_in_shard(parameter_name) + loaded_begin, loaded_end = loaded_fsdp._get_parameter_range_in_shard(parameter_name) + if self_begin >= self_end or loaded_begin >= loaded_end: + # Parameter is not present in both shards, no overlap. + return + + # Tensor-parallel case: the overlap cannot be represented as a slice. if shards is None: # Dry run. Since we only need to know if there can be overlap, # we skip the slow computation and return a dummy value. @@ -558,15 +605,18 @@ def _copy_tensor_parallel_shard_overlaps( return device = next(iter(shards.values())).device + # Create an array that associates each entry in the `parameter_name` slice of `shard` + # to the index of the same parameter entry in `loaded_shard`, or -1 if not present. overlap_index_map = self.parameter_global_to_shard( loaded_fsdp._get_parameter_shard_indices_in_full_weight(parameter_name, device, loaded_meta), parameter_name, _parameter_meta=self_meta, ) + # Create a mask to exclude the missing entries. overlap_mask = overlap_index_map >= 0 overlap_index_map_masked = overlap_index_map[overlap_mask] - overlap_count = overlap_mask.sum().item() - if overlap_count == 0: + overlap_size = overlap_mask.sum().item() + if overlap_size == 0: return begin, end = self._get_parameter_range_in_shard(parameter_name) @@ -574,28 +624,20 @@ def _copy_tensor_parallel_shard_overlaps( # Shards can be empty (frozen weights) if shard.numel() == 0: continue - if loaded_shards[shard_name].numel() == 0: - shard[begin:end][overlap_mask] = 0 - counter += overlap_count - continue - shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] - counter += overlap_count - - -def _reduce_tensor_parallelism_in_meta(meta: TensorMeta, shared_tp: int) -> tuple[TensorMeta, int]: - # Make a `TensorMeta` look like it has less tensor parallelism. - dims = list(meta.dims) - dim = dims[meta.tensor_parallel_dim_index] - new_size = meta.tensor_parallel_size // shared_tp - shared_rank = meta.tensor_parallel_rank - dims[meta.tensor_parallel_dim_index] = TensorDim( - dim.name, - dim.global_size // shared_tp, - dataclasses.replace( - dim.parallel_dim, - size=new_size, - rank=meta.tensor_parallel_rank % new_size, - global_rank=dim.parallel_dim.global_ranks[shared_tp * shared_rank : shared_tp * (shared_rank + 1)], - ), + counter[(parameter_name, shard_name)] = overlap_size + # Masked copy of the overlap index map. + shard[begin:end][overlap_mask] = ( + loaded_shards[shard_name][overlap_index_map_masked] if loaded_shards[shard_name].numel() > 0 else 0 + ) + + +def _reduce_tensor_parallel_size(distributed_dim: DistributedDim, shared_size: int): + new_size = distributed_dim.size // shared_size + shared_rank = distributed_dim.rank // new_size + new_dim = dataclasses.replace( + distributed_dim, + size=new_size, + rank=distributed_dim.rank % new_size, + global_ranks=distributed_dim.global_ranks[shared_size * shared_rank : shared_size * (shared_rank + 1)], ) - return TensorMeta(meta, tensor_name=meta.tensor_name, dims=tuple(dims)), shared_rank + return new_dim, new_size, shared_rank diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index b8f12de3..2f18f136 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -8,7 +8,7 @@ from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import ShardName, StageConfig, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP @@ -51,20 +51,13 @@ def __init__( parameter_metas, frozen_metas = self._get_parameter_metas() self._parameter_metas = parameter_metas + frozen_metas self._fsdps = [] - gradient_buffer_dtype = ( - self._distributed_config.optimization_dtype - if self._config.full_precision_gradients - else self._distributed_config.training_dtype - ) if parameter_metas: self._fsdps.append( FSDP( f"stage_{self._index}", parameter_metas, - self._distributed_config.get_distributed_dim(DistributedDimNames.data), - training_dtype=self._distributed_config.training_dtype, - gradient_buffer_dtype=gradient_buffer_dtype, - optimization_dtype=self._distributed_config.optimization_dtype, + self._distributed_config, + full_precision_gradient_buffer=self._config.full_precision_gradients, ) ) if frozen_metas: @@ -72,14 +65,9 @@ def __init__( FSDP( f"stage_{self._index}_frozen", frozen_metas, - self._distributed_config.get_distributed_dim(DistributedDimNames.data), - training_dtype=self._distributed_config.training_dtype, - gradient_buffer_dtype=gradient_buffer_dtype, - optimization_dtype=( - self._distributed_config.optimization_dtype - if self._config.store_frozen_weights_in_optimization_precision - else self._distributed_config.training_dtype.torch - ), + self._distributed_config, + full_precision_gradient_buffer=self._config.full_precision_gradients, + full_precision_shards=self._config.store_frozen_weights_in_optimization_precision, ) ) # TODO: Separate fsdp for tied weights? diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d2c01af0..d8425786 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -319,10 +319,10 @@ def _get_weight_and_bias_converters( class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), @@ -446,10 +446,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), @@ -498,10 +498,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), @@ -544,10 +544,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -568,10 +568,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk @@ -609,13 +609,13 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" modeling_file = modeling_mtp_llama.__file__ configuration_file = configuration_mtp_llama.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -697,6 +697,7 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" modeling_file = modeling_dream.__file__ configuration_file = configuration_dream.__file__ generation_utils_file = generation_utils.__file__ @@ -704,7 +705,6 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DreamModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -725,6 +725,7 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam ) format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" modeling_file = modeling_diffusion_llama.__file__ configuration_file = configuration_diffusion_llama.__file__ generation_utils_file = generation_utils.__file__ @@ -732,7 +733,6 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 1e3dfc75..d780e4d6 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -10,7 +10,7 @@ from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy -from fast_llm.utils import Assert, flatten_without_copy +from fast_llm.utils import Assert class _SafeTensorSliceMeta(type): @@ -187,6 +187,8 @@ def local_to_global( def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, + # Return an expanded tensor, avoiding `flatten` which copies the data. + expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. @@ -195,17 +197,13 @@ def global_to_local( tensor_ = tensor[:] assert not self._reductions - for i, dim in enumerate(self.dims): + for i, dim in reversed(list(enumerate(self.dims))): if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = flatten_without_copy( - tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank], - i, - i + len(dim.expanded_shape) - 1, - ) + tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( + dim.parallel_dim.size, i + dim.parallel_dim_index + )[dim.parallel_dim.rank] - return tensor_.view(self.shape) + return tensor_ if expand else tensor_.reshape(self.shape) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -220,6 +218,17 @@ def memory_usage(self) -> int: def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> torch.Tensor: return validate_tensor(tensor, self, device) + def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": + # Replace the tensor-parallel `DistributedDim` in `meta`. + # Note: This will turn `ParameterMeta` into `TensorMeta` + if not self.is_tensor_parallel: + return self + dims = list(self.dims) + dims[self.tensor_parallel_dim_index] = dims[self.tensor_parallel_dim_index].replace_parallel_dim( + distributed_dim + ) + return TensorMeta(self, tensor_name=self.tensor_name, dims=tuple(dims), reductions=self._reductions) + class ParameterMeta(TensorMeta): def __init__( diff --git a/fast_llm/utils.py b/fast_llm/utils.py index ce41fcac..7bbdd697 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -71,13 +71,6 @@ def rms_diff(x: "torch.Tensor", y: "torch.Tensor") -> "torch.Tensor": return torch.norm(x - y, 2, dtype=torch.float32) / x.numel() ** 0.5 # noqa -def flatten_without_copy(x: "torch.Tensor", start_dim: int = 0, end_dim: int = -1) -> "torch.Tensor": - # Similar to torch.flatten, but never returns a copy. - return x.view( - *x.shape[:start_dim], x.shape[start_dim:end_dim].numel(), *([] if end_dim == -1 else x.shape[end_dim + 1 :]) - ) - - class Tag: __slots__ = ("value",) diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index e247c33b..06416b6a 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -1,4 +1,5 @@ import gc +import logging import pathlib import typing @@ -14,18 +15,23 @@ ) from fast_llm.engine.distributed.distributed import ProcessGroupPool from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.utils import header from tests.models.test_checkpoint import do_get_convert_path from tests.utils.model_configs import ModelTestingConfig from tests.utils.run_test_script import parse_run_distributed_script +logger = logging.getLogger(__name__) + def _test_load_and_save_parallel( model_testing_config: ModelTestingConfig, pretrained_path: pathlib.Path, - pretrained_format: CheckpointFormat, + pretrained_format: type[CheckpointFormat], distributed_config: dict[str, typing.Any], save_path: pathlib.Path, ): + logger.info(header(save_path.name)) + logger.info(f"Loading {pretrained_format.name} checkpoint from {pretrained_path}") model = model_testing_config.model_class.from_pretrained( CheckpointLoadConfig(path=pretrained_path, format=pretrained_format), # The world size and rank are already set through environment variable. @@ -33,6 +39,7 @@ def _test_load_and_save_parallel( mode=StageMode.inference, ) for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + logger.info(f"Loading {save_format.name} checkpoint to {save_path / save_format.name}") model.save_checkpoint(CheckpointSaveConfig(path=save_path / save_format.name, format=save_format)) del model gc.collect() @@ -89,37 +96,37 @@ def main(args: list[str] | None = None) -> None: pretrained_path=pretrained_path, pretrained_format=pretrained_format, distributed_config={"pipeline_parallel": 2}, - save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_ppp2", + save_path=base_path / f"load_pretrained_{pretrained_format.name}_in_pp2", ) dist = DistributedCheckpointFormat.name _test_load_and_save_parallel( model_testing_config=model_testing_config, pretrained_path=base_path / f"load_pretrained_{dist}_in_dp2" / dist, - pretrained_format=pretrained_format, + pretrained_format=DistributedCheckpointFormat, distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, - save_path=base_path / f"load_pretrained_dp2_in_stp2", + save_path=base_path / "load_pretrained_dp2_in_stp2", ) _test_load_and_save_parallel( model_testing_config=model_testing_config, pretrained_path=base_path / f"load_pretrained_{dist}_in_stp2" / dist, - pretrained_format=pretrained_format, + pretrained_format=DistributedCheckpointFormat, distributed_config={}, - save_path=base_path / f"load_pretrained_stp2_in_dp2", + save_path=base_path / "load_pretrained_stp2_in_dp2", ) _test_load_and_save_parallel( model_testing_config=model_testing_config, pretrained_path=base_path / f"load_pretrained_{dist}_in_tp2" / dist, - pretrained_format=pretrained_format, + pretrained_format=DistributedCheckpointFormat, distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, - save_path=base_path / f"load_pretrained_tp2_in_stp2", + save_path=base_path / "load_pretrained_tp2_in_stp2", ) _test_load_and_save_parallel( model_testing_config=model_testing_config, pretrained_path=base_path / f"load_pretrained_{dist}_in_pp2" / dist, - pretrained_format=pretrained_format, + pretrained_format=DistributedCheckpointFormat, distributed_config={"tensor_parallel": 2}, - save_path=base_path / f"load_pretrained_pp2_in_tp2", + save_path=base_path / "load_pretrained_pp2_in_tp2", ) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 8d5928d7..623b0dd5 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -356,54 +356,80 @@ def test_save_and_load_in_parallel(run_distributed_script_for_all_models, load_a ) -@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) -@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_parallel_checkpoint(model_testing_config, load_and_save_parallel_base_path, get_convert_path): - # Check the consistency of the checkpoints saved in `test_save_and_load_in_parallel` - checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) - # Compare Distributed checkpoints - for rank in range(2): - _compare_safetensor_files( - *[ - load_and_save_parallel_base_path - / f"load_pretrained_{format_.name}_in_dp2" - / DistributedCheckpointFormat.name - / f"rank_{rank}.safetensors" - for format_ in checkpoint_formats +@pytest.fixture(scope="module") +def parallel_checkpoint_names(model_testing_config): + names = [] + for format_ in (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format): + names.extend( + [ + f"load_pretrained_{format_.name}_in_dp2", + f"load_pretrained_{format_.name}_in_tp2", + f"load_pretrained_{format_.name}_in_stp2", + f"load_pretrained_{format_.name}_in_pp2", ] ) - # Compare Fast-LLM checkpoints - _compare_safetensor_files( - # Fast-LLM checkpoints are independent of the distributed configuration that saved it. - get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", - *[ - load_and_save_parallel_base_path - / f"load_pretrained_{format_.name}_in_dp2" - / FastLLMCheckpointFormat.name - / f"model_0.safetensors" - for format_ in checkpoint_formats - ], + names.extend( + [ + "load_pretrained_dp2_in_stp2", + "load_pretrained_stp2_in_dp2", + "load_pretrained_tp2_in_stp2", + "load_pretrained_pp2_in_tp2", + "load_pretrained_dp2_in_stp2", + "load_pretrained_dp2_in_stp2", + ] ) + return names @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_load_parallel_checkpoint( - model_testing_config, load_and_save_parallel_base_path, get_convert_path, load_and_compare_checkpoints +def test_load_parallel_checkpoint_in_single_gpu( + load_and_save_parallel_base_path, get_convert_path, load_and_compare_checkpoints, parallel_checkpoint_names ): # Test single-gpu loading of multi-gpu distributed checkpoints. - checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ _WEIGHT_SHARD_SAVE_NAME ] - for format_ in checkpoint_formats: + for name in parallel_checkpoint_names: load_and_compare_checkpoints( DistributedCheckpointFormat, - load_and_save_parallel_base_path - / f"load_pretrained_{format_.name}_in_dp2" - / DistributedCheckpointFormat.name, + load_and_save_parallel_base_path / name / DistributedCheckpointFormat.name, None, reference_shard, ) + + +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_parallel_checkpoint_consistency(model_testing_config, load_and_save_parallel_base_path, get_convert_path): + # Check the consistency of the checkpoints saved in `test_save_and_load_in_parallel` + checkpoint_formats = (DistributedCheckpointFormat, FastLLMCheckpointFormat, model_testing_config.checkpoint_format) + # Compare Distributed checkpoints + for config in ("dp2", "tp2", "stp2", "pp2"): + for rank in range(2): + _compare_safetensor_files( + *[ + load_and_save_parallel_base_path + / f"load_pretrained_{format_.name}_in_{config}" + / DistributedCheckpointFormat.name + / f"rank_{rank}.safetensors" + for format_ in checkpoint_formats + ] + ) + + +@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) +def test_multi_gpu_fast_llm_checkpoint( + model_testing_config, load_and_save_parallel_base_path, get_convert_path, parallel_checkpoint_names +): + # Fast-LLM checkpoints are independent of the distributed configuration that saved it. + _compare_safetensor_files( + get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", + *[ + load_and_save_parallel_base_path / name / FastLLMCheckpointFormat.name / f"model_0.safetensors" + for name in parallel_checkpoint_names + ], + ) From a43db877939c1fa6398a84bdd4c0db9028d4eba6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Jun 2025 19:29:22 -0400 Subject: [PATCH 65/69] fixes --- tests/models/distributed_test_checkpoint.py | 2 +- tests/models/test_checkpoint.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 06416b6a..9e706ebe 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -119,7 +119,7 @@ def main(args: list[str] | None = None) -> None: pretrained_path=base_path / f"load_pretrained_{dist}_in_tp2" / dist, pretrained_format=DistributedCheckpointFormat, distributed_config={"tensor_parallel": 2, "sequence_tensor_parallel": True}, - save_path=base_path / "load_pretrained_tp2_in_stp2", + save_path=base_path / "load_pretrained_tp2_in_pp2", ) _test_load_and_save_parallel( model_testing_config=model_testing_config, diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 623b0dd5..63a25747 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -373,10 +373,8 @@ def parallel_checkpoint_names(model_testing_config): [ "load_pretrained_dp2_in_stp2", "load_pretrained_stp2_in_dp2", - "load_pretrained_tp2_in_stp2", + "load_pretrained_tp2_in_pp2", "load_pretrained_pp2_in_tp2", - "load_pretrained_dp2_in_stp2", - "load_pretrained_dp2_in_stp2", ] ) return names @@ -426,10 +424,12 @@ def test_multi_gpu_fast_llm_checkpoint( model_testing_config, load_and_save_parallel_base_path, get_convert_path, parallel_checkpoint_names ): # Fast-LLM checkpoints are independent of the distributed configuration that saved it. + # TODO: Check pipeline-parallel checkpoints (two files). _compare_safetensor_files( get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / f"model_0.safetensors", *[ load_and_save_parallel_base_path / name / FastLLMCheckpointFormat.name / f"model_0.safetensors" for name in parallel_checkpoint_names + if "in_pp2" not in name ], ) From eaccd07bc052df894c62c35254357743200dad37 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 27 Jun 2025 19:36:36 -0400 Subject: [PATCH 66/69] fix --- fast_llm/config.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index cdc1dd5d..0004501b 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -379,6 +379,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ + if self._validated: + return self try: expected_class = self.get_subclass(self.type) except KeyError as e: @@ -392,15 +394,14 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: # Done during validation so we don't accidentally use default subtypes as updates. self.type = self.dynamic_type_name - if not self._validated: - try: - self._validate() - except (ValidationError, FieldTypeError) as e: - if _is_validating: - raise - else: - raise type(e)("\n".join(e.args)) from None - self._validated = True + try: + self._validate() + except (ValidationError, FieldTypeError) as e: + if _is_validating: + raise + else: + raise type(e)("\n".join(e.args)) from None + self._validated = True return self def _validate(self) -> None: From 8398f27801d8136bf7e74519f97530b748817448 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Jul 2025 10:45:45 -0400 Subject: [PATCH 67/69] fixes --- fast_llm/engine/distributed/distributed.py | 26 +++++++++++++--------- fast_llm/engine/training/trainer.py | 2 +- tests/utils/model_configs.py | 10 ++++----- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 9719ff2e..fbbf9b6a 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -90,13 +90,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): global _default_pool assert _default_pool is self _default_pool = None + self.shutdown() - def __del__(self): + def shutdown(self): # Shutdown the process group backend explicitly to prevent a nccl warning. # We can't call `destroy_process_group` directly because pytorch doesn't know about it. for group in self._process_groups.values(): - if group is not None and hasattr(group, "_shutdown"): - group._shutdown() # noqa + group.shutdown() + + def __del__(self): + self.shutdown() _default_pool: ProcessGroupPool | None = None @@ -114,7 +117,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False, pool: ProcessGroupPool | None = None): + def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) assert self._config.reference_config is None self._use_cpu = use_cpu @@ -128,14 +131,13 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False, pool: Proce self.device = torch.device(self._config.local_rank) torch.cuda.set_device(self.device) - if pool is None and _default_pool is None: + self._local_pool = _default_pool is None + if self._local_pool: self._pool = ProcessGroupPool(self._config.rank, self._config.world_size, self._config.timeout) else: - if pool is None: - pool = _default_pool - Assert.eq(pool._world_size, self._config.world_size) - Assert.eq(pool._rank, self._config.rank) - self._pool = pool + self._pool = _default_pool + Assert.eq(self._pool._world_size, self._config.world_size) + Assert.eq(self._pool._rank, self._config.rank) self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) self.data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.data]) @@ -210,3 +212,7 @@ def set_step(self, step: int, phase: PhaseType) -> None: seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase] self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED) self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) + + def __del__(self): + if self._local_pool: + self._pool.shutdown() diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a3cf078d..766398d0 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -96,7 +96,7 @@ def run( done = training_progress.done completed_steps = training_progress.completed_steps - if done or self.config.enabled(completed_steps): + if (done and self.config.enabled()) or self.config.enabled(completed_steps): return self.evaluator.run(training_progress, run_index=self._config.get_run_count(completed_steps - 1)) else: return EvaluationMetrics() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b8dd29e8..199d5b72 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -433,12 +433,12 @@ def _update_and_add_testing_config( checkpoint_format=MixtralGPTHuggingfaceCheckpointFormat, # TODO: New base image broke mixtral groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.broken, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.broken, - ModelTestingGroup.megatron: ModelTestingGroupAction.broken, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.megatron: ModelTestingGroupAction.normal, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, ) From 327837329f3f4203b23466cc587eb7592237edd7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Jul 2025 11:41:32 -0400 Subject: [PATCH 68/69] fix --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c549259..d5ac1e3d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - loss = (per_sample_loss * loss_mask).sum() / torch.maximum(loss_mask.sum(), 1) + loss = (per_sample_loss * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) else: loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: From 1248d8b652a6325d716eaea557d473443d4c6346 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Jul 2025 12:06:41 -0400 Subject: [PATCH 69/69] fix --- fast_llm/functional/cross_entropy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d5ac1e3d..513510ec 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,9 +145,9 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - loss = (per_sample_loss * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) - else: - loss = per_sample_loss.mean() + per_sample_loss = per_sample_loss * loss_mask + + loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.MEAN, group=group)