Skip to content

Commit fb9c64b

Browse files
author
Judgment Release Bot
committed
[Bump Minor Version] Release: Merge staging to main
2 parents e7e28ec + 7d8386f commit fb9c64b

File tree

6 files changed

+562
-391
lines changed

6 files changed

+562
-391
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ Judgeval's agent monitoring infra provides a simple harness for integrating GRPO
3636
await trainer.train(
3737
agent_function=your_agent_function, # entry point to your agent
3838
scorers=[RewardScorer()], # Custom scorer you define based on task criteria, acts as reward
39-
prompts=training_prompts, # Tasks
40-
rft_provider="fireworks"
39+
prompts=training_prompts # Tasks
4140
)
4241
```
4342

src/judgeval/trainer/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from judgeval.trainer.trainer import JudgmentTrainer
22
from judgeval.trainer.config import TrainerConfig, ModelConfig
33
from judgeval.trainer.trainable_model import TrainableModel
4+
from judgeval.trainer.base_trainer import BaseTrainer
5+
from judgeval.trainer.fireworks_trainer import FireworksTrainer
46

5-
__all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]
7+
__all__ = [
8+
"JudgmentTrainer",
9+
"TrainerConfig",
10+
"ModelConfig",
11+
"TrainableModel",
12+
"BaseTrainer",
13+
"FireworksTrainer",
14+
]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
3+
from .config import TrainerConfig, ModelConfig
4+
from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
5+
6+
if TYPE_CHECKING:
7+
from judgeval.tracer import Tracer
8+
from .trainable_model import TrainableModel
9+
10+
11+
class BaseTrainer(ABC):
12+
"""
13+
Abstract base class for training providers.
14+
15+
This class defines the interface that all training provider implementations
16+
must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
17+
concrete implementation of this interface.
18+
"""
19+
20+
def __init__(
21+
self,
22+
config: TrainerConfig,
23+
trainable_model: "TrainableModel",
24+
tracer: "Tracer",
25+
project_name: Optional[str] = None,
26+
):
27+
"""
28+
Initialize the base trainer.
29+
30+
Args:
31+
config: TrainerConfig instance with training parameters
32+
trainable_model: TrainableModel instance to use for training
33+
tracer: Tracer for observability
34+
project_name: Project name for organizing training runs
35+
"""
36+
self.config = config
37+
self.trainable_model = trainable_model
38+
self.tracer = tracer
39+
self.project_name = project_name or "judgment_training"
40+
41+
@abstractmethod
42+
async def generate_rollouts_and_rewards(
43+
self,
44+
agent_function: Callable[[Any], Any],
45+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
46+
prompts: List[Any],
47+
num_prompts_per_step: Optional[int] = None,
48+
num_generations_per_prompt: Optional[int] = None,
49+
concurrency: Optional[int] = None,
50+
) -> Any:
51+
"""
52+
Generate rollouts and compute rewards using the current model snapshot.
53+
54+
Args:
55+
agent_function: Function/agent to call for generating responses
56+
scorers: List of scorer objects to evaluate responses
57+
prompts: List of prompts to use for training
58+
num_prompts_per_step: Number of prompts to use per step
59+
num_generations_per_prompt: Generations per prompt
60+
concurrency: Concurrency limit
61+
62+
Returns:
63+
Provider-specific dataset format for training
64+
"""
65+
pass
66+
67+
@abstractmethod
68+
async def run_reinforcement_learning(
69+
self,
70+
agent_function: Callable[[Any], Any],
71+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
72+
prompts: List[Any],
73+
) -> ModelConfig:
74+
"""
75+
Run the iterative reinforcement learning fine-tuning loop.
76+
77+
Args:
78+
agent_function: Function/agent to call for generating responses
79+
scorers: List of scorer objects to evaluate responses
80+
prompts: List of prompts to use for training
81+
82+
Returns:
83+
ModelConfig: Configuration of the trained model
84+
"""
85+
pass
86+
87+
@abstractmethod
88+
async def train(
89+
self,
90+
agent_function: Callable[[Any], Any],
91+
scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
92+
prompts: List[Any],
93+
) -> ModelConfig:
94+
"""
95+
Start the reinforcement learning fine-tuning process.
96+
97+
This is the main entry point for running the training.
98+
99+
Args:
100+
agent_function: Function/agent to call for generating responses
101+
scorers: List of scorer objects to evaluate responses
102+
prompts: List of prompts to use for training
103+
104+
Returns:
105+
ModelConfig: Configuration of the trained model
106+
"""
107+
pass
108+
109+
@abstractmethod
110+
def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
111+
"""
112+
Extract message history from spans for training purposes.
113+
114+
Returns:
115+
List of message dictionaries with 'role' and 'content' keys
116+
"""
117+
pass

src/judgeval/trainer/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TrainerConfig:
1616
user_id: str
1717
model_id: str
1818
base_model_name: str = "qwen2p5-7b-instruct"
19-
rft_provider: str = "fireworks"
19+
rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
2020
num_steps: int = 5
2121
num_generations_per_prompt: int = 4
2222
num_prompts_per_step: int = 4

0 commit comments

Comments
 (0)