Skip to content

Commit c0e4ade

Browse files
committed
multithread jmh
1 parent 8030f9f commit c0e4ade

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

onnxruntime-benchmark/src/jmh/java/com/jyuzawa/onnxruntime_benchmark/Microbenchmark.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.openjdk.jmh.annotations.Param;
2626
import org.openjdk.jmh.annotations.Setup;
2727
import org.openjdk.jmh.annotations.State;
28+
import org.openjdk.jmh.annotations.Threads;
2829
import org.openjdk.jmh.annotations.Warmup;
2930
import org.openjdk.jmh.infra.Blackhole;
3031

@@ -39,9 +40,10 @@
3940
public class Microbenchmark {
4041

4142
private static final String ONNXRUNTIME_JAVA = "onnxruntime-java";
43+
private static final String ONNXRUNTIME_JAVA_ARENA = "onnxruntime-java-arena";
4244
private static final String MICROSOFT = "microsoft";
4345

44-
@Param(value = {ONNXRUNTIME_JAVA, MICROSOFT})
46+
@Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, MICROSOFT})
4547
private String implementation;
4648

4749
@Param({"16", "256", "4096"})
@@ -77,12 +79,14 @@ public void setup() throws Exception {
7779
input[i] = random.nextLong();
7880
}
7981
wrapper = switch (implementation) {
80-
case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes);
82+
case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false);
83+
case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true);
8184
case MICROSOFT -> new Microsoft(bytes);
8285
default -> throw new IllegalArgumentException();};
8386
}
8487

8588
@Benchmark
89+
@Threads(Threads.MAX)
8690
public void run(Blackhole bh) throws Exception {
8791
bh.consume(wrapper.evaluate(input));
8892
}

onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Benchmark.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static final void main(String[] args) throws Exception {
4040
.addOutput(ValueInfoProto.newBuilder().setName("output").setType(type)))
4141
.build();
4242
byte[] bytes = model.toByteArray();
43-
List<Wrapper> wrappers = List.of(new OnnxruntimeJava(bytes), new Microsoft(bytes));
43+
List<Wrapper> wrappers = List.of(new OnnxruntimeJava(bytes, false), new Microsoft(bytes));
4444
long i = 0;
4545
long startMs = System.currentTimeMillis();
4646
while (i >= 0) {

onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJava.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
package com.jyuzawa.onnxruntime_benchmark;
66

77
import com.jyuzawa.onnxruntime.Environment;
8+
import com.jyuzawa.onnxruntime.ExecutionProvider;
89
import com.jyuzawa.onnxruntime.NamedCollection;
910
import com.jyuzawa.onnxruntime.OnnxRuntime;
1011
import com.jyuzawa.onnxruntime.OnnxRuntimeLoggingLevel;
1112
import com.jyuzawa.onnxruntime.OnnxValue;
1213
import com.jyuzawa.onnxruntime.Session;
1314
import com.jyuzawa.onnxruntime.Transaction;
1415
import java.io.IOException;
16+
import java.util.Map;
1517

1618
final class OnnxruntimeJava implements Wrapper {
1719

@@ -23,8 +25,12 @@ final class OnnxruntimeJava implements Wrapper {
2325

2426
private final Session session;
2527

26-
OnnxruntimeJava(byte[] bytes) throws IOException {
27-
this.session = ENVIRONMENT.newSession().setByteArray(bytes).build();
28+
OnnxruntimeJava(byte[] bytes, boolean arena) throws IOException {
29+
this.session = ENVIRONMENT
30+
.newSession()
31+
.setByteArray(bytes)
32+
.addProvider(ExecutionProvider.CPU_EXECUTION_PROVIDER, Map.of("use_arena", arena ? "1" : "0"))
33+
.build();
2834
}
2935

3036
@Override

0 commit comments

Comments
 (0)