Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "The simplest way to train and run adapters on top of foundation m
authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
license = { text = "MIT License" }
dependencies = [
"torch>=2.1.1",
"torch>=2.4.1",
"safetensors>=0.4.5",
"pillow>=10.4.0",
"jaxtyping>=0.2.23",
Expand Down
11 changes: 2 additions & 9 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
from enum import Enum
from logging import warning
from pathlib import Path
from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar
from typing import Annotated, Callable, Literal, Type, TypeVar

import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, BeforeValidator, ConfigDict
from torch import Tensor
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.optim.optimizer import Optimizer, ParamsT
from torch.optim.sgd import SGD

from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field

# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
# See https://github.com/pytorch/pytorch/pull/111114
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]


TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)]
IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)]
StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)]
Expand Down