diff --git a/docs/user_guide/evaluators.md b/docs/user_guide/evaluators.md new file mode 100644 index 00000000..8bc6cb21 --- /dev/null +++ b/docs/user_guide/evaluators.md @@ -0,0 +1,92 @@ +# Evaluations + +Fast-LLM allows you to perform various evaluations during training or as a separate evaluation step. In both cases, you need to use your training config with `training.evaluators` specified. + +For evaluators used during training, both `interval` and `offset` must be specified. Then, start training as usual with: + +`fast-llm train gpt --config path/to/training/config.yaml` + +To perform evaluation as a separate step, use the same training config. Depending on the training progress, either the start model or the latest checkpoint will be loaded, and `interval` and `offset` will be ignored. To start evaluation: + +`fast-llm evaluate gpt --config path/to/training/config.yaml` + +## Currently Supported Evaluators + +- `loss` +- `lm_eval` + +## Loss Evaluator + +To set up loss evaluation, specify a dataset to be used in the `data.datasets` section of the config. You must also define the loss evaluator in the `training.evaluators` config section. See example below. + +```yaml +training: + evaluations: + stack_3b: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + interval: 10 +data: + datasets: + stack_3b: + type: memmap + path: path/to/memmap/dataset + fineweb: + type: memmap + path: path/to/memmap/dataset1 +``` + +## Evaluation Harness (`lm_eval`) Evaluator + +**Note:** Only data parallelism is currently supported for the `lm_eval` evaluator. + +To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks. + +The following environment variables may need to be set: + +- `HF_HOME`: Path for Hugging Face data caching +- `WANDB_API_KEY_PATH`: Path to a file containing your Weights & Biases API key (if logging to W&B) +- `HUGGINGFACE_API_KEY_PATH`: Path to a file containing your Hugging Face hub token +- `NLTK_DATA`: Path to a directory that will contain downloaded NLTK packages (needed for some tasks) +- `HF_ALLOW_CODE_EVAL=1`: Required for some evaluation tasks + +You may need to specify additional environment variables depending on the `lm_eval` tasks you want to run. + +To specify an `lm_eval` task, the evaluator config includes the following fields: + +### Model Config + +The model instantiated for training is reused for evaluation, so you don't need to specify it separately. However, there are some parameters specific to `lm_eval`. See `fast_llm/engine/evaluation/config.EvaluatorLmEvalConfig` for details. + +### CLI Parameters for `lm_eval` + +All other parameters are specified as if you were calling the `lm_eval` CLI, using a list of strings. Some CLI parameters are ignored or restricted—specifically those related to model loading, W&B, batch sizes, and device setup, as these are managed by the rest of the Fast-LLM configuration. + +Also, the tokenizer must be specified in `data.tokenizer`. If the tokenizer does not have a `bos_token`, it must be specified explicitly in `data.tokenizer.bos_token`. Although `lm_eval` does not use the `bos_token` directly, it is still required because the same tokenizer is used by other Fast-LLM components. + +Below is an example of the config: + +```yaml +training: + evaluations: + lm_eval_tasks1: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k,xnli_en,wikitext,ifeval + - --output_path + - /path/to/lm_eval/output +data: + tokenizer: + path: path/to/the/tokenizer +``` diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 66ce096d..98ea0037 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -1,5 +1,6 @@ import contextlib import logging +import os import sys import traceback @@ -8,6 +9,16 @@ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.runnable import RunnableConfig +# This must be set before importing numexpr, +# because by default, the maximum number of threads is 64. +# On systems with more cores, numexpr logs an error and +# ignores the thread setting if it exceeds the limit. +if "NUMEXPR_MAX_THREADS" not in os.environ: + import multiprocessing + + os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count()) + + # Import these submodules to ensure classes are added to the dynamic class registry. import fast_llm.data.auto # isort: skip import fast_llm.engine.checkpoint.convert # isort: skip diff --git a/fast_llm/config.py b/fast_llm/config.py index 0004501b..c534b11f 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -735,7 +735,7 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: "Config| dict[str, typing.Any]]", + default: "Config| dict[str, typing.Any]", *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index e82e0801..185a4cbd 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -6,12 +6,17 @@ Todo: Move all core methods elsewhere (functional?). """ +import collections import contextlib import datetime +import io +import itertools import logging +import pickle import typing import torch +import torch.monitor from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, @@ -26,6 +31,96 @@ logger = logging.getLogger(__name__) +def _check_single_tensor(param, param_name) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, torch.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, torch.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + if tensor_dtype.is_complex: + tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _rank_not_in_group(group: typing.Optional[ProcessGroup]) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == torch.distributed.GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + # TODO: get global rank + global_rank = -1 + logger.warning(f"Running {op_name} on global rank {global_rank} which does not " "belong to the given group.") + + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj, device, group): + with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._object_to_tensor").guard(): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard(): + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError("Argument ``gather_list`` must be specified on destination rank.") + elif gather_list: + raise ValueError("Argument ``gather_list`` must NOT be specified on non-destination ranks.") + + def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None: if group is not None and timeout is not None: # TODO: Only works for nccl? @@ -133,3 +228,201 @@ def set_generator(generator: torch.Generator) -> typing.Generator[None, None, No finally: generator.set_state(default_generator.get_state()) default_generator.set_state(old_state) + + +def gather( + tensor: torch.Tensor, + gather_list: typing.Optional[list[torch.Tensor]] = None, + group: typing.Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: typing.Optional[int] = None, +): + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + assert group is not None + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + if group_dst is None: + group_dst = 0 + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, gather_list) + output_tensors = [gather_list] if group_dst == my_group_rank else [] + input_tensors = [tensor] + + opts = torch.distributed.GatherOptions() + opts.rootRank = group_dst + # Absent in ver 2.6 + # opts.asyncOp = async_op + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + elif work is not None: # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def scatter( + tensor: torch.Tensor, + scatter_list: typing.Optional[list[torch.Tensor]] = None, + group: typing.Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: typing.Optional[int] = None, +): + _check_single_tensor(tensor, "tensor") + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + assert group is not None + if group_src is None: + group_src = 0 + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + my_group_rank = group.rank() + if group_src == my_group_rank: + if not scatter_list: + raise ValueError("Argument ``scatter_list`` must be specified on source rank.") + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError("Argument ``scatter_list`` must NOT be specified on non-source ranks.") + input_tensors = [] + output_tensors = [tensor] + + opts = torch.distributed.ScatterOptions() + opts.rootRank = group_src + opts.asyncOp = async_op + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + elif work is not None: # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def gather_object( + current_device: torch.device | str, + obj: typing.Any, + object_gather_list: typing.Optional[list[typing.Any]] = None, + group: typing.Optional[ProcessGroup] = None, + group_dst: typing.Optional[int] = None, +): + assert group is not None + if group_dst is None: + group_dst = 0 + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, object_gather_list) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = group.size() + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_group_rank == group_dst: + coalesced_output_tensor = torch.empty(max_object_size * group_size, dtype=torch.uint8, device=current_device) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + gather( + input_tensor, + gather_list=output_tensors if my_group_rank == group_dst else None, # type: ignore[possibly-undefined] + group_dst=group_dst, + group=group, + ) + if my_group_rank != group_dst: + return + + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +def scatter_object_list( + pg_device: torch.device | str, + scatter_object_output_list: list[typing.Any], + scatter_object_input_list: typing.Optional[list[typing.Any]] = None, + group: typing.Optional[ProcessGroup] = None, + group_src: typing.Optional[int] = None, +): + assert group is not None + if group_src is None: + group_src = 0 + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: + raise ValueError("Expected argument scatter_object_output_list to be a list of size at least 1.") + + my_group_rank = group.rank() + if my_group_rank == group_src: + if scatter_object_input_list is None: + raise ValueError("source rank must provide non-None scatter_object_input_list") + tensor_list, tensor_sizes = zip( + *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for tensor in tensor_list: # type: ignore[possibly-undefined] + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + broadcast(max_tensor_size, src=group_src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) + scatter( + output_tensor, + scatter_list=None if my_group_rank != group_src else tensor_list, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + scatter( + obj_tensor_size, + scatter_list=None if my_group_rank != group_src else tensor_sizes, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 7223631f..265e5f98 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLoss + from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, EvaluatorLoss @config_class() @@ -62,3 +62,54 @@ def get_evaluator( from fast_llm.engine.evaluation.evaluator import EvaluatorLoss return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class(dynamic_type={EvaluatorConfig: "lm_eval"}) +class EvaluatorLmEvalConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + + cli_args: list[str] = Field( + default_factory=lambda: [], + desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", + ) + + truncation: bool = Field( + default=False, + desc="Whether to use truncation during tokenization (useful when inputs exceed model's max length);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + logits_cache: bool = Field( + default=True, + desc="Whether to enable logits caching for speedup and avoiding recomputation during repeated evaluations;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + add_bos_token: bool = Field( + default=False, + desc="Whether to prepend a beginning-of-sequence (BOS) token, required for some models like LLaMA;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + prefix_token_id: int | None = Field( + default=None, + desc="Token ID to use as a prefix to the input (e.g., for control codes or prompts);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + max_length: int | None = Field( + default=None, + desc="Maximum sequence length including both prompt and newly generated tokens." + " If not set, it is inferred from the Fast-LLM model config or tokenizer.", + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLmEval": + from fast_llm.engine.evaluation.lm_eval.evaluator import EvaluatorLmEval + + return EvaluatorLmEval(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 78aad230..f593883c 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -19,8 +19,6 @@ from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib -# from fast_llm.engine.training.lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate - logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py new file mode 100644 index 00000000..3de3663e --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -0,0 +1,90 @@ +import logging +import os +import pathlib +import typing + +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import EvaluatorLmEvalConfig +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorSamplingParameters, + TrainingProgress, +) +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + +logger = logging.getLogger(__name__) + + +class EvaluatorLmEval[ConfigType: EvaluatorLmEvalConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[EvaluatorLmEvalConfig]] = EvaluatorLmEvalConfig + + _hf_model: "HuggingfaceBaseModelForCausalLM" = None + _flm_wrapper: "FastLLMLmEvalWrapper" = None + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + if "HUGGINGFACE_API_KEY_PATH" in os.environ: + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + else: + if not "HF_TOKEN" in os.environ: + logger.warning( + "No `HF_TOKEN` or `HUGGINGFACE_API_KEY_PATH` environment variable provided. " + "Assuming the user has already logged in to the Hugging Face Hub." + ) + + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + + super().setup(distributed, run, multi_stage, runner, data, phase) + + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( + self._multi_stage, runner=self._runner + ) + + # For reporting purposes, just to indicate it is from Fast-LLM + # as lm_eval.simple_evaluate will take it for results['config']['model'] + self._hf_model.config.name_or_path = type(self._hf_model).__name__ + + self._flm_wrapper = FastLLMLmEvalWrapper( + model=self._hf_model, + tokenizer=self._data.tokenizer.tokenizer, + truncation=self._config.truncation, + logits_cache=self._config.logits_cache, + add_bos_token=self._config.add_bos_token, + prefix_token_id=self._config.prefix_token_id, + max_length=self._config.max_length, + ) + self._is_setup = True + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ + completed_steps = 0 if training_progress is None else training_progress.completed_steps + + self._flm_wrapper.run(self._config.cli_args, completed_steps, self._run.index) + + # lm_eval logs to disc, wandb and prints to screen itself + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py new file mode 100644 index 00000000..ed42d464 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -0,0 +1,938 @@ +import copy +import logging + +import jinja2 +import lm_eval.api.instance +import lm_eval.api.model +import lm_eval.evaluator +import lm_eval.models.utils +import lm_eval.utils +import torch +import torch.nn.functional as F +import tqdm.auto +import transformers + +from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object_list +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results +from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.layers.transformer.rotary.config import NoRotaryConfig + +eval_logger = logging.getLogger(__name__) + + +class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): + _DEFAULT_MAX_LENGTH = 2048 + _DEFAULT_MAX_GEN_TOKENS = 256 + + def __init__( + self, + model: HuggingfaceBaseModelForCausalLM, + tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, + truncation: bool | None = False, + logits_cache: bool = True, + add_bos_token: bool | None = False, + prefix_token_id: int | None = None, + max_length: int | None = None, + ): + super().__init__() + + # === Distributed setup === + self._rank = 0 # For lm_eval: always run on main rank + self._world_size = 1 + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + + if ( + self._distributed.config.sequence_data_rank == 0 + and self._distributed.config.pipeline_rank == 0 + and self._distributed.config.tensor_rank == 0 + ): + self._group = self._distributed.batch_data_group + else: + self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER + + # === Model & tokenizer setup === + self._model = model + self._device = model.device + self._config = model.config + + assert isinstance(tokenizer, (transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast)) + self._tokenizer = tokenizer + self._tokenizer = lm_eval.models.utils.configure_pad_token(self._tokenizer, model_config=self._config) + + # === Generation/configuration parameters === + self._truncation = truncation + self._logits_cache = logits_cache + self._add_bos_token = add_bos_token + self._max_length = max_length + self._custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + + # === Internal constants === + self._backend = "causal" + self._vocab_size = self._tokenizer.vocab_size + + # === Batch configuration === + self._batch_schedule = 1 + self._batch_sizes = {} # Not used dynamically by lm_eval + self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel + self._max_batch_size = self._batch_size + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self._tokenizer.eos_token_id + + # overrides from TemplateLM, but not used externally + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self._custom_prefix_token_id is not None: + return self._custom_prefix_token_id + if self._tokenizer.bos_token_id is not None: + return self._tokenizer.bos_token_id + return self._tokenizer.eos_token_id + + @property + def max_length(self): + # if max length manually set, return it + if self._max_length: + return self._max_length + + # check if it is absolute positional encoding and return max_position_embeddings + if hasattr(self._config.fast_llm_config.base_model, "transformer"): + # NOTE: will need to extend if more relative encoding types will be added + if isinstance(self._config.fast_llm_config.base_model.transformer.rotary, NoRotaryConfig): + return self._config.fast_llm_config.base_model.max_position_embeddings + + # check if tokenizer holds model sequence leigh info + if hasattr(self._tokenizer, "model_max_length"): + if self._tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self._tokenizer.model_max_length + + # finally try to get sequence length from batch config + if hasattr(self._model._inference_runner._batch_config, "sequence_length"): + return self._model._inference_runner._batch_config.sequence_length + + return self._DEFAULT_MAX_LENGTH + + # @property + # def device(self): + # # only used for world_size when lm_eval world size > 1 and + # # should not be called with current lm_eval support implementation + # return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer(self): + return self._tokenizer + + @property + def tokenizer_name(self) -> str: + return self._tokenizer.name_or_path.replace("/", "__") + + def run(self, cli_args: list[str], completed_steps: int, run_index: int): + if self._distributed.config.rank == 0: + args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params(cli_args, completed_steps, run_index) + simple_eval_kwargs["model"] = self + + # Needed for reporting as batch_size is set from args not lm for reporting in evaluate + simple_eval_kwargs["batch_size"] = self._batch_size + simple_eval_kwargs["max_batch_size"] = self._max_batch_size + + # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 + # Expected to be a string even if empty and not None in simple_evaluate + simple_eval_kwargs["model_args"] = "" + + results = lm_eval.evaluator.simple_evaluate(**simple_eval_kwargs) + self.stop_workers() + + # Evaluation_tracker save expects model to be either string, but if model is passed + # LM wrapper needs to be deep copyable and json serializable + simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( + self._model.config.name_or_path + ) + + if results is not None: + process_lm_eval_results( + args, + results, + simple_eval_kwargs["evaluation_tracker"], + completed_steps, + ) + else: + self.worker_model_invoke() + + # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() + # already have barrier + safe_barrier(self._distributed.world_group, f"lm_eval Run end") + + def _model_invoke( + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_generate: bool, + **generation_kwargs, + ): + # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. + # Messages could include types like logits, generate, finished. + + # Groups is always None if world size is 1 + if self._group is None: + # Must not be called with continue_generate false on one process + assert continue_generate + return self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + world_size = self._group.size() + + assert self._group.rank() == 0 + + if continue_generate: + assert input_ids is not None + if generate: + assert max_length is not None and stop is not None + + # always divide by batch_size, if not full batch, some ranks will get less work or not at all + assert self._batch_size % world_size == 0 + step = self._batch_size // world_size + + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + attention_mask = [ + attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None + for i in range(world_size) + ] + labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] + + scatter_list = [ + [ + input_ids[i], + attention_mask[i], + labels[i], + max_length, + stop, + generate, + continue_generate, + generation_kwargs, + ] + for i in range(world_size) + ] + else: + scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + + obj_list = [None] + scatter_object_list( + self._distributed.device, + obj_list, + scatter_list, + group_src=0, + group=self._group, + ) + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = tuple( + obj_list[0] + ) + + if continue_generate == False: + return + + assert len(input_ids) > 0 + + res = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + gather_list = [None] * world_size + gather_object( + self._distributed.device, + res, + gather_list, + group_dst=0, + group=self._group, + ) + # Clean gather list from empty shards + gather_list = [el for el in gather_list if len(el) > 0] + + # If it was model generate tensors could be of different length + # so we aggregate results to list instead of a tensor + if generate: + res = sum((el.tolist() for el in gather_list), []) + else: + assert all(el.device.type == "cpu" for el in gather_list) + res = torch.cat(gather_list, dim=0) + + return res + + def worker_model_invoke(self): + assert self._group is not None + # if isinstance(self.group, dist.ProcessGroup): + if not isinstance(self._group, int): + # groups is None for world_size 1 + assert self._group.rank() != 0 + # on worker ranks the function need to wait to be called multiple times + while True: + scatter_list = None + obj_list = [None] + scatter_object_list( + self._distributed.device, + obj_list, + scatter_list, + group_src=0, + group=self._group, + ) + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + tuple(obj_list[0]) + ) + + if not continue_generate: + break + + # if some data was received, work, otherwise return empty tensor + if len(input_ids) > 0: + res = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + else: + res = input_ids + + gather_list = None + gather_object( + self._distributed.device, + res, + gather_list, + group_dst=0, + group=self._group, + ) + else: + # TODO: implement distributed model support + assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def stop_workers(self): + if self._group is None or (world_size := self._group.size()) == 1: + return + self._model_invoke(None, None, None, None, None, None, continue_generate=False) + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def _model_invoke_inner( + self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + ): + if generate: + return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) + else: + return self._model_call_inner(input_ids, attention_mask, labels) + + def _model_call(self, input_ids, attention_mask=None, labels=None): + return self._model_invoke( + input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True + ) + + def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + return self._model_invoke( + input_ids, + attention_mask, + None, + max_length, + stop, + generate=True, + continue_generate=True, + **generation_kwargs, + ) + + def _model_call_inner(self, input_ids, attention_mask=None, labels=None): + """ + :param input_ids: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attention_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + if attention_mask is not None or labels is not None: + assert attention_mask is not None and labels is not None + + # TODO: do we need no_grad for fast_llm model? + with torch.no_grad(): + # We move logits to the CPU because they will be copied across processes and nodes + # in a multi-GPU, multi-node setup and eventually collected on the main rank. + # We cannot afford to accumulate them on rank 0 GPU, as GPU memory may already be tight. + # CPU tensors are slower, but we typically have much more CPU RAM available. + + # TODO: Check if it's possible to move some of the _loglikelihood_tokens work here + # and pass only the results around instead of the full logits. + # Computing errors here is also preferable, as copying logits across nodes and GPUs + # is inefficient and can involve gigabytes of data. + return self._model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits.cpu() + + def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self._tokenizer, stop, input_ids.shape[1], input_ids.shape[0] + ) + + return self._model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self._tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self._backend == "causal": + special_tokens_kwargs = {"add_special_tokens": False or self._add_bos_token} + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self._tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: list[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self._tokenizer.padding_side + self._tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self._backend == "causal": + add_special_tokens = {"add_special_tokens": False or self._add_bos_token} + + encoding = self._tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self._tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: + if self._backend == "causal": + assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self._backend == "seq2seq": + assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False + ) -> list[float]: + adaptive_batch_size = None + if self._batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm.auto.tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + # The tokenizer may raise: "Token indices sequence length is longer than the specified maximum sequence length for this model" + # This is expected and fine, as the sequence will be split into chunks of max_length later. + rolling_token_windows: list[tuple[list[int], list[int]]] = list( + map( + lm_eval.utils.make_disjoint_window, + lm_eval.utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self._device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self._batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial("loglikelihood_rolling", (string,), request_total) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self._batch_schedule) + if sched in self._batch_sizes: + return self._batch_sizes[sched] + if (len(self._batch_sizes) > 1) and (self._batch_sizes[sched - 1] == self._max_batch_size): + # if previous batch size is already maximal, skip recomputation + self._batch_sizes[sched] = self._max_batch_size + return self._batch_sizes[sched] + print(f"Passed argument batch_size = auto:{self._batch_schedule}. Detecting largest batch size") + self._batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self._batch_sizes[sched]}") + return self._batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: list[tuple[tuple[str, str], list[int], list[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> list[tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: tuple[tuple[str, str], list[int], list[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + # NOTE: the group_fn Defines the key to group and lookup one-token continuations + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + re_ord = lm_eval.models.utils.Collator( + requests, + sort_fn=_collate, + group_by="contexts" if self._backend == "causal" and self._logits_cache else None, + group_fn=lambda req: req[-2] + req[-1][:-1], + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = self._batch_size if self._batch_size != "auto" else override_bs if override_bs is not None else 0 + batch_fn = ( + self._batch_scheduler + if self._batch_size == "auto" and n_reordered_requests > 0 and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm.auto.tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self._backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warning( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self._device, + ) + (inplen,) = inp.shape + elif self._backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self._device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self._device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen + + padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self._backend == "causal": + batched_inps = lm_eval.models.utils.pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self._backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] + batched_conts = lm_eval.models.utils.pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = lm_eval.models.utils.pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attention_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), dim=-1 + ) # [batch, padding_length (inp or cont), vocab] + + # TODO: Consider moving this part to per-shard execution in a multi-GPU and multi-node setup + # to avoid copying logits between GPUs and nodes, and to enable performing logits computations on the GPU. + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self._backend == "causal" else None + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + # NOTE: Currently, computations are performed on the CPU due to limited GPU memory. + cont_toks = torch.tensor(cont_toks, dtype=torch.long, device="cpu").unsqueeze(0) # [1, seq] + + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial("loglikelihood", request_str, answer) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: + res = [] + + pbar = tqdm.auto.tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self._batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self._batch_size + if self._batch_size != "auto" + else adaptive_batch_size if adaptive_batch_size is not None else 0 + ) + batch_fn = self._batch_scheduler if self._batch_size == "auto" and not adaptive_batch_size else None + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + # NOTE: for sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + re_ords = lm_eval.models.utils.Collator( + [reg.args for reg in requests], + sort_fn=lambda req: (-len(self.tok_encode(req[0])), req[0]), + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = lm_eval.models.utils.handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self._DEFAULT_MAX_GEN_TOKENS + + # set the max length in tokens of inputs ("context_enc") + if self._backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert ( + max_ctx_len > 0 + ), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + elif self._backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + input_ids, attention_mask = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self._truncation, + ) + input_ids = input_ids.to(self._device) + attention_mask = attention_mask.to(self._device) + + if "max_length" not in kwargs: + kwargs["max_length"] = input_ids.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + input_ids=input_ids, + attention_mask=attention_mask, + stop=until, + **kwargs, + ) + + # cont_toks_list = cont.tolist() + cont_toks_list = cont + + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self._backend == "causal": + cont_toks = cont_toks[input_ids.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self._tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning("Failed to apply chat template. removing the system role in chat history.") + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self._tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated diff --git a/fast_llm/engine/evaluation/lm_eval/utils.py b/fast_llm/engine/evaluation/lm_eval/utils.py new file mode 100644 index 00000000..5f7f9a94 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/utils.py @@ -0,0 +1,238 @@ +import argparse +import json +import logging +import os +import pathlib +import sys +from pathlib import Path + +import lm_eval.__main__ +import lm_eval.evaluator +import lm_eval.loggers +import lm_eval.tasks +import lm_eval.utils + +eval_logger = logging.getLogger(__name__) + + +def parse_eval_args(parser: argparse.ArgumentParser, args: list[str]) -> argparse.Namespace: + lm_eval.__main__.check_argument_types(parser) + return parser.parse_args(args) + + +def prepare_lm_eval_simple_eval_params( + cli_args: list[str], + completed_steps: int, + run_index: int, +) -> tuple[argparse.Namespace, dict[str, any]]: + """ + Parses CLI arguments for an LM evaluation run and prepares keyword arguments + for the `evaluate` function. + + This function wraps argument parsing, environment configuration, task resolution, + and metadata setup needed for evaluation with Fast-LLM's `lm_eval` wrapper. It also + handles special cases like hub token injection, dynamic sample loading, and task + listing commands. + + Args: + cli_args (list[str]): Command-line arguments, excluding the program name. + completed_steps (int): Current number of completed training steps, used to + uniquely tag evaluation output paths. + run_index (int): index of the current run of Fast-LLM experiment + + Returns: + tuple: + - argparse.Namespace: Parsed CLI arguments. + - dict: Keyword arguments to pass into `simple_evaluate`, including task list, + tracker, cache settings, random seeds, and generation parameters. + + Raises: + ValueError: If required fields like `--tasks` or `--output_path` are missing + when needed, or if misconfigured combinations are detected. + SystemExit: If special task listing flags are used. + """ + parser = lm_eval.__main__.setup_parser() + args = parse_eval_args(parser, cli_args) + + # NOTE: all this args are set by fast_llm on the model directly or not used here + assert not args.wandb_args # default empty string + assert not args.wandb_config_args # default empty string + assert args.model == "hf" # default value of 'hf' + assert not args.model_args # default empty string + assert args.batch_size == 1 # default value of 1 + assert args.max_batch_size is None + assert args.device is None + + # update the evaluation tracker args with the output path and the HF token + evaluation_tracker_args = "" + if args.output_path: + args.output_path = str(pathlib.Path(args.output_path) / f"runs/{run_index}/{completed_steps}") + evaluation_tracker_args += f",output_path={args.output_path}" + + evaluation_tracker_args = lm_eval.utils.simple_parse_args_string(evaluation_tracker_args) + evaluation_tracker = lm_eval.loggers.EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError("Specify --output_path if providing --log_samples or --predict_only") + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + eval_logger.info(f"Including path: {args.include_path}") + metadata = ( + lm_eval.utils.simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args if isinstance(args.model_args, dict) else {} + ) | (args.metadata if isinstance(args.metadata, dict) else lm_eval.utils.simple_parse_args_string(args.metadata)) + + task_manager = lm_eval.tasks.TaskManager(include_path=args.include_path, metadata=metadata) + + if args.limit: + eval_logger.warning( + " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) + if args.samples: + assert args.limit is None, "If --samples is not None, then --limit must be None." + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + eval_logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = lm_eval.utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = lm_eval.utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" + f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all" + " available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG'" + " to troubleshoot task registration issues." + ) + + ( + eval_logger.info(f"Selected Tasks: {task_names}") + if eval_logger.getEffectiveLevel() >= logging.INFO + else print(f"Selected Tasks: {task_names}") + ) + + request_caching_args = lm_eval.evaluator.request_caching_arg_to_dict(cache_requests=args.cache_requests) + + eval_kwargs = dict( + tasks=task_names, + num_fewshot=args.num_fewshot, + # batch_size=args.batch_size, + # max_batch_size=args.max_batch_size, + # device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + return args, eval_kwargs + + +def process_lm_eval_results( + args: argparse.Namespace, + results: dict[str, any], + evaluation_tracker: lm_eval.loggers.EvaluationTracker, + completed_steps: int | None, +) -> None: + if results is not None: + completed_steps = 0 if completed_steps is None else completed_steps + import wandb + + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps(results, indent=2, default=lm_eval.utils.handle_non_serializable, ensure_ascii=False) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging if we have the run to log to + # we expect the rest of the fast_llm code will finish the run. + if wandb.run is not None: + try: + wandb_logger = lm_eval.loggers.WandbLogger(init_args={"step": completed_steps}) + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples(task_name=task_name, samples=samples[task_name]) + + if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub: + evaluation_tracker.recreate_metadata_card() + + # TODO: convert to logging entries instead? + print( + f"{results["config"]["model"]}, gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {results["config"]["batch_size"]}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + print(lm_eval.utils.make_table(results)) + if "groups" in results: + print(lm_eval.utils.make_table(results, "groups")) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 3c2db428..54a82492 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -1,3 +1,4 @@ +import logging import os import pathlib import typing @@ -14,6 +15,8 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class HuggingfacePreTrainedModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig @@ -41,6 +44,8 @@ def __init__( # The HF constructor performs a deep copy of the config, # but config.fast_llm_config may contain non-picklable items like process groups. # Temporarily remove it before the call and restore it afterward. + # TODO: Find a clean solution — overriding __deepcopy__ doesn't work here + # because internally they use copy.deepcopy(self.__dict__). fast_llm_config = config.fast_llm_config config.fast_llm_config = None super().__init__(config, **kwargs) @@ -64,6 +69,11 @@ def __init__( with transformers.modeling_utils.no_init_weights(): self.post_init() + if fast_llm_model.config.multi_stage.zero_stage == 3: + logger.warning( + "zero_stage=3 is used for the model; forward and generate will be extremely slow during inference." + ) + @classmethod def from_pretrained( cls, diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 766398d0..1767a631 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -318,7 +318,7 @@ def _run_training(self) -> None: log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=True) def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # Tracking loss. @@ -339,6 +339,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._config.training.prefetch_factor, ) + has_test_phase = PhaseType.test in self._samples_per_split + log_main_rank("Training ...") # TODO: Synchronization is probably unnecessary. @@ -456,7 +458,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: ) if is_main_rank() and metrics: - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=not (done and has_test_phase)) stop = done or self._config.training.shutdown.enabled(self._completed_steps) diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index 185b89c2..724b5b71 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -44,12 +44,12 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): else: self._wandb = None - def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]]) -> None: + def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]], commit: bool) -> None: # Note: metrics modified in-place if self._wandb is not None: import wandb - wandb.log(metrics, step=completed_steps) # noqa + wandb.log(metrics, step=completed_steps, commit=commit) # noqa def alert(self, title, text, level="INFO", wait=0.001) -> None: if self._wandb is not None and self._config.alert.post_alerts: diff --git a/mkdocs.yaml b/mkdocs.yaml index ab71bc23..85fd4bff 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -179,6 +179,7 @@ nav: - Configuration: user_guide/configuration.md - Multi-Stage: user_guide/multi-stage.md - Parallelism: user_guide/parallelism.md + - Evaluators: user_guide/evaluators.md - Developer Guide: - Configuration: developer_guide/configuration.md - Model: