Skip to content

Commit 39cfb20

Browse files
committed
add selectedTools as an attribute of the session queryConfiguration
1 parent 651389b commit 39cfb20

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

backend/src/main/java/com/cloudera/cai/rag/Types.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,12 @@ public record RagDataSource(
101101
@Nullable Long totalDocSize,
102102
boolean availableForDefaultProject) {}
103103

104+
@With
104105
public record QueryConfiguration(
105-
boolean enableHyde, boolean enableSummaryFilter, boolean enableToolCalling) {}
106+
boolean enableHyde,
107+
boolean enableSummaryFilter,
108+
boolean enableToolCalling,
109+
List<String> selectedTools) {}
106110

107111
@With
108112
@Builder

backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
@Component
5858
public class SessionRepository {
5959
public static final Types.QueryConfiguration DEFAULT_QUERY_CONFIGURATION =
60-
new Types.QueryConfiguration(false, true, false);
60+
new Types.QueryConfiguration(false, true, false, List.of());
6161
private final Jdbi jdbi;
6262
private final ObjectMapper objectMapper = new ObjectMapper();
6363

@@ -169,6 +169,9 @@ private Types.QueryConfiguration extractQueryConfiguration(RowView rowView)
169169
if (queryConfiguration == null) {
170170
return DEFAULT_QUERY_CONFIGURATION;
171171
}
172+
if (queryConfiguration.selectedTools() == null) {
173+
queryConfiguration = queryConfiguration.withSelectedTools(List.of());
174+
}
172175
return queryConfiguration;
173176
}
174177

backend/src/test/java/com/cloudera/cai/rag/TestData.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public static Types.Session createTestSessionInstance(
8282
"test-model",
8383
"test-rerank-model",
8484
3,
85-
new Types.QueryConfiguration(false, true, true));
85+
new Types.QueryConfiguration(false, true, true, List.of()));
8686
}
8787

8888
public static Types.CreateSession createSessionInstance(String sessionName) {
@@ -97,7 +97,7 @@ public static Types.CreateSession createSessionInstance(
9797
"test-model",
9898
"test-rerank-model",
9999
3,
100-
new Types.QueryConfiguration(false, true, true),
100+
new Types.QueryConfiguration(false, true, true, List.of()),
101101
projectId);
102102
}
103103

backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ void update() {
145145
.withRerankModel(updatedRerankModel)
146146
.withName(updatedName)
147147
.withProjectId(updatedProjectId)
148-
.withQueryConfiguration(new Types.QueryConfiguration(true, false, true)),
148+
.withQueryConfiguration(
149+
new Types.QueryConfiguration(true, false, true, List.of("foo"))),
149150
request);
150151

151152
assertThat(updatedSession.id()).isNotNull();
@@ -160,7 +161,7 @@ void update() {
160161
assertThat(updatedSession.createdById()).isEqualTo("test-user");
161162
assertThat(updatedSession.lastInteractionTime()).isNull();
162163
assertThat(updatedSession.queryConfiguration())
163-
.isEqualTo(new Types.QueryConfiguration(true, false, true));
164+
.isEqualTo(new Types.QueryConfiguration(true, false, true, List.of("foo")));
164165
}
165166

166167
@Test

llm-service/app/services/metadata_apis/session_metadata_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# DATA.
3737
#
3838
import json
39-
from dataclasses import dataclass
39+
from dataclasses import dataclass, field
4040
from datetime import datetime
4141
from typing import List, Any, Optional
4242

@@ -51,6 +51,7 @@ class SessionQueryConfiguration:
5151
enable_hyde: bool
5252
enable_summary_filter: bool
5353
enable_tool_calling: bool = False
54+
selected_tools: list[str] = field(default_factory=list)
5455

5556

5657
@dataclass
@@ -116,6 +117,7 @@ def session_from_java_response(data: dict[str, Any]) -> Session:
116117
enable_tool_calling=data["queryConfiguration"].get(
117118
"enableToolCalling", False
118119
),
120+
selected_tools=data["queryConfiguration"]["selectedTools"] or []
119121
),
120122
)
121123

@@ -133,6 +135,7 @@ def update_session(session: Session, user_name: Optional[str]) -> Session:
133135
"enableHyde": session.query_configuration.enable_hyde,
134136
"enableSummaryFilter": session.query_configuration.enable_summary_filter,
135137
"enableToolCalling": session.query_configuration.enable_tool_calling,
138+
"selectedTools": session.query_configuration.selected_tools
136139
},
137140
)
138141
headers = {

0 commit comments

Comments
 (0)