Skip to content

Adds lm_eval to evaluations #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cb744b2
copy from sandbox
bigximik Jun 20, 2025
0967483
changes for loss test for new tests structure
bigximik Jun 20, 2025
71ff61a
lm_eval integration changes for the new api
bigximik Jun 20, 2025
79fd43e
made lm_eval dependency lazy imported for optional dependency
bigximik Jun 20, 2025
2d9f479
removed hard coded batch size
bigximik Jun 20, 2025
7c62100
remved unncecessary set to evaluatation
bigximik Jun 25, 2025
c89d269
commit wandb step after finishing logging
bigximik Jun 26, 2025
9455cd5
support for env varieables for lm_eval integration
bigximik Jun 27, 2025
69180a3
merge from main
bigximik Jun 27, 2025
c9a3b18
user guide for evaluators added
bigximik Jun 27, 2025
426b5e3
fix tensor concatination for logits from different gpus
bigximik Jun 27, 2025
0bf8282
docs update
bigximik Jun 27, 2025
68f524b
removed manual test configs
bigximik Jun 27, 2025
a36e0be
added debug prints
bigximik Jun 27, 2025
9baa512
fix for gather_list and remove debug print
bigximik Jun 27, 2025
21678ab
removed debug print
bigximik Jun 28, 2025
7cccf9a
moved returned logits to cpu in lm_eval wrapper
bigximik Jun 28, 2025
7cd681a
fix to move all logits computations to cpu
bigximik Jun 30, 2025
59ff1e5
Merge branch 'main' of github.com:ServiceNow/Fast-LLM into denis/lm_eval
bigximik Jun 30, 2025
27e5de8
Merge branch 'main' of github.com:ServiceNow/Fast-LLM into denis/lm_eval
bigximik Jul 2, 2025
88faca0
fix typo
bigximik Jul 2, 2025
e3a4a6e
removed commented code, obsolete todo
bigximik Jul 2, 2025
89e67d2
changes to wrapper
bigximik Jul 2, 2025
6871359
refactorred lm_eval integration
bigximik Jul 2, 2025
6b74739
import change
bigximik Jul 2, 2025
c398444
zero stage 3 inference warning added and TODO
bigximik Jul 2, 2025
62846d2
removed docstrings
bigximik Jul 2, 2025
e61cc3e
removed unused fields, change generate call
bigximik Jul 3, 2025
6a2ab35
changed to all fields to be private, removed properties which are use…
bigximik Jul 3, 2025
6e1704f
Simplify scatter/gather
jlamypoirier Jul 8, 2025
2499b4e
clean up, more comments
bigximik Jul 9, 2025
44aa138
fixed tipo
bigximik Jul 9, 2025
f81a673
moved setting of NUMEXPR_MAX_THREADS
bigximik Jul 9, 2025
d56ce57
Evaluators renames
bigximik Jul 11, 2025
b32c91f
return change
bigximik Jul 11, 2025
93091dd
change local function to lambda
bigximik Jul 11, 2025
50e65ee
somme speedup
bigximik Jul 11, 2025
d32258e
fix not to log absent head output
bigximik Jul 11, 2025
98d1d77
added lm_eval integration tests
bigximik Jul 11, 2025
9f2de97
fix not removal comment for import
bigximik Jul 11, 2025
b451543
docs update
bigximik Jul 14, 2025
910d54e
scatter fix
bigximik Jul 14, 2025
077f2ac
fix offset normalization in validation
bigximik Jul 14, 2025
ac9025d
tests polishing
bigximik Jul 14, 2025
30d85df
more tests polishing
bigximik Jul 14, 2025
f60fa35
fixes
jlamypoirier Jul 15, 2025
ada41ca
Merge branch 'main' of github.com:ServiceNow/Fast-LLM into denis/lm_eval
bigximik Jul 15, 2025
2f5d2d0
changed prepare funciton to just copy traning runs
bigximik Jul 15, 2025
f05db2c
disabled test
bigximik Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions fast_llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import logging
import os
import sys
import traceback

Expand All @@ -9,16 +8,6 @@
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.config_utils.runnable import RunnableConfig

# This must be set before importing numexpr,
# because by default, the maximum number of threads is 64.
# On systems with more cores, numexpr logs an error and
# ignores the thread setting if it exceeds the limit.
if "NUMEXPR_MAX_THREADS" not in os.environ:
import multiprocessing

os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count())


# Import these submodules to ensure classes are added to the dynamic class registry.
import fast_llm.data.auto # isort: skip
import fast_llm.engine.checkpoint.convert # isort: skip
Expand Down
9 changes: 9 additions & 0 deletions fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def get_run(self, distributed: "Distributed") -> "Run":
return run

def _set_external_variables(self) -> None:
# This must be set before importing numexpr,
# because by default, the maximum number of threads is 64.
# On systems with more cores, numexpr logs an error and
# ignores the thread setting if it exceeds the limit.
if "NUMEXPR_MAX_THREADS" not in os.environ:
import multiprocessing

os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count())

import torch._dynamo

# TODO: Find an alternative to get reliable tensor-parallel overlap.
Expand Down
16 changes: 8 additions & 8 deletions fast_llm/engine/evaluation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, EvaluatorLoss
from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator


@config_class()
Expand Down Expand Up @@ -40,7 +40,7 @@ def _from_dict(


@config_class(dynamic_type={EvaluatorConfig: "loss"})
class EvaluatorLossConfig(EvaluatorConfig):
class LossEvaluatorConfig(EvaluatorConfig):
_abstract: typing.ClassVar[bool] = False

iterations: int | None = Field(
Expand All @@ -58,14 +58,14 @@ def get_evaluator(
batch_config: BatchConfig,
data_load_num_proc: int,
train_iters: int | None = None,
) -> "EvaluatorLoss":
from fast_llm.engine.evaluation.evaluator import EvaluatorLoss
) -> "LossEvaluator":
from fast_llm.engine.evaluation.evaluator import LossEvaluator

return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters)
return LossEvaluator(name, self, batch_config, data_load_num_proc, train_iters)


@config_class(dynamic_type={EvaluatorConfig: "lm_eval"})
class EvaluatorLmEvalConfig(EvaluatorConfig):
class LmEvalEvaluatorConfig(EvaluatorConfig):
_abstract: typing.ClassVar[bool] = False

cli_args: list[str] = Field(
Expand Down Expand Up @@ -110,6 +110,6 @@ def get_evaluator(
data_load_num_proc: int,
train_iters: int | None = None,
) -> "EvaluatorLmEval":
from fast_llm.engine.evaluation.lm_eval.evaluator import EvaluatorLmEval
from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator

return EvaluatorLmEval(name, self, batch_config, data_load_num_proc, train_iters)
return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters)
8 changes: 4 additions & 4 deletions fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.engine.config_utils.run import Run, log_main_rank
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, EvaluatorLossConfig
from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, LossEvaluatorConfig
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.engine.schedule.runner import ScheduleRunner
Expand Down Expand Up @@ -50,7 +50,7 @@ class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC):
def __init__(
self,
name: str,
eval_config: EvaluatorLossConfig,
eval_config: LossEvaluatorConfig,
batch_config: BatchConfig,
data_load_num_proc: int,
train_iters: int | None = None,
Expand Down Expand Up @@ -94,8 +94,8 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None:
"""


class EvaluatorLoss[ConfigType: EvaluatorLossConfig](Evaluator[ConfigType]):
config_class: typing.ClassVar[type[EvaluatorLossConfig]] = EvaluatorLossConfig
class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]):
config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig

def setup(
self,
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/engine/evaluation/lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fast_llm.engine.config_utils.run import Run
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.evaluation.config import EvaluatorLmEvalConfig
from fast_llm.engine.evaluation.config import LmEvalEvaluatorConfig
from fast_llm.engine.evaluation.evaluator import (
EvaluationMetrics,
Evaluator,
Expand All @@ -24,8 +24,8 @@
logger = logging.getLogger(__name__)


class EvaluatorLmEval[ConfigType: EvaluatorLmEvalConfig](Evaluator[ConfigType]):
config_class: typing.ClassVar[type[EvaluatorLmEvalConfig]] = EvaluatorLmEvalConfig
class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]):
config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig

_hf_model: "HuggingfaceBaseModelForCausalLM" = None
_flm_wrapper: "FastLLMLmEvalWrapper" = None
Expand Down
34 changes: 15 additions & 19 deletions fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _model_invoke(
# TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple.
# Messages could include types like logits, generate, finished.

# Groups is always None if world size is 1
# Group is always None if world size is 1
if self._group is None:
# Must not be called with continue_generate false on one process
assert continue_generate
Expand All @@ -208,7 +208,7 @@ def _model_invoke(
if generate:
assert max_length is not None and stop is not None

# always divide by batch_size, if not full batch, some ranks will get less work or not at all
# always divide by world_size, if not full batch, some ranks will get less work or not at all
assert self._batch_size % world_size == 0
step = self._batch_size // world_size

Expand Down Expand Up @@ -242,8 +242,8 @@ def _model_invoke(
)
)

if continue_generate == False:
return
if not continue_generate:
return None

assert len(input_ids) > 0

Expand Down Expand Up @@ -280,6 +280,7 @@ def worker_model_invoke(self):
)
)

# Stop signal was send, end waiting/processing loop
if not continue_generate:
break

Expand All @@ -298,7 +299,8 @@ def worker_model_invoke(self):
safe_barrier(self._distributed.world_group, "lm_eval_end")

def stop_workers(self):
if self._group is None or (world_size := self._group.size()) == 1:
# Group is always None if world size is 1
if self._group is None:
return
self._model_invoke(None, None, None, None, None, None, continue_generate=False)
safe_barrier(self._distributed.world_group, "lm_eval_end")
Expand Down Expand Up @@ -581,26 +583,20 @@ def _loglikelihood_tokens(
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []

def _collate(req: tuple[tuple[str, str], list[int], list[int]]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end

toks = req[1] + req[2]
return -len(toks), tuple(toks)

# NOTE: the group_fn Defines the key to group and lookup one-token continuations
# NOTE: for the sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
# NOTE: the group_fn Defines the key to group and lookup one-token continuations
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
re_ord = lm_eval.models.utils.Collator(
requests,
sort_fn=_collate,
sort_fn=lambda req: (-(len(req[1]) + len(req[2])), tuple(req[1]) + tuple(req[2])),
group_by="contexts" if self._backend == "causal" and self._logits_cache else None,
group_fn=lambda req: req[-2] + req[-1][:-1],
)
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def forward(
losses,
metrics,
)
self._log_layer_forward(output, kwargs, i)
if output is not None:
self._log_layer_forward(output, kwargs, i)

# TODO: very slow and memory consuming, only use for debugging for now
# TODO: decide if and how we want to return
Expand Down
85 changes: 85 additions & 0 deletions tests/models/test_lm_eval_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import huggingface_hub
import pytest
import transformers

from tests.models.test_checkpoint import _prepare_resume_fn
from tests.utils.model_configs import ModelTestingGroup
from tests.utils.utils import requires_cuda, requires_lm_eval

# NOTE: These tests only verify that the functionality runs without crashing.
# NOTE: The tokenizer is from a LLaMA-style model, which may not be suitable for all models,
# but it should be sufficient since we are not concerned with actual accuracy in this tests.


@pytest.fixture(scope="module")
def model_path(result_path):
return huggingface_hub.snapshot_download(
repo_id="HuggingFaceTB/SmolLM2-135M-Instruct",
local_dir=result_path / "lm_eval/model",
)


def get_lm_eval_config(base_path, tokenizer_path):
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)
return [
f"data.tokenizer.path={tokenizer_path}",
f"model.base_model.vocab_size={tokenizer.vocab_size}",
"training.evaluators.evaluation_test.interval=1",
"training.evaluators.evaluation_test.evaluator.type=lm_eval",
"training.evaluators.evaluation_test.evaluator.cli_args="
f'["--tasks","gsm8k,xnli_en,wikitext","--output_path","{str(base_path / "lm_eval")}","--limit","10"]',
]


@pytest.mark.extra_slow
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does this take? It would be worrying not to have any tests other than extra-slow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very long 40-80 sec per test

@requires_lm_eval
@requires_cuda
@pytest.mark.model_testing_group(ModelTestingGroup.basic)
def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, model_path):
run_test_script_for_all_models(
get_lm_eval_config(run_test_script_base_path / "test_lm_eval_in_training", model_path)
+ ["training.checkpoint.interval=1"]
)


@pytest.mark.extra_slow
@requires_lm_eval
@requires_cuda
@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.basic)
def test_lm_eval_evaluation(run_test_script_for_all_models, run_test_script_base_path, model_path):
run_test_script_for_all_models(
get_lm_eval_config(run_test_script_base_path / "test_lm_eval_evaluation", model_path),
compare="test_lm_eval_in_training",
prepare_fn=_prepare_resume_fn,
do_compare=False,
task="evaluate",
)


@pytest.mark.extra_slow
@requires_lm_eval
@requires_cuda
@pytest.mark.model_testing_group(ModelTestingGroup.distributed)
def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, model_path):
run_test_script_for_all_models(
get_lm_eval_config(run_test_script_base_path / "test_lm_eval_in_training_dp2", model_path)
+ ["training.checkpoint.interval=1"],
num_gpus=2,
)


@pytest.mark.extra_slow
@requires_lm_eval
@requires_cuda
@pytest.mark.depends_on(on=["test_lm_eval_in_training_dp2[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.distributed)
def test_lm_eval_evaluation_dp2(run_test_script_for_all_models, run_test_script_base_path, model_path):
run_test_script_for_all_models(
get_lm_eval_config(run_test_script_base_path / "test_lm_eval_evaluation_dp2", model_path),
compare="test_lm_eval_in_training_dp2",
prepare_fn=_prepare_resume_fn,
do_compare=False,
num_gpus=2,
task="evaluate",
)
3 changes: 3 additions & 0 deletions tests/utils/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pytest

from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.evaluation.evaluators import ( # noqa: F401 # needed for dynamic type registration
EvaluatorsConfig,
)
from fast_llm.engine.multi_stage.config import FastLLMModelConfig
from fast_llm.engine.training.config import TrainerConfig
from fast_llm.models.gpt.config import (
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/run_test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def do_run_test_script(
do_compare: bool = True,
rendezvous_port: int,
torchrun_port: int,
task: str = "train",
):
is_parallel = DistributedConfig.default_world_size > 1
if is_parallel:
Expand All @@ -83,7 +84,7 @@ def do_run_test_script(
if is_megatron:
args = ["Megatron-LM/pretrain_gpt.py", *args, f"--structured-logs-dir={path}", f"--data-cache-path={path}"]
else:
args = ["--no-python", "fast-llm", "train", model_type, *args, f"run.experiment_dir={path}"]
args = ["--no-python", "fast-llm", task, model_type, *args, f"run.experiment_dir={path}"]
get_test_dataset()
if (num_gpus == 1 or is_parallel) and not is_megatron:
print(" ".join(args[1:]))
Expand Down Expand Up @@ -117,6 +118,7 @@ def do_run_test_script_for_all_models(
test_name: str,
base_path: pathlib.Path,
model_testing_config: ModelTestingConfig,
task: str = "train",
):
do_run_test_script(
base_path / test_name,
Expand All @@ -131,6 +133,7 @@ def do_run_test_script_for_all_models(
do_compare=do_compare,
rendezvous_port=rendezvous_port,
torchrun_port=torchrun_port,
task=task,
)


Expand Down
9 changes: 9 additions & 0 deletions tests/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")

try:
import lm_eval # noqa: F401

_lm_eval_installed = True
except ImportError:
_lm_eval_installed = False

requires_lm_eval = pytest.mark.skipif(not _lm_eval_installed, reason="lm_eval is not installed")


TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests")

Expand Down