Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
<version>${revision}</version>
</dependency>

<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-observation-extension</artifactId>
<version>${revision}</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,37 @@
*/
package com.alibaba.cloud.ai.autoconfigure.arms;

import com.alibaba.cloud.ai.observation.client.prompt.PromptMetadataAwareChatClientObservationConvention;
import com.alibaba.cloud.ai.observation.model.ChatModelInputObservationHandler;
import com.alibaba.cloud.ai.observation.model.ChatModelOutputObservationHandler;
import com.alibaba.cloud.ai.observation.model.PromptMetadataAwareChatModelObservationConvention;
import com.alibaba.cloud.ai.tool.ObservableToolCallingManager;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;

/**
* @author Lumian
*/
@AutoConfiguration
@ConditionalOnClass(ChatModel.class)
@ConditionalOnProperty(prefix = ArmsCommonProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@EnableConfigurationProperties(ArmsCommonProperties.class)
public class ArmsAutoConfiguration {

@Bean
@ConditionalOnProperty(prefix = ArmsCommonProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@ConditionalOnProperty(prefix = ArmsCommonProperties.CONFIG_PREFIX, name = "tool.enabled", havingValue = "true")
ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor,
ObjectProvider<ObservationRegistry> observationRegistry) {
Expand All @@ -45,4 +56,32 @@ ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver,
.build();
}

@Bean
ChatClientObservationConvention chatClientObservationConvention() {
return new PromptMetadataAwareChatClientObservationConvention();
}

@Bean
ChatModelObservationConvention chatModelObservationConvention() {
return new PromptMetadataAwareChatModelObservationConvention();
}

@Bean
@ConditionalOnMissingBean(value = { ChatModelInputObservationHandler.class },
name = { "chatModelInputObservationHandler" })
@ConditionalOnProperty(prefix = ArmsCommonProperties.CONFIG_PREFIX, name = "model.capture-input",
havingValue = "true")
ChatModelInputObservationHandler armsChatModelInputObservationHandler(ArmsCommonProperties properties) {
return new ChatModelInputObservationHandler(properties.getModel().getMessageMode());
}

@Bean
@ConditionalOnMissingBean(value = { ChatModelOutputObservationHandler.class },
name = { "chatModelOutputObservationHandler" })
@ConditionalOnProperty(prefix = ArmsCommonProperties.CONFIG_PREFIX, name = "model.capture-output",
havingValue = "true")
ChatModelOutputObservationHandler armsChatModelOutputObservationHandler(ArmsCommonProperties properties) {
return new ChatModelOutputObservationHandler(properties.getModel().getMessageMode());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.alibaba.cloud.ai.autoconfigure.arms;

import com.alibaba.cloud.ai.observation.model.semconv.MessageMode;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
Expand All @@ -33,6 +34,10 @@ public class ArmsCommonProperties {
*/
private boolean enabled = false;

private ModelProperties model = new ModelProperties();

private ToolProperties tool = new ToolProperties();

public boolean isEnabled() {
return enabled;
}
Expand All @@ -41,4 +46,77 @@ public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public ModelProperties getModel() {
return model;
}

public void setModel(ModelProperties model) {
this.model = model;
}

public ToolProperties getTool() {
return tool;
}

public void setTool(ToolProperties tool) {
this.tool = tool;
}

public static class ModelProperties {

/**
* Enable Arms instrumentations and conventions.
*/
private boolean captureInput = false;

/**
* Enable Arms instrumentations and conventions.
*/
private boolean captureOutput = false;

/**
* Arms export type enumeration.
*/
private MessageMode messageMode = MessageMode.OPEN_TELEMETRY;

public boolean isCaptureInput() {
return captureInput;
}

public void setCaptureInput(boolean captureInput) {
this.captureInput = captureInput;
}

public boolean isCaptureOutput() {
return captureOutput;
}

public void setCaptureOutput(boolean captureOutput) {
this.captureOutput = captureOutput;
}

public MessageMode getMessageMode() {
return messageMode;
}

public void setMessageMode(MessageMode messageMode) {
this.messageMode = messageMode;
}

}

public static class ToolProperties {

private boolean enabled = true;

public boolean isEnabled() {
return enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

}

}
3 changes: 3 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
<module>spring-ai-alibaba-jmanus</module>
<module>spring-ai-alibaba-deepresearch</module>

<!-- Spring AI Alibaba Extension Modules -->
<module>spring-ai-alibaba-observation-extension</module>

<!-- Spring AI Alibaba Tool Call Plugins -->
<module>community/tool-calls/spring-ai-alibaba-starter-tool-calling-common</module>
<module>community/tool-calls/spring-ai-alibaba-starter-tool-calling-time</module>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
Expand All @@ -52,6 +53,7 @@
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @author yyyhhx
* @since 2024/7/31 10:57
*/
public class DashScopeEmbeddingModel extends AbstractEmbeddingModel {
Expand Down Expand Up @@ -232,4 +234,40 @@ public void setObservationConvention(EmbeddingModelObservationConvention observa
this.observationConvention = observationConvention;
}

/**
* Embed the provided texts and return the embeddings.
* @return The embeddings
*/
@Override
public List<float[]> embed(List<String> texts) {
Assert.notNull(texts, "Texts must not be null");
return this.call(new EmbeddingRequest(texts, defaultOptions))
.getResults()
.stream()
.map(Embedding::getOutput)
.toList();
}

/**
* Embed the provided documents and return the embeddings.
* @return The embeddings
*/
@Override
public List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
if (options.getModel() == null && options.getDimensions() == null && defaultOptions != null) {
options = defaultOptions;
}
return super.embed(documents, options, batchingStrategy);
}

/**
* Embed the provided documents and return the response.
* @return The embedding response
*/
@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
Assert.notNull(texts, "Texts must not be null");
return this.call(new EmbeddingRequest(texts, defaultOptions));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ protected KeyValues toolDescription(KeyValues keyValues, ArmsToolCallingObservat

protected KeyValues toolParameters(KeyValues keyValues, ArmsToolCallingObservationContext context) {
if (context.getToolCall().arguments() != null) {
return keyValues.and(HighCardinalityKeyNames.TOOL_PARAMETERS.asString(), context.getToolCall().arguments());
return keyValues.and(HighCardinalityKeyNames.TOOL_PARAMETERS.asString(), context.getToolCall().arguments())
.and(HighCardinalityKeyNames.INPUT_VALUE.asString(), context.getToolCall().arguments());
}
return keyValues;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ public static class Builder {

private String enableRankerKey;

private Boolean enableRanker;
private Boolean enableRanker = false;

private String rerankModelKey;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public StateGraph nl2sqlGraph(ChatClient.Builder chatClientBuilder) throws Graph
.addNode(SQL_GENERATE_NODE, node_async(new SqlGenerateNode(chatClientBuilder, nl2SqlService)))
.addNode(PLANNER_NODE, node_async(new PlannerNode(chatClientBuilder)))
.addNode(PLAN_EXECUTOR_NODE, node_async(new PlanExecutorNode()))
.addNode(SQL_EXECUTE_NODE, node_async(new SqlExecuteNode(dbAccessor, datasourceService)))
.addNode(SQL_EXECUTE_NODE, node_async(new SqlExecuteNode(dbAccessor, datasourceService, dbConfig)))
.addNode(PYTHON_GENERATE_NODE,
node_async(new PythonGenerateNode(codeExecutorProperties, chatClientBuilder)))
.addNode(PYTHON_EXECUTE_NODE, node_async(new PythonExecuteNode(codePoolExecutor)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ public class SqlExecuteNode extends AbstractPlanBasedNode {

private final DatasourceService datasourceService;

public SqlExecuteNode(Accessor dbAccessor, DatasourceService datasourceService) {
private final DbConfig dbConfig;

public SqlExecuteNode(Accessor dbAccessor, DatasourceService datasourceService, DbConfig dbConfig) {
super();
this.dbAccessor = dbAccessor;
this.datasourceService = datasourceService;
this.dbConfig = dbConfig;
}

@Override
Expand Down Expand Up @@ -97,7 +100,8 @@ private DbConfig getAgentDbConfig(OverAllState state) {
// Get the agent ID from the state
String agentIdStr = StateUtils.getStringValue(state, Constant.AGENT_ID);
if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
throw new RuntimeException("未找到智能体ID,无法获取数据源配置");
// 返回默认数据源
return dbConfig;
}

Integer agentId = Integer.valueOf(agentIdStr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public List<String> extractEvidences(String query) {
public List<String> extractEvidences(String query, String agentId) {
logger.debug("Extracting evidences for query: {} with agentId: {}", query, agentId);
List<Document> evidenceDocuments;
if (agentId != null) {
if (agentId != null && !agentId.trim().isEmpty()) {
evidenceDocuments = vectorStoreService.getDocumentsForAgent(agentId, query, "evidence");
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public List<Document> getTableDocuments(String query) {
* Get all table documents by keywords - supports agent isolation
*/
public List<Document> getTableDocuments(String query, String agentId) {
if (agentId != null) {
if (agentId != null && !agentId.trim().isEmpty()) {
return vectorStoreService.getDocumentsForAgent(agentId, query, "table");
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public List<BusinessKnowledgeDTO> getFieldByDataSetId(String dataSetId) {
return new BusinessKnowledgeDTO(rs.getString("business_term"), // businessTerm
rs.getString("description"), // description
rs.getString("synonyms"), // synonyms
rs.getObject("is_recall", boolean.class), // defaultRecall (convert to
rs.getObject("is_recall", Boolean.class), // defaultRecall (convert to
// Boolean)
rs.getString("data_set_id") // datasetId
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public List<SemanticModelDTO> getFieldByDataSetId(String dataSetId) {
return this.jdbcTemplate.query(FIELD_GET_BY_DATASET_IDS, new Object[] { dataSetId }, (rs, rowNum) -> {
return new SemanticModelDTO(rs.getString("agent_id"), rs.getString("origin_name"),
rs.getString("field_name"), rs.getString("synonyms"), rs.getString("description"),
rs.getObject("is_recall", boolean.class), rs.getObject("status", boolean.class),
rs.getObject("is_recall", Boolean.class), rs.getObject("status", Boolean.class),
rs.getString("type"), rs.getString("origin_description"));
});
}
Expand Down
Loading
Loading