Skip to content

Commit db01040

Browse files
MiggiV2jmartisk
andauthored
[FEATURE] ToolProvider | Select tools dynamically on incoming message (#989)
* ✨ (ToolProvider): Make ToolProvider Quarkus ready * ♻️ (DevUi): Feedback from jmartisk (see below) - Support StreamingChat - Inject Supplier<ToolProvider> - Ignore tools when toolProvider exists * ✏️ (DevUi): Typo in setToolsViaProviderIfAvailable * 📝 (ToolProvider): Update agent-and-tools.adoc * Update docs/modules/ROOT/pages/agent-and-tools.adoc Co-authored-by: Jan Martiska <jmartisk@redhat.com> --------- Co-authored-by: Jan Martiska <jmartisk@redhat.com>
1 parent 4851856 commit db01040

File tree

14 files changed

+455
-30
lines changed

14 files changed

+455
-30
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@
6262
import dev.langchain4j.service.Moderate;
6363
import dev.langchain4j.service.output.ServiceOutputParser;
6464
import io.quarkiverse.langchain4j.ModelName;
65+
import io.quarkiverse.langchain4j.RegisterAiService;
6566
import io.quarkiverse.langchain4j.ToolBox;
6667
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
68+
import io.quarkiverse.langchain4j.deployment.devui.ToolProviderInfo;
6769
import io.quarkiverse.langchain4j.deployment.items.MethodParameterAllowedAnnotationsBuildItem;
6870
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
6971
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
@@ -208,13 +210,15 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
208210
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
209211
BuildProducer<RequestImageModelBeanBuildItem> requestImageModelBeanProducer,
210212
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
213+
BuildProducer<ToolProviderMetaBuildItem> toolProviderProducer,
211214
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
212215
BuildProducer<GeneratedClassBuildItem> generatedClassProducer) {
213216
IndexView index = indexBuildItem.getIndex();
214217

215218
Set<String> chatModelNames = new HashSet<>();
216219
Set<String> moderationModelNames = new HashSet<>();
217220
Set<String> imageModelNames = new HashSet<>();
221+
List<ToolProviderInfo> toolProviderInfos = new ArrayList<>();
218222
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
219223
for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
220224
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
@@ -323,6 +327,15 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
323327
validateSupplierAndRegisterForReflection(moderationModelSupplierClassName, index, reflectiveClassProducer);
324328
}
325329

330+
DotName toolProviderClassName = LangChain4jDotNames.BEAN_IF_EXISTS_TOOL_PROVIDER_SUPPLIER;
331+
AnnotationValue toolProviderValue = instance.value("toolProviderSupplier");
332+
if (toolProviderValue != null) {
333+
toolProviderClassName = toolProviderValue.asClass().name();
334+
validateSupplierAndRegisterForReflection(toolProviderClassName, index, reflectiveClassProducer);
335+
toolProviderInfos.add(new ToolProviderInfo(toolProviderClassName.toString(),
336+
declarativeAiServiceClassInfo.simpleName()));
337+
}
338+
326339
DotName imageModelSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_IMAGE_MODEL_SUPPLIER;
327340
AnnotationValue imageModelSupplierValue = instance.value("imageModelSupplier");
328341
if (imageModelSupplierValue != null) {
@@ -381,8 +394,10 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
381394
cdiScope,
382395
chatModelName,
383396
moderationModelName,
384-
imageModelName));
397+
imageModelName,
398+
toolProviderClassName));
385399
}
400+
toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos));
386401

387402
for (String chatModelName : chatModelNames) {
388403
requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(chatModelName));
@@ -462,6 +477,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
462477
boolean needsModerationModelBean = false;
463478
boolean needsImageModelBean = false;
464479
Set<DotName> allToolNames = new HashSet<>();
480+
Set<DotName> allToolProviders = new HashSet<>();
465481

466482
for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
467483
ClassInfo declarativeAiServiceClassInfo = bi.getServiceClassInfo();
@@ -477,6 +493,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
477493

478494
List<String> toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
479495

496+
String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null
497+
? bi.getToolProviderClassDotName().toString()
498+
: null);
499+
480500
String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null
481501
? bi.getChatMemoryProviderSupplierClassDotName().toString()
482502
: null;
@@ -556,7 +576,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
556576
serviceClassName,
557577
chatLanguageModelSupplierClassName,
558578
streamingChatLanguageModelSupplierClassName,
559-
toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName,
579+
toolClassNames,
580+
toolProviderSupplierClassName,
581+
chatMemoryProviderSupplierClassName, retrieverClassName,
560582
retrievalAugmentorSupplierClassName,
561583
auditServiceClassSupplierName,
562584
moderationModelSupplierClassName,
@@ -668,6 +690,13 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
668690
needsImageModelBean = true;
669691
}
670692

693+
if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
694+
.equals(toolProviderSupplierClassName) && toolProviderSupplierClassName != null) {
695+
DotName toolProvider = DotName.createSimple(toolProviderSupplierClassName);
696+
configurator.addInjectionPoint(ClassType.create(toolProvider));
697+
allToolProviders.add(toolProvider);
698+
}
699+
671700
configurator
672701
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
673702
new Type[] { ClassType.create(OutputGuardrail.class) }, null))
@@ -700,6 +729,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
700729
if (needsImageModelBean) {
701730
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.IMAGE_MODEL));
702731
}
732+
if (!allToolProviders.isEmpty()) {
733+
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolProviders));
734+
}
703735
if (!allToolNames.isEmpty()) {
704736
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
705737
}
@@ -795,7 +827,7 @@ public void handleAiServices(
795827
for (ClassInfo classInfo : index.getKnownUsers(LangChain4jDotNames.AI_SERVICES)) {
796828
String className = classInfo.name().toString();
797829
if (className.startsWith("io.quarkiverse.langchain4j") || className.startsWith("dev.langchain4j")) { // TODO: this can be made smarter if
798-
// needed
830+
// needed
799831
continue;
800832
}
801833
try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
1616
private final DotName chatLanguageModelSupplierClassDotName;
1717
private final DotName streamingChatLanguageModelSupplierClassDotName;
1818
private final List<DotName> toolDotNames;
19+
private final DotName toolProviderClassDotName;
1920

2021
private final DotName chatMemoryProviderSupplierClassDotName;
2122
private final DotName retrieverClassDotName;
@@ -46,7 +47,8 @@ public DeclarativeAiServiceBuildItem(
4647
DotName cdiScope,
4748
String chatModelName,
4849
String moderationModelName,
49-
String imageModelName) {
50+
String imageModelName,
51+
DotName toolProviderClassDotName) {
5052
this.serviceClassInfo = serviceClassInfo;
5153
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
5254
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
@@ -63,6 +65,7 @@ public DeclarativeAiServiceBuildItem(
6365
this.chatModelName = chatModelName;
6466
this.moderationModelName = moderationModelName;
6567
this.imageModelName = imageModelName;
68+
this.toolProviderClassDotName = toolProviderClassDotName;
6669
}
6770

6871
public ClassInfo getServiceClassInfo() {
@@ -128,4 +131,8 @@ public String getModerationModelName() {
128131
public String getImageModelName() {
129132
return imageModelName;
130133
}
134+
135+
public DotName getToolProviderClassDotName() {
136+
return toolProviderClassDotName;
137+
}
131138
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ public class LangChain4jDotNames {
9999
static final DotName BEAN_IF_EXISTS_IMAGE_MODEL_SUPPLIER = DotName.createSimple(
100100
RegisterAiService.BeanIfExistsImageModelSupplier.class);
101101

102+
static final DotName BEAN_IF_EXISTS_TOOL_PROVIDER_SUPPLIER = DotName.createSimple(
103+
RegisterAiService.BeanIfExistsToolProviderSupplier.class);
104+
102105
static final DotName QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER = DotName.createSimple(
103106
QuarkusAiServiceContextQualifier.class);
104107

@@ -108,4 +111,5 @@ public class LangChain4jDotNames {
108111
static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class);
109112
static final DotName IMAGE = DotName.createSimple(Image.class);
110113
static final DotName RESULT = DotName.createSimple(Result.class);
114+
static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProcessor.class);
111115
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package io.quarkiverse.langchain4j.deployment;
2+
3+
import java.util.List;
4+
5+
import io.quarkiverse.langchain4j.deployment.devui.ToolProviderInfo;
6+
import io.quarkus.builder.item.SimpleBuildItem;
7+
8+
/**
9+
* Holds metadata about toolProviders discovered at build time
10+
*/
11+
public final class ToolProviderMetaBuildItem extends SimpleBuildItem {
12+
List<ToolProviderInfo> metadata;
13+
14+
public ToolProviderMetaBuildItem(List<ToolProviderInfo> metaData) {
15+
this.metadata = metaData;
16+
}
17+
18+
public List<ToolProviderInfo> getMetadata() {
19+
return metadata;
20+
}
21+
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import io.quarkiverse.langchain4j.deployment.DeclarativeAiServiceBuildItem;
99
import io.quarkiverse.langchain4j.deployment.EmbeddingStoreBuildItem;
1010
import io.quarkiverse.langchain4j.deployment.LangChain4jDotNames;
11+
import io.quarkiverse.langchain4j.deployment.ToolProviderMetaBuildItem;
1112
import io.quarkiverse.langchain4j.deployment.ToolsMetadataBuildItem;
1213
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
1314
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
@@ -28,6 +29,7 @@ public class LangChain4jDevUIProcessor {
2829

2930
@BuildStep(onlyIf = IsDevelopment.class)
3031
CardPageBuildItem cardPage(List<DeclarativeAiServiceBuildItem> aiServices,
32+
ToolProviderMetaBuildItem toolProviderMetaBuildItem,
3133
ToolsMetadataBuildItem toolsMetadataBuildItem,
3234
List<EmbeddingModelProviderCandidateBuildItem> embeddingModelCandidateBuildItems,
3335
List<InProcessEmbeddingBuildItem> inProcessEmbeddingModelBuildItems,
@@ -60,6 +62,10 @@ CardPageBuildItem cardPage(List<DeclarativeAiServiceBuildItem> aiServices,
6062

6163
additionalDevUiCardBuildItem.getBuildTimeData().forEach((k, v) -> card.addBuildTimeData(k, v));
6264
}
65+
66+
List<ToolProviderInfo> toolProviderInfos = toolProviderMetaBuildItem.getMetadata();
67+
card.addBuildTimeData("toolProviders", toolProviderInfos);
68+
6369
return card;
6470
}
6571

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package io.quarkiverse.langchain4j.deployment.devui;
2+
3+
public class ToolProviderInfo {
4+
private String className;
5+
private String aiServiceName;
6+
7+
public ToolProviderInfo(String className, String aiServiceName) {
8+
this.className = className;
9+
this.aiServiceName = aiServiceName;
10+
}
11+
12+
public String getClassName() {
13+
return className;
14+
}
15+
16+
public void setClassName(String className) {
17+
this.className = className;
18+
}
19+
20+
public String getAiServiceName() {
21+
return aiServiceName;
22+
}
23+
24+
public void setAiServiceName(String aiServiceName) {
25+
this.aiServiceName = aiServiceName;
26+
}
27+
}
Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import { LitElement, html, css} from 'lit';
1+
import {css, html, LitElement} from 'lit';
22
import '@vaadin/grid';
33
import '@vaadin/grid/vaadin-grid-sort-column.js';
44

5-
import {tools} from 'build-time-data';
5+
import {toolProviders, tools} from 'build-time-data';
66

77

88
export class QwcTools extends LitElement {
@@ -12,6 +12,7 @@ export class QwcTools extends LitElement {
1212
height: 100%;
1313
display: flex;
1414
}
15+
1516
vaadin-grid {
1617
margin-left: 15px;
1718
margin-right: 15px;
@@ -21,38 +22,57 @@ export class QwcTools extends LitElement {
2122

2223
static properties = {
2324
"_tools": {state: true},
25+
"_toolProviders": {state: true},
2426
}
2527

2628
constructor() {
2729
super();
2830
this._tools = tools;
31+
this._toolProviders = toolProviders;
2932
}
3033

3134
render() {
32-
if (this._tools) {
35+
if (this._toolProviders.length > 0) {
36+
return this._renderToolProvider();
37+
} else if (this._tools) {
3338
return this._renderToolTable();
3439
} else {
3540
return html`<span>No tools found</span>`;
3641
}
3742
}
3843

44+
_renderToolProvider() {
45+
return html`
46+
<vaadin-grid .items="${this._toolProviders}" theme="no-border">
47+
<vaadin-grid-sort-column auto-width
48+
path="className"
49+
header="Class name">
50+
</vaadin-grid-sort-column>
51+
<vaadin-grid-column auto-width
52+
path="aiServiceName"
53+
header="AiService">
54+
</vaadin-grid-column>
55+
</vaadin-grid>`;
56+
}
57+
3958
_renderToolTable() {
4059
return html`
41-
<vaadin-grid .items="${this._tools}" theme="no-border">
42-
<vaadin-grid-sort-column auto-width
43-
path="className"
44-
header="Class name">
45-
</vaadin-grid-sort-column>
46-
<vaadin-grid-column auto-width
47-
path="name"
48-
header="Tool name">
49-
</vaadin-grid-column>
50-
<vaadin-grid-column auto-width
51-
path="description"
52-
header="Description">
53-
</vaadin-grid-column>
54-
</vaadin-grid>`;
60+
<vaadin-grid .items="${this._tools}" theme="no-border">
61+
<vaadin-grid-sort-column auto-width
62+
path="className"
63+
header="Class name">
64+
</vaadin-grid-sort-column>
65+
<vaadin-grid-column auto-width
66+
path="name"
67+
header="Tool name">
68+
</vaadin-grid-column>
69+
<vaadin-grid-column auto-width
70+
path="description"
71+
header="Description">
72+
</vaadin-grid-column>
73+
</vaadin-grid>`;
5574
}
5675

5776
}
77+
5878
customElements.define('qwc-tools', QwcTools);

0 commit comments

Comments
 (0)