diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreService.java index e94c8eee3c..a4603d7cf2 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/main/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreService.java @@ -26,6 +26,7 @@ import com.alibaba.cloud.ai.request.SearchRequest; import com.alibaba.cloud.ai.service.base.BaseVectorStoreService; import com.google.gson.Gson; +import org.apache.commons.collections.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; @@ -94,6 +95,7 @@ protected EmbeddingModel getEmbeddingModel() { * @param schemaInitRequest schema initialization request * @throws Exception if an error occurs */ + @Override public Boolean schema(SchemaInitRequest schemaInitRequest) throws Exception { log.info("Starting schema initialization for database: {}, schema: {}, tables: {}", schemaInitRequest.getDbConfig().getUrl(), schemaInitRequest.getDbConfig().getSchema(), @@ -173,12 +175,18 @@ private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfi columnInfoBO.setSamples(gson.toJson(sampleColumn)); } - ColumnInfoBO primaryColumnDO = columnInfoBOS.stream() + List targetPrimaryList = columnInfoBOS.stream() .filter(ColumnInfoBO::isPrimary) - .findFirst() - .orElse(new ColumnInfoBO()); - - tableInfoBO.setPrimaryKey(primaryColumnDO.getName()); + .collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(targetPrimaryList)) { + List columnNames = targetPrimaryList.stream() + .map(ColumnInfoBO::getName) + .collect(Collectors.toList()); + tableInfoBO.setPrimaryKeys(columnNames); + } + else { + tableInfoBO.setPrimaryKeys(new ArrayList<>()); + } tableInfoBO .setForeignKey(String.join("、", foreignKeyMap.getOrDefault(tableInfoBO.getName(), new ArrayList<>()))); } @@ -216,7 +224,7 @@ public Document convertTableToDocument(TableInfoBO tableInfoBO) { metadata.put("name", tableInfoBO.getName()); metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse("")); metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse("")); - metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse("")); + metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>())); metadata.put("vectorType", "table"); Document document = new Document(tableInfoBO.getName(), text, metadata); log.debug("Created table document with ID: {}", tableInfoBO.getName()); @@ -467,7 +475,7 @@ private Document convertTableToDocumentForAgent(String agentId, TableInfoBO tabl metadata.put("name", tableInfoBO.getName()); metadata.put("description", Optional.ofNullable(tableInfoBO.getDescription()).orElse("")); metadata.put("foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse("")); - metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse("")); + metadata.put("primaryKey", Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>())); metadata.put("vectorType", "table"); Document document = new Document(id, text, metadata); diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreServiceTest.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreServiceTest.java index f53c3df37c..803daaf031 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreServiceTest.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-chat/src/test/java/com/alibaba/cloud/ai/service/simple/SimpleVectorStoreServiceTest.java @@ -24,6 +24,7 @@ import com.alibaba.cloud.ai.request.DeleteRequest; import com.alibaba.cloud.ai.request.SchemaInitRequest; import com.alibaba.cloud.ai.request.SearchRequest; +import com.google.common.collect.Lists; import com.google.gson.Gson; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -435,7 +436,7 @@ private TableInfoBO createMockTableInfo() { .name("test_table") .description("Test table") .schema("test_schema") - .primaryKey("id") + .primaryKeys(Lists.newArrayList("id")) .foreignKey("foreign_key_info") .build(); } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-common/src/main/java/com/alibaba/cloud/ai/connector/bo/TableInfoBO.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-common/src/main/java/com/alibaba/cloud/ai/connector/bo/TableInfoBO.java index b9ccb4214a..66d365bd78 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-common/src/main/java/com/alibaba/cloud/ai/connector/bo/TableInfoBO.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-common/src/main/java/com/alibaba/cloud/ai/connector/bo/TableInfoBO.java @@ -15,6 +15,7 @@ */ package com.alibaba.cloud.ai.connector.bo; +import java.util.List; import java.util.Objects; public class TableInfoBO extends DdlBaseBO { @@ -29,19 +30,19 @@ public class TableInfoBO extends DdlBaseBO { private String foreignKey; - private String primaryKey; + private List primaryKeys; public TableInfoBO() { } public TableInfoBO(String schema, String name, String description, String type, String foreignKey, - String primaryKey) { + List primaryKeys) { this.schema = schema; this.name = name; this.description = description; this.type = type; this.foreignKey = foreignKey; - this.primaryKey = primaryKey; + this.primaryKeys = primaryKeys; } public String getSchema() { @@ -84,36 +85,38 @@ public void setForeignKey(String foreignKey) { this.foreignKey = foreignKey; } - public String getPrimaryKey() { - return primaryKey; + public List getPrimaryKeys() { + return primaryKeys; } - public void setPrimaryKey(String primaryKey) { - this.primaryKey = primaryKey; + public void setPrimaryKeys(List primaryKeys) { + this.primaryKeys = primaryKeys; } @Override public String toString() { return "TableInfoBO{" + "schema='" + schema + '\'' + ", name='" + name + '\'' + ", description='" + description - + '\'' + ", type='" + type + '\'' + ", foreignKey='" + foreignKey + '\'' + ", primaryKey='" + primaryKey - + '\'' + '}'; + + '\'' + ", type='" + type + '\'' + ", foreignKey='" + foreignKey + '\'' + ", primaryKeys=" + + primaryKeys + '}'; } @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || getClass() != o.getClass()) { return false; + } TableInfoBO that = (TableInfoBO) o; return Objects.equals(schema, that.schema) && Objects.equals(name, that.name) && Objects.equals(description, that.description) && Objects.equals(type, that.type) - && Objects.equals(foreignKey, that.foreignKey) && Objects.equals(primaryKey, that.primaryKey); + && Objects.equals(foreignKey, that.foreignKey) && Objects.equals(primaryKeys, that.primaryKeys); } @Override public int hashCode() { - return Objects.hash(schema, name, description, type, foreignKey, primaryKey); + return Objects.hash(schema, name, description, type, foreignKey, primaryKeys); } public static TableInfoBOBuilder builder() { @@ -132,7 +135,7 @@ public static final class TableInfoBOBuilder { private String foreignKey; - private String primaryKey; + private List primaryKeys; private TableInfoBOBuilder() { } @@ -166,8 +169,8 @@ public TableInfoBOBuilder foreignKey(String foreignKey) { return this; } - public TableInfoBOBuilder primaryKey(String primaryKey) { - this.primaryKey = primaryKey; + public TableInfoBOBuilder primaryKeys(List primaryKeys) { + this.primaryKeys = primaryKeys; return this; } @@ -178,7 +181,7 @@ public TableInfoBO build() { tableInfoBO.setDescription(description); tableInfoBO.setType(type); tableInfoBO.setForeignKey(foreignKey); - tableInfoBO.setPrimaryKey(primaryKey); + tableInfoBO.setPrimaryKeys(primaryKeys); return tableInfoBO; } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/AnalyticDbVectorStoreManagementService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/AnalyticDbVectorStoreManagementService.java index dd28f5b086..a2bb75587c 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/AnalyticDbVectorStoreManagementService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/AnalyticDbVectorStoreManagementService.java @@ -36,6 +36,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.gson.Gson; +import org.apache.commons.collections.CollectionUtils; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.vectorstore.VectorStore; @@ -271,12 +272,19 @@ private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfi columnInfoBO.setSamples(gson.toJson(sampleColumn)); } - ColumnInfoBO primaryColumnDO = columnInfoBOS.stream() + List targetPrimaryList = columnInfoBOS.stream() .filter(ColumnInfoBO::isPrimary) - .findFirst() - .orElse(new ColumnInfoBO()); + .collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(targetPrimaryList)) { + List columnNames = targetPrimaryList.stream() + .map(ColumnInfoBO::getName) + .collect(Collectors.toList()); + tableInfoBO.setPrimaryKeys(columnNames); + } + else { + tableInfoBO.setPrimaryKeys(new ArrayList<>()); + } - tableInfoBO.setPrimaryKey(primaryColumnDO.getName()); tableInfoBO.setForeignKey(String.join("、", buildForeignKeyList(tableInfoBO.getName()))); } @@ -313,7 +321,7 @@ public Document convertTableToDocument(TableInfoBO tableInfoBO) { Map metadata = Map.of("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""), "name", tableInfoBO.getName(), "description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""), "foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""), "primaryKey", - Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse(""), "vectorType", "table"); + Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()), "vectorType", "table"); return new Document(tableInfoBO.getName(), text, metadata); } diff --git a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/SimpleVectorStoreManagementService.java b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/SimpleVectorStoreManagementService.java index 1a475bd333..25ee9685b0 100644 --- a/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/SimpleVectorStoreManagementService.java +++ b/spring-ai-alibaba-nl2sql/spring-ai-alibaba-nl2sql-management/src/main/java/com/alibaba/cloud/ai/service/SimpleVectorStoreManagementService.java @@ -28,6 +28,7 @@ import com.alibaba.cloud.ai.request.EvidenceRequest; import com.alibaba.cloud.ai.request.SchemaInitRequest; import com.google.gson.Gson; +import org.apache.commons.collections.CollectionUtils; import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.vectorstore.SearchRequest; @@ -150,12 +151,18 @@ private void processTable(TableInfoBO tableInfoBO, DbQueryParameter dqp, DbConfi columnInfoBO.setSamples(gson.toJson(sampleColumn)); } - ColumnInfoBO primaryColumnDO = columnInfoBOS.stream() + List targetPrimaryList = columnInfoBOS.stream() .filter(ColumnInfoBO::isPrimary) - .findFirst() - .orElse(new ColumnInfoBO()); - - tableInfoBO.setPrimaryKey(primaryColumnDO.getName()); + .collect(Collectors.toList()); + if (CollectionUtils.isNotEmpty(targetPrimaryList)) { + List columnNames = targetPrimaryList.stream() + .map(ColumnInfoBO::getName) + .collect(Collectors.toList()); + tableInfoBO.setPrimaryKeys(columnNames); + } + else { + tableInfoBO.setPrimaryKeys(new ArrayList<>()); + } tableInfoBO.setForeignKey(String.join("、", buildForeignKeyList(tableInfoBO.getName()))); } @@ -180,7 +187,7 @@ public Document convertTableToDocument(TableInfoBO tableInfoBO) { Map metadata = Map.of("schema", Optional.ofNullable(tableInfoBO.getSchema()).orElse(""), "name", tableInfoBO.getName(), "description", Optional.ofNullable(tableInfoBO.getDescription()).orElse(""), "foreignKey", Optional.ofNullable(tableInfoBO.getForeignKey()).orElse(""), "primaryKey", - Optional.ofNullable(tableInfoBO.getPrimaryKey()).orElse(""), "vectorType", "table"); + Optional.ofNullable(tableInfoBO.getPrimaryKeys()).orElse(new ArrayList<>()), "vectorType", "table"); return new Document(tableInfoBO.getName(), text, metadata); }