Skip to content

Commit 4cf2377

Browse files
mudabirhussainmarkpollack
authored andcommitted
feature: Add common TranscriptionModel interface for audio transcription
- Created TranscriptionModel interface that extends Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> - Implemented `call(AudioTranscriptionPrompt)` method for better compatibility between OpenAI and Azure OpenAI transcription models - Added default convenience methods for handling Resource and AudioTranscriptionOptions to return transcription as a String - Adds unit tests for the OpenAiAudioTranscriptaonModel using the `@RestClientTest` approach to mock the OpenAI API. - Create OpenAiAudioTranscriptionModelTests tath consolidates logic from TranscriptionModelTests into the new test class that is more appropriately named. Enhances the unit and integration test suites for the OpenAiAudioTranscriptionModel to ensure full coverage of the TranscriptionModel interface. - Adds a unit test to OpenAiAudioTranscriptionModelTests to verify the transcribe method that accepts AudioTranscriptionOptions. - Adds integration tests to OpenAiTranscriptionModelIT to exercise the transcribe convenience methods against the live API. - Renames existing integration tests for improved clarity. Authored-by: Mudabir Hussain <mudabirhussain@users.noreply.github.com>
1 parent aa590e8 commit 4cf2377

File tree

8 files changed

+333
-92
lines changed

8 files changed

+333
-92
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse.Word;
3535
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.TranscriptResponseFormat;
3636
import org.springframework.ai.azure.openai.metadata.AzureOpenAiAudioTranscriptionResponseMetadata;
37-
import org.springframework.ai.model.Model;
3837
import org.springframework.ai.model.ModelOptionsUtils;
38+
import org.springframework.ai.audio.transcription.TranscriptionModel;
3939
import org.springframework.core.io.Resource;
4040
import org.springframework.util.Assert;
4141
import org.springframework.util.StringUtils;
@@ -47,7 +47,7 @@
4747
*
4848
* @author Piotr Olaszewski
4949
*/
50-
public class AzureOpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
50+
public class AzureOpenAiAudioTranscriptionModel implements TranscriptionModel {
5151

5252
private static final List<AudioTranscriptionFormat> JSON_FORMATS = List.of(AudioTranscriptionFormat.JSON,
5353
AudioTranscriptionFormat.VERBOSE_JSON);

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt;
2424
import org.springframework.ai.audio.transcription.AudioTranscriptionResponse;
2525
import org.springframework.ai.chat.metadata.RateLimit;
26-
import org.springframework.ai.model.Model;
26+
import org.springframework.ai.audio.transcription.TranscriptionModel;
2727
import org.springframework.ai.openai.api.OpenAiAudioApi;
2828
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
2929
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata;
@@ -45,7 +45,7 @@
4545
* @see OpenAiAudioApi
4646
* @since 0.8.1
4747
*/
48-
public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
48+
public class OpenAiAudioTranscriptionModel implements TranscriptionModel {
4949

5050
private final Logger logger = LoggerFactory.getLogger(getClass());
5151

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.audio.transcription;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt;
21+
import org.springframework.ai.audio.transcription.AudioTranscriptionResponse;
22+
import org.springframework.ai.audio.transcription.TranscriptionModel;
23+
import org.springframework.ai.model.SimpleApiKey;
24+
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
25+
import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions;
26+
import org.springframework.ai.openai.api.OpenAiAudioApi;
27+
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat;
28+
import org.springframework.ai.retry.RetryUtils;
29+
import org.springframework.beans.factory.annotation.Autowired;
30+
import org.springframework.boot.test.autoconfigure.web.client.RestClientTest;
31+
import org.springframework.context.annotation.Bean;
32+
import org.springframework.context.annotation.Configuration;
33+
import org.springframework.core.io.ClassPathResource;
34+
import org.springframework.http.HttpMethod;
35+
import org.springframework.http.MediaType;
36+
import org.springframework.test.web.client.MockRestServiceServer;
37+
import org.springframework.util.LinkedMultiValueMap;
38+
import org.springframework.web.client.RestClient;
39+
import org.springframework.web.reactive.function.client.WebClient;
40+
41+
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.springframework.test.web.client.match.MockRestRequestMatchers.method;
43+
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
44+
import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess;
45+
46+
@RestClientTest(OpenAiAudioTranscriptionModelTests.Config.class)
47+
class OpenAiAudioTranscriptionModelTests {
48+
49+
@Autowired
50+
private MockRestServiceServer server;
51+
52+
@Autowired
53+
private TranscriptionModel transcriptionModel;
54+
55+
@Test
56+
void transcribeRequestReturnsResponseCorrectly() {
57+
String mockResponse = """
58+
{
59+
"text": "All your bases are belong to us"
60+
}
61+
""".stripIndent();
62+
63+
this.server.expect(requestTo("https://api.openai.com/v1/audio/transcriptions"))
64+
.andExpect(method(HttpMethod.POST))
65+
.andRespond(withSuccess(mockResponse, MediaType.APPLICATION_JSON));
66+
67+
String transcription = this.transcriptionModel.transcribe(new ClassPathResource("/speech.flac"));
68+
69+
assertThat(transcription).isEqualTo("All your bases are belong to us");
70+
this.server.verify();
71+
}
72+
73+
@Test
74+
void callWithDefaultOptions() {
75+
String mockResponse = """
76+
{
77+
"text": "Hello, this is a test transcription."
78+
}
79+
""".stripIndent();
80+
81+
this.server.expect(requestTo("https://api.openai.com/v1/audio/transcriptions"))
82+
.andExpect(method(HttpMethod.POST))
83+
.andRespond(withSuccess(mockResponse, MediaType.APPLICATION_JSON));
84+
85+
AudioTranscriptionPrompt prompt = new AudioTranscriptionPrompt(new ClassPathResource("/speech.flac"));
86+
AudioTranscriptionResponse response = this.transcriptionModel.call(prompt);
87+
88+
assertThat(response.getResult().getOutput()).isEqualTo("Hello, this is a test transcription.");
89+
this.server.verify();
90+
}
91+
92+
@Test
93+
void transcribeWithOptions() {
94+
String mockResponse = """
95+
{
96+
"text": "Hello, this is a test transcription with options."
97+
}
98+
""".stripIndent();
99+
100+
this.server.expect(requestTo("https://api.openai.com/v1/audio/transcriptions"))
101+
.andExpect(method(HttpMethod.POST))
102+
.andRespond(withSuccess(mockResponse, MediaType.APPLICATION_JSON));
103+
104+
OpenAiAudioTranscriptionOptions options = OpenAiAudioTranscriptionOptions.builder()
105+
.temperature(0.5f)
106+
.responseFormat(TranscriptResponseFormat.JSON)
107+
.build();
108+
109+
String transcription = this.transcriptionModel.transcribe(new ClassPathResource("/speech.flac"), options);
110+
111+
assertThat(transcription).isEqualTo("Hello, this is a test transcription with options.");
112+
this.server.verify();
113+
}
114+
115+
@Configuration
116+
static class Config {
117+
118+
@Bean
119+
public OpenAiAudioApi openAiAudioApi(RestClient.Builder builder) {
120+
return new OpenAiAudioApi("https://api.openai.com", new SimpleApiKey("test-api-key"),
121+
new LinkedMultiValueMap<>(), builder, WebClient.builder(),
122+
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
123+
}
124+
125+
@Bean
126+
public OpenAiAudioTranscriptionModel openAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
127+
return new OpenAiAudioTranscriptionModel(audioApi);
128+
}
129+
130+
}
131+
132+
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class OpenAiTranscriptionModelIT extends AbstractIT {
4040
private Resource audioFile;
4141

4242
@Test
43-
void transcriptionTest() {
43+
void callTest() {
4444
OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionOptions.builder()
4545
.responseFormat(TranscriptResponseFormat.TEXT)
4646
.temperature(0f)
@@ -53,7 +53,7 @@ void transcriptionTest() {
5353
}
5454

5555
@Test
56-
void transcriptionTestWithOptions() {
56+
void callTestWithOptions() {
5757
OpenAiAudioApi.TranscriptResponseFormat responseFormat = OpenAiAudioApi.TranscriptResponseFormat.VTT;
5858

5959
OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionOptions.builder()
@@ -69,4 +69,24 @@ void transcriptionTestWithOptions() {
6969
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
7070
}
7171

72+
@Test
73+
void transcribeTest() {
74+
String response = this.transcriptionModel.transcribe(this.audioFile);
75+
assertThat(response).isNotNull();
76+
assertThat(response.toLowerCase().contains("fellow")).isTrue();
77+
}
78+
79+
@Test
80+
void transcribeTestWithOptions() {
81+
OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionOptions.builder()
82+
.language("en")
83+
.prompt("Ask not this, but ask that")
84+
.temperature(0f)
85+
.responseFormat(TranscriptResponseFormat.TEXT)
86+
.build();
87+
String response = this.transcriptionModel.transcribe(this.audioFile, transcriptionOptions);
88+
assertThat(response).isNotNull();
89+
assertThat(response.toLowerCase().contains("fellow")).isTrue();
90+
}
91+
7292
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java

Lines changed: 0 additions & 86 deletions
This file was deleted.

models/spring-ai-openai/src/test/resources/speech.flac

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.audio.transcription;
18+
19+
import org.springframework.ai.model.Model;
20+
import org.springframework.core.io.Resource;
21+
22+
/**
23+
* A transcription model is a type of AI model that converts audio to text. This is also
24+
* known as Speech-to-Text.
25+
*
26+
* @author Mudabir Hussain
27+
* @since 1.0.0
28+
*/
29+
public interface TranscriptionModel extends Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
30+
31+
/**
32+
* Transcribes the audio from the given prompt.
33+
* @param transcriptionPrompt The prompt containing the audio resource and options.
34+
* @return The transcription response.
35+
*/
36+
AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPrompt);
37+
38+
/**
39+
* A convenience method for transcribing an audio resource.
40+
* @param resource The audio resource to transcribe.
41+
* @return The transcribed text.
42+
*/
43+
default String transcribe(Resource resource) {
44+
AudioTranscriptionPrompt prompt = new AudioTranscriptionPrompt(resource);
45+
return this.call(prompt).getResult().getOutput();
46+
}
47+
48+
/**
49+
* A convenience method for transcribing an audio resource with the given options.
50+
* @param resource The audio resource to transcribe.
51+
* @param options The transcription options.
52+
* @return The transcribed text.
53+
*/
54+
default String transcribe(Resource resource, AudioTranscriptionOptions options) {
55+
AudioTranscriptionPrompt prompt = new AudioTranscriptionPrompt(resource, options);
56+
return this.call(prompt).getResult().getOutput();
57+
}
58+
59+
}

0 commit comments

Comments
 (0)