From cb744b2ceeba420e205e922868008dd0a2781dee Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 20 Jun 2025 11:14:40 +0000 Subject: [PATCH 01/26] copy from sandbox --- examples/qwen_evaluate.yaml | 87 ++ examples/smol_evaluate.yaml | 86 ++ fast_llm/core/distributed.py | 519 ++++++++++ fast_llm/engine/evaluation/config.py | 46 + fast_llm/engine/evaluation/evaluator.py | 102 +- .../evaluation/lm_eval/fast_llm_wrapper.py | 931 ++++++++++++++++++ fast_llm/engine/evaluation/lm_eval/utils.py | 246 +++++ 7 files changed, 2014 insertions(+), 3 deletions(-) create mode 100644 examples/qwen_evaluate.yaml create mode 100644 examples/smol_evaluate.yaml create mode 100644 fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py create mode 100644 fast_llm/engine/evaluation/lm_eval/utils.py diff --git a/examples/qwen_evaluate.yaml b/examples/qwen_evaluate.yaml new file mode 100644 index 00000000..f11c2bf3 --- /dev/null +++ b/examples/qwen_evaluate.yaml @@ -0,0 +1,87 @@ +training: + train_iters: 100_000 + logs: + interval: 10 + evaluations: + gsm8k: + run_interval: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k + - --output_path + - /mnt/checkpoints/test/denis/smol_eval_experiment/lm_eval + stack_3b: + run_interval: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + run_interval: + interval: 10 + evaluator: + iterations: 10 + dataset_name: fineweb + checkpoint: + interval: 1000 + keep: 5 + test_iters: 0 + export: # (1)! + format: llama + interval: 20_000 +batch: + micro_batch_size: 16 + sequence_length: 4096 + batch_size: 32 +data: + tokenizer: + path: /mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct + bos_token: "<|endoftext|>" + datasets: + # Bad dataset they are tokenized with different tokenizer, then llama + training: + type: file + path: /mnt/datasets/test/denis/fineweb_the_stack_3b.yaml + stack_3b: + type: memmap + path: /mnt/datasets/data_collections/the_stack_3b/tokens/stack_3b/default/train/99 + fineweb: + type: memmap + path: /mnt/datasets/data_collections/standalone_datasets/tokens/HuggingFaceFW/fineweb/default/train/9_1000 +optimizer: + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + learning_rate: + base: 1.0e-04 # (3)! + minimum: 1.0e-05 + decay_style: cosine + decay_iterations: 100_000 + warmup_iterations: 2000 +pretrained: # (4)! + format: qwen2 + path: /mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct + model_weights: yes # (5)! +model: + base_model: + transformer: + use_flash_attention: yes + cross_entropy_impl: fused + multi_stage: + zero_stage: 2 + distributed: + training_dtype: bf16 + +run: + experiment_dir: "/mnt/checkpoints/test/denis/qwen_eval_experiment" + +# training: +# logs: +# interval: 10 +# wandb: +# project_name: ${job.project_name} +# group_name: ${job.project_version} diff --git a/examples/smol_evaluate.yaml b/examples/smol_evaluate.yaml new file mode 100644 index 00000000..1d8822c0 --- /dev/null +++ b/examples/smol_evaluate.yaml @@ -0,0 +1,86 @@ +training: + train_iters: 100_000 + logs: + interval: 10 + evaluations: + gsm8k: + run_interval: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k + - --output_path + - /mnt/checkpoints/test/denis/smol_eval_experiment/lm_eval + stack_3b: + run_interval: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + run_interval: + interval: 10 + evaluator: + iterations: 10 + dataset_name: fineweb + checkpoint: + interval: 1000 + keep: 5 + test_iters: 0 + export: # (1)! + format: llama + interval: 20_000 +batch: + micro_batch_size: 16 + sequence_length: 4096 + batch_size: 32 +data: + tokenizer: + path: /mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct + datasets: + # Bad dataset they are tokenized with different tokenizer, then llama + training: + type: file + path: /mnt/datasets/test/denis/fineweb_the_stack_3b.yaml + stack_3b: + type: memmap + path: /mnt/datasets/data_collections/the_stack_3b/tokens/stack_3b/default/train/99 + fineweb: + type: memmap + path: /mnt/datasets/data_collections/standalone_datasets/tokens/HuggingFaceFW/fineweb/default/train/9_1000 +optimizer: + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + learning_rate: + base: 1.0e-04 # (3)! + minimum: 1.0e-05 + decay_style: cosine + decay_iterations: 100_000 + warmup_iterations: 2000 +pretrained: # (4)! + format: llama + path: /mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct/ + model_weights: yes # (5)! +model: + base_model: + transformer: + use_flash_attention: yes + cross_entropy_impl: fused + multi_stage: + zero_stage: 2 + distributed: + training_dtype: bf16 + +run: + experiment_dir: "/mnt/checkpoints/test/denis/smol_eval_experiment" + +# training: +# logs: +# interval: 10 +# wandb: +# project_name: ${job.project_name} +# group_name: ${job.project_version} diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index e82e0801..ffbeab39 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -6,12 +6,17 @@ Todo: Move all core methods elsewhere (functional?). """ +import collections import contextlib import datetime +import io +import itertools import logging +import pickle import typing import torch +import torch.monitor from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, @@ -26,6 +31,117 @@ logger = logging.getLogger(__name__) +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _check_single_tensor(param, param_name) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, torch.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, torch.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + if tensor_dtype.is_complex: + tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _rank_not_in_group(group: typing.Optional[ProcessGroup]) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == torch.distributed.GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + # TODO: get global rank + global_rank = -1 + logger.warning(f"Running {op_name} on global rank {global_rank} which does not " "belong to the given group.") + + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj, device, group): + with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._object_to_tensor").guard(): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + + # TODO: do we need to log this level of details? + # if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + # backend = get_backend(group) + # if backend == Backend.NCCL: + # hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + # logger.warning( + # "_object_to_tensor size: %s hash value: %s", + # byte_tensor.numel(), + # hash, + # ) + + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard(): + + # TODO: do we need to log this level of details? + # if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + # backend = get_backend(group) + # if backend == Backend.NCCL: + # hash = torch._C._distributed_c10d._hash_tensors([tensor]) + # logger.warning( + # "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + # ) + + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError("Argument ``gather_list`` must be specified on destination rank.") + elif gather_list: + raise ValueError("Argument ``gather_list`` must NOT be specified on non-destination ranks.") + + def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None: if group is not None and timeout is not None: # TODO: Only works for nccl? @@ -133,3 +249,406 @@ def set_generator(generator: torch.Generator) -> typing.Generator[None, None, No finally: generator.set_state(default_generator.get_state()) default_generator.set_state(old_state) + + +def gather( + tensor: torch.Tensor, + gather_list: typing.Optional[list[torch.Tensor]] = None, + group: typing.Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: typing.Optional[int] = None, +): + """ + Gathers a list of tensors in a single process. + + This function requires all tensors to be the same size on each process. + + Args: + tensor (Tensor): Input tensor. + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) + group (ProcessGroup, optional): The process group to work on. + async_op (bool, optional): Whether this op should be an async op + group_dst (int, optional): Destination rank on ``group``. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in gather_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> # We have 2 process groups, 2 ranks. + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> tensor = torch.ones(tensor_size, device=device) + rank + >>> if dist.get_rank() == 0: + >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] + >>> else: + >>> gather_list = None + >>> dist.gather(tensor, gather_list, dst=0) + >>> # Rank 0 gets gathered data. + >>> gather_list + [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 + None # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + assert group is not None + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + if group_dst is None: + group_dst = 0 + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, gather_list) + output_tensors = [gather_list] if group_dst == my_group_rank else [] + input_tensors = [tensor] + + opts = torch.distributed.GatherOptions() + opts.rootRank = group_dst + # Absent in ver 2.6 + # opts.asyncOp = async_op + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + elif work is not None: # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def scatter( + tensor: torch.Tensor, + scatter_list: typing.Optional[list[torch.Tensor]] = None, + group: typing.Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: typing.Optional[int] = None, +): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Complex tensors are supported. + + Args: + tensor (Tensor): Output tensor. + scatter_list (list[Tensor]): List of tensors to scatter (default is + None, must be specified on the source rank) + group (ProcessGroup, optional): The process group to work on. + async_op (bool, optional): Whether this op should be an async op + group_src (int, optional): Source rank on ``group``. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in scatter_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> output_tensor = torch.zeros(tensor_size, device=device) + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> # Only tensors, all of which must be the same size. + >>> t_ones = torch.ones(tensor_size, device=device) + >>> t_fives = torch.ones(tensor_size, device=device) * 5 + >>> scatter_list = [t_ones, t_fives] + >>> else: + >>> scatter_list = None + >>> dist.scatter(output_tensor, scatter_list, src=0) + >>> # Rank i gets scatter_list[i]. + >>> output_tensor + tensor([1., 1.], device='cuda:0') # Rank 0 + tensor([5., 5.], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + assert group is not None + if group_src is None: + group_src = 0 + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + my_group_rank = group.rank() + if group_src == my_group_rank: + if not scatter_list: + raise ValueError("Argument ``scatter_list`` must be specified on source rank.") + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError("Argument ``scatter_list`` must NOT be specified on non-source ranks.") + input_tensors = [] + output_tensors = [tensor] + + opts = torch.distributed.ScatterOptions() + opts.rootRank = group_src + opts.asyncOp = async_op + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + elif work is not None: # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def gather_object( + current_device: torch.device | str, + obj: typing.Any, + object_gather_list: typing.Optional[list[typing.Any]] = None, + group: typing.Optional[ProcessGroup] = None, + group_dst: typing.Optional[int] = None, +): + """ + Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + current_device: (torch.device | str): device to use for object serialization to + tensor, must be this process assigned gpu for nccl backend. + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + ... gather_objects[dist.get_rank()], + ... output if dist.get_rank() == 0 else None, + ... dst=0 + ... ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + assert group is not None + if group_dst is None: + group_dst = 0 + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, object_gather_list) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = group.size() + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_group_rank == group_dst: + coalesced_output_tensor = torch.empty(max_object_size * group_size, dtype=torch.uint8, device=current_device) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + gather( + input_tensor, + gather_list=output_tensors if my_group_rank == group_dst else None, # type: ignore[possibly-undefined] + group_dst=group_dst, + group=group, + ) + if my_group_rank != group_dst: + return + + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +def scatter_object_list( + pg_device: torch.device | str, + scatter_object_output_list: list[typing.Any], + scatter_object_input_list: typing.Optional[list[typing.Any]] = None, + group: typing.Optional[ProcessGroup] = None, + group_src: typing.Optional[int] = None, +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole group. + + Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + pg_device: (torch.device | str): device to use for object serialization to + tensor, must be this process assigned gpu for nccl backend. + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any], optional): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + group: (ProcessGroup, optional): The process group to work on. + group_src (int, optional): Source rank on ``group``. + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`scatter_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`scatter` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + assert group is not None + if group_src is None: + group_src = 0 + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: + raise ValueError("Expected argument scatter_object_output_list to be a list of size at least 1.") + + my_group_rank = group.rank() + if my_group_rank == group_src: + if scatter_object_input_list is None: + raise ValueError("source rank must provide non-None scatter_object_input_list") + tensor_list, tensor_sizes = zip( + *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for tensor in tensor_list: # type: ignore[possibly-undefined] + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + broadcast(max_tensor_size, src=group_src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) + scatter( + output_tensor, + scatter_list=None if my_group_rank != group_src else tensor_list, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + scatter( + obj_tensor_size, + scatter_list=None if my_group_rank != group_src else tensor_sizes, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 7223631f..d0adf7b2 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -62,3 +62,49 @@ def get_evaluator( from fast_llm.engine.evaluation.evaluator import EvaluatorLoss return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class() +class EvaluatorLmEvalConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "lm_eval" + + cli_args: list[str] = Field( + default_factory=lambda: [], + desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", + ) + + truncation: bool = Field( + default=False, + desc="Whether to use truncation during tokenization (useful when inputs exceed model's max length);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + logits_cache: bool = Field( + default=True, + desc="Whether to enable logits caching for speedup and avoiding recomputation during repeated evaluations;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + add_bos_token: bool = Field( + default=False, + desc="Whether to prepend a beginning-of-sequence (BOS) token, required for some models like LLaMA;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + prefix_token_id: int | None = Field( + default=None, + desc="Token ID to use as a prefix to the input (e.g., for control codes or prompts);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLmEval": + from fast_llm.engine.evaluation.evaluator import EvaluatorLmEval + + return EvaluatorLmEval(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index f07a8c48..eaad6299 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -4,13 +4,22 @@ import time import typing +from lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate + from fast_llm.config import Configurable from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.engine.config_utils.run import Run, log_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, EvaluatorLossConfig +from fast_llm.engine.evaluation.config import ( + EvaluatorConfig, + EvaluatorConfigBase, + EvaluatorLmEvalConfig, + EvaluatorLossConfig, +) +from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper +from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -19,8 +28,6 @@ from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib -# from fast_llm.engine.training.lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate - logger = logging.getLogger(__name__) @@ -243,6 +250,95 @@ def _get_data_iterator( ) +class EvaluatorLmEval[ConfigType: EvaluatorLmEvalConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[EvaluatorLmEvalConfig]] = EvaluatorLmEvalConfig + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + super().setup(distributed, run, multi_stage, runner, data, phase) + + # TODO: pass mini and batch size of the same length for lm_eval not to crash during training + # or implement min batch sequential awareness in fas_llm_wrapper for lm_eval + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class().from_model( + self._multi_stage, self._batch_config.micro_batch_size, self._runner + ) + + # For reporting purposes, just to indicate it is from Fast-LLM + # as lm_eval.simple_evaluate will take it for results['config']['model'] + self._hf_model.config.name_or_path = type(self._hf_model).__name__ + + self._flm_wrapper = FastLLMLmEvalWrapper( + model=self._hf_model, + tokenizer=self._data.tokenizer.tokenizer, + truncation=self._config.truncation, + logits_cache=self._config.logits_cache, + add_bos_token=self._config.add_bos_token, + prefix_token_id=self._config.prefix_token_id, + ) + self._is_setup = True + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + # TODO: use run_index instead? + # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ + completed_steps = 0 if training_progress is None else training_progress.completed_steps + + if self._run.is_main_rank: + args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params( + self._config.cli_args, completed_steps, self._run.index + ) + simple_eval_kwargs["model"] = self._flm_wrapper + + # Needed for reporting as batch_size is set from args not lm for reporting in evaluate + simple_eval_kwargs["batch_size"] = self._flm_wrapper.batch_size + simple_eval_kwargs["max_batch_size"] = self._flm_wrapper.max_batch_size + + # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 + # Expected to be a string even if empty and not None in simple_evaluate + simple_eval_kwargs["model_args"] = "" + + results = lm_eval_simple_evaluate(**simple_eval_kwargs) + self._flm_wrapper.stop_workers() + + # Evaluation_tracker save expects model to be either string, but if model is passed + # LM wrapper needs to be deep copyable and json serializable + simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( + self._hf_model.config.name_or_path + ) + + if results is not None: + process_lm_eval_results( + args, + results, + simple_eval_kwargs["evaluation_tracker"], + completed_steps, + ) + else: + self._flm_wrapper.worker_model_invoke() + + # TODO: do we need it here as self._flm_wrapper.stop_workers() and self._flm_wrapper.worker_model_invoke() + # already have barrier + safe_barrier(self._distributed.world_group, f"Evaluation Harness Run end") + + # lm_eval logs to disc, wandb and prints to screen itself + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None + + # NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation. class EvaluatorRunner: _is_setup: bool = False diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py new file mode 100644 index 00000000..6080f38b --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -0,0 +1,931 @@ +import copy +import logging + +import jinja2 +import lm_eval.api.instance +import lm_eval.api.model +import lm_eval.models.utils +import lm_eval.utils +import torch +import torch.nn.functional as F +import transformers +from tqdm.auto import tqdm + +from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object_list +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + +eval_logger = logging.getLogger(__name__) + + +class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): + _DEFAULT_MAX_LENGTH = 2048 + + def __init__( + self, + model: HuggingfaceBaseModelForCausalLM, + tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, + truncation: bool | None = False, + logits_cache: bool = True, + add_bos_token: bool | None = False, + prefix_token_id: int | None = None, + ): + super().__init__() + # This is for lm_eval sake, we always run lm_eval on one main rank + self._rank = 0 + self._world_size = 1 + + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + dist_config: DistributedConfig = self._distributed.config + # get batch_data_parallel group leaders + if dist_config.sequence_data_rank == 0 and dist_config.pipeline_rank == 0 and dist_config.tensor_rank == 0: + self.group = self._distributed.batch_data_group + else: + self.group = torch.distributed.GroupMember.NON_GROUP_MEMBER + + # TODO: clean code which does not used parts from HFLM + backend = "causal" + revision = "main" + delta = None + peft = None + + # set some inputs which are expected in HFLM but are set by our model config + self.backend = backend + + # set tokenizer object + assert isinstance(tokenizer, transformers.PreTrainedTokenizer) or isinstance( + tokenizer, transformers.PreTrainedTokenizerFast + ) + self.tokenizer = tokenizer + + # initialize model fields + self._model = model + self._device = self._model.device + self._config = self._model.config + + # access self._model through self.model property outside this method + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + # select (or create) a pad token to use + self.tokenizer = lm_eval.models.utils.configure_pad_token(self.tokenizer, model_config=self.config) + + self.add_bos_token = add_bos_token + # TODO: do we support gemma models? + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS" + " token will be used as Gemma underperforms without it." + ) + + self._max_length = model._inference_runner._batch_config.sequence_length + self.pretrained = model + self.delta = delta + self.peft = peft + self.revision = revision + + self.batch_schedule = 1 + self.batch_sizes = {} + self.batch_size_per_gpu = 16 # model._inference_runner._batch_config.micro_batch_size + self.batch_size = self.batch_size_per_gpu * dist_config.batch_data_parallel + self.max_batch_size = self.batch_size + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + + def _model_invoke( + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_generate: bool, + **generation_kwargs, + ): + if self.group is None or (world_size := self.group.size()) == 1: + # Must not be called with continue_generate false on one process + assert continue_generate + return self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + rank = self.group.rank() + assert rank == 0 + + if continue_generate: + assert input_ids is not None + if generate: + assert max_length is not None and stop is not None + + # always divide by batch_size, if not full batch, some ranks will get less work or not at all + step = self.batch_size // world_size + + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + attention_mask = [ + attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None + for i in range(world_size) + ] + labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] + + scatter_list = [ + [ + input_ids[i], + attention_mask[i], + labels[i], + max_length, + stop, + generate, + continue_generate, + generation_kwargs, + ] + for i in range(world_size) + ] + else: + scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + + obj_list = [None] + scatter_object_list( + self._distributed.device, + obj_list, + scatter_list, + group_src=0, + group=self.group, + ) + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = tuple( + obj_list[0] + ) + + if continue_generate == False: + return + + assert len(input_ids) > 0 + + res = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + gather_list = [None] * world_size + gather_object( + self._distributed.device, + res, + gather_list, + group_dst=0, + group=self.group, + ) + + # If it was model generate tensors could be of different length + # so we aggregate results to list instead of a tensor + if generate: + res = sum((el.tolist() for el in gather_list), []) + else: + res = torch.cat(gather_list, dim=0) + + return res + + def worker_model_invoke(self): + assert self.group is not None + # if isinstance(self.group, dist.ProcessGroup): + if not isinstance(self.group, int): + assert self.group.size() > 1 and self.group.rank() != 0 + # on worker ranks the function need to wait to be called multiple times + while True: + scatter_list = None + obj_list = [None] + scatter_object_list( + self._distributed.device, + obj_list, + scatter_list, + group_src=0, + group=self.group, + ) + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + tuple(obj_list[0]) + ) + + if continue_generate == False: + break + + # if some data was received, work, otherwise return empty tensor + if len(input_ids) > 0: + res = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + else: + res = input_ids + + gather_list = None + gather_object( + self._distributed.device, + res, + gather_list, + group_dst=0, + group=self.group, + ) + else: + # TODO: implement distributed model support + assert self.group == torch.distributed.GroupMember.NON_GROUP_MEMBER + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def stop_workers(self): + if self.group is None or (world_size := self.group.size()) == 1: + return + self._model_invoke(None, None, None, None, None, None, continue_generate=False) + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def _model_invoke_inner( + self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + ): + if generate: + return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) + else: + return self._model_call_inner(input_ids, attention_mask, labels) + + def _model_call(self, input_ids, attention_mask=None, labels=None): + return self._model_invoke( + input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True + ) + + def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + return self._model_invoke( + input_ids, + attention_mask, + None, + max_length, + stop, + generate=True, + continue_generate=True, + **generation_kwargs, + ) + + def _model_call_inner(self, input_ids, attention_mask=None, labels=None): + """ + :param input_ids: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attention_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + # TODO: do we need no_grad for our model? + with torch.no_grad(): + if attention_mask is not None or labels is not None: + assert attention_mask is not None and labels is not None + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits + else: + return self.model( + input_ids=input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits + + def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0] + ) + if attention_mask is None: + return self.model.generate( + input_ids=input_ids, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + else: + return self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: # if max length manually set, return it + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + # TODO: check removing this does not affect lm_eval + # @property + # def batch_size(self): + # return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token} + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: list[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self.backend == "causal": + add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self.tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: + if self.backend == "causal": + assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self.backend == "seq2seq": + assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False + ) -> list[float]: + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + rolling_token_windows: list[tuple[list[int], list[int]]] = list( + map( + lm_eval.utils.make_disjoint_window, + lm_eval.utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self.device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self.batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial("loglikelihood_rolling", (string,), request_total) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self.batch_schedule) + if sched in self.batch_sizes: + return self.batch_sizes[sched] + if (len(self.batch_sizes) > 1) and (self.batch_sizes[sched - 1] == self.max_batch_size): + # if previous batch size is already maximal, skip recomputation + self.batch_sizes[sched] = self.max_batch_size + return self.batch_sizes[sched] + print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size") + self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self.batch_sizes[sched]}") + return self.batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: list[tuple[tuple[str, str], list[int], list[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> list[tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: tuple[tuple[str, str], list[int], list[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = lm_eval.models.utils.Collator( + requests, + sort_fn=_collate, + group_by="contexts" if self.backend == "causal" and self.logits_cache else None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warn( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + elif self.backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self.device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen + + padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.backend == "causal": + batched_inps = lm_eval.models.utils.pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] + batched_conts = lm_eval.models.utils.pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = lm_eval.models.utils.pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attention_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), dim=-1 + ) # [batch, padding_length (inp or cont), vocab] + + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self.backend == "causal" else None + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=self.device).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial("loglikelihood", request_str, answer) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: + res = [] + + def _collate(req: tuple[str, dict]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(req[0]) + return -len(toks), req[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else adaptive_batch_size if adaptive_batch_size is not None else 0 + ) + batch_fn = self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + re_ords = lm_eval.models.utils.Collator( + [reg.args for reg in requests], + sort_fn=_collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = lm_eval.models.utils.handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + if self.backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert ( + max_ctx_len > 0 + ), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + elif self.backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + input_ids, attention_mask = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self.truncation, + ) + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + if "max_length" not in kwargs: + kwargs["max_length"] = input_ids.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + input_ids=input_ids, + attention_mask=attention_mask, + stop=until, + **kwargs, + ) + + # cont_toks_list = cont.tolist() + cont_toks_list = cont + + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self.backend == "causal": + cont_toks = cont_toks[input_ids.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning("Failed to apply chat template. removing the system role in chat history.") + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated diff --git a/fast_llm/engine/evaluation/lm_eval/utils.py b/fast_llm/engine/evaluation/lm_eval/utils.py new file mode 100644 index 00000000..f7b0cd50 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/utils.py @@ -0,0 +1,246 @@ +import argparse +import json +import logging +import os +import pathlib +import sys +from pathlib import Path + +import lm_eval.__main__ +import lm_eval.evaluator +import lm_eval.loggers +import lm_eval.tasks +import lm_eval.utils + +eval_logger = logging.getLogger(__name__) + + +def parse_eval_args(parser: argparse.ArgumentParser, args: list[str]) -> argparse.Namespace: + lm_eval.__main__.check_argument_types(parser) + return parser.parse_args(args) + + +def prepare_lm_eval_simple_eval_params( + cli_args: list[str], + completed_steps: int, + run_index: int, +) -> tuple[argparse.Namespace, dict[str, any]]: + """ + Parses CLI arguments for an LM evaluation run and prepares keyword arguments + for the `evaluate` function. + + This function wraps argument parsing, environment configuration, task resolution, + and metadata setup needed for evaluation with Fast-LLM's `lm_eval` wrapper. It also + handles special cases like hub token injection, dynamic sample loading, and task + listing commands. + + Args: + cli_args (list[str]): Command-line arguments, excluding the program name. + completed_steps (int): Current number of completed training steps, used to + uniquely tag evaluation output paths. + run_index (int): index of the current run of Fast-LLM experiment + + Returns: + tuple: + - argparse.Namespace: Parsed CLI arguments. + - dict: Keyword arguments to pass into `simple_evaluate`, including task list, + tracker, cache settings, random seeds, and generation parameters. + + Raises: + ValueError: If required fields like `--tasks` or `--output_path` are missing + when needed, or if misconfigured combinations are detected. + SystemExit: If special task listing flags are used. + """ + parser = lm_eval.__main__.setup_parser() + args = parse_eval_args(parser, cli_args) + + # NOTE: all this args are set by fast_llm on the model directly or not used here + assert not args.wandb_args # default empty string + assert not args.wandb_config_args # default empty string + assert args.model == "hf" # default value of 'hf' + assert not args.model_args # default empty string + assert args.batch_size == 1 # default value of 1 + assert args.max_batch_size is None + assert args.device is None + # if args.wandb_args: + # wandb_args_dict = simple_parse_args_string(args.wandb_args) + # wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args) + # wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict) + + # TODO: change logging levels from fast_llm to lm_eval and then back? + # utils.setup_logging(args.verbosity) + # eval_logger = logging.getLogger(__name__) + + # update the evaluation tracker args with the output path and the HF token + evaluation_tracker_args = "" + if args.output_path: + args.output_path = str(pathlib.Path(args.output_path) / f"runs/{run_index}/{completed_steps}") + evaluation_tracker_args += f",output_path={args.output_path}" + + evaluation_tracker_args = lm_eval.utils.simple_parse_args_string(evaluation_tracker_args) + evaluation_tracker = lm_eval.loggers.EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError("Specify --output_path if providing --log_samples or --predict_only") + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + eval_logger.info(f"Including path: {args.include_path}") + metadata = ( + lm_eval.utils.simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args if isinstance(args.model_args, dict) else {} + ) | (args.metadata if isinstance(args.metadata, dict) else lm_eval.utils.simple_parse_args_string(args.metadata)) + + task_manager = lm_eval.tasks.TaskManager(include_path=args.include_path, metadata=metadata) + + if args.limit: + eval_logger.warning( + " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) + if args.samples: + assert args.limit is None, "If --samples is not None, then --limit must be None." + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + eval_logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = lm_eval.utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = lm_eval.utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" + f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all" + " available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG'" + " to troubleshoot task registration issues." + ) + + ( + eval_logger.info(f"Selected Tasks: {task_names}") + if eval_logger.getEffectiveLevel() >= logging.INFO + else print(f"Selected Tasks: {task_names}") + ) + + request_caching_args = lm_eval.evaluator.request_caching_arg_to_dict(cache_requests=args.cache_requests) + + eval_kwargs = dict( + tasks=task_names, + num_fewshot=args.num_fewshot, + # batch_size=args.batch_size, + # max_batch_size=args.max_batch_size, + # device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + return args, eval_kwargs + + +def process_lm_eval_results( + args: argparse.Namespace, + results: dict[str, any], + evaluation_tracker: lm_eval.loggers.EvaluationTracker, + completed_steps: int | None, +) -> None: + if results is not None: + completed_steps = 0 if completed_steps is None else completed_steps + import wandb + + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps(results, indent=2, default=lm_eval.utils.handle_non_serializable, ensure_ascii=False) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging if we have the run to log to + # we expect the rest of the fast_llm code will finish the run. + if wandb.run is not None: + try: + wandb_logger = lm_eval.loggers.WandbLogger(init_args={"step": completed_steps}) + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples(task_name=task_name, samples=samples[task_name]) + + if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub: + evaluation_tracker.recreate_metadata_card() + + # TODO: convert to logging entries instead? + print( + f"{results["config"]["model"]}, gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {results["config"]["batch_size"]}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + print(lm_eval.utils.make_table(results)) + if "groups" in results: + print(lm_eval.utils.make_table(results, "groups")) From 0967483e1cf0e818cdb3760a688a9423a8a5cb4c Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 20 Jun 2025 12:43:01 +0000 Subject: [PATCH 02/26] changes for loss test for new tests structure --- tests/{ => models}/test_gpt_loss.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) rename tests/{ => models}/test_gpt_loss.py (90%) diff --git a/tests/test_gpt_loss.py b/tests/models/test_gpt_loss.py similarity index 90% rename from tests/test_gpt_loss.py rename to tests/models/test_gpt_loss.py index 89262eca..ba53b57a 100644 --- a/tests/test_gpt_loss.py +++ b/tests/models/test_gpt_loss.py @@ -12,7 +12,7 @@ from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig -from tests.test_gpt_generate_and_forward import model_and_tokenizer # noqa: F401 +from tests.models.test_generate import model_path # noqa: F401 from tests.utils.utils import requires_cuda @@ -109,12 +109,10 @@ def _test_for_phase(model_path, fast_llm_checkpoint_format, phase): # @pytest.mark.extra_slow @requires_cuda -def test_loss_validation_vs_inference(model_and_tokenizer): - model_path, _, fast_llm_checkpoint_format = model_and_tokenizer +def test_loss_validation_vs_inference(model_path): + iter_losses_validation = _test_for_phase(model_path, LlamaGPTHuggingfaceCheckpointFormat, PhaseType.validation) - iter_losses_validation = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.validation) - - iter_losses_inference = _test_for_phase(model_path, fast_llm_checkpoint_format, PhaseType.inference) + iter_losses_inference = _test_for_phase(model_path, LlamaGPTHuggingfaceCheckpointFormat, PhaseType.inference) assert len(iter_losses_validation) == len(iter_losses_inference) for key in iter_losses_validation.keys(): From 71ff61ad2994edab2d4d219de0f17b40ca113ea9 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 20 Jun 2025 12:44:03 +0000 Subject: [PATCH 03/26] lm_eval integration changes for the new api --- fast_llm/engine/evaluation/config.py | 5 ++--- fast_llm/engine/evaluation/evaluator.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index d0adf7b2..a3839588 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLoss + from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, EvaluatorLoss @config_class() @@ -64,10 +64,9 @@ def get_evaluator( return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) -@config_class() +@config_class(dynamic_type={EvaluatorConfig: "lm_eval"}) class EvaluatorLmEvalConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "lm_eval" cli_args: list[str] = Field( default_factory=lambda: [], diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index eaad6299..dffd8c13 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -264,10 +264,8 @@ def setup( ) -> None: super().setup(distributed, run, multi_stage, runner, data, phase) - # TODO: pass mini and batch size of the same length for lm_eval not to crash during training - # or implement min batch sequential awareness in fas_llm_wrapper for lm_eval - self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class().from_model( - self._multi_stage, self._batch_config.micro_batch_size, self._runner + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( + self._multi_stage, runner=self._runner ) # For reporting purposes, just to indicate it is from Fast-LLM From 79fd43ee4fe412b7795c4af29c2b95525ef77b4a Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 20 Jun 2025 13:17:54 +0000 Subject: [PATCH 04/26] made lm_eval dependency lazy imported for optional dependency --- fast_llm/engine/evaluation/evaluator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index dffd8c13..5ca40c50 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -4,8 +4,6 @@ import time import typing -from lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate - from fast_llm.config import Configurable from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data @@ -18,8 +16,6 @@ EvaluatorLmEvalConfig, EvaluatorLossConfig, ) -from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper -from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -262,6 +258,8 @@ def setup( data: Data, phase: PhaseType, ) -> None: + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + super().setup(distributed, run, multi_stage, runner, data, phase) self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( @@ -287,6 +285,13 @@ def run( training_progress: TrainingProgress | None = None, run_index: int | None = None, ) -> EvaluationMetrics: + from lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate + + from fast_llm.engine.evaluation.lm_eval.utils import ( + prepare_lm_eval_simple_eval_params, + process_lm_eval_results, + ) + assert self._is_setup # TODO: use run_index instead? From 2d9f4791d4421e13ac56b3b352ceb8e3f4778ca4 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 20 Jun 2025 13:18:24 +0000 Subject: [PATCH 05/26] removed hard coded batch size --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 6080f38b..553e97f6 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -92,7 +92,7 @@ def __init__( self.batch_schedule = 1 self.batch_sizes = {} - self.batch_size_per_gpu = 16 # model._inference_runner._batch_config.micro_batch_size + self.batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size self.batch_size = self.batch_size_per_gpu * dist_config.batch_data_parallel self.max_batch_size = self.batch_size From 7c6210061e342f19e1e4132756efa1e86889af54 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 25 Jun 2025 14:11:55 +0000 Subject: [PATCH 06/26] remved unncecessary set to evaluatation --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 553e97f6..83fe3e42 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -64,11 +64,6 @@ def __init__( self._device = self._model.device self._config = self._model.config - # access self._model through self.model property outside this method - if isinstance(self.model, torch.nn.Module): - self.model.eval() - self.model.tie_weights() - self.truncation = truncation self.logits_cache = logits_cache self.vocab_size = self.tokenizer.vocab_size From c89d2697c2b99ff9a031eb42fc85f029d78b1822 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 26 Jun 2025 16:21:42 +0000 Subject: [PATCH 07/26] commit wandb step after finishing logging --- fast_llm/engine/training/trainer.py | 6 ++++-- fast_llm/engine/training/wandb.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a3cf078d..8f6b1498 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -318,7 +318,7 @@ def _run_training(self) -> None: log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=True) def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # Tracking loss. @@ -339,6 +339,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._config.training.prefetch_factor, ) + has_test_phase = PhaseType.test in self._samples_per_split + log_main_rank("Training ...") # TODO: Synchronization is probably unnecessary. @@ -456,7 +458,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: ) if is_main_rank() and metrics: - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=not (done and has_test_phase)) stop = done or self._config.training.shutdown.enabled(self._completed_steps) diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index 185b89c2..724b5b71 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -44,12 +44,12 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): else: self._wandb = None - def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]]) -> None: + def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]], commit: bool) -> None: # Note: metrics modified in-place if self._wandb is not None: import wandb - wandb.log(metrics, step=completed_steps) # noqa + wandb.log(metrics, step=completed_steps, commit=commit) # noqa def alert(self, title, text, level="INFO", wait=0.001) -> None: if self._wandb is not None and self._config.alert.post_alerts: From 9455cd514b83839ac0a469e1dfa425a9c1993d72 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 09:19:43 +0000 Subject: [PATCH 08/26] support for env varieables for lm_eval integration --- fast_llm/cli.py | 11 +++++++++++ fast_llm/engine/evaluation/evaluator.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 34546120..f9fbaa80 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -1,4 +1,5 @@ import logging +import os import sys import traceback @@ -7,6 +8,16 @@ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.runnable import RunnableConfig +# This must be set before importing numexpr, +# because by default, the maximum number of threads is 64. +# On systems with more cores, numexpr logs an error and +# ignores the thread setting if it exceeds the limit. +if "NUMEXPR_MAX_THREADS" not in os.environ: + import multiprocessing as mp + + os.environ["NUMEXPR_MAX_THREADS"] = str(mp.cpu_count()) + + # Import these submodules to ensure classes are added to the dynamic class registry. import fast_llm.data.auto # isort: skip import fast_llm.engine.checkpoint.convert # isort: skip diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 5ca40c50..9f24294e 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -1,6 +1,8 @@ import abc import dataclasses import logging +import os +import pathlib import time import typing @@ -258,6 +260,8 @@ def setup( data: Data, phase: PhaseType, ) -> None: + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper super().setup(distributed, run, multi_stage, runner, data, phase) From c9a3b183dc0e8e09581abca48b9d57c4a1cd539b Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 11:38:12 +0000 Subject: [PATCH 09/26] user guide for evaluators added --- docs/user_guide/evaluators.md | 90 +++++++++++++++++++++++++++++++++++ mkdocs.yaml | 1 + 2 files changed, 91 insertions(+) create mode 100644 docs/user_guide/evaluators.md diff --git a/docs/user_guide/evaluators.md b/docs/user_guide/evaluators.md new file mode 100644 index 00000000..c640e4ef --- /dev/null +++ b/docs/user_guide/evaluators.md @@ -0,0 +1,90 @@ +# Evaluations + +Fast-LLM allows you to perform various evaluations during training or as a separate evaluation step. In both cases, you need to use your training config with `training.evaluators` specified. + +For evaluators used during training, both `interval` and `offset` must be specified. Then, start training as usual with: + +`fast-llm train gpt --config path/to/training/config.yaml` + +To perform evaluation as a separate step, use the same training config. Depending on the training progress, either the start model or the latest checkpoint will be loaded, and `interval` and `offset` will be ignored. To start evaluation: + +`fast-llm evaluate gpt --config path/to/training/config.yaml` + +## Currently Supported Evaluators + +- `loss` +- `lm_eval` + +## Loss Evaluator + +To set up loss evaluation, specify a dataset to be used in the `data.datasets` section of the config. You must also define the loss evaluator in the `training.evaluators` config section. See example below. + +```yaml +training: + evaluations: + stack_3b: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + interval: 10 +data: + datasets: + stack_3b: + type: memmap + path: path/to/memmap/dataset + fineweb: + type: memmap + path: path/to/memmap/dataset1 +``` + +## Evaluation Harness (`lm_eval`) Evaluator + +To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks. + +The following environment variables may need to be set: + +- `HF_HOME`: Path for Hugging Face data caching +- `WANDB_API_KEY_PATH`: Path to a file containing your Weights & Biases API key (if logging to W&B) +- `HUGGINGFACE_API_KEY_PATH`: Path to a file containing your Hugging Face hub token +- `NLTK_DATA`: Path to a directory that will contain downloaded NLTK packages (needed for some tasks) +- `HF_ALLOW_CODE_EVAL=1`: Required for some evaluation tasks + +You may need to specify additional environment variables depending on the `lm_eval` tasks you want to run. + +To specify an `lm_eval` task, the evaluator config includes the following fields: + +### Model Config + +The model instantiated for training is reused for evaluation, so you don't need to specify it separately. However, there are some parameters specific to `lm_eval`. See `fast_llm/engine/evaluation/config.EvaluatorLmEvalConfig` for details. + +### CLI Parameters for `lm_eval` + +All other parameters are specified as if you were calling the `lm_eval` CLI, using a list of strings. Some CLI parameters are ignored or restricted—specifically those related to model loading, W&B, batch sizes, and device setup, as these are managed by the rest of the Fast-LLM configuration. + +Also, the tokenizer must be specified in `data.tokenizer`. If the tokenizer does not have a `bos_token`, it must be specified explicitly in `data.tokenizer.bos_token`. Although `lm_eval` does not use the `bos_token` directly, it is still required because the same tokenizer is used by other Fast-LLM components. + +Below is an example of the config: + +```yaml +training: + evaluations: + lm_eval_tasks1: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k,xnli_en,wikitext,ifeval + - --output_path + - /path/to/lm_eval/output +data: + tokenizer: + path: path/to/the/tokenizer +``` diff --git a/mkdocs.yaml b/mkdocs.yaml index ab71bc23..85fd4bff 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -179,6 +179,7 @@ nav: - Configuration: user_guide/configuration.md - Multi-Stage: user_guide/multi-stage.md - Parallelism: user_guide/parallelism.md + - Evaluators: user_guide/evaluators.md - Developer Guide: - Configuration: developer_guide/configuration.md - Model: From 426b5e357355efd10a653029a243b77f7b8a8081 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 14:17:17 +0000 Subject: [PATCH 10/26] fix tensor concatination for logits from different gpus --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 83fe3e42..5c87235d 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -182,6 +182,9 @@ def _model_invoke( if generate: res = sum((el.tolist() for el in gather_list), []) else: + # Tensors gathered via gather_object will remain on their original GPUs, + # even if they came from another node. Move them to the current GPU. + gather_list = [el.to(self.device) for el in gather_list] res = torch.cat(gather_list, dim=0) return res From 0bf8282d7062524c2cfd3718ed4b05376f8bafb6 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 14:44:28 +0000 Subject: [PATCH 11/26] docs update --- docs/user_guide/evaluators.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/user_guide/evaluators.md b/docs/user_guide/evaluators.md index c640e4ef..8bc6cb21 100644 --- a/docs/user_guide/evaluators.md +++ b/docs/user_guide/evaluators.md @@ -46,6 +46,8 @@ data: ## Evaluation Harness (`lm_eval`) Evaluator +**Note:** Only data parallelism is currently supported for the `lm_eval` evaluator. + To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks. The following environment variables may need to be set: From 68f524b21ef8ad8e013a97d45697f7316f2c8cb3 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 14:46:07 +0000 Subject: [PATCH 12/26] removed manual test configs --- examples/qwen_evaluate.yaml | 87 ------------------------------------- examples/smol_evaluate.yaml | 86 ------------------------------------ 2 files changed, 173 deletions(-) delete mode 100644 examples/qwen_evaluate.yaml delete mode 100644 examples/smol_evaluate.yaml diff --git a/examples/qwen_evaluate.yaml b/examples/qwen_evaluate.yaml deleted file mode 100644 index f11c2bf3..00000000 --- a/examples/qwen_evaluate.yaml +++ /dev/null @@ -1,87 +0,0 @@ -training: - train_iters: 100_000 - logs: - interval: 10 - evaluations: - gsm8k: - run_interval: - interval: 10 - evaluator: - type: lm_eval - cli_args: - - --tasks - - gsm8k - - --output_path - - /mnt/checkpoints/test/denis/smol_eval_experiment/lm_eval - stack_3b: - run_interval: - interval: 10 - evaluator: - type: loss - iterations: 10 - dataset_name: stack_3b - fineweb: - run_interval: - interval: 10 - evaluator: - iterations: 10 - dataset_name: fineweb - checkpoint: - interval: 1000 - keep: 5 - test_iters: 0 - export: # (1)! - format: llama - interval: 20_000 -batch: - micro_batch_size: 16 - sequence_length: 4096 - batch_size: 32 -data: - tokenizer: - path: /mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct - bos_token: "<|endoftext|>" - datasets: - # Bad dataset they are tokenized with different tokenizer, then llama - training: - type: file - path: /mnt/datasets/test/denis/fineweb_the_stack_3b.yaml - stack_3b: - type: memmap - path: /mnt/datasets/data_collections/the_stack_3b/tokens/stack_3b/default/train/99 - fineweb: - type: memmap - path: /mnt/datasets/data_collections/standalone_datasets/tokens/HuggingFaceFW/fineweb/default/train/9_1000 -optimizer: - weight_decay: 0.1 - beta_1: 0.9 - beta_2: 0.95 - learning_rate: - base: 1.0e-04 # (3)! - minimum: 1.0e-05 - decay_style: cosine - decay_iterations: 100_000 - warmup_iterations: 2000 -pretrained: # (4)! - format: qwen2 - path: /mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct - model_weights: yes # (5)! -model: - base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused - multi_stage: - zero_stage: 2 - distributed: - training_dtype: bf16 - -run: - experiment_dir: "/mnt/checkpoints/test/denis/qwen_eval_experiment" - -# training: -# logs: -# interval: 10 -# wandb: -# project_name: ${job.project_name} -# group_name: ${job.project_version} diff --git a/examples/smol_evaluate.yaml b/examples/smol_evaluate.yaml deleted file mode 100644 index 1d8822c0..00000000 --- a/examples/smol_evaluate.yaml +++ /dev/null @@ -1,86 +0,0 @@ -training: - train_iters: 100_000 - logs: - interval: 10 - evaluations: - gsm8k: - run_interval: - interval: 10 - evaluator: - type: lm_eval - cli_args: - - --tasks - - gsm8k - - --output_path - - /mnt/checkpoints/test/denis/smol_eval_experiment/lm_eval - stack_3b: - run_interval: - interval: 10 - evaluator: - type: loss - iterations: 10 - dataset_name: stack_3b - fineweb: - run_interval: - interval: 10 - evaluator: - iterations: 10 - dataset_name: fineweb - checkpoint: - interval: 1000 - keep: 5 - test_iters: 0 - export: # (1)! - format: llama - interval: 20_000 -batch: - micro_batch_size: 16 - sequence_length: 4096 - batch_size: 32 -data: - tokenizer: - path: /mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct - datasets: - # Bad dataset they are tokenized with different tokenizer, then llama - training: - type: file - path: /mnt/datasets/test/denis/fineweb_the_stack_3b.yaml - stack_3b: - type: memmap - path: /mnt/datasets/data_collections/the_stack_3b/tokens/stack_3b/default/train/99 - fineweb: - type: memmap - path: /mnt/datasets/data_collections/standalone_datasets/tokens/HuggingFaceFW/fineweb/default/train/9_1000 -optimizer: - weight_decay: 0.1 - beta_1: 0.9 - beta_2: 0.95 - learning_rate: - base: 1.0e-04 # (3)! - minimum: 1.0e-05 - decay_style: cosine - decay_iterations: 100_000 - warmup_iterations: 2000 -pretrained: # (4)! - format: llama - path: /mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct/ - model_weights: yes # (5)! -model: - base_model: - transformer: - use_flash_attention: yes - cross_entropy_impl: fused - multi_stage: - zero_stage: 2 - distributed: - training_dtype: bf16 - -run: - experiment_dir: "/mnt/checkpoints/test/denis/smol_eval_experiment" - -# training: -# logs: -# interval: 10 -# wandb: -# project_name: ${job.project_name} -# group_name: ${job.project_version} From a36e0be02366e6591682ef147b97d66a1fb66e0b Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 15:18:01 +0000 Subject: [PATCH 13/26] added debug prints --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 5c87235d..26f7fb3f 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -122,9 +122,13 @@ def _model_invoke( assert max_length is not None and stop is not None # always divide by batch_size, if not full batch, some ranks will get less work or not at all + assert self.batch_size % world_size == 0 step = self.batch_size // world_size + orig_size = input_ids.shape[0] input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + if orig_size < self.batch_size: + print("input_ids", input_ids) attention_mask = [ attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None for i in range(world_size) @@ -184,6 +188,8 @@ def _model_invoke( else: # Tensors gathered via gather_object will remain on their original GPUs, # even if they came from another node. Move them to the current GPU. + if orig_size < self.batch_size: + print("gather_list", gather_list) gather_list = [el.to(self.device) for el in gather_list] res = torch.cat(gather_list, dim=0) From 9baa5123f1c2d44bdc24935a820b7d272d2574ec Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 27 Jun 2025 15:40:35 +0000 Subject: [PATCH 14/26] fix for gather_list and remove debug print --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 26f7fb3f..918e952c 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -180,6 +180,8 @@ def _model_invoke( group_dst=0, group=self.group, ) + # Clean gather list from empty shards + gather_list = [el for el in gather_list if len(el) > 0] # If it was model generate tensors could be of different length # so we aggregate results to list instead of a tensor @@ -188,8 +190,6 @@ def _model_invoke( else: # Tensors gathered via gather_object will remain on their original GPUs, # even if they came from another node. Move them to the current GPU. - if orig_size < self.batch_size: - print("gather_list", gather_list) gather_list = [el.to(self.device) for el in gather_list] res = torch.cat(gather_list, dim=0) From 21678ab8d3de12af2a306007b5b77d3f7358f2b0 Mon Sep 17 00:00:00 2001 From: bigximik Date: Sat, 28 Jun 2025 15:45:01 +0000 Subject: [PATCH 15/26] removed debug print --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 918e952c..24700594 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -124,11 +124,8 @@ def _model_invoke( # always divide by batch_size, if not full batch, some ranks will get less work or not at all assert self.batch_size % world_size == 0 step = self.batch_size // world_size - orig_size = input_ids.shape[0] input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] - if orig_size < self.batch_size: - print("input_ids", input_ids) attention_mask = [ attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None for i in range(world_size) From 7cccf9a481f4f7587b8a4e2d9ca7a389cd7242d1 Mon Sep 17 00:00:00 2001 From: bigximik Date: Sat, 28 Jun 2025 16:08:20 +0000 Subject: [PATCH 16/26] moved returned logits to cpu in lm_eval wrapper --- .../evaluation/lm_eval/fast_llm_wrapper.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 24700594..f264aa01 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -106,6 +106,9 @@ def _model_invoke( continue_generate: bool, **generation_kwargs, ): + # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. + # Messages could include types like logits, generate, finished. + if self.group is None or (world_size := self.group.size()) == 1: # Must not be called with continue_generate false on one process assert continue_generate @@ -185,9 +188,7 @@ def _model_invoke( if generate: res = sum((el.tolist() for el in gather_list), []) else: - # Tensors gathered via gather_object will remain on their original GPUs, - # even if they came from another node. Move them to the current GPU. - gather_list = [el.to(self.device) for el in gather_list] + assert all(el.device.type == "cpu" for el in gather_list) res = torch.cat(gather_list, dim=0) return res @@ -282,8 +283,17 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): A torch tensor of shape [batch, sequence, vocab] with the logits returned from the model's decoder """ - # TODO: do we need no_grad for our model? + # TODO: do we need no_grad for fast_llm model? with torch.no_grad(): + # We move logits to the CPU because they will be copied across processes and nodes + # in a multi-GPU, multi-node setup and eventually collected on the main rank. + # We cannot afford to accumulate them on rank 0 GPU, as GPU memory may already be tight. + # CPU tensors are slower, but we typically have much more CPU RAM available. + + # TODO: Check if it's possible to move some of the _loglikelihood_tokens work here + # and pass only the results around instead of the full logits. + # Computing errors here is also preferable, as copying logits across nodes and GPUs + # is inefficient and can involve gigabytes of data. if attention_mask is not None or labels is not None: assert attention_mask is not None and labels is not None return self.model( @@ -297,7 +307,7 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): output_attentions=False, output_hidden_states=False, return_dict=True, - ).logits + ).logits.cpu() else: return self.model( input_ids=input_ids, @@ -310,7 +320,7 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): output_attentions=False, output_hidden_states=False, return_dict=True, - ).logits + ).logits.cpu() def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): # temperature = 0.0 if not set From 7cd681ac4138522b723e0dd28d57e91621be711d Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 30 Jun 2025 10:37:29 +0000 Subject: [PATCH 17/26] fix to move all logits computations to cpu --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index f264aa01..3a57e708 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -745,6 +745,8 @@ def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): self._model_call(batched_inps, **call_kwargs), dim=-1 ) # [batch, padding_length (inp or cont), vocab] + # TODO: Consider moving this part to per-shard execution in a multi-GPU and multi-node setup + # to avoid copying logits between GPUs and nodes, and to enable performing logits computations on the GPU. for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( chunk, multi_logits, inplens, cont_toks_list ): @@ -772,7 +774,9 @@ def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): cont_toks=cont_toks, logits=logits, ): - cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=self.device).unsqueeze(0) # [1, seq] + # NOTE: Currently, computations are performed on the CPU due to limited GPU memory. + cont_toks = torch.tensor(cont_toks, dtype=torch.long, device="cpu").unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() # Obtain log-probs at the corresponding continuation token indices From 88faca04e1242ff0d46ac6003119ddc408e2819b Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 10:05:51 +0000 Subject: [PATCH 18/26] fix typo --- fast_llm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 0004501b..c534b11f 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -735,7 +735,7 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: "Config| dict[str, typing.Any]]", + default: "Config| dict[str, typing.Any]", *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, From e3a4a6ef06f302b87e4ba52726d37a2172199f66 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 10:06:23 +0000 Subject: [PATCH 19/26] removed commented code, obsolete todo --- fast_llm/engine/evaluation/lm_eval/utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/utils.py b/fast_llm/engine/evaluation/lm_eval/utils.py index f7b0cd50..5f7f9a94 100644 --- a/fast_llm/engine/evaluation/lm_eval/utils.py +++ b/fast_llm/engine/evaluation/lm_eval/utils.py @@ -62,14 +62,6 @@ def prepare_lm_eval_simple_eval_params( assert args.batch_size == 1 # default value of 1 assert args.max_batch_size is None assert args.device is None - # if args.wandb_args: - # wandb_args_dict = simple_parse_args_string(args.wandb_args) - # wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args) - # wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict) - - # TODO: change logging levels from fast_llm to lm_eval and then back? - # utils.setup_logging(args.verbosity) - # eval_logger = logging.getLogger(__name__) # update the evaluation tracker args with the output path and the HF token evaluation_tracker_args = "" From 89e67d2c9a126921c04dbac8578931da1b79af9a Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 11:48:43 +0000 Subject: [PATCH 20/26] changes to wrapper --- .../evaluation/lm_eval/fast_llm_wrapper.py | 124 +++++++----------- 1 file changed, 47 insertions(+), 77 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 3a57e708..e81e7675 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -8,8 +8,8 @@ import lm_eval.utils import torch import torch.nn.functional as F +import tqdm.auto import transformers -from tqdm.auto import tqdm from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object_list from fast_llm.engine.distributed.config import DistributedConfig @@ -213,7 +213,7 @@ def worker_model_invoke(self): tuple(obj_list[0]) ) - if continue_generate == False: + if not continue_generate: break # if some data was received, work, otherwise return empty tensor @@ -283,6 +283,9 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): A torch tensor of shape [batch, sequence, vocab] with the logits returned from the model's decoder """ + if attention_mask is not None or labels is not None: + assert attention_mask is not None and labels is not None + # TODO: do we need no_grad for fast_llm model? with torch.no_grad(): # We move logits to the CPU because they will be copied across processes and nodes @@ -294,33 +297,18 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): # and pass only the results around instead of the full logits. # Computing errors here is also preferable, as copying logits across nodes and GPUs # is inefficient and can involve gigabytes of data. - if attention_mask is not None or labels is not None: - assert attention_mask is not None and labels is not None - return self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - use_cache=False, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ).logits.cpu() - else: - return self.model( - input_ids=input_ids, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=False, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ).logits.cpu() + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits.cpu() def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): # temperature = 0.0 if not set @@ -340,25 +328,19 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0] ) - if attention_mask is None: - return self.model.generate( - input_ids=input_ids, - max_length=max_length, - stopping_criteria=stopping_criteria, - pad_token_id=self.tokenizer.pad_token_id, - use_cache=False, - **generation_kwargs, - ) - else: - return self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - stopping_criteria=stopping_criteria, - pad_token_id=self.tokenizer.pad_token_id, - use_cache=False, - **generation_kwargs, - ) + + kwargs = { + "input_ids": input_ids, + "max_length": max_length, + "stopping_criteria": stopping_criteria, + "pad_token_id": self.tokenizer.pad_token_id, + "use_cache": False, + **generation_kwargs, + } + if attention_mask is not None: + kwargs["attention_mask"] = attention_mask + + return self.model.generate(**kwargs) @property def config(self): @@ -367,11 +349,7 @@ def config(self): @property def model(self): - # returns the model, unwrapping it if using Accelerate - if hasattr(self, "accelerator"): - return self.accelerator.unwrap_model(self._model) - else: - return self._model + return self._model @property def eot_token_id(self): @@ -516,7 +494,7 @@ def loglikelihood_rolling( request_window_counts = [] # Track number of windows per request for req_idx, (string,) in enumerate( - tqdm( + tqdm.auto.tqdm( [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0)), ) @@ -618,19 +596,16 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): toks = req[1] + req[2] return -len(toks), tuple(toks) - def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): - """Defines the key to group and lookup one-token continuations""" - # Use with group_by="contexts" (optional)" - # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. - # speeds up some multiple-choice tasks proportionally to the number of choices. - # groups requests by context+continuation[:-1] and infer on one request/group. - return req[-2] + req[-1][:-1] - + # NOTE: the group_fn Defines the key to group and lookup one-token continuations + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. re_ord = lm_eval.models.utils.Collator( requests, sort_fn=_collate, group_by="contexts" if self.backend == "causal" and self.logits_cache else None, - group_fn=_lookup_one_token_cont, + group_fn=lambda req: req[-2] + req[-1][:-1], ) # automatic (variable) batch size detection for vectorization @@ -644,7 +619,7 @@ def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): ) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) - pbar = tqdm( + pbar = tqdm.auto.tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running loglikelihood requests", @@ -802,18 +777,7 @@ def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: res = [] - def _collate(req: tuple[str, dict]): - """Defines the key for the sorted method""" - # the negative sign on len(toks) sorts descending - this has a few advantages: - # - time estimates will always be over not underestimates, which is more useful for planning - # - to know the size of a batch when going through the list, you know the first one is always the batch - # padded context length. this is useful to simplify the batching logic and more importantly to make - # automatic adaptive batches much much easier to implement - # - any OOMs will happen right away rather than near the end - toks = self.tok_encode(req[0]) - return -len(toks), req[0] - - pbar = tqdm( + pbar = tqdm.auto.tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", @@ -837,9 +801,15 @@ def _collate(req: tuple[str, dict]): # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + # NOTE: for sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end re_ords = lm_eval.models.utils.Collator( [reg.args for reg in requests], - sort_fn=_collate, + sort_fn=lambda req: (-len(self.tok_encode(req[0])), req[0]), group_by="gen_kwargs", group_fn=lambda x: x[1], ) From 6871359da3db5faf381a1b07c1c6a3cdc0ccc3fd Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 13:42:28 +0000 Subject: [PATCH 21/26] refactorred lm_eval integration --- fast_llm/engine/evaluation/config.py | 2 +- fast_llm/engine/evaluation/evaluator.py | 107 +----------------- .../engine/evaluation/lm_eval/evaluator.py | 89 +++++++++++++++ .../evaluation/lm_eval/fast_llm_wrapper.py | 66 ++++++++--- 4 files changed, 142 insertions(+), 122 deletions(-) create mode 100644 fast_llm/engine/evaluation/lm_eval/evaluator.py diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index a3839588..9f79b906 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -104,6 +104,6 @@ def get_evaluator( data_load_num_proc: int, train_iters: int | None = None, ) -> "EvaluatorLmEval": - from fast_llm.engine.evaluation.evaluator import EvaluatorLmEval + from fast_llm.engine.evaluation.lm_eval.evaluator import EvaluatorLmEval return EvaluatorLmEval(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 9083c256..f593883c 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -1,8 +1,6 @@ import abc import dataclasses import logging -import os -import pathlib import time import typing @@ -12,12 +10,7 @@ from fast_llm.engine.config_utils.run import Run, log_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.evaluation.config import ( - EvaluatorConfig, - EvaluatorConfigBase, - EvaluatorLmEvalConfig, - EvaluatorLossConfig, -) +from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, EvaluatorLossConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -251,104 +244,6 @@ def _get_data_iterator( ) -class EvaluatorLmEval[ConfigType: EvaluatorLmEvalConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[EvaluatorLmEvalConfig]] = EvaluatorLmEvalConfig - - def setup( - self, - distributed: Distributed, - run: Run, - multi_stage: FastLLMModel, - runner: ScheduleRunner, - data: Data, - phase: PhaseType, - ) -> None: - os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() - - from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper - - super().setup(distributed, run, multi_stage, runner, data, phase) - - self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( - self._multi_stage, runner=self._runner - ) - - # For reporting purposes, just to indicate it is from Fast-LLM - # as lm_eval.simple_evaluate will take it for results['config']['model'] - self._hf_model.config.name_or_path = type(self._hf_model).__name__ - - self._flm_wrapper = FastLLMLmEvalWrapper( - model=self._hf_model, - tokenizer=self._data.tokenizer.tokenizer, - truncation=self._config.truncation, - logits_cache=self._config.logits_cache, - add_bos_token=self._config.add_bos_token, - prefix_token_id=self._config.prefix_token_id, - ) - self._is_setup = True - - def run( - self, - training_progress: TrainingProgress | None = None, - run_index: int | None = None, - ) -> EvaluationMetrics: - from lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate - - from fast_llm.engine.evaluation.lm_eval.utils import ( - prepare_lm_eval_simple_eval_params, - process_lm_eval_results, - ) - - assert self._is_setup - - # TODO: use run_index instead? - # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ - completed_steps = 0 if training_progress is None else training_progress.completed_steps - - if self._run.is_main_rank: - args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params( - self._config.cli_args, completed_steps, self._run.index - ) - simple_eval_kwargs["model"] = self._flm_wrapper - - # Needed for reporting as batch_size is set from args not lm for reporting in evaluate - simple_eval_kwargs["batch_size"] = self._flm_wrapper.batch_size - simple_eval_kwargs["max_batch_size"] = self._flm_wrapper.max_batch_size - - # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 - # Expected to be a string even if empty and not None in simple_evaluate - simple_eval_kwargs["model_args"] = "" - - results = lm_eval_simple_evaluate(**simple_eval_kwargs) - self._flm_wrapper.stop_workers() - - # Evaluation_tracker save expects model to be either string, but if model is passed - # LM wrapper needs to be deep copyable and json serializable - simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( - self._hf_model.config.name_or_path - ) - - if results is not None: - process_lm_eval_results( - args, - results, - simple_eval_kwargs["evaluation_tracker"], - completed_steps, - ) - else: - self._flm_wrapper.worker_model_invoke() - - # TODO: do we need it here as self._flm_wrapper.stop_workers() and self._flm_wrapper.worker_model_invoke() - # already have barrier - safe_barrier(self._distributed.world_group, f"Evaluation Harness Run end") - - # lm_eval logs to disc, wandb and prints to screen itself - return EvaluationMetrics() - - def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: - return None - - # NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation. class EvaluatorRunner: _is_setup: bool = False diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py new file mode 100644 index 00000000..404549a3 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -0,0 +1,89 @@ +import logging +import os +import pathlib +import typing + +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import EvaluatorLmEvalConfig +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorSamplingParameters, + TrainingProgress, +) +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + +logger = logging.getLogger(__name__) + + +class EvaluatorLmEval[ConfigType: EvaluatorLmEvalConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[EvaluatorLmEvalConfig]] = EvaluatorLmEvalConfig + + _hf_model: "HuggingfaceBaseModelForCausalLM" = None + _flm_wrapper: "FastLLMLmEvalWrapper" = None + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + if "HUGGINGFACE_API_KEY_PATH" in os.environ: + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + else: + if not "HF_TOKEN" in os.environ: + logger.warning( + "No `HF_TOKEN` or `HUGGINGFACE_API_KEY_PATH` environment variable provided. " + "Assuming the user has already logged in to the Hugging Face Hub." + ) + + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + + super().setup(distributed, run, multi_stage, runner, data, phase) + + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( + self._multi_stage, runner=self._runner + ) + + # For reporting purposes, just to indicate it is from Fast-LLM + # as lm_eval.simple_evaluate will take it for results['config']['model'] + self._hf_model.config.name_or_path = type(self._hf_model).__name__ + + self._flm_wrapper = FastLLMLmEvalWrapper( + model=self._hf_model, + tokenizer=self._data.tokenizer.tokenizer, + truncation=self._config.truncation, + logits_cache=self._config.logits_cache, + add_bos_token=self._config.add_bos_token, + prefix_token_id=self._config.prefix_token_id, + ) + self._is_setup = True + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ + completed_steps = 0 if training_progress is None else training_progress.completed_steps + + self._flm_wrapper.run(self._config.cli_args, completed_steps, self._run.index) + + # lm_eval logs to disc, wandb and prints to screen itself + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index e81e7675..2babfad2 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -4,6 +4,7 @@ import jinja2 import lm_eval.api.instance import lm_eval.api.model +import lm_eval.evaluator import lm_eval.models.utils import lm_eval.utils import torch @@ -12,8 +13,8 @@ import transformers from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object_list -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM eval_logger = logging.getLogger(__name__) @@ -37,9 +38,12 @@ def __init__( self._world_size = 1 self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed - dist_config: DistributedConfig = self._distributed.config # get batch_data_parallel group leaders - if dist_config.sequence_data_rank == 0 and dist_config.pipeline_rank == 0 and dist_config.tensor_rank == 0: + if ( + self._distributed.config.sequence_data_rank == 0 + and self._distributed.config.pipeline_rank == 0 + and self._distributed.config.tensor_rank == 0 + ): self.group = self._distributed.batch_data_group else: self.group = torch.distributed.GroupMember.NON_GROUP_MEMBER @@ -71,13 +75,6 @@ def __init__( self.tokenizer = lm_eval.models.utils.configure_pad_token(self.tokenizer, model_config=self.config) self.add_bos_token = add_bos_token - # TODO: do we support gemma models? - if "gemma" in getattr(self.config, "model_type", ""): - self.add_bos_token = True - eval_logger.info( - f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS" - " token will be used as Gemma underperforms without it." - ) self._max_length = model._inference_runner._batch_config.sequence_length self.pretrained = model @@ -88,13 +85,49 @@ def __init__( self.batch_schedule = 1 self.batch_sizes = {} self.batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size - self.batch_size = self.batch_size_per_gpu * dist_config.batch_data_parallel + self.batch_size = self.batch_size_per_gpu * self._distributed.config.batch_data_parallel self.max_batch_size = self.batch_size self.custom_prefix_token_id = prefix_token_id if prefix_token_id is not None: eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + def run(self, cli_args: list[str], completed_steps: int, run_index: int): + if self._distributed.config.rank == 0: + args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params(cli_args, completed_steps, run_index) + simple_eval_kwargs["model"] = self + + # Needed for reporting as batch_size is set from args not lm for reporting in evaluate + simple_eval_kwargs["batch_size"] = self.batch_size + simple_eval_kwargs["max_batch_size"] = self.max_batch_size + + # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 + # Expected to be a string even if empty and not None in simple_evaluate + simple_eval_kwargs["model_args"] = "" + + results = lm_eval.evaluator.simple_evaluate(**simple_eval_kwargs) + self.stop_workers() + + # Evaluation_tracker save expects model to be either string, but if model is passed + # LM wrapper needs to be deep copyable and json serializable + simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( + self._model.config.name_or_path + ) + + if results is not None: + process_lm_eval_results( + args, + results, + simple_eval_kwargs["evaluation_tracker"], + completed_steps, + ) + else: + self.worker_model_invoke() + + # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() + # already have barrier + safe_barrier(self._distributed.world_group, f"lm_eval Run end") + def _model_invoke( self, input_ids, @@ -109,15 +142,17 @@ def _model_invoke( # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. # Messages could include types like logits, generate, finished. - if self.group is None or (world_size := self.group.size()) == 1: + # Groups is always None if world size is 1 + if self.group is None: # Must not be called with continue_generate false on one process assert continue_generate return self._model_invoke_inner( input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs ) - rank = self.group.rank() - assert rank == 0 + world_size = self.group.size() + + assert self.group.rank() == 0 if continue_generate: assert input_ids is not None @@ -197,7 +232,8 @@ def worker_model_invoke(self): assert self.group is not None # if isinstance(self.group, dist.ProcessGroup): if not isinstance(self.group, int): - assert self.group.size() > 1 and self.group.rank() != 0 + # groups is None for world_size 1 + assert self.group.rank() != 0 # on worker ranks the function need to wait to be called multiple times while True: scatter_list = None From 6b747390ff3fbe5e5f59c8a1029417be04c5a0ca Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 13:43:17 +0000 Subject: [PATCH 22/26] import change --- fast_llm/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 8cb9d912..98ea0037 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -14,9 +14,9 @@ # On systems with more cores, numexpr logs an error and # ignores the thread setting if it exceeds the limit. if "NUMEXPR_MAX_THREADS" not in os.environ: - import multiprocessing as mp + import multiprocessing - os.environ["NUMEXPR_MAX_THREADS"] = str(mp.cpu_count()) + os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count()) # Import these submodules to ensure classes are added to the dynamic class registry. From c3984447ef726350cf4989316e988a2bcf65eb07 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 14:00:58 +0000 Subject: [PATCH 23/26] zero stage 3 inference warning added and TODO --- fast_llm/engine/inference/huggingface.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 3c2db428..54a82492 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -1,3 +1,4 @@ +import logging import os import pathlib import typing @@ -14,6 +15,8 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class HuggingfacePreTrainedModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig @@ -41,6 +44,8 @@ def __init__( # The HF constructor performs a deep copy of the config, # but config.fast_llm_config may contain non-picklable items like process groups. # Temporarily remove it before the call and restore it afterward. + # TODO: Find a clean solution — overriding __deepcopy__ doesn't work here + # because internally they use copy.deepcopy(self.__dict__). fast_llm_config = config.fast_llm_config config.fast_llm_config = None super().__init__(config, **kwargs) @@ -64,6 +69,11 @@ def __init__( with transformers.modeling_utils.no_init_weights(): self.post_init() + if fast_llm_model.config.multi_stage.zero_stage == 3: + logger.warning( + "zero_stage=3 is used for the model; forward and generate will be extremely slow during inference." + ) + @classmethod def from_pretrained( cls, From 62846d22061b7f407fb6aec876967b1e856b0f82 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 2 Jul 2025 14:16:44 +0000 Subject: [PATCH 24/26] removed docstrings --- fast_llm/core/distributed.py | 234 +---------------------------------- 1 file changed, 4 insertions(+), 230 deletions(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index ffbeab39..185a4cbd 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -31,10 +31,6 @@ logger = logging.getLogger(__name__) -def _as_iterable(obj) -> collections.abc.Iterable: - return obj if isinstance(obj, list) else (obj,) - - def _check_single_tensor(param, param_name) -> None: """Check that the parameter ``param_name`` is a single tensor.""" if not isinstance(param, torch.Tensor): @@ -58,6 +54,10 @@ def _check_tensor_list(param, param_name) -> None: ) +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + def _ensure_all_tensors_same_dtype(*tensors) -> None: last_dtype = None for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): @@ -102,33 +102,12 @@ def _object_to_tensor(obj, device, group): # See: https://github.com/pytorch/pytorch/issues/65696 byte_tensor = torch.ByteTensor(byte_storage).to(device) - # TODO: do we need to log this level of details? - # if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): - # backend = get_backend(group) - # if backend == Backend.NCCL: - # hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) - # logger.warning( - # "_object_to_tensor size: %s hash value: %s", - # byte_tensor.numel(), - # hash, - # ) - local_size = torch.LongTensor([byte_tensor.numel()]).to(device) return byte_tensor, local_size def _tensor_to_object(tensor, tensor_size, group): with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard(): - - # TODO: do we need to log this level of details? - # if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): - # backend = get_backend(group) - # if backend == Backend.NCCL: - # hash = torch._C._distributed_c10d._hash_tensors([tensor]) - # logger.warning( - # "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash - # ) - tensor = tensor.cpu() buf = tensor.numpy().tobytes()[:tensor_size] return _unpickler(io.BytesIO(buf)).load() @@ -258,43 +237,6 @@ def gather( async_op: bool = False, group_dst: typing.Optional[int] = None, ): - """ - Gathers a list of tensors in a single process. - - This function requires all tensors to be the same size on each process. - - Args: - tensor (Tensor): Input tensor. - gather_list (list[Tensor], optional): List of appropriately, - same-sized tensors to use for gathered data - (default is None, must be specified on the destination rank) - group (ProcessGroup, optional): The process group to work on. - async_op (bool, optional): Whether this op should be an async op - group_dst (int, optional): Destination rank on ``group``. - - Returns: - Async work handle, if async_op is set to True. - None, if not async_op or if not part of the group - - .. note:: Note that all Tensors in gather_list must have the same size. - - Example:: - >>> # xdoctest: +SKIP("no rank") - >>> # We have 2 process groups, 2 ranks. - >>> tensor_size = 2 - >>> device = torch.device(f'cuda:{rank}') - >>> tensor = torch.ones(tensor_size, device=device) + rank - >>> if dist.get_rank() == 0: - >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] - >>> else: - >>> gather_list = None - >>> dist.gather(tensor, gather_list, dst=0) - >>> # Rank 0 gets gathered data. - >>> gather_list - [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 - None # Rank 1 - - """ _check_single_tensor(tensor, "tensor") # Parameter ``gather_list`` may be left unspecified on non-dst ranks. @@ -334,50 +276,6 @@ def scatter( async_op: bool = False, group_src: typing.Optional[int] = None, ): - """ - Scatters a list of tensors to all processes in a group. - - Each process will receive exactly one tensor and store its data in the - ``tensor`` argument. - - Complex tensors are supported. - - Args: - tensor (Tensor): Output tensor. - scatter_list (list[Tensor]): List of tensors to scatter (default is - None, must be specified on the source rank) - group (ProcessGroup, optional): The process group to work on. - async_op (bool, optional): Whether this op should be an async op - group_src (int, optional): Source rank on ``group``. - - Returns: - Async work handle, if async_op is set to True. - None, if not async_op or if not part of the group - - .. note:: Note that all Tensors in scatter_list must have the same size. - - Example:: - >>> # xdoctest: +SKIP("need process group init") - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist - >>> tensor_size = 2 - >>> device = torch.device(f'cuda:{rank}') - >>> output_tensor = torch.zeros(tensor_size, device=device) - >>> if dist.get_rank() == 0: - >>> # Assumes world_size of 2. - >>> # Only tensors, all of which must be the same size. - >>> t_ones = torch.ones(tensor_size, device=device) - >>> t_fives = torch.ones(tensor_size, device=device) * 5 - >>> scatter_list = [t_ones, t_fives] - >>> else: - >>> scatter_list = None - >>> dist.scatter(output_tensor, scatter_list, src=0) - >>> # Rank i gets scatter_list[i]. - >>> output_tensor - tensor([1., 1.], device='cuda:0') # Rank 0 - tensor([5., 5.], device='cuda:1') # Rank 1 - - """ _check_single_tensor(tensor, "tensor") # Parameter ``scatter_list`` may be left unspecified on non-src ranks. if scatter_list: @@ -425,72 +323,6 @@ def gather_object( group: typing.Optional[ProcessGroup] = None, group_dst: typing.Optional[int] = None, ): - """ - Gathers picklable objects from the whole group in a single process. - - Similar to :func:`gather`, but Python objects can be passed in. Note that the - object must be picklable in order to be gathered. - - Args: - current_device: (torch.device | str): device to use for object serialization to - tensor, must be this process assigned gpu for nccl backend. - obj (Any): Input object. Must be picklable. - object_gather_list (list[Any]): Output list. On the ``dst`` rank, it - should be correctly sized as the size of the group for this - collective and will contain the output. Must be ``None`` on non-dst - ranks. (default is ``None``) - dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). - (If both ``dst`` and ``group_dst`` are None, default is global rank 0) - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` - - Returns: - None. On the ``dst`` rank, ``object_gather_list`` will contain the - output of the collective. - - .. note:: Note that this API differs slightly from the gather collective - since it does not provide an async_op handle and thus will be a blocking - call. - - .. note:: For NCCL-based processed groups, internal tensor representations - of objects must be moved to the GPU device before communication takes - place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to - ensure that this is set so that each rank has an individual GPU, via - ``torch.cuda.set_device()``. - - .. warning:: - Object collectives have a number of serious performance and scalability - limitations. See :ref:`object_collectives` for details. - - .. warning:: - :func:`gather_object` uses ``pickle`` module implicitly, which is - known to be insecure. It is possible to construct malicious pickle data - which will execute arbitrary code during unpickling. Only call this - function with data you trust. - - .. warning:: - Calling :func:`gather_object` with GPU tensors is not well supported - and inefficient as it incurs GPU -> CPU transfer since tensors would be - pickled. Please consider using :func:`gather` instead. - - Example:: - >>> # xdoctest: +SKIP("need process group init") - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist - >>> # Assumes world_size of 3. - >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object - >>> output = [None for _ in gather_objects] - >>> dist.gather_object( - ... gather_objects[dist.get_rank()], - ... output if dist.get_rank() == 0 else None, - ... dst=0 - ... ) - >>> # On rank 0 - >>> output - ['foo', 12, {1: 2}] - """ assert group is not None if group_dst is None: group_dst = 0 @@ -546,64 +378,6 @@ def scatter_object_list( group: typing.Optional[ProcessGroup] = None, group_src: typing.Optional[int] = None, ): - """ - Scatters picklable objects in ``scatter_object_input_list`` to the whole group. - - Similar to :func:`scatter`, but Python objects can be passed in. On - each rank, the scattered object will be stored as the first element of - ``scatter_object_output_list``. Note that all objects in - ``scatter_object_input_list`` must be picklable in order to be scattered. - - Args: - pg_device: (torch.device | str): device to use for object serialization to - tensor, must be this process assigned gpu for nccl backend. - scatter_object_output_list (List[Any]): Non-empty list whose first - element will store the object scattered to this rank. - scatter_object_input_list (List[Any], optional): List of input objects to scatter. - Each object must be picklable. Only objects on the ``src`` rank will - be scattered, and the argument can be ``None`` for non-src ranks. - group: (ProcessGroup, optional): The process group to work on. - group_src (int, optional): Source rank on ``group``. - - Returns: - ``None``. If rank is part of the group, ``scatter_object_output_list`` - will have its first element set to the scattered object for this rank. - - .. note:: Note that this API differs slightly from the scatter collective - since it does not provide an ``async_op`` handle and thus will be a - blocking call. - - .. warning:: - Object collectives have a number of serious performance and scalability - limitations. See :ref:`object_collectives` for details. - - .. warning:: - :func:`scatter_object_list` uses ``pickle`` module implicitly, which - is known to be insecure. It is possible to construct malicious pickle - data which will execute arbitrary code during unpickling. Only call this - function with data you trust. - - .. warning:: - Calling :func:`scatter_object_list` with GPU tensors is not well supported - and inefficient as it incurs GPU -> CPU transfer since tensors would be - pickled. Please consider using :func:`scatter` instead. - - Example:: - >>> # xdoctest: +SKIP("need process group init") - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist - >>> if dist.get_rank() == 0: - >>> # Assumes world_size of 3. - >>> objects = ["foo", 12, {1: 2}] # any picklable object - >>> else: - >>> # Can be any list on non-src ranks, elements are not used. - >>> objects = [None, None, None] - >>> output_list = [None] - >>> dist.scatter_object_list(output_list, objects, src=0) - >>> # Rank i gets objects[i]. For example, on rank 2: - >>> output_list - [{1: 2}] - """ assert group is not None if group_src is None: group_src = 0 From e61cc3ed3212241b44467fc27f9817db4c0475cf Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 3 Jul 2025 09:48:18 +0000 Subject: [PATCH 25/26] removed unused fields, change generate call --- .../evaluation/lm_eval/fast_llm_wrapper.py | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 2babfad2..af3b441f 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -48,14 +48,8 @@ def __init__( else: self.group = torch.distributed.GroupMember.NON_GROUP_MEMBER - # TODO: clean code which does not used parts from HFLM - backend = "causal" - revision = "main" - delta = None - peft = None - # set some inputs which are expected in HFLM but are set by our model config - self.backend = backend + self.backend = "causal" # set tokenizer object assert isinstance(tokenizer, transformers.PreTrainedTokenizer) or isinstance( @@ -78,9 +72,6 @@ def __init__( self._max_length = model._inference_runner._batch_config.sequence_length self.pretrained = model - self.delta = delta - self.peft = peft - self.revision = revision self.batch_schedule = 1 self.batch_sizes = {} @@ -365,18 +356,15 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0] ) - kwargs = { - "input_ids": input_ids, - "max_length": max_length, - "stopping_criteria": stopping_criteria, - "pad_token_id": self.tokenizer.pad_token_id, - "use_cache": False, + return self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, **generation_kwargs, - } - if attention_mask is not None: - kwargs["attention_mask"] = attention_mask - - return self.model.generate(**kwargs) + ) @property def config(self): From 6a2ab3588adbfb13d553802e267a4e4d58c3cccd Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 3 Jul 2025 14:29:19 +0000 Subject: [PATCH 26/26] changed to all fields to be private, removed properties which are used only internally, made max_lenght settable --- fast_llm/engine/evaluation/config.py | 6 + .../engine/evaluation/lm_eval/evaluator.py | 1 + .../evaluation/lm_eval/fast_llm_wrapper.py | 328 +++++++++--------- 3 files changed, 170 insertions(+), 165 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 9f79b906..265e5f98 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -97,6 +97,12 @@ class EvaluatorLmEvalConfig(EvaluatorConfig): " passed to the Fast-LLM lm_eval model wrapper.", ) + max_length: int | None = Field( + default=None, + desc="Maximum sequence length including both prompt and newly generated tokens." + " If not set, it is inferred from the Fast-LLM model config or tokenizer.", + ) + def get_evaluator( self, name: str, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 404549a3..3de3663e 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -67,6 +67,7 @@ def setup( logits_cache=self._config.logits_cache, add_bos_token=self._config.add_bos_token, prefix_token_id=self._config.prefix_token_id, + max_length=self._config.max_length, ) self._is_setup = True diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index af3b441f..ed42d464 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,12 +16,14 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.layers.transformer.rotary.config import NoRotaryConfig eval_logger = logging.getLogger(__name__) class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): _DEFAULT_MAX_LENGTH = 2048 + _DEFAULT_MAX_GEN_TOKENS = 256 def __init__( self, @@ -31,57 +33,113 @@ def __init__( logits_cache: bool = True, add_bos_token: bool | None = False, prefix_token_id: int | None = None, + max_length: int | None = None, ): super().__init__() - # This is for lm_eval sake, we always run lm_eval on one main rank - self._rank = 0 - self._world_size = 1 + # === Distributed setup === + self._rank = 0 # For lm_eval: always run on main rank + self._world_size = 1 self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed - # get batch_data_parallel group leaders + if ( self._distributed.config.sequence_data_rank == 0 and self._distributed.config.pipeline_rank == 0 and self._distributed.config.tensor_rank == 0 ): - self.group = self._distributed.batch_data_group + self._group = self._distributed.batch_data_group else: - self.group = torch.distributed.GroupMember.NON_GROUP_MEMBER + self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER - # set some inputs which are expected in HFLM but are set by our model config - self.backend = "causal" + # === Model & tokenizer setup === + self._model = model + self._device = model.device + self._config = model.config + + assert isinstance(tokenizer, (transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast)) + self._tokenizer = tokenizer + self._tokenizer = lm_eval.models.utils.configure_pad_token(self._tokenizer, model_config=self._config) + + # === Generation/configuration parameters === + self._truncation = truncation + self._logits_cache = logits_cache + self._add_bos_token = add_bos_token + self._max_length = max_length + self._custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") - # set tokenizer object - assert isinstance(tokenizer, transformers.PreTrainedTokenizer) or isinstance( - tokenizer, transformers.PreTrainedTokenizerFast - ) - self.tokenizer = tokenizer + # === Internal constants === + self._backend = "causal" + self._vocab_size = self._tokenizer.vocab_size - # initialize model fields - self._model = model - self._device = self._model.device - self._config = self._model.config + # === Batch configuration === + self._batch_schedule = 1 + self._batch_sizes = {} # Not used dynamically by lm_eval + self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel + self._max_batch_size = self._batch_size - self.truncation = truncation - self.logits_cache = logits_cache - self.vocab_size = self.tokenizer.vocab_size - # select (or create) a pad token to use - self.tokenizer = lm_eval.models.utils.configure_pad_token(self.tokenizer, model_config=self.config) + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self._tokenizer.eos_token_id - self.add_bos_token = add_bos_token + # overrides from TemplateLM, but not used externally + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self._custom_prefix_token_id is not None: + return self._custom_prefix_token_id + if self._tokenizer.bos_token_id is not None: + return self._tokenizer.bos_token_id + return self._tokenizer.eos_token_id - self._max_length = model._inference_runner._batch_config.sequence_length - self.pretrained = model + @property + def max_length(self): + # if max length manually set, return it + if self._max_length: + return self._max_length - self.batch_schedule = 1 - self.batch_sizes = {} - self.batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size - self.batch_size = self.batch_size_per_gpu * self._distributed.config.batch_data_parallel - self.max_batch_size = self.batch_size + # check if it is absolute positional encoding and return max_position_embeddings + if hasattr(self._config.fast_llm_config.base_model, "transformer"): + # NOTE: will need to extend if more relative encoding types will be added + if isinstance(self._config.fast_llm_config.base_model.transformer.rotary, NoRotaryConfig): + return self._config.fast_llm_config.base_model.max_position_embeddings - self.custom_prefix_token_id = prefix_token_id - if prefix_token_id is not None: - eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + # check if tokenizer holds model sequence leigh info + if hasattr(self._tokenizer, "model_max_length"): + if self._tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self._tokenizer.model_max_length + + # finally try to get sequence length from batch config + if hasattr(self._model._inference_runner._batch_config, "sequence_length"): + return self._model._inference_runner._batch_config.sequence_length + + return self._DEFAULT_MAX_LENGTH + + # @property + # def device(self): + # # only used for world_size when lm_eval world size > 1 and + # # should not be called with current lm_eval support implementation + # return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer(self): + return self._tokenizer + + @property + def tokenizer_name(self) -> str: + return self._tokenizer.name_or_path.replace("/", "__") def run(self, cli_args: list[str], completed_steps: int, run_index: int): if self._distributed.config.rank == 0: @@ -89,8 +147,8 @@ def run(self, cli_args: list[str], completed_steps: int, run_index: int): simple_eval_kwargs["model"] = self # Needed for reporting as batch_size is set from args not lm for reporting in evaluate - simple_eval_kwargs["batch_size"] = self.batch_size - simple_eval_kwargs["max_batch_size"] = self.max_batch_size + simple_eval_kwargs["batch_size"] = self._batch_size + simple_eval_kwargs["max_batch_size"] = self._max_batch_size # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 # Expected to be a string even if empty and not None in simple_evaluate @@ -134,16 +192,16 @@ def _model_invoke( # Messages could include types like logits, generate, finished. # Groups is always None if world size is 1 - if self.group is None: + if self._group is None: # Must not be called with continue_generate false on one process assert continue_generate return self._model_invoke_inner( input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs ) - world_size = self.group.size() + world_size = self._group.size() - assert self.group.rank() == 0 + assert self._group.rank() == 0 if continue_generate: assert input_ids is not None @@ -151,8 +209,8 @@ def _model_invoke( assert max_length is not None and stop is not None # always divide by batch_size, if not full batch, some ranks will get less work or not at all - assert self.batch_size % world_size == 0 - step = self.batch_size // world_size + assert self._batch_size % world_size == 0 + step = self._batch_size // world_size input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] attention_mask = [ @@ -183,7 +241,7 @@ def _model_invoke( obj_list, scatter_list, group_src=0, - group=self.group, + group=self._group, ) input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = tuple( obj_list[0] @@ -204,7 +262,7 @@ def _model_invoke( res, gather_list, group_dst=0, - group=self.group, + group=self._group, ) # Clean gather list from empty shards gather_list = [el for el in gather_list if len(el) > 0] @@ -220,11 +278,11 @@ def _model_invoke( return res def worker_model_invoke(self): - assert self.group is not None + assert self._group is not None # if isinstance(self.group, dist.ProcessGroup): - if not isinstance(self.group, int): + if not isinstance(self._group, int): # groups is None for world_size 1 - assert self.group.rank() != 0 + assert self._group.rank() != 0 # on worker ranks the function need to wait to be called multiple times while True: scatter_list = None @@ -234,7 +292,7 @@ def worker_model_invoke(self): obj_list, scatter_list, group_src=0, - group=self.group, + group=self._group, ) input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( tuple(obj_list[0]) @@ -257,15 +315,15 @@ def worker_model_invoke(self): res, gather_list, group_dst=0, - group=self.group, + group=self._group, ) else: # TODO: implement distributed model support - assert self.group == torch.distributed.GroupMember.NON_GROUP_MEMBER + assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER safe_barrier(self._distributed.world_group, "lm_eval_end") def stop_workers(self): - if self.group is None or (world_size := self.group.size()) == 1: + if self._group is None or (world_size := self._group.size()) == 1: return self._model_invoke(None, None, None, None, None, None, continue_generate=False) safe_barrier(self._distributed.world_group, "lm_eval_end") @@ -324,7 +382,7 @@ def _model_call_inner(self, input_ids, attention_mask=None, labels=None): # and pass only the results around instead of the full logits. # Computing errors here is also preferable, as copying logits across nodes and GPUs # is inefficient and can involve gigabytes of data. - return self.model( + return self._model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, @@ -353,81 +411,19 @@ def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **g generation_kwargs.pop("temperature") # build stopping criteria stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( - self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0] + self._tokenizer, stop, input_ids.shape[1], input_ids.shape[0] ) - return self.model.generate( + return self._model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, stopping_criteria=stopping_criteria, - pad_token_id=self.tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, use_cache=False, **generation_kwargs, ) - @property - def config(self): - # return the associated transformers.AutoConfig for the given pretrained model. - return self._config - - @property - def model(self): - return self._model - - @property - def eot_token_id(self): - # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* - return self.tokenizer.eos_token_id - - @property - def prefix_token_id(self): - # it is used as prefix for loglikelihood - if self.custom_prefix_token_id is not None: - return self.custom_prefix_token_id - if self.tokenizer.bos_token_id is not None: - return self.tokenizer.bos_token_id - return self.tokenizer.eos_token_id - - @property - def max_length(self): - if self._max_length: # if max length manually set, return it - return self._max_length - seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") - for attr in seqlen_config_attrs: - if hasattr(self.model.config, attr): - return getattr(self.model.config, attr) - if hasattr(self.tokenizer, "model_max_length"): - if self.tokenizer.model_max_length == 1000000000000000019884624838656: - return self._DEFAULT_MAX_LENGTH - return self.tokenizer.model_max_length - return self._DEFAULT_MAX_LENGTH - - @property - def max_gen_toks(self) -> int: - return 256 - - # TODO: check removing this does not affect lm_eval - # @property - # def batch_size(self): - # return self.batch_size_per_gpu - - @property - def device(self): - return self._device - - @property - def rank(self): - return self._rank - - @property - def world_size(self): - return self._world_size - - @property - def tokenizer_name(self) -> str: - return self.tokenizer.name_or_path.replace("/", "__") - def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: """ """ # default for None - empty dict, use predefined tokenizer param @@ -436,13 +432,13 @@ def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=Non # by default for CausalLM - false or self.add_bos_token is set if add_special_tokens is None: - if self.backend == "causal": - special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token} + if self._backend == "causal": + special_tokens_kwargs = {"add_special_tokens": False or self._add_bos_token} # otherwise the method explicitly defines the value else: special_tokens_kwargs = {"add_special_tokens": add_special_tokens} - encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + encoding = self._tokenizer.encode(string, **special_tokens_kwargs) # left-truncate the encoded context to be at most `left_truncate_len` tokens long if left_truncate_len: @@ -458,14 +454,14 @@ def tok_batch_encode( truncation: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. - old_padding_side = self.tokenizer.padding_side - self.tokenizer.padding_side = padding_side + old_padding_side = self._tokenizer.padding_side + self._tokenizer.padding_side = padding_side add_special_tokens = {} - if self.backend == "causal": - add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + if self._backend == "causal": + add_special_tokens = {"add_special_tokens": False or self._add_bos_token} - encoding = self.tokenizer( + encoding = self._tokenizer( strings, truncation=truncation, padding="longest", @@ -481,20 +477,20 @@ def tok_batch_encode( ) encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] - self.tokenizer.padding_side = old_padding_side + self._tokenizer.padding_side = old_padding_side return encoding["input_ids"], encoding["attention_mask"] def tok_decode(self, tokens, skip_special_tokens=True): - return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: - if self.backend == "causal": + if self._backend == "causal": assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" # discard right-padding. # also discard the input/context tokens. we'll only score continuations. logits = logits[inplen - contlen : inplen] - elif self.backend == "seq2seq": + elif self._backend == "seq2seq": assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" # only discard right-padding. # the logits input to this fn only contain decoder-side tokens. @@ -506,7 +502,7 @@ def loglikelihood_rolling( self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False ) -> list[float]: adaptive_batch_size = None - if self.batch_size == "auto": + if self._batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() @@ -523,6 +519,8 @@ def loglikelihood_rolling( disable=(disable_tqdm or (self.rank != 0)), ) ): + # The tokenizer may raise: "Token indices sequence length is longer than the specified maximum sequence length for this model" + # This is expected and fine, as the sequence will be split into chunks of max_length later. rolling_token_windows: list[tuple[list[int], list[int]]] = list( map( lm_eval.utils.make_disjoint_window, @@ -545,14 +543,14 @@ def loglikelihood_rolling( # Handle distributed case padding pad_amnt = 0 if self.world_size > 1: - mytensor = torch.tensor(len(all_windows), device=self.device) + mytensor = torch.tensor(len(all_windows), device=self._device) gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() pad_amnt = max(gathered) - gathered[self.rank] if pad_amnt > 0: all_windows += pad_amnt * [all_windows[0]] all_nlls = [] - batch_size = adaptive_batch_size or self.batch_size + batch_size = adaptive_batch_size or self._batch_size for i in range(0, len(all_windows), batch_size): batch = all_windows[i : i + batch_size] # Extract just the windows for processing, keeping track of request indices @@ -587,17 +585,17 @@ def loglikelihood_rolling( return loglikelihoods def _batch_scheduler(self, pos, n_reordered_requests): - sched = pos // int(len(n_reordered_requests) / self.batch_schedule) - if sched in self.batch_sizes: - return self.batch_sizes[sched] - if (len(self.batch_sizes) > 1) and (self.batch_sizes[sched - 1] == self.max_batch_size): + sched = pos // int(len(n_reordered_requests) / self._batch_schedule) + if sched in self._batch_sizes: + return self._batch_sizes[sched] + if (len(self._batch_sizes) > 1) and (self._batch_sizes[sched - 1] == self._max_batch_size): # if previous batch size is already maximal, skip recomputation - self.batch_sizes[sched] = self.max_batch_size - return self.batch_sizes[sched] - print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size") - self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) - print(f"Determined largest batch size: {self.batch_sizes[sched]}") - return self.batch_sizes[sched] + self._batch_sizes[sched] = self._max_batch_size + return self._batch_sizes[sched] + print(f"Passed argument batch_size = auto:{self._batch_schedule}. Detecting largest batch size") + self._batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self._batch_sizes[sched]}") + return self._batch_sizes[sched] def _loglikelihood_tokens( self, @@ -628,17 +626,17 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): re_ord = lm_eval.models.utils.Collator( requests, sort_fn=_collate, - group_by="contexts" if self.backend == "causal" and self.logits_cache else None, + group_by="contexts" if self._backend == "causal" and self._logits_cache else None, group_fn=lambda req: req[-2] + req[-1][:-1], ) # automatic (variable) batch size detection for vectorization # pull longest context sample from request n_reordered_requests = len(re_ord) - batch_size = self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 + batch_size = self._batch_size if self._batch_size != "auto" else override_bs if override_bs is not None else 0 batch_fn = ( self._batch_scheduler - if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs + if self._batch_size == "auto" and n_reordered_requests > 0 and not override_bs else None ) @@ -676,10 +674,10 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # when too long to fit in context, truncate from the left - if self.backend == "causal": + if self._backend == "causal": total_length = len(context_enc) + len(continuation_enc) if total_length > self.max_length + 1: - eval_logger.warn( + eval_logger.warning( f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " f"exceeds model's maximum length ({self.max_length}). " f"Truncating {total_length - self.max_length + 1} tokens from the left." @@ -687,14 +685,14 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): inp = torch.tensor( (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], dtype=torch.long, - device=self.device, + device=self._device, ) (inplen,) = inp.shape - elif self.backend == "seq2seq": + elif self._backend == "seq2seq": inp = torch.tensor( (context_enc)[-self.max_length :], dtype=torch.long, - device=self.device, + device=self._device, ) (inplen,) = inp.shape @@ -706,7 +704,7 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): # TODO: left-shift these? # TODO: our code assumes we never end up truncating conts for either model type dtype=torch.long, - device=self.device, + device=self._device, ) (contlen,) = cont.shape @@ -722,11 +720,11 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): # create encoder attn mask and batched conts, if seq2seq call_kwargs = {} - if self.backend == "causal": + if self._backend == "causal": batched_inps = lm_eval.models.utils.pad_and_concat( padding_len_inp, inps, padding_side="right" ) # [batch, padding_len_inp] - elif self.backend == "seq2seq": + elif self._backend == "seq2seq": # TODO: left-pad encoder inps and mask? batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] batched_conts = lm_eval.models.utils.pad_and_concat( @@ -755,7 +753,7 @@ def _collate(req: tuple[tuple[str, str], list[int], list[int]]): # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable - ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self.backend == "causal" else None + ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self._backend == "causal" else None logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] @@ -807,7 +805,7 @@ def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_ desc="Running generate_until requests", ) adaptive_batch_size = None - if self.batch_size == "auto": + if self._batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() @@ -815,11 +813,11 @@ def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_ adaptive_batch_size = batch_size # for each different set of kwargs, we execute all requests, by batch. batch_size = ( - self.batch_size - if self.batch_size != "auto" + self._batch_size + if self._batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0 ) - batch_fn = self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None + batch_fn = self._batch_scheduler if self._batch_size == "auto" and not adaptive_batch_size else None # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling @@ -855,16 +853,16 @@ def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_ if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: - max_gen_toks = self.max_gen_toks + max_gen_toks = self._DEFAULT_MAX_GEN_TOKENS # set the max length in tokens of inputs ("context_enc") - if self.backend == "causal": + if self._backend == "causal": # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks assert ( max_ctx_len > 0 ), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." - elif self.backend == "seq2seq": + elif self._backend == "seq2seq": # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length @@ -872,10 +870,10 @@ def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_ input_ids, attention_mask = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, - truncation=self.truncation, + truncation=self._truncation, ) - input_ids = input_ids.to(self.device) - attention_mask = attention_mask.to(self.device) + input_ids = input_ids.to(self._device) + attention_mask = attention_mask.to(self._device) if "max_length" not in kwargs: kwargs["max_length"] = input_ids.shape[1] + max_gen_toks @@ -893,7 +891,7 @@ def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_ for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM - if self.backend == "causal": + if self._backend == "causal": cont_toks = cont_toks[input_ids.shape[1] :] s = self.tok_decode(cont_toks) @@ -921,7 +919,7 @@ def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation Method to apply a chat template to a list of chat history between user and model. """ try: - chat_templated = self.tokenizer.apply_chat_template( + chat_templated = self._tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=add_generation_prompt, @@ -930,7 +928,7 @@ def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation except jinja2.exceptions.TemplateError: eval_logger.warning("Failed to apply chat template. removing the system role in chat history.") chat_history = [msg for msg in chat_history if msg["role"] != "system"] - chat_templated = self.tokenizer.apply_chat_template( + chat_templated = self._tokenizer.apply_chat_template( chat_history, tokenize=False, add_generation_prompt=add_generation_prompt,