Skip to content

Commit e318bf2

Browse files
authored
fix(nl2sql): 修复表不支持联合主键的问题 (#2206)
* fix(nl2sql): mul primary keys * fix(nl2sql): Run `spring-javaformat:apply` to fix. * fix(studio): spring-javaformat:apply
1 parent 7c519bb commit e318bf2

File tree

5 files changed

+63
-36
lines changed

5 files changed

+63
-36
lines changed

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreService.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.alibaba.cloud.ai.request.SearchRequest;
2727
import com.alibaba.cloud.ai.service.base.BaseVectorStoreService;
2828
import com.google.gson.Gson;
29+
import org.apache.commons.collections.CollectionUtils;
2930
import org.slf4j.Logger;
3031
import org.slf4j.LoggerFactory;
3132
import org.springframework.ai.document.Document;
@@ -94,6 +95,7 @@ protected EmbeddingModel getEmbeddingModel() {
9495
* @param schemaInitRequest schema initialization request
9596
* @throws Exception if an error occurs
9697
*/
98+
@Override
9799
public Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception {
98100
log.info("Starting schema initialization for database: {}, schema: {}, tables: {}",
99101
schemaInitRequest.getDbConfig().getUrl(), schemaInitRequest.getDbConfig().getSchema(),
@@ -173,12 +175,18 @@ private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfi
173175
columnInfoBO.setSamples(gson.toJson(sampleColumn));
174176
}
175177

176-
ColumnInfoBO primaryColumnDO = columnInfoBOS.stream()
178+
List<ColumnInfoBO> targetPrimaryList = columnInfoBOS.stream()
177179
.filter(ColumnInfoBO::isPrimary)
178-
.findFirst()
179-
.orElse(new ColumnInfoBO());
180-
181-
tableInfoBO.setPrimaryKey(primaryColumnDO.getName());
180+
.collect(Collectors.toList());
181+
if (CollectionUtils.isNotEmpty(targetPrimaryList)) {
182+
List<String> columnNames = targetPrimaryList.stream()
183+
.map(ColumnInfoBO::getName)
184+
.collect(Collectors.toList());
185+
tableInfoBO.setPrimaryKeys(columnNames);
186+
}
187+
else {
188+
tableInfoBO.setPrimaryKeys(new ArrayList<>());
189+
}
182190
tableInfoBO
183191
.setForeignKey(String.join("、", foreignKeyMap.getOrDefault(tableInfoBO.getName(), new ArrayList<>())));
184192
}
@@ -216,7 +224,7 @@ public Document convertTableToDocument(TableInfoBO tableInfoBO) {
216224
metadata.put("name", tableInfoBO.getName());
217225
metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""));
218226
metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""));
219-
metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse(""));
227+
metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()));
220228
metadata.put("vectorType", "table");
221229
Document document = new Document(tableInfoBO.getName(), text, metadata);
222230
log.debug("Created table document with ID: {}", tableInfoBO.getName());
@@ -467,7 +475,7 @@ private Document convertTableToDocumentForAgent(String agentId, TableInfoBO tabl
467475
metadata.put("name", tableInfoBO.getName());
468476
metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""));
469477
metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""));
470-
metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse(""));
478+
metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()));
471479
metadata.put("vectorType", "table");
472480

473481
Document document = new Document(id, text, metadata);

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreServiceTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import com.alibaba.cloud.ai.request.DeleteRequest;
2525
import com.alibaba.cloud.ai.request.SchemaInitRequest;
2626
import com.alibaba.cloud.ai.request.SearchRequest;
27+
import com.google.common.collect.Lists;
2728
import com.google.gson.Gson;
2829
import org.junit.jupiter.api.BeforeEach;
2930
import org.junit.jupiter.api.Disabled;
@@ -435,7 +436,7 @@ private TableInfoBO createMockTableInfo() {
435436
.name("test_table")
436437
.description("Test table")
437438
.schema("test_schema")
438-
.primaryKey("id")
439+
.primaryKeys(Lists.newArrayList("id"))
439440
.foreignKey("foreign_key_info")
440441
.build();
441442
}

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-common/src/main/java/com/alibaba/cloud/ai/connector/bo/TableInfoBO.java

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package com.alibaba.cloud.ai.connector.bo;
1717

18+
import java.util.List;
1819
import java.util.Objects;
1920

2021
public class TableInfoBO extends DdlBaseBO {
@@ -29,19 +30,19 @@ public class TableInfoBO extends DdlBaseBO {
2930

3031
private String foreignKey;
3132

32-
private String primaryKey;
33+
private List<String> primaryKeys;
3334

3435
public TableInfoBO() {
3536
}
3637

3738
public TableInfoBO(String schema, String name, String description, String type, String foreignKey,
38-
String primaryKey) {
39+
List<String> primaryKeys) {
3940
this.schema = schema;
4041
this.name = name;
4142
this.description = description;
4243
this.type = type;
4344
this.foreignKey = foreignKey;
44-
this.primaryKey = primaryKey;
45+
this.primaryKeys = primaryKeys;
4546
}
4647

4748
public String getSchema() {
@@ -84,36 +85,38 @@ public void setForeignKey(String foreignKey) {
8485
this.foreignKey = foreignKey;
8586
}
8687

87-
public String getPrimaryKey() {
88-
return primaryKey;
88+
public List<String> getPrimaryKeys() {
89+
return primaryKeys;
8990
}
9091

91-
public void setPrimaryKey(String primaryKey) {
92-
this.primaryKey = primaryKey;
92+
public void setPrimaryKeys(List<String> primaryKeys) {
93+
this.primaryKeys = primaryKeys;
9394
}
9495

9596
@Override
9697
public String toString() {
9798
return "TableInfoBO{" + "schema='" + schema + '\'' + ", name='" + name + '\'' + ", description='" + description
98-
+ '\'' + ", type='" + type + '\'' + ", foreignKey='" + foreignKey + '\'' + ", primaryKey='" + primaryKey
99-
+ '\'' + '}';
99+
+ '\'' + ", type='" + type + '\'' + ", foreignKey='" + foreignKey + '\'' + ", primaryKeys="
100+
+ primaryKeys + '}';
100101
}
101102

102103
@Override
103104
public boolean equals(Object o) {
104-
if (this == o)
105+
if (this == o) {
105106
return true;
106-
if (o == null || getClass() != o.getClass())
107+
}
108+
if (o == null || getClass() != o.getClass()) {
107109
return false;
110+
}
108111
TableInfoBO that = (TableInfoBO) o;
109112
return Objects.equals(schema, that.schema) && Objects.equals(name, that.name)
110113
&& Objects.equals(description, that.description) && Objects.equals(type, that.type)
111-
&& Objects.equals(foreignKey, that.foreignKey) && Objects.equals(primaryKey, that.primaryKey);
114+
&& Objects.equals(foreignKey, that.foreignKey) && Objects.equals(primaryKeys, that.primaryKeys);
112115
}
113116

114117
@Override
115118
public int hashCode() {
116-
return Objects.hash(schema, name, description, type, foreignKey, primaryKey);
119+
return Objects.hash(schema, name, description, type, foreignKey, primaryKeys);
117120
}
118121

119122
public static TableInfoBOBuilder builder() {
@@ -132,7 +135,7 @@ public static final class TableInfoBOBuilder {
132135

133136
private String foreignKey;
134137

135-
private String primaryKey;
138+
private List<String> primaryKeys;
136139

137140
private TableInfoBOBuilder() {
138141
}
@@ -166,8 +169,8 @@ public TableInfoBOBuilder foreignKey(String foreignKey) {
166169
return this;
167170
}
168171

169-
public TableInfoBOBuilder primaryKey(String primaryKey) {
170-
this.primaryKey = primaryKey;
172+
public TableInfoBOBuilder primaryKeys(List<String> primaryKeys) {
173+
this.primaryKeys = primaryKeys;
171174
return this;
172175
}
173176

@@ -178,7 +181,7 @@ public TableInfoBO build() {
178181
tableInfoBO.setDescription(description);
179182
tableInfoBO.setType(type);
180183
tableInfoBO.setForeignKey(foreignKey);
181-
tableInfoBO.setPrimaryKey(primaryKey);
184+
tableInfoBO.setPrimaryKeys(primaryKeys);
182185
return tableInfoBO;
183186
}
184187

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/AnalyticDbVectorStoreManagementService.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.fasterxml.jackson.core.type.TypeReference;
3737
import com.fasterxml.jackson.databind.ObjectMapper;
3838
import com.google.gson.Gson;
39+
import org.apache.commons.collections.CollectionUtils;
3940
import org.springframework.ai.document.Document;
4041
import org.springframework.ai.embedding.EmbeddingModel;
4142
import org.springframework.ai.vectorstore.VectorStore;
@@ -271,12 +272,19 @@ private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfi
271272
columnInfoBO.setSamples(gson.toJson(sampleColumn));
272273
}
273274

274-
ColumnInfoBO primaryColumnDO = columnInfoBOS.stream()
275+
List<ColumnInfoBO> targetPrimaryList = columnInfoBOS.stream()
275276
.filter(ColumnInfoBO::isPrimary)
276-
.findFirst()
277-
.orElse(new ColumnInfoBO());
277+
.collect(Collectors.toList());
278+
if (CollectionUtils.isNotEmpty(targetPrimaryList)) {
279+
List<String> columnNames = targetPrimaryList.stream()
280+
.map(ColumnInfoBO::getName)
281+
.collect(Collectors.toList());
282+
tableInfoBO.setPrimaryKeys(columnNames);
283+
}
284+
else {
285+
tableInfoBO.setPrimaryKeys(new ArrayList<>());
286+
}
278287

279-
tableInfoBO.setPrimaryKey(primaryColumnDO.getName());
280288
tableInfoBO.setForeignKey(String.join("、", buildForeignKeyList(tableInfoBO.getName())));
281289
}
282290

@@ -313,7 +321,7 @@ public Document convertTableToDocument(TableInfoBO tableInfoBO) {
313321
Map<String, Object> metadata = Map.of("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""), "name",
314322
tableInfoBO.getName(), "description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""),
315323
"foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""), "primaryKey",
316-
Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse(""), "vectorType", "table");
324+
Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()), "vectorType", "table");
317325
return new Document(tableInfoBO.getName(), text, metadata);
318326
}
319327

spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/SimpleVectorStoreManagementService.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.alibaba.cloud.ai.request.EvidenceRequest;
2929
import com.alibaba.cloud.ai.request.SchemaInitRequest;
3030
import com.google.gson.Gson;
31+
import org.apache.commons.collections.CollectionUtils;
3132
import org.springframework.ai.document.Document;
3233
import org.springframework.ai.document.MetadataMode;
3334
import org.springframework.ai.vectorstore.SearchRequest;
@@ -150,12 +151,18 @@ private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfi
150151
columnInfoBO.setSamples(gson.toJson(sampleColumn));
151152
}
152153

153-
ColumnInfoBO primaryColumnDO = columnInfoBOS.stream()
154+
List<ColumnInfoBO> targetPrimaryList = columnInfoBOS.stream()
154155
.filter(ColumnInfoBO::isPrimary)
155-
.findFirst()
156-
.orElse(new ColumnInfoBO());
157-
158-
tableInfoBO.setPrimaryKey(primaryColumnDO.getName());
156+
.collect(Collectors.toList());
157+
if (CollectionUtils.isNotEmpty(targetPrimaryList)) {
158+
List<String> columnNames = targetPrimaryList.stream()
159+
.map(ColumnInfoBO::getName)
160+
.collect(Collectors.toList());
161+
tableInfoBO.setPrimaryKeys(columnNames);
162+
}
163+
else {
164+
tableInfoBO.setPrimaryKeys(new ArrayList<>());
165+
}
159166
tableInfoBO.setForeignKey(String.join("、", buildForeignKeyList(tableInfoBO.getName())));
160167
}
161168

@@ -180,7 +187,7 @@ public Document convertTableToDocument(TableInfoBO tableInfoBO) {
180187
Map<String, Object> metadata = Map.of("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""), "name",
181188
tableInfoBO.getName(), "description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""),
182189
"foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""), "primaryKey",
183-
Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse(""), "vectorType", "table");
190+
Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()), "vectorType", "table");
184191
return new Document(tableInfoBO.getName(), text, metadata);
185192
}
186193

0 commit comments

Comments
 (0)