Skip to content

Commit e0f1362

Browse files
authored
Merge pull request #1213 from flyinfish/discussions/1206-order
Provide Evaluations in same order as they where submitted
2 parents b980304 + 4c93baf commit e0f1362

File tree

3 files changed

+90
-37
lines changed

3 files changed

+90
-37
lines changed

testing/scorer/scorer-core/src/main/java/io/quarkiverse/langchain4j/testing/scorer/EvaluationReport.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
/**
99
* Report of the evaluation of a set of samples.
1010
*/
11-
public class EvaluationReport {
11+
public class EvaluationReport<T> {
1212

13-
private final List<Scorer.EvaluationResult<?>> evaluations;
13+
private final List<Scorer.EvaluationResult<T>> evaluations;
1414
private final double score;
1515

1616
/**
1717
* Create a new evaluation report and computes the global score.
1818
*
1919
* @param evaluations the evaluations, must not be {@code null}, must not be empty.
2020
*/
21-
public EvaluationReport(List<Scorer.EvaluationResult<?>> evaluations) {
21+
public EvaluationReport(List<Scorer.EvaluationResult<T>> evaluations) {
2222
this.evaluations = evaluations;
2323
this.score = 100.0 * evaluations.stream().filter(Scorer.EvaluationResult::passed).count() / evaluations.size();
2424
}
@@ -33,7 +33,7 @@ public double score() {
3333
/**
3434
* @return the evaluations
3535
*/
36-
public List<Scorer.EvaluationResult<?>> evaluations() {
36+
public List<Scorer.EvaluationResult<T>> evaluations() {
3737
return evaluations;
3838
}
3939

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.testing.scorer;
22

33
import java.io.Closeable;
4+
import java.util.Comparator;
45
import java.util.List;
56
import java.util.concurrent.CopyOnWriteArrayList;
67
import java.util.concurrent.CountDownLatch;
@@ -28,50 +29,64 @@ public Scorer() {
2829
}
2930

3031
@SuppressWarnings({ "unchecked" })
31-
public <T> EvaluationReport evaluate(Samples<T> samples, Function<Parameters, T> function,
32-
EvaluationStrategy<T>... strategies) {
33-
List<EvaluationResult<?>> evaluations = new CopyOnWriteArrayList<>();
32+
public <T> EvaluationReport<T> evaluate(
33+
Samples<T> samples, Function<Parameters, T> function, EvaluationStrategy<T>... strategies) {
34+
List<OrderedEvaluationResult<T>> evaluations = new CopyOnWriteArrayList<>();
3435
CountDownLatch latch = new CountDownLatch(samples.size());
36+
var index = 0;
3537
for (EvaluationSample<T> sample : samples) {
3638
// TODO Should we handle the context somehow.
37-
executor.submit(() -> {
38-
try {
39-
var response = execute(sample, function);
40-
LOG.infof("Evaluating sample `%s`", sample.name());
41-
for (EvaluationStrategy<T> strategy : strategies) {
42-
EvaluationResult<T> evaluation = EvaluationResult.fromCompletedEvaluation(sample,
43-
response, strategy.evaluate(sample, response));
44-
LOG.infof("Evaluation of sample `%s` with strategy `%s`: %s", sample.name(),
45-
strategy.getClass().getSimpleName(),
46-
evaluation.passed() ? "OK" : "KO");
47-
evaluations.add(evaluation);
48-
}
49-
} catch (Throwable e) {
50-
LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name());
51-
evaluations.add(EvaluationResult.fromEvaluationThrowable(sample, e));
52-
} finally {
53-
latch.countDown();
54-
}
55-
});
39+
var currentIndex = index++;
40+
executor.submit(
41+
() -> {
42+
try {
43+
var response = execute(sample, function);
44+
LOG.infof("Evaluating sample `%s`", sample.name());
45+
for (EvaluationStrategy<T> strategy : strategies) {
46+
EvaluationResult<T> evaluation = EvaluationResult.fromCompletedEvaluation(
47+
sample, response, strategy.evaluate(sample, response));
48+
LOG.infof(
49+
"Evaluation of sample `%s` with strategy `%s`: %s",
50+
sample.name(),
51+
strategy.getClass().getSimpleName(),
52+
evaluation.passed() ? "OK" : "KO");
53+
evaluations.add(new OrderedEvaluationResult(currentIndex, evaluation));
54+
}
55+
} catch (Throwable e) {
56+
LOG.errorf(e, "Failed to evaluate sample `%s`", sample.name());
57+
evaluations.add(
58+
new OrderedEvaluationResult(
59+
currentIndex, EvaluationResult.fromEvaluationThrowable(sample, e)));
60+
} finally {
61+
latch.countDown();
62+
}
63+
});
5664
}
5765
try {
5866
latch.await();
5967
} catch (InterruptedException e) {
6068
Thread.currentThread().interrupt();
6169
}
62-
return new EvaluationReport(evaluations);
70+
var orderedEvalutions = evaluations.stream()
71+
.sorted(Comparator.comparing(OrderedEvaluationResult::index))
72+
.map(OrderedEvaluationResult::evaluation)
73+
.toList();
74+
return new EvaluationReport<>(orderedEvalutions);
6375
}
6476

6577
public void close() {
6678
executor.shutdown();
6779
}
6880

69-
public record EvaluationResult<T>(EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
70-
public static <T> EvaluationResult<T> fromCompletedEvaluation(EvaluationSample<T> sample, T result, boolean passed) {
81+
public record EvaluationResult<T>(
82+
EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
83+
public static <T> EvaluationResult<T> fromCompletedEvaluation(
84+
EvaluationSample<T> sample, T result, boolean passed) {
7185
return new EvaluationResult<>(sample, result, null, passed);
7286
}
7387

74-
public static <T> EvaluationResult<T> fromEvaluationThrowable(EvaluationSample<T> sample, Throwable thrown) {
88+
public static <T> EvaluationResult<T> fromEvaluationThrowable(
89+
EvaluationSample<T> sample, Throwable thrown) {
7590
return new EvaluationResult<>(sample, null, thrown, false);
7691
}
7792
}
@@ -84,4 +99,6 @@ private <T> T execute(EvaluationSample<T> sample, Function<Parameters, T> functi
8499
}
85100
}
86101

102+
private record OrderedEvaluationResult<T>(int index, EvaluationResult<T> evaluation) {
103+
}
87104
}

testing/scorer/scorer-core/src/test/java/io/quarkiverse/langchain4j/testing/scorer/ScorerTest.java

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import java.util.List;
66
import java.util.function.Function;
7+
import java.util.stream.Stream;
78

89
import org.junit.jupiter.api.AfterEach;
910
import org.junit.jupiter.api.Test;
@@ -40,20 +41,55 @@ void evaluateShouldReturnCorrectReport() {
4041
EvaluationStrategy<String> strategy = (sample, actual) -> actual.equals(sample.expectedOutput());
4142

4243
Samples<String> samples = new Samples<>(sample1, sample2);
43-
EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy);
44+
EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy);
4445

4546
assertThat(report).isNotNull();
4647
assertThat(report.score()).isEqualTo(50.0); // Only one sample should pass.
4748
assertThat(report.evaluations()).hasSize(2);
4849

4950
var actualEvaluations = report.evaluations().stream()
50-
.map(e -> "%s[%s;%s=%s]".formatted(e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed()))
51+
.map(
52+
e -> "%s[%s;%s=%s]"
53+
.formatted(
54+
e.sample().name(), e.sample().expectedOutput(), e.result(), e.passed()))
5155
.toList();
52-
assertThat(actualEvaluations).containsExactlyInAnyOrder(
53-
"Sample1[expected1:param1;expected1:param1=true]",
54-
"Sample2[expected2;expected1:param1=false]");
56+
assertThat(actualEvaluations)
57+
.containsExactly(
58+
"Sample1[expected1:param1;expected1:param1=true]",
59+
"Sample2[expected2;expected1:param1=false]");
5560
}
5661

62+
@SuppressWarnings("unchecked")
63+
@Test
64+
void evaluateShouldReturnCorrectlyOrderedReport() {
65+
scorer = new Scorer(2);
66+
var sleeps = Stream.of(25l, 0l);
67+
var samples = new Samples<>(
68+
sleeps
69+
.map(
70+
sleep -> new EvaluationSample<>(
71+
"%s".formatted(sleep),
72+
new Parameters().add(new Parameter.UnnamedParameter(sleep)),
73+
"irrelevant-for-this-test",
74+
List.of()))
75+
.toList());
76+
77+
var actual = scorer.evaluate(samples, this::sleep, (sample, actualOutput) -> true);
78+
79+
var actualOrder = actual.evaluations().stream().map(e -> e.sample().name()).toList();
80+
assertThat(actualOrder).containsExactly("25", "0");
81+
}
82+
83+
private String sleep(Parameters params) {
84+
long ms = params.get(0);
85+
try {
86+
Thread.sleep(ms);
87+
} catch (InterruptedException e) {
88+
throw new RuntimeException(e);
89+
}
90+
return "sleeped %s".formatted(ms);
91+
};
92+
5793
@Test
5894
@SuppressWarnings("unchecked")
5995
void evaluateShouldHandleExceptionsInFunction() {
@@ -71,7 +107,7 @@ void evaluateShouldHandleExceptionsInFunction() {
71107
EvaluationStrategy<String> strategy = (s, actual) -> false;
72108

73109
Samples<String> samples = new Samples<>(sample);
74-
EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy);
110+
EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy);
75111

76112
assertThat(report).isNotNull();
77113
assertThat(report.score()).isEqualTo(0.0); // All evaluations should fail.
@@ -96,7 +132,7 @@ void evaluateShouldHandleMultipleStrategies() {
96132
EvaluationStrategy<String> strategy2 = (s, actual) -> actual.length() > 3;
97133

98134
Samples<String> samples = new Samples<>(sample);
99-
EvaluationReport report = scorer.evaluate(samples, mockFunction, strategy1, strategy2);
135+
EvaluationReport<String> report = scorer.evaluate(samples, mockFunction, strategy1, strategy2);
100136

101137
assertThat(report).isNotNull();
102138
assertThat(report.score()).isEqualTo(100.0); // Both strategies should pass for the sample.

0 commit comments

Comments
 (0)