Skip to content

Commit 31600d7

Browse files
authored
Merge pull request #1003 from holly-cummins/guardrail-api
Move inner classes to top level classes for API concision
2 parents 12df5b3 + f618093 commit 31600d7

File tree

12 files changed

+55
-39
lines changed

12 files changed

+55
-39
lines changed

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
import dev.langchain4j.service.UserMessage;
3434
import io.quarkiverse.langchain4j.RegisterAiService;
3535
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
36+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
3637
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
3738
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
3839
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
40+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
3941
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
4042
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
4143
import io.quarkus.test.QuarkusUnitTest;

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import dev.langchain4j.service.UserMessage;
2828
import io.quarkiverse.langchain4j.RegisterAiService;
2929
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
30+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
3031
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
3132
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
3233
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
34+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
3335
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
3436
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
3537
import io.quarkus.test.QuarkusUnitTest;

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import dev.langchain4j.service.UserMessage;
3131
import io.quarkiverse.langchain4j.RegisterAiService;
3232
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
33+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
3334
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
3435
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
3536
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
@@ -179,7 +180,7 @@ public static class MemoryCheck implements InputGuardrail {
179180
AtomicInteger spy = new AtomicInteger(0);
180181

181182
@Override
182-
public InputGuardrailResult validate(InputGuardrail.InputGuardrailParams params) {
183+
public InputGuardrailResult validate(InputGuardrailParams params) {
183184
spy.incrementAndGet();
184185
if (params.memory().messages().isEmpty()) {
185186
assertThat(params.userMessage().singleText()).isEqualTo("foo");

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import dev.langchain4j.service.SystemMessage;
3030
import io.quarkiverse.langchain4j.RegisterAiService;
3131
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
32+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
3233
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
3334
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
3435
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import dev.langchain4j.service.SystemMessage;
3131
import io.quarkiverse.langchain4j.RegisterAiService;
3232
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
33+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
3334
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
3435
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
3536
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import java.util.Arrays;
44

55
import dev.langchain4j.data.message.UserMessage;
6-
import dev.langchain4j.memory.ChatMemory;
7-
import dev.langchain4j.rag.AugmentationResult;
86
import io.smallrye.common.annotation.Experimental;
97

108
/**
@@ -14,7 +12,7 @@
1412
* Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation.
1513
*/
1614
@Experimental("This feature is experimental and the API is subject to change")
17-
public interface InputGuardrail extends Guardrail<InputGuardrail.InputGuardrailParams, InputGuardrailResult> {
15+
public interface InputGuardrail extends Guardrail<InputGuardrailParams, InputGuardrailResult> {
1816

1917
/**
2018
* Validates the {@code user message} that will be sent to the LLM.
@@ -42,17 +40,6 @@ default InputGuardrailResult validate(InputGuardrailParams params) {
4240
return validate(params.userMessage());
4341
}
4442

45-
/**
46-
* Represents the parameter passed to {@link #validate(InputGuardrailParams)}.
47-
*
48-
* @param userMessage the user message, cannot be {@code null}
49-
* @param memory the memory, can be {@code null} or empty
50-
* @param augmentationResult the augmentation result, can be {@code null}
51-
*/
52-
record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
53-
AugmentationResult augmentationResult) implements GuardrailParams {
54-
}
55-
5643
/**
5744
* @return The result of a successful input guardrail validation.
5845
*/
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.guardrails;
2+
3+
import dev.langchain4j.data.message.UserMessage;
4+
import dev.langchain4j.memory.ChatMemory;
5+
import dev.langchain4j.rag.AugmentationResult;
6+
7+
/**
8+
* Represents the parameter passed to {@link InputGuardrail#validate(InputGuardrailParams)}.
9+
*
10+
* @param userMessage the user message, cannot be {@code null}
11+
* @param memory the memory, can be {@code null} or empty
12+
* @param augmentationResult the augmentation result, can be {@code null}
13+
*/
14+
public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
15+
AugmentationResult augmentationResult) implements GuardrailParams {
16+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import java.util.Arrays;
44

55
import dev.langchain4j.data.message.AiMessage;
6-
import dev.langchain4j.memory.ChatMemory;
7-
import dev.langchain4j.rag.AugmentationResult;
86
import io.smallrye.common.annotation.Experimental;
97

108
/**
@@ -18,7 +16,7 @@
1816
* The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3.
1917
*/
2018
@Experimental("This feature is experimental and the API is subject to change")
21-
public interface OutputGuardrail extends Guardrail<OutputGuardrail.OutputGuardrailParams, OutputGuardrailResult> {
19+
public interface OutputGuardrail extends Guardrail<OutputGuardrailParams, OutputGuardrailResult> {
2220

2321
/**
2422
* Validates the response from the LLM.
@@ -45,17 +43,6 @@ default OutputGuardrailResult validate(OutputGuardrailParams params) {
4543
return validate(params.responseFromLLM());
4644
}
4745

48-
/**
49-
* Represents the parameter passed to {@link #validate(OutputGuardrailParams)}.
50-
*
51-
* @param responseFromLLM the response from the LLM
52-
* @param memory the memory, can be {@code null} or empty
53-
* @param augmentationResult the augmentation result, can be {@code null}
54-
*/
55-
record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
56-
AugmentationResult augmentationResult) implements GuardrailParams {
57-
}
58-
5946
/**
6047
* @return The result of a successful output guardrail validation.
6148
*/
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.guardrails;
2+
3+
import dev.langchain4j.data.message.AiMessage;
4+
import dev.langchain4j.memory.ChatMemory;
5+
import dev.langchain4j.rag.AugmentationResult;
6+
7+
/**
8+
* Represents the parameter passed to {@link OutputGuardrail#validate(OutputGuardrailParams)}.
9+
*
10+
* @param responseFromLLM the response from the LLM
11+
* @param memory the memory, can be {@code null} or empty
12+
* @param augmentationResult the augmentation result, can be {@code null}
13+
*/
14+
public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
15+
AugmentationResult augmentationResult) implements GuardrailParams {
16+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import io.quarkiverse.langchain4j.audit.Audit;
6363
import io.quarkiverse.langchain4j.audit.AuditService;
6464
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
65+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
6566
import io.quarkiverse.langchain4j.runtime.ContextLocals;
6667
import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser;
6768
import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil;
@@ -305,7 +306,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
305306

306307
response = GuardrailsSupport.invokeOutputGuardrails(methodCreateInfo, chatMemory, context.chatModel, response,
307308
toolSpecifications,
308-
new OutputGuardrail.OutputGuardrailParams(response.content(), chatMemory, augmentationResult));
309+
new OutputGuardrailParams(response.content(), chatMemory, augmentationResult));
309310

310311
// everything worked as expected so let's commit the messages
311312
chatMemory.commit();

0 commit comments

Comments
 (0)