Skip to content

Commit bb5c722

Browse files
authored
Merge pull request #1132 from dastrobu/feature/no-auth-model-auth-provider
Support no authorization header being set when null is returned from ModelAuthProvider
2 parents 9b0f287 + 942c830 commit bb5c722

File tree

9 files changed

+209
-14
lines changed

9 files changed

+209
-14
lines changed

core/runtime/src/main/java/io/quarkiverse/langchain4j/auth/ModelAuthProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public interface ModelAuthProvider {
2323
*
2424
* @param input representation of an HTTP request to the model provider.
2525
* @return authorization data which must include an HTTP Authorization scheme value, for example: "Bearer the_access_token".
26+
* Returning null will result in no Authorization header being set.
2627
*/
2728
String getAuthorization(Input input);
2829

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaRestApi.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,13 @@ public OllamaRestAPIFilter(ModelAuthProvider authorizer) {
170170

171171
@Override
172172
public void filter(ClientRequestContext context) {
173-
context.getHeaders().putSingle(
174-
"Authorization",
175-
authorizer
176-
.getAuthorization(new AuthInputImpl(context.getMethod(), context.getUri(), context.getHeaders())));
173+
String authValue = authorizer.getAuthorization(new AuthInputImpl(
174+
context.getMethod(),
175+
context.getUri(),
176+
context.getHeaders()));
177+
if (authValue != null) {
178+
context.getHeaders().putSingle("Authorization", authValue);
179+
}
177180
}
178181

179182
private record AuthInputImpl(
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package io.quarkiverse.langchain4j.ollama;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertNull;
5+
import static org.mockito.Mockito.*;
6+
7+
import jakarta.ws.rs.client.ClientRequestContext;
8+
import jakarta.ws.rs.core.MultivaluedHashMap;
9+
10+
import org.junit.jupiter.api.BeforeEach;
11+
import org.junit.jupiter.api.Test;
12+
13+
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
14+
15+
class OllamaRestApiFilterTest {
16+
17+
private ModelAuthProvider authProvider;
18+
private ClientRequestContext context;
19+
private MultivaluedHashMap<Object, Object> headers;
20+
private OllamaRestApi.OllamaRestAPIFilter ollamaRestApiFilter;
21+
22+
@BeforeEach
23+
void setUpFilter() {
24+
context = mock(ClientRequestContext.class);
25+
headers = new MultivaluedHashMap<>();
26+
doReturn(headers).when(context).getHeaders();
27+
28+
authProvider = mock(ModelAuthProvider.class);
29+
30+
ollamaRestApiFilter = new OllamaRestApi.OllamaRestAPIFilter(authProvider);
31+
}
32+
33+
@Test
34+
void nullDoesNotSetAuthorization() {
35+
doReturn(null).when(authProvider).getAuthorization(any());
36+
37+
ollamaRestApiFilter.filter(context);
38+
39+
assertNull(headers.getFirst("Authorization"));
40+
}
41+
42+
@Test
43+
void valueDoesSetAuthorization() {
44+
var token = "token";
45+
doReturn(token).when(authProvider).getAuthorization(any());
46+
47+
ollamaRestApiFilter.filter(context);
48+
49+
assertEquals(token, headers.getFirst("Authorization"));
50+
}
51+
52+
}

model-providers/openai/openai-common/runtime/pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@
2626
<artifactId>quarkus-langchain4j-core</artifactId>
2727
<version>${project.version}</version>
2828
</dependency>
29+
<dependency>
30+
<groupId>io.quarkus</groupId>
31+
<artifactId>quarkus-junit5</artifactId>
32+
<scope>test</scope>
33+
</dependency>
34+
<dependency>
35+
<groupId>io.quarkus</groupId>
36+
<artifactId>quarkus-junit5-mockito</artifactId>
37+
<scope>test</scope>
38+
</dependency>
2939

3040
<dependency>
3141
<groupId>dev.langchain4j</groupId>

model-providers/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/common/OpenAiRestApi.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,11 @@ public OpenAIRestAPIFilter(ModelAuthProvider authorizer) {
192192

193193
@Override
194194
public void filter(ResteasyReactiveClientRequestContext requestContext) {
195-
requestContext
196-
.getHeaders()
197-
.putSingle("Authorization", authorizer.getAuthorization(new AuthInputImpl(requestContext.getMethod(),
198-
requestContext.getUri(), requestContext.getHeaders())));
195+
String authValue = authorizer.getAuthorization(new AuthInputImpl(requestContext.getMethod(),
196+
requestContext.getUri(), requestContext.getHeaders()));
197+
if (authValue != null) {
198+
requestContext.getHeaders().putSingle("Authorization", authValue);
199+
}
199200
}
200201

201202
private record AuthInputImpl(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package io.quarkiverse.langchain4j.openai.common;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertNull;
5+
import static org.mockito.Mockito.*;
6+
7+
import jakarta.ws.rs.core.MultivaluedHashMap;
8+
9+
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
10+
import org.junit.jupiter.api.BeforeEach;
11+
import org.junit.jupiter.api.Test;
12+
13+
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
14+
15+
public class OpenAIRestAPIFilterTest {
16+
17+
private ModelAuthProvider authProvider;
18+
private ResteasyReactiveClientRequestContext context;
19+
private MultivaluedHashMap<Object, Object> headers;
20+
private OpenAiRestApi.OpenAIRestAPIFilter ollamaRestApiFilter;
21+
22+
@BeforeEach
23+
void setUpFilter() {
24+
context = mock(ResteasyReactiveClientRequestContext.class);
25+
headers = new MultivaluedHashMap<>();
26+
doReturn(headers).when(context).getHeaders();
27+
28+
authProvider = mock(ModelAuthProvider.class);
29+
30+
ollamaRestApiFilter = new OpenAiRestApi.OpenAIRestAPIFilter(authProvider);
31+
}
32+
33+
@Test
34+
void nullDoesNotSetAuthorization() {
35+
doReturn(null).when(authProvider).getAuthorization(any());
36+
37+
ollamaRestApiFilter.filter(context);
38+
39+
assertNull(headers.getFirst("Authorization"));
40+
}
41+
42+
@Test
43+
void valueDoesSetAuthorization() {
44+
var token = "token";
45+
doReturn(token).when(authProvider).getAuthorization(any());
46+
47+
ollamaRestApiFilter.filter(context);
48+
49+
assertEquals(token, headers.getFirst("Authorization"));
50+
}
51+
52+
}

model-providers/vertex-ai-gemini/runtime/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
<scope>test</scope>
3636
</dependency>
3737
<dependency>
38-
<groupId>org.mockito</groupId>
39-
<artifactId>mockito-core</artifactId>
38+
<groupId>io.quarkus</groupId>
39+
<artifactId>quarkus-junit5-mockito</artifactId>
4040
<scope>test</scope>
4141
</dependency>
4242
<dependency>

model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,10 @@ public void filter(ResteasyReactiveClientRequestContext context) {
146146
public void run() {
147147
try {
148148
final Input authInput = new AuthInputImpl(context.getMethod(), context.getUri(), context.getHeaders());
149-
String authorization = authorizer != null ? authorizer.getAuthorization(authInput) : null;
150-
if (authorization == null) {
151-
authorization = defaultAuthorizer.getAuthorization(authInput);
149+
var auth = (authorizer != null ? authorizer : defaultAuthorizer).getAuthorization(authInput);
150+
if (auth != null) {
151+
context.getHeaders().add("Authorization", auth);
152152
}
153-
context.getHeaders().add("Authorization", authorization);
154153
context.resume();
155154
} catch (Exception e) {
156155
context.resume(e);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package io.quarkiverse.langchain4j.vertexai.runtime.gemini;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
import static org.mockito.ArgumentMatchers.any;
5+
import static org.mockito.Mockito.*;
6+
7+
import jakarta.enterprise.context.ApplicationScoped;
8+
import jakarta.ws.rs.core.MultivaluedHashMap;
9+
10+
import org.eclipse.microprofile.context.ManagedExecutor;
11+
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
12+
import org.junit.jupiter.api.BeforeEach;
13+
import org.junit.jupiter.api.Test;
14+
15+
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
16+
import io.quarkus.test.InjectMock;
17+
import io.quarkus.test.junit.QuarkusTest;
18+
19+
@QuarkusTest
20+
class VertxAiGeminiRestApiTokenFilterTest {
21+
22+
@InjectMock
23+
MyModelAuthProvider authProvider;
24+
25+
private ResteasyReactiveClientRequestContext context;
26+
private MultivaluedHashMap<Object, Object> headers;
27+
private VertxAiGeminiRestApi.TokenFilter ollamaRestApiFilter;
28+
29+
@BeforeEach
30+
void setUpFilter() {
31+
context = mock(ResteasyReactiveClientRequestContext.class);
32+
headers = new MultivaluedHashMap<>();
33+
doReturn(headers).when(context).getHeaders();
34+
35+
ollamaRestApiFilter = new VertxAiGeminiRestApi.TokenFilter(mockSyncExecutor());
36+
}
37+
38+
private static ManagedExecutor mockSyncExecutor() {
39+
ManagedExecutor executor = mock(ManagedExecutor.class);
40+
// execute the runnable immediately, to avoid any async issues
41+
doAnswer(invocation -> {
42+
((Runnable) invocation.getArgument(0)).run();
43+
return null;
44+
}).when(executor).submit(any(Runnable.class));
45+
return executor;
46+
}
47+
48+
@Test
49+
void nullDoesNotSetAuthorization() {
50+
doReturn(null).when(authProvider).getAuthorization(any());
51+
52+
ollamaRestApiFilter.filter(context);
53+
54+
assertNull(headers.getFirst("Authorization"));
55+
}
56+
57+
@Test
58+
void valueDoesSetAuthorization() {
59+
var token = "token";
60+
doReturn(token).when(authProvider).getAuthorization(any());
61+
62+
ollamaRestApiFilter.filter(context);
63+
64+
assertEquals(token, headers.getFirst("Authorization"));
65+
}
66+
67+
@ApplicationScoped
68+
static class MyModelAuthProvider implements ModelAuthProvider {
69+
70+
@Override
71+
public String getAuthorization(Input input) {
72+
fail("should never be called");
73+
return null;
74+
}
75+
}
76+
77+
}

0 commit comments

Comments
 (0)