Skip to content

Adds lm_eval to evaluations #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cb744b2
copy from sandbox
bigximik Jun 20, 2025
0967483
changes for loss test for new tests structure
bigximik Jun 20, 2025
71ff61a
lm_eval integration changes for the new api
bigximik Jun 20, 2025
79fd43e
made lm_eval dependency lazy imported for optional dependency
bigximik Jun 20, 2025
2d9f479
removed hard coded batch size
bigximik Jun 20, 2025
7c62100
remved unncecessary set to evaluatation
bigximik Jun 25, 2025
c89d269
commit wandb step after finishing logging
bigximik Jun 26, 2025
9455cd5
support for env varieables for lm_eval integration
bigximik Jun 27, 2025
69180a3
merge from main
bigximik Jun 27, 2025
c9a3b18
user guide for evaluators added
bigximik Jun 27, 2025
426b5e3
fix tensor concatination for logits from different gpus
bigximik Jun 27, 2025
0bf8282
docs update
bigximik Jun 27, 2025
68f524b
removed manual test configs
bigximik Jun 27, 2025
a36e0be
added debug prints
bigximik Jun 27, 2025
9baa512
fix for gather_list and remove debug print
bigximik Jun 27, 2025
21678ab
removed debug print
bigximik Jun 28, 2025
7cccf9a
moved returned logits to cpu in lm_eval wrapper
bigximik Jun 28, 2025
7cd681a
fix to move all logits computations to cpu
bigximik Jun 30, 2025
59ff1e5
Merge branch 'main' of github.com:ServiceNow/Fast-LLM into denis/lm_eval
bigximik Jun 30, 2025
27e5de8
Merge branch 'main' of github.com:ServiceNow/Fast-LLM into denis/lm_eval
bigximik Jul 2, 2025
88faca0
fix typo
bigximik Jul 2, 2025
e3a4a6e
removed commented code, obsolete todo
bigximik Jul 2, 2025
89e67d2
changes to wrapper
bigximik Jul 2, 2025
6871359
refactorred lm_eval integration
bigximik Jul 2, 2025
6b74739
import change
bigximik Jul 2, 2025
c398444
zero stage 3 inference warning added and TODO
bigximik Jul 2, 2025
62846d2
removed docstrings
bigximik Jul 2, 2025
e61cc3e
removed unused fields, change generate call
bigximik Jul 3, 2025
6a2ab35
changed to all fields to be private, removed properties which are use…
bigximik Jul 3, 2025
6e1704f
Simplify scatter/gather
jlamypoirier Jul 8, 2025
2499b4e
clean up, more comments
bigximik Jul 9, 2025
44aa138
fixed tipo
bigximik Jul 9, 2025
f81a673
moved setting of NUMEXPR_MAX_THREADS
bigximik Jul 9, 2025
d56ce57
Evaluators renames
bigximik Jul 11, 2025
b32c91f
return change
bigximik Jul 11, 2025
93091dd
change local function to lambda
bigximik Jul 11, 2025
50e65ee
somme speedup
bigximik Jul 11, 2025
d32258e
fix not to log absent head output
bigximik Jul 11, 2025
98d1d77
added lm_eval integration tests
bigximik Jul 11, 2025
9f2de97
fix not removal comment for import
bigximik Jul 11, 2025
b451543
docs update
bigximik Jul 14, 2025
910d54e
scatter fix
bigximik Jul 14, 2025
077f2ac
fix offset normalization in validation
bigximik Jul 14, 2025
ac9025d
tests polishing
bigximik Jul 14, 2025
30d85df
more tests polishing
bigximik Jul 14, 2025
f60fa35
fixes
jlamypoirier Jul 15, 2025
ada41ca
Merge branch 'main' of github.com:ServiceNow/Fast-LLM into denis/lm_eval
bigximik Jul 15, 2025
2f5d2d0
changed prepare funciton to just copy traning runs
bigximik Jul 15, 2025
f05db2c
disabled test
bigximik Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions docs/user_guide/evaluators.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Evaluations

Fast-LLM allows you to perform various evaluations during training or as a separate evaluation step. In both cases, you need to use your training config with `training.evaluators` specified.

For evaluators used during training, both `interval` and `offset` must be specified. Then, start training as usual with:

`fast-llm train gpt --config path/to/training/config.yaml`

To perform evaluation as a separate step, use the same training config. Depending on the training progress, either the start model or the latest checkpoint will be loaded, and `interval` and `offset` will be ignored. To start evaluation:

`fast-llm evaluate gpt --config path/to/training/config.yaml`

## Currently Supported Evaluators

- `loss`
- `lm_eval`

## Loss Evaluator

To set up loss evaluation, specify a dataset to be used in the `data.datasets` section of the config. You must also define the loss evaluator in the `training.evaluators` config section. See example below.

```yaml
training:
evaluations:
stack_3b:
interval: 10
evaluator:
type: loss
iterations: 10
dataset_name: stack_3b
fineweb:
evaluator:
type: loss
iterations: 10
dataset_name: stack_3b
interval: 10
data:
datasets:
stack_3b:
type: memmap
path: path/to/memmap/dataset
fineweb:
type: memmap
path: path/to/memmap/dataset1
```

## Evaluation Harness (`lm_eval`) Evaluator

**Note:** Only data parallelism is currently supported for the `lm_eval` evaluator.

To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks.

The following environment variables may need to be set:

- `HF_HOME`: Path for Hugging Face data caching
- `WANDB_API_KEY_PATH`: Path to a file containing your Weights & Biases API key (if logging to W&B)
- `HUGGINGFACE_API_KEY_PATH`: Path to a file containing your Hugging Face hub token
- `NLTK_DATA`: Path to a directory that will contain downloaded NLTK packages (needed for some tasks)
- `HF_ALLOW_CODE_EVAL=1`: Required for some evaluation tasks

You may need to specify additional environment variables depending on the `lm_eval` tasks you want to run.

To specify an `lm_eval` task, the evaluator config includes the following fields:

### Model Config

The model instantiated for training is reused for evaluation, so you don't need to specify it separately. However, there are some parameters specific to `lm_eval`. See `fast_llm/engine/evaluation/config.EvaluatorLmEvalConfig` for details.

### CLI Parameters for `lm_eval`

All other parameters are specified as if you were calling the `lm_eval` CLI, using a list of strings. Some CLI parameters are ignored or restrictedβ€”specifically those related to model loading, W&B, batch sizes, and device setup, as these are managed by the rest of the Fast-LLM configuration.

Also, the tokenizer must be specified in `data.tokenizer`. If the tokenizer does not have a `bos_token`, it must be specified explicitly in `data.tokenizer.bos_token`. Although `lm_eval` does not use the `bos_token` directly, it is still required because the same tokenizer is used by other Fast-LLM components.

Below is an example of the config:

```yaml
training:
evaluations:
lm_eval_tasks1:
interval: 10
evaluator:
type: lm_eval
cli_args:
- --tasks
- gsm8k,xnli_en,wikitext,ifeval
- --output_path
- /path/to/lm_eval/output
data:
tokenizer:
path: path/to/the/tokenizer
```
11 changes: 11 additions & 0 deletions fast_llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import logging
import os
import sys
import traceback

Expand All @@ -8,6 +9,16 @@
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.config_utils.runnable import RunnableConfig

# This must be set before importing numexpr,
# because by default, the maximum number of threads is 64.
# On systems with more cores, numexpr logs an error and
# ignores the thread setting if it exceeds the limit.
if "NUMEXPR_MAX_THREADS" not in os.environ:
import multiprocessing

os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count())


# Import these submodules to ensure classes are added to the dynamic class registry.
import fast_llm.data.auto # isort: skip
import fast_llm.engine.checkpoint.convert # isort: skip
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def _get_class_name(cls) -> str:
@classmethod
def from_dict(
cls,
default: "Config| dict[str, typing.Any]]",
default: "Config| dict[str, typing.Any]",
*updates: "Config| dict[str | tuple[str, ...], typing.Any]",
strict: bool = True,
update_type: UpdateType = UpdateType.override,
Expand Down
Loading