Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions spring-ai-alibaba-nl2sql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
<netty-resolver-dns-native-macos.version>4.1.114.Final</netty-resolver-dns-native-macos.version>
<netty-transport-native-epoll.version>4.1.114.Final</netty-transport-native-epoll.version>
<jsonschema.version>4.37.0</jsonschema.version>
<h2.version>2.3.232</h2.version>
</properties>

<dependencyManagement>
Expand All @@ -60,6 +61,12 @@
<artifactId>postgresql</artifactId>
<version>${postgresql.version}</version>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>${h2.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>commons-collections</groupId>
<artifactId>commons-collections</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,23 @@
@Configuration(proxyBeanMethods = false)
public class BaseDefaultConfiguration {

private static final Logger logger = LoggerFactory.getLogger(Nl2sqlConfiguration.class);
private static final Logger logger = LoggerFactory.getLogger(BaseDefaultConfiguration.class);

private final Accessor dbAccessor;

private final DbConfig dbConfig;

private BaseDefaultConfiguration(@Qualifier("mysqlAccessor") Accessor accessor, DbConfig dbConfig) {
this.dbAccessor = accessor;
private BaseDefaultConfiguration(DbConfig dbConfig, @Qualifier("mysqlAccessor") Accessor mysqlDbAccessor,
@Qualifier("h2Accessor") Accessor h2DbAccessor, @Qualifier("postgreAccessor") Accessor postgreDbAccessor) {
if ("h2".equals(dbConfig.getDialectType())) {
dbAccessor = h2DbAccessor;
}
else if ("postgre".equals(dbConfig.getDialectType())) {
dbAccessor = postgreDbAccessor;
}
else {
dbAccessor = mysqlDbAccessor;
}
this.dbConfig = dbConfig;
}

Expand All @@ -70,4 +79,10 @@ public BaseSchemaService defaultSchemaService(
return new SimpleSchemaService(dbConfig, gson, vectorStoreService);
}

@Bean("dbAccessor")
@ConditionalOnMissingBean(name = "dbAccessor")
public Accessor dbAccessor() {
return dbAccessor;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public class Nl2sqlConfiguration {

public Nl2sqlConfiguration(@Qualifier("nl2SqlServiceImpl") BaseNl2SqlService nl2SqlService,
@Qualifier("schemaServiceImpl") BaseSchemaService schemaService,
@Qualifier("mysqlAccessor") Accessor dbAccessor, DbConfig dbConfig,
@Qualifier("dbAccessor") Accessor dbAccessor, DbConfig dbConfig,
CodeExecutorProperties codeExecutorProperties, CodePoolExecutorService codePoolExecutor,
SemanticModelRecallService semanticModelRecallService,
BusinessKnowledgeRecallService businessKnowledgeRecallService, UserPromptConfigService promptConfigService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ else if ("postgresql".equalsIgnoreCase(type)) {
"jdbc:postgresql://%s:%d/%s?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=Asia/Shanghai",
host, port, databaseName);
}
else if ("h2".equalsIgnoreCase(type)) {
this.connectionUrl = String.format(
"jdbc:h2:mem:%s;DB_CLOSE_DELAY=-1;DATABASE_TO_LOWER=true;MODE=MySQL;DB_CLOSE_ON_EXIT=FALSE",
databaseName);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
dbConfig.setConnectionType("jdbc");
dbConfig.setDialectType("postgresql");
}
else if ("h2".equalsIgnoreCase(datasource.getType())) {
dbConfig.setConnectionType("jdbc");
dbConfig.setDialectType("h2");
}
else {
throw new RuntimeException("不支持的数据库类型: " + datasource.getType());
}
Expand Down Expand Up @@ -181,6 +185,7 @@ private Map<String, Object> executeSqlQuery(OverAllState state, Integer currentS
// Execute business logic first - actual SQL execution
DbQueryParameter dbQueryParameter = new DbQueryParameter();
dbQueryParameter.setSql(sqlQuery);
dbQueryParameter.setSchema(dbConfig.getSchema());

try {
// Execute SQL query and get results immediately
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class AnalyticNl2SqlService extends BaseNl2SqlService {
@Autowired
public AnalyticNl2SqlService(@Qualifier("analyticVectorStoreService") BaseVectorStoreService vectorStoreService,
@Qualifier("analyticSchemaService") BaseSchemaService schemaService, LlmService aiService,
@Qualifier("mysqlAccessor") Accessor dbAccessor, DbConfig dbConfig) {
@Qualifier("dbAccessor") Accessor dbAccessor, DbConfig dbConfig) {
super(vectorStoreService, schemaService, aiService, dbAccessor, dbConfig);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class SimpleNl2SqlService extends BaseNl2SqlService {
@Autowired
public SimpleNl2SqlService(@Qualifier("simpleVectorStoreService") BaseVectorStoreService vectorStoreService,
@Qualifier("simpleSchemaService") BaseSchemaService schemaService, LlmService aiService,
@Qualifier("mysqlAccessor") Accessor accessor, DbConfig dbConfig) {
@Qualifier("dbAccessor") Accessor accessor, DbConfig dbConfig) {

super(vectorStoreService, schemaService, aiService, accessor, dbConfig);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public class SimpleVectorStoreService extends BaseVectorStoreService {

@Autowired
public SimpleVectorStoreService(EmbeddingModel embeddingModel, Gson gson,
@Qualifier("mysqlAccessor") Accessor dbAccessor, DbConfig dbConfig,
@Qualifier("dbAccessor") Accessor dbAccessor, DbConfig dbConfig,
AgentVectorStoreManager agentVectorStoreManager) {
log.info("Initializing SimpleVectorStoreService with EmbeddingModel: {}",
embeddingModel.getClass().getSimpleName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public DBConnectionPool getPoolByType(String type) {
case "mysql", "mysqljdbcconnectionpool" -> poolMap.get("mysqlJdbcConnectionPool");
case "postgresql", "postgres", "postgresqljdbcconnectionpool" ->
poolMap.get("postgreSqlJdbcConnectionPool");
case "h2", "h2jdbcconnectionpool" -> poolMap.get("h2JdbcConnectionPool");
default -> null;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public static ResultSetBO executeSqlAndReturnObject(Connection connection, Strin
statement.execute("set search_path = '" + schema + "';");
}
}
else if (dialect.equals(DatabaseDialectEnum.H2.code)) {
if (StringUtils.isNotEmpty(schema)) {
statement.execute("use " + schema + ";");
}
}

try (ResultSet rs = statement.executeQuery(sql)) {
return ResultSetBuilder.buildFrom(rs, schema);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.cloud.ai.connector.h2;

import com.alibaba.cloud.ai.connector.DBConnectionPool;
import com.alibaba.cloud.ai.connector.accessor.defaults.AbstractAccessor;
import com.alibaba.cloud.ai.connector.support.DdlFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;

/**
* @author HunterPorter
* @author <a href="mailto:zongpeng_hzp@163.com">HunterPorter</a>
*/

@Service("h2Accessor")
public class H2DBAccessor extends AbstractAccessor {

private final static String ACCESSOR_TYPE = "H2_Accessor";

protected H2DBAccessor(DdlFactory ddlFactory,
@Qualifier("h2JdbcConnectionPool") DBConnectionPool dbConnectionPool) {

super(ddlFactory, dbConnectionPool);
}

@Override
public String getDbAccessorType() {

return ACCESSOR_TYPE;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.connector.h2;

import com.alibaba.cloud.ai.connector.AbstractDBConnectionPool;
import com.alibaba.cloud.ai.enums.DatabaseDialectEnum;
import com.alibaba.cloud.ai.enums.ErrorCodeEnum;
import org.springframework.stereotype.Service;

import static com.alibaba.cloud.ai.enums.ErrorCodeEnum.DATABASE_NOT_EXIST_42000;
import static com.alibaba.cloud.ai.enums.ErrorCodeEnum.DATASOURCE_CONNECTION_FAILURE_08S01;
import static com.alibaba.cloud.ai.enums.ErrorCodeEnum.OTHERS;
import static com.alibaba.cloud.ai.enums.ErrorCodeEnum.PASSWORD_ERROR_28000;

@Service("h2JdbcConnectionPool")
public class H2JdbcConnectionPool extends AbstractDBConnectionPool {

@Override
public DatabaseDialectEnum getDialect() {
return DatabaseDialectEnum.H2;
}

@Override
public String getDriver() {
return "org.h2.Driver";
}

@Override
public ErrorCodeEnum errorMapping(String sqlState) {
ErrorCodeEnum ret = ErrorCodeEnum.fromCode(sqlState);
if (ret != null) {
return ret;
}
return switch (sqlState) {
case "08S01" -> DATASOURCE_CONNECTION_FAILURE_08S01;
case "28000" -> PASSWORD_ERROR_28000;
case "42000" -> DATABASE_NOT_EXIST_42000;
default -> OTHERS;
};
}

}
Loading