Skip to content

Commit 27ecc04

Browse files
authored
Merge pull request #11 from codelion/fix-dim-issue-in-training
Fix dim issue in training
2 parents 768fea1 + c834412 commit 27ecc04

File tree

10 files changed

+691
-22
lines changed

10 files changed

+691
-22
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,4 @@ cython_debug/
172172

173173
**/.DS_store
174174
demo_classifier/
175+
scripts/benchmark_results/

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,40 @@ The system combines three key components:
102102
- safetensors ≥ 0.3.1
103103
- faiss-cpu ≥ 1.7.4 (or faiss-gpu for GPU support)
104104

105+
## Benefits of Adaptive Classification in LLM Routing
106+
107+
We evaluate the effectiveness of adaptive classification in optimizing LLM routing decisions. Using the arena-hard-auto-v0.1 dataset with 500 queries, we compared routing performance with and without adaptation while maintaining consistent overall success rates.
108+
109+
### Key Results
110+
111+
| Metric | Without Adaptation | With Adaptation | Impact |
112+
|--------|-------------------|-----------------|---------|
113+
| High Model Routes | 113 (22.6%) | 98 (19.6%) | 0.87x |
114+
| Low Model Routes | 387 (77.4%) | 402 (80.4%) | 1.04x |
115+
| High Model Success Rate | 40.71% | 29.59% | 0.73x |
116+
| Low Model Success Rate | 16.54% | 20.15% | 1.22x |
117+
| Overall Success Rate | 22.00% | 22.00% | 1.00x |
118+
| Cost Savings* | 25.60% | 32.40% | 1.27x |
119+
120+
*Cost savings calculation assumes high-cost model is 2x the cost of low-cost model
121+
122+
### Analysis
123+
124+
The results highlight several key benefits of adaptive classification:
125+
126+
1. **Improved Cost Efficiency**: While maintaining the same overall success rate (22%), the adaptive classifier achieved 32.40% cost savings compared to 25.60% without adaptation - a relative improvement of 1.27x in cost efficiency.
127+
128+
2. **Better Resource Utilization**: The adaptive system routed more queries to the low-cost model (402 vs 387) while reducing high-cost model usage (98 vs 113), demonstrating better resource allocation.
129+
130+
3. **Learning from Experience**: Through adaptation, the system improved the success rate of low-model routes from 16.54% to 20.15% (1.22x increase), showing effective learning from successful cases.
131+
132+
4. **ROI on Adaptation**: The system adapted to 110 new examples during evaluation, leading to a 6.80% improvement in cost savings while maintaining quality - demonstrating significant return on the adaptation investment.
133+
134+
This real-world evaluation demonstrates that adaptive classification can significantly improve cost efficiency in LLM routing without compromising overall performance.
135+
105136
## References
106137

138+
- [RouteLLM: Learning to Route LLMs with Preference Data](https://arxiv.org/abs/2406.18665)
107139
- [Transformer^2: Self-adaptive LLMs](https://arxiv.org/abs/2501.06252)
108140
- [Lamini Classifier Agent Toolkit](https://www.lamini.ai/blog/classifier-agent-toolkit)
109141
- [Protoformer: Embedding Prototypes for Transformers](https://arxiv.org/abs/2206.12710)

scripts/adaptive_router/config.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
3.39 MB
Binary file not shown.

scripts/eval_llmrouter_arena.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
import argparse
2+
import json
3+
import os
4+
import logging
5+
import time
6+
from typing import List, Dict, Optional, Tuple
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
10+
from openai import OpenAI
11+
from datasets import load_dataset
12+
from tqdm import tqdm
13+
from adaptive_classifier import AdaptiveClassifier
14+
from sklearn.feature_extraction.text import TfidfVectorizer
15+
from sklearn.metrics.pairwise import cosine_similarity
16+
17+
# Configure logging
18+
logging.basicConfig(level=logging.INFO)
19+
logger = logging.getLogger(__name__)
20+
21+
# Initialize OpenAI client
22+
client = OpenAI(base_url="http://localhost:8000/v1", api_key=os.environ.get("OPENAI_API_KEY"))
23+
# client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
24+
25+
@dataclass
26+
class RouterConfig:
27+
"""Configuration for the LLM Router evaluation."""
28+
high_model: str = "gpt-4o"
29+
low_model: str = "gpt-4o-mini"
30+
similarity_threshold: float = 0.6
31+
max_retries: int = 3
32+
retry_delay: int = 1
33+
adaptive_router_path: str = "./adaptive_router"
34+
35+
class LLMRouter:
36+
"""Router class to direct queries to appropriate models."""
37+
38+
def __init__(self, config: RouterConfig, enable_adaptation: bool = True):
39+
"""Initialize the router with classifier and configuration."""
40+
self.config = config
41+
self.enable_adaptation = enable_adaptation
42+
self.classifier = AdaptiveClassifier.load(config.adaptive_router_path)
43+
self.stats = {
44+
"total_queries": 0,
45+
"high_routes": 0,
46+
"low_routes": 0,
47+
"high_success": 0,
48+
"low_success": 0,
49+
"adapted_examples": 0
50+
}
51+
52+
def route_and_evaluate(self, query: str) -> Tuple[bool, Dict]:
53+
"""Route query to appropriate model and evaluate results."""
54+
# Get routing decision
55+
predictions = self.classifier.predict(query)
56+
route = predictions[0][0] # Get top prediction
57+
58+
# Select model based on route
59+
model = self.config.high_model if route == "HIGH" else self.config.low_model
60+
61+
# Update stats
62+
self.stats["total_queries"] += 1
63+
if route == "HIGH":
64+
self.stats["high_routes"] += 1
65+
else:
66+
self.stats["low_routes"] += 1
67+
68+
# Perform RTC evaluation
69+
passed_rtc, similarity_score, details = perform_rtc_evaluation(
70+
query, model, self.config
71+
)
72+
73+
# Update success stats
74+
if passed_rtc:
75+
if route == "HIGH":
76+
self.stats["high_success"] += 1
77+
else:
78+
self.stats["low_success"] += 1
79+
80+
# Adapt if enabled and RTC passed
81+
if self.enable_adaptation and passed_rtc:
82+
self.adapt_to_example(query, route)
83+
self.stats["adapted_examples"] += 1
84+
85+
evaluation_result = {
86+
"query": query,
87+
"route": route,
88+
"model": model,
89+
"passed_rtc": passed_rtc,
90+
"similarity_score": similarity_score,
91+
"evaluation_details": details
92+
}
93+
94+
return passed_rtc, evaluation_result
95+
96+
def adapt_to_example(self, query: str, label: str):
97+
"""Add successful example to classifier."""
98+
if self.enable_adaptation:
99+
self.classifier.add_examples([query], [label])
100+
101+
def save_classifier(self):
102+
"""Save the adapted classifier."""
103+
if self.enable_adaptation:
104+
self.classifier.save(self.config.adaptive_router_path)
105+
106+
def get_stats(self) -> Dict:
107+
"""Get routing statistics."""
108+
stats = self.stats.copy()
109+
stats["high_success_rate"] = (
110+
stats["high_success"] / stats["high_routes"]
111+
if stats["high_routes"] > 0 else 0
112+
)
113+
stats["low_success_rate"] = (
114+
stats["low_success"] / stats["low_routes"]
115+
if stats["low_routes"] > 0 else 0
116+
)
117+
stats["overall_success_rate"] = (
118+
(stats["high_success"] + stats["low_success"]) / stats["total_queries"]
119+
if stats["total_queries"] > 0 else 0
120+
)
121+
stats["cost_saving_ratio"] = (
122+
stats["low_success"] / stats["total_queries"]
123+
if stats["total_queries"] > 0 else 0
124+
)
125+
return stats
126+
127+
def perform_rtc_evaluation(
128+
query: str,
129+
model: str,
130+
config: RouterConfig
131+
) -> Tuple[bool, float, Dict]:
132+
"""Perform Round-Trip Correctness evaluation."""
133+
# Get initial response
134+
response_1 = get_llm_response([
135+
{"role": "user", "content": query}
136+
], model, config)
137+
138+
if not response_1:
139+
return False, 0.0, {"error": "Failed to get initial response"}
140+
141+
# Generate alternate query
142+
inverse_prompt = f"""Given this query and response pair, generate a new query that would lead to a similar response. Focus on the key aspects that would generate equivalent content:
143+
144+
Original Query: {query}
145+
Response: {response_1}
146+
147+
Generate a new query that would elicit a similar response:"""
148+
149+
alternate_query = get_llm_response([
150+
{"role": "user", "content": inverse_prompt}
151+
], model, config)
152+
153+
if not alternate_query:
154+
return False, 0.0, {"error": "Failed to generate alternate query"}
155+
156+
# Get response for alternate query
157+
response_2 = get_llm_response([
158+
{"role": "user", "content": alternate_query}
159+
], model, config)
160+
161+
if not response_2:
162+
return False, 0.0, {"error": "Failed to get second response"}
163+
164+
# Compute similarity
165+
similarity_score = compute_similarity(response_1, response_2)
166+
167+
evaluation_details = {
168+
"original_query": query,
169+
"response_1": response_1,
170+
"alternate_query": alternate_query,
171+
"response_2": response_2,
172+
"similarity_score": similarity_score
173+
}
174+
175+
return similarity_score >= config.similarity_threshold, similarity_score, evaluation_details
176+
177+
def get_llm_response(
178+
messages: List[Dict],
179+
model: str,
180+
config: RouterConfig
181+
) -> Optional[str]:
182+
"""Get response from the LLM with retry logic."""
183+
for attempt in range(config.max_retries):
184+
try:
185+
response = client.chat.completions.create(
186+
model=model,
187+
messages=messages,
188+
max_tokens=4096
189+
)
190+
return response.choices[0].message.content.strip()
191+
except Exception as e:
192+
logger.error(f"Error getting LLM response (attempt {attempt + 1}): {e}")
193+
if attempt < config.max_retries - 1:
194+
time.sleep(config.retry_delay)
195+
continue
196+
return None
197+
198+
def compute_similarity(text1: str, text2: str) -> float:
199+
"""Compute cosine similarity between two texts using TF-IDF."""
200+
try:
201+
vectorizer = TfidfVectorizer(stop_words='english')
202+
tfidf_matrix = vectorizer.fit_transform([text1, text2])
203+
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
204+
return float(similarity)
205+
except Exception as e:
206+
logger.error(f"Error computing similarity: {e}")
207+
return 0.0
208+
209+
def extract_first_turn_content(turns: List[Dict]) -> str:
210+
"""Extract content from first turn in conversation."""
211+
if not turns or not isinstance(turns, list):
212+
return ""
213+
return turns[0].get("content", "")
214+
215+
def evaluate_dataset(config: RouterConfig, enable_adaptation: bool, output_file: str):
216+
"""Evaluate the dataset using the LLM router."""
217+
# Initialize router
218+
router = LLMRouter(config, enable_adaptation=enable_adaptation)
219+
220+
# Load dataset
221+
dataset = load_dataset("lmarena-ai/arena-hard-auto-v0.1")
222+
223+
results = []
224+
225+
# Process each example
226+
for item in tqdm(dataset["train"], desc="Evaluating examples"):
227+
query = extract_first_turn_content(item["turns"])
228+
if not query:
229+
continue
230+
231+
passed_rtc, evaluation_result = router.route_and_evaluate(query)
232+
results.append(evaluation_result)
233+
234+
# Save intermediate results
235+
save_results(output_file, router, results)
236+
237+
# Save final state if adaptation was enabled
238+
if enable_adaptation:
239+
router.save_classifier()
240+
241+
# Print final summary
242+
print_summary(router)
243+
244+
def save_results(output_file: str, router: LLMRouter, results: List[Dict]):
245+
"""Save evaluation results to file."""
246+
with open(output_file, 'w') as f:
247+
json.dump({
248+
"stats": router.get_stats(),
249+
"results": results
250+
}, f, indent=2)
251+
252+
def print_summary(router: LLMRouter):
253+
"""Print evaluation summary."""
254+
stats = router.get_stats()
255+
256+
logger.info("\nEvaluation Summary:")
257+
logger.info(f"Total queries processed: {stats['total_queries']}")
258+
logger.info(f"High-model routes: {stats['high_routes']}")
259+
logger.info(f"Low-model routes: {stats['low_routes']}")
260+
logger.info(f"High-model successes: {stats['high_success']}")
261+
logger.info(f"Low-model successes: {stats['low_success']}")
262+
logger.info(f"Adapted examples: {stats['adapted_examples']}")
263+
logger.info(f"High-model success rate: {stats['high_success_rate']*100:.2f}%")
264+
logger.info(f"Low-model success rate: {stats['low_success_rate']*100:.2f}%")
265+
logger.info(f"Overall success rate: {stats['overall_success_rate']*100:.2f}%")
266+
logger.info(f"Potential cost savings: {stats['cost_saving_ratio']*100:.2f}%")
267+
268+
def main():
269+
parser = argparse.ArgumentParser(
270+
description="Evaluate LLM router on arena-hard-auto dataset"
271+
)
272+
parser.add_argument(
273+
"--high-model",
274+
type=str,
275+
default="gpt-4o",
276+
help="Model to use for high-complexity queries"
277+
)
278+
parser.add_argument(
279+
"--low-model",
280+
type=str,
281+
default="gpt-4o-mini",
282+
help="Model to use for low-complexity queries"
283+
)
284+
parser.add_argument(
285+
"--without-adaptation",
286+
action="store_true",
287+
help="Disable adaptive learning during evaluation"
288+
)
289+
parser.add_argument(
290+
"--output",
291+
type=str,
292+
default="router_eval_results.json",
293+
help="Output file for results"
294+
)
295+
parser.add_argument(
296+
"--router-path",
297+
type=str,
298+
default="./adaptive_router",
299+
help="Path to load/save the adaptive router"
300+
)
301+
302+
args = parser.parse_args()
303+
304+
# Create results directory
305+
os.makedirs("benchmark_results", exist_ok=True)
306+
output_file = os.path.join("benchmark_results", args.output)
307+
308+
# Create configuration
309+
config = RouterConfig(
310+
high_model=args.high_model,
311+
low_model=args.low_model,
312+
adaptive_router_path=args.router_path
313+
)
314+
315+
# Run evaluation
316+
evaluate_dataset(
317+
config,
318+
enable_adaptation=not args.without_adaptation,
319+
output_file=output_file
320+
)
321+
322+
if __name__ == "__main__":
323+
main()

0 commit comments

Comments
 (0)