Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,21 @@ public List<Document> getTableDocumentsByDatasource(Integer datasourceId, String

Filter.Expression filterExpression = DynamicFilterService.combineWithAnd(conditions);

// 执行向量检索
// 语义优先:按用户查询的向量相似度召回
SearchRequest searchRequest = SearchRequest.builder()
.query(query)
.topK(tableTopK)
.similarityThreshold(tableThreshold)
.filterExpression(filterExpression)
.build();
List<Document> results = agentVectorStoreService.similaritySearch(searchRequest);

return agentVectorStoreService.getDocumentsOnlyByFilter(filterExpression, tableTopK);
// 降级兜底:语义召回为空时,回退到全量元数据过滤
if (results.isEmpty()) {
log.info("Semantic recall returned empty for query [{}], falling back to metadata filter", query);
results = agentVectorStoreService.getDocumentsOnlyByFilter(filterExpression, tableTopK);
}
return results;
}

private List<String> getMissingTableNamesWithForeignKeySet(List<Document> tableDocuments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.alibaba.cloud.ai.dataagent.dto.search.AgentSearchRequest;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;

import java.util.List;
Expand Down Expand Up @@ -45,6 +46,9 @@ public interface AgentVectorStoreService {
// 通过元数据过滤精确查找
List<Document> getDocumentsOnlyByFilter(Filter.Expression filterExpression, Integer topK);

// 通过完整 SearchRequest 执行向量相似度检索
List<Document> similaritySearch(SearchRequest searchRequest);

boolean hasDocuments(String agentId);

void addDocuments(String agentId, List<Document> documents);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ public List<Document> getDocumentsOnlyByFilter(Filter.Expression filterExpressio
return vectorStore.similaritySearch(searchRequest);
}

@Override
public List<Document> similaritySearch(SearchRequest searchRequest) {
Assert.notNull(searchRequest, "searchRequest cannot be null.");
return vectorStore.similaritySearch(searchRequest);
}

@Override
public boolean hasDocuments(String agentId) {
// 类似 MySQL 的 LIMIT 1,只检查是否存在文档
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,50 @@
*/
package com.alibaba.cloud.ai.dataagent.workflow.node;

import static com.alibaba.cloud.ai.dataagent.constant.Constant.AGENT_ID;
import static com.alibaba.cloud.ai.dataagent.constant.Constant.COLUMN_DOCUMENTS__FOR_SCHEMA_OUTPUT;
import static com.alibaba.cloud.ai.dataagent.constant.Constant.INPUT_KEY;
import static com.alibaba.cloud.ai.dataagent.constant.Constant.QUERY_ENHANCE_NODE_OUTPUT;
import static com.alibaba.cloud.ai.dataagent.constant.Constant.SCHEMA_RECALL_NODE_OUTPUT;
import static com.alibaba.cloud.ai.dataagent.constant.Constant.TABLE_DOCUMENTS_FOR_SCHEMA_OUTPUT;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import com.alibaba.cloud.ai.dataagent.dto.prompt.QueryEnhanceOutputDTO;
import com.alibaba.cloud.ai.dataagent.entity.AgentDatasource;
import com.alibaba.cloud.ai.dataagent.mapper.AgentDatasourceMapper;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.cloud.ai.dataagent.service.datasource.AgentDatasourceService;
import com.alibaba.cloud.ai.dataagent.service.schema.SchemaService;
import com.alibaba.cloud.ai.dataagent.util.ChatResponseUtil;
import com.alibaba.cloud.ai.dataagent.util.FluxUtil;
import com.alibaba.cloud.ai.dataagent.util.StateUtil;
import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;

import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static com.alibaba.cloud.ai.dataagent.constant.Constant.*;

/**
* Schema recall node that retrieves relevant database schema information based on
* keywords and intent.
* Schema recall node that retrieves relevant database schema information based
* on keywords and intent.
*
* This node is responsible for: - Recalling relevant tables based on user input -
* Retrieving column documents based on extracted keywords - Organizing schema information
* for subsequent processing - Providing streaming feedback during recall process
* This node is responsible for: - Recalling relevant tables based on user input
* - Retrieving column documents based on extracted keywords - Organizing schema
* information for subsequent processing - Providing streaming feedback during
* recall process
*
* @author zhangshenghang
*/
Expand All @@ -54,100 +67,209 @@
@AllArgsConstructor
public class SchemaRecallNode implements NodeAction {

private static final int MAX_DISPLAY_TABLES = 10;

private final SchemaService schemaService;

private final AgentDatasourceMapper agentDatasourceMapper;

private final AgentDatasourceService agentDatasourceService;

@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String canonicalQuery = resolveCanonicalQuery(state);
Long agentId = parseAgentId(StateUtil.getStringValue(state, AGENT_ID));
if (agentId == null) {
return buildEarlyExitResult(state, """

系统未能识别当前智能体标识,无法继续检索 Schema。
请刷新页面后重试;若仍失败,请联系管理员排查 Agent 配置。
流程已终止。
""");
}

// get input information
QueryEnhanceOutputDTO queryEnhanceOutputDTO = StateUtil.getObjectValue(state, QUERY_ENHANCE_NODE_OUTPUT,
QueryEnhanceOutputDTO.class);
String input = queryEnhanceOutputDTO.getCanonicalQuery();
String agentId = StateUtil.getStringValue(state, AGENT_ID);

// 查询 Agent 的激活数据源
Integer datasourceId = agentDatasourceMapper.selectActiveDatasourceIdByAgentId(Long.valueOf(agentId));

Integer datasourceId = agentDatasourceMapper.selectActiveDatasourceIdByAgentId(agentId);
if (datasourceId == null) {
log.warn("Agent {} has no active datasource", agentId);
// 返回空结果
String noDataSourceMessage = """
\n 该智能体没有激活的数据源
return buildEarlyExitResult(state, """
该智能体没有激活的数据源

这可能是因为:
1. 数据源尚未配置或关联。
2. 所有数据源都已被禁用。
3. 请先配置并激活数据源。
流程已终止。
""";

Flux<ChatResponse> displayFlux = Flux.create(emitter -> {
emitter.next(ChatResponseUtil.createResponse(noDataSourceMessage));
emitter.complete();
});

Flux<GraphResponse<StreamingOutput>> generator = FluxUtil
.createStreamingGeneratorWithMessages(this.getClass(), state, currentState -> {
return Map.of(TABLE_DOCUMENTS_FOR_SCHEMA_OUTPUT, Collections.emptyList(),
COLUMN_DOCUMENTS__FOR_SCHEMA_OUTPUT, Collections.emptyList());
}, displayFlux);

return Map.of(SCHEMA_RECALL_NODE_OUTPUT, generator);
""");
}

// Execute business logic first - recall schema information immediately
List<Document> tableDocuments = new ArrayList<>(
schemaService.getTableDocumentsByDatasource(datasourceId, input));
// extract table names
List<String> recalledTableNames = extractTableName(tableDocuments);
schemaService.getTableDocumentsByDatasource(datasourceId, canonicalQuery));
List<String> recalledTableNames = extractTableNames(tableDocuments);
List<Document> columnDocuments = schemaService.getColumnDocumentsByTableName(datasourceId, recalledTableNames);

String failMessage = """
\n 未检索到相关数据表

这可能是因为:
1. 数据源尚未初始化。
2. 您的提问与当前数据库中的表结构无关。
3. 请尝试点击“初始化数据源”或换一个与业务相关的问题。
4. 如果你用A嵌入模型初始化数据源,却更换为B嵌入模型,请重新初始化数据源
流程已终止。
""";

Flux<ChatResponse> displayFlux = Flux.create(emitter -> {
emitter.next(ChatResponseUtil.createResponse("开始初步召回Schema信息..."));
emitter.next(ChatResponseUtil.createResponse("开始初步召回 Schema 信息..."));
emitter.next(ChatResponseUtil.createResponse(
"初步表信息召回完成,数量: " + tableDocuments.size() + ",表名: " + String.join(", ", recalledTableNames)));
if (tableDocuments.isEmpty()) {
emitter.next(ChatResponseUtil.createResponse(failMessage));
List<String> availableTables = getAvailableTables(agentId);
String fallbackMessage = buildFallbackMessage(canonicalQuery, datasourceId, availableTables);
emitter.next(ChatResponseUtil.createResponse(fallbackMessage));
}
emitter.next(ChatResponseUtil.createResponse("初步Schema信息召回完成."));
emitter.next(ChatResponseUtil.createResponse("初步 Schema 信息召回完成。"));
emitter.complete();
});

Flux<GraphResponse<StreamingOutput>> generator = FluxUtil.createStreamingGeneratorWithMessages(this.getClass(),
state, currentState -> {
return Map.of(TABLE_DOCUMENTS_FOR_SCHEMA_OUTPUT, tableDocuments,
COLUMN_DOCUMENTS__FOR_SCHEMA_OUTPUT, columnDocuments);
}, displayFlux);
Flux<GraphResponse<StreamingOutput>> generator = FluxUtil.createStreamingGeneratorWithMessages(
this.getClass(),
state,
currentState -> Map.of(TABLE_DOCUMENTS_FOR_SCHEMA_OUTPUT, tableDocuments,
COLUMN_DOCUMENTS__FOR_SCHEMA_OUTPUT, columnDocuments),
displayFlux);

// Return the processing result
return Map.of(SCHEMA_RECALL_NODE_OUTPUT, generator);
}

private static List<String> extractTableName(List<Document> tableDocuments) {
private String resolveCanonicalQuery(OverAllState state) {
QueryEnhanceOutputDTO queryEnhanceOutputDTO = StateUtil.getObjectValue(state, QUERY_ENHANCE_NODE_OUTPUT,
QueryEnhanceOutputDTO.class, (QueryEnhanceOutputDTO) null);
if (queryEnhanceOutputDTO != null && StringUtils.hasText(queryEnhanceOutputDTO.getCanonicalQuery())) {
return queryEnhanceOutputDTO.getCanonicalQuery().trim();
}

String rawInput = StateUtil.getStringValue(state, INPUT_KEY, "");
if (StringUtils.hasText(rawInput)) {
return rawInput.trim();
}
return "(用户问题为空)";
}

private Long parseAgentId(String rawAgentId) {
if (!StringUtils.hasText(rawAgentId)) {
log.warn("Agent id is empty in workflow state");
return null;
}
try {
return Long.valueOf(rawAgentId.trim());
}
catch (NumberFormatException ex) {
log.warn("Invalid agent id in workflow state: {}", rawAgentId, ex);
return null;
}
}

private Map<String, Object> buildEarlyExitResult(OverAllState state, String message) {
Flux<ChatResponse> displayFlux = Flux.create(emitter -> {
emitter.next(ChatResponseUtil.createResponse(message));
emitter.complete();
});

Flux<GraphResponse<StreamingOutput>> generator = FluxUtil.createStreamingGeneratorWithMessages(
this.getClass(),
state,
currentState -> Map.of(TABLE_DOCUMENTS_FOR_SCHEMA_OUTPUT, Collections.emptyList(),
COLUMN_DOCUMENTS__FOR_SCHEMA_OUTPUT, Collections.emptyList()),
displayFlux);
return Map.of(SCHEMA_RECALL_NODE_OUTPUT, generator);
}

private static List<String> extractTableNames(List<Document> tableDocuments) {
List<String> tableNames = new ArrayList<>();
// metadata中的name字段
for (Document document : tableDocuments) {
String name = (String) document.getMetadata().get("name");
if (name != null && !name.isEmpty()) {
tableNames.add(name);
Object nameObject = document.getMetadata().get("name");
if (nameObject instanceof String name && StringUtils.hasText(name)) {
tableNames.add(name.trim());
}
}
log.info("At this SchemaRecallNode, Recall tables are: {}", tableNames);
log.info("At SchemaRecallNode, recalled tables are: {}", tableNames);
return tableNames;
}

private List<String> getAvailableTables(Long agentId) {
try {
AgentDatasource currentDatasource = agentDatasourceService.getCurrentAgentDatasource(agentId);
if (currentDatasource == null || currentDatasource.getSelectTables() == null
|| currentDatasource.getSelectTables().isEmpty()) {
return List.of();
}

Set<String> uniqueTables = new LinkedHashSet<>();
for (String tableName : currentDatasource.getSelectTables()) {
if (StringUtils.hasText(tableName)) {
uniqueTables.add(tableName.trim());
}
}
return List.copyOf(uniqueTables);
}
catch (Exception e) {
log.warn("Failed to load selected tables for agent {}", agentId, e);
return List.of();
}
}

private String buildFallbackMessage(String userQuery, Integer datasourceId, List<String> availableTables) {
String formattedTables = formatAvailableTables(availableTables);
List<String> suggestedQuestions = buildSuggestedQuestions(availableTables);

return ("""

未检索到与当前问题相关的数据表。

当前问题:
%s

当前可用表:
%s

建议你可以这样提问:
1. %s
2. %s
3. %s

下一步操作:
1. 确认已执行“初始化数据源”,并且初始化使用的是当前 Embedding 模型。
2. 若刚切换过 Embedding 模型,请重新初始化该数据源。
3. 在问题中补充业务关键词或表字段关键词(例如:订单、用户、金额、日期)。
4. 若是业务口径词(例如“人均 GDP”),建议在知识库补充“术语-字段映射”。
5. 如需排查,请检查数据源 ID:%s。
流程已终止。
""").formatted(userQuery, formattedTables, suggestedQuestions.get(0), suggestedQuestions.get(1),
suggestedQuestions.get(2), datasourceId);
}

private String formatAvailableTables(List<String> availableTables) {
if (availableTables.isEmpty()) {
return "暂无(当前智能体还没有配置已选表)";
}
List<String> displayTables = availableTables.stream().limit(MAX_DISPLAY_TABLES).toList();
if (availableTables.size() > MAX_DISPLAY_TABLES) {
return String.join(", ", displayTables) + " ...(共 " + availableTables.size() + " 张)";
}
return String.join(", ", displayTables);
}

private List<String> buildSuggestedQuestions(List<String> availableTables) {
if (availableTables.isEmpty()) {
return List.of("查询最近30天核心业务指标趋势", "按地区统计核心指标分布", "查询核心对象 Top10 及占比");
}

int size = availableTables.size();
String first = availableTables.get(0);
String second = size > 1 ? availableTables.get(1) : null;
String third = size > 2 ? availableTables.get(2) : null;

String q1 = "查询 " + first + " 最近30天的数量趋势";
String q2 = second != null
? "按维度统计 " + second + " 的分布情况"
: "按维度统计 " + first + " 的分布情况";
String q3 = (second != null && third != null)
? "关联 " + second + " 与 " + third + " 分析核心指标"
: second != null
? "关联 " + first + " 与 " + second + " 分析核心指标"
: "在问题中补充 " + first + " 的关键字段后重试";

return List.of(q1, q2, q3);
}

}