Agent 初版设计

This commit is contained in:
2026-04-29 17:15:49 +08:00
parent 2c1afdc97c
commit e5ca9e24aa
13 changed files with 1819 additions and 1255 deletions
@@ -0,0 +1,239 @@
"use client";
import { useCallback, useEffect, useRef, useState } from "react";
import { streamAgentChat } from "@/lib/chatStream";
import type { StreamEvent } from "@/lib/chatStream";
import type {
AgentArtifact,
ChatProgress,
Message,
PersistedChatState,
} from "../GlobalChatbox.types";
import { CHAT_STORAGE_KEY, createId, getInitialChatState } from "../GlobalChatbox.utils";
type UseAgentChatSessionOptions = {
onToolCall: (
event: StreamEvent & { type: "tool_call" },
options: {
assistantMessageId: string;
appendArtifact: (messageId: string, artifact: AgentArtifact) => void;
},
) => void;
onBeforeSend?: () => void;
};
const upsertProgress = (
progress: ChatProgress[] | undefined,
event: StreamEvent & { type: "progress" },
) => {
const next = [...(progress ?? [])];
const index = next.findIndex((item) => item.id === event.id);
const nextItem: ChatProgress = {
id: event.id,
phase: event.phase,
status: event.status,
title: event.title,
detail: event.detail,
};
if (index >= 0) {
next[index] = nextItem;
} else {
next.push(nextItem);
}
return next;
};
const completeRunningProgress = (progress: ChatProgress[] | undefined) =>
progress?.map((item) =>
item.status === "running" ? { ...item, status: "completed" as const } : item,
);
export const useAgentChatSession = ({
onToolCall,
onBeforeSend,
}: UseAgentChatSessionOptions) => {
const initialChatStateRef = useRef<PersistedChatState | null>(null);
if (initialChatStateRef.current === null) {
initialChatStateRef.current = getInitialChatState();
}
const [messages, setMessages] = useState<Message[]>(
initialChatStateRef.current.messages,
);
const [sessionId, setSessionId] = useState<string | undefined>(
initialChatStateRef.current.sessionId,
);
const [isStreaming, setIsStreaming] = useState(false);
const abortRef = useRef<AbortController | null>(null);
useEffect(() => {
const state: PersistedChatState = { messages, sessionId };
try {
window.localStorage.setItem(CHAT_STORAGE_KEY, JSON.stringify(state));
} catch (error) {
console.error("[GlobalChatbox] Failed to persist chat state:", error);
}
}, [messages, sessionId]);
const appendArtifact = useCallback((messageId: string, artifact: AgentArtifact) => {
setMessages((prev) =>
prev.map((message) =>
message.id === messageId
? {
...message,
artifacts: [...(message.artifacts ?? []), artifact],
}
: message,
),
);
}, []);
const sendPrompt = useCallback(
async (rawPrompt: string) => {
const prompt = rawPrompt.trim();
if (!prompt || isStreaming) return;
onBeforeSend?.();
const userId = createId();
const assistantId = createId();
setIsStreaming(true);
setMessages((prev) => [
...prev,
{ id: userId, role: "user", content: prompt },
{ id: assistantId, role: "assistant", content: "" },
]);
const controller = new AbortController();
abortRef.current = controller;
try {
await streamAgentChat({
message: prompt,
sessionId,
signal: controller.signal,
onEvent: (event) => {
if ("sessionId" in event && !sessionId && event.sessionId) {
setSessionId(event.sessionId);
}
if (event.type === "token") {
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
? {
...message,
content: message.content + event.content,
isError: false,
}
: message,
),
);
} else if (event.type === "progress") {
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
? { ...message, progress: upsertProgress(message.progress, event) }
: message,
),
);
} else if (event.type === "tool_call") {
onToolCall(event, {
assistantMessageId: assistantId,
appendArtifact,
});
} else if (event.type === "done") {
setMessages((prev) =>
prev.map((message) => {
if (message.id !== assistantId) return message;
const completedProgress = completeRunningProgress(message.progress);
if (
message.content.trim().length === 0 &&
!(message.artifacts?.length)
) {
return {
...message,
content:
"Agent 已完成处理,但没有生成文本回答。请查看过程记录,或换个更具体的问题重试。",
progress: completedProgress,
};
}
return { ...message, progress: completedProgress };
}),
);
setIsStreaming(false);
} else if (event.type === "error") {
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
? {
...message,
content: message.content || `⚠️ **错误:** ${event.message}`,
isError: true,
progress: completeRunningProgress(message.progress),
}
: message,
),
);
setIsStreaming(false);
}
},
});
} catch (error) {
if (abortRef.current?.signal.aborted) {
setMessages((prev) =>
prev.filter(
(message) =>
!(
message.id === assistantId &&
message.role === "assistant" &&
message.content.trim().length === 0 &&
!(message.artifacts?.length)
),
),
);
return;
}
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
? {
...message,
content: `⚠️ **错误:** ${String(error)}`,
isError: true,
progress: completeRunningProgress(message.progress),
}
: message,
),
);
setIsStreaming(false);
} finally {
abortRef.current = null;
setIsStreaming(false);
}
},
[appendArtifact, isStreaming, onBeforeSend, onToolCall, sessionId],
);
const abort = useCallback(() => {
abortRef.current?.abort();
setIsStreaming(false);
}, []);
const reset = useCallback(() => {
abortRef.current?.abort();
setMessages([]);
setSessionId(undefined);
setIsStreaming(false);
}, []);
return {
messages,
isStreaming,
sessionId,
sendPrompt,
abort,
reset,
};
};
@@ -0,0 +1,237 @@
"use client";
import { useCallback } from "react";
import { useChatToolStore, type ChatToolAction } from "@/store/chatToolStore";
import type { StreamEvent } from "@/lib/chatStream";
import type { AgentArtifact, AgentArtifactKind } from "../GlobalChatbox.types";
type ToolCallEvent = StreamEvent & { type: "tool_call" };
type HandleToolCallOptions = {
assistantMessageId: string;
appendArtifact: (messageId: string, artifact: AgentArtifact) => void;
};
const FEATURE_TYPE_MAP: Record<
string,
{ layer: string; geometryKind: "point" | "line"; label: string }
> = {
junction: { layer: "geo_junctions_mat", geometryKind: "point", label: "节点" },
junctions: { layer: "geo_junctions_mat", geometryKind: "point", label: "节点" },
pipe: { layer: "geo_pipes_mat", geometryKind: "line", label: "管道" },
pipes: { layer: "geo_pipes_mat", geometryKind: "line", label: "管道" },
valve: { layer: "geo_valves", geometryKind: "point", label: "阀门" },
valves: { layer: "geo_valves", geometryKind: "point", label: "阀门" },
reservoir: { layer: "geo_reservoirs", geometryKind: "point", label: "水源" },
reservoirs: { layer: "geo_reservoirs", geometryKind: "point", label: "水源" },
pump: { layer: "geo_pumps", geometryKind: "point", label: "泵站" },
pumps: { layer: "geo_pumps", geometryKind: "point", label: "泵站" },
tank: { layer: "geo_tanks", geometryKind: "point", label: "水池" },
tanks: { layer: "geo_tanks", geometryKind: "point", label: "水池" },
};
const LOCATE_TOOL_CONFIG: Record<
string,
{ layer: string; geometryKind: "point" | "line"; label: string }
> = {
locate_pipes: { layer: "geo_pipes_mat", geometryKind: "line", label: "管道" },
locate_junctions: { layer: "geo_junctions_mat", geometryKind: "point", label: "节点" },
locate_valves: { layer: "geo_valves", geometryKind: "point", label: "阀门" },
locate_reservoirs: { layer: "geo_reservoirs", geometryKind: "point", label: "水源" },
locate_pumps: { layer: "geo_pumps", geometryKind: "point", label: "泵站" },
locate_tanks: { layer: "geo_tanks", geometryKind: "point", label: "水池" },
};
const normalizeIds = (params: Record<string, unknown>): string[] => {
const rawIds = params.ids;
if (Array.isArray(rawIds)) {
return rawIds.map((id) => String(id).trim()).filter(Boolean);
}
if (typeof rawIds === "string") {
return rawIds
.split(",")
.map((id) => id.trim())
.filter(Boolean);
}
return [];
};
const resolveScadaFeatureInfos = (params: Record<string, unknown>): [string, string][] => {
const rawFeatureInfos = params.feature_infos;
if (Array.isArray(rawFeatureInfos)) {
const normalizedFeatureInfos = rawFeatureInfos
.map((item) => (Array.isArray(item) ? item : null))
.filter((item): item is [unknown, unknown] => Boolean(item))
.map(
(item) =>
[String(item[0] ?? ""), String(item[1] ?? "scada")] as [
string,
string,
],
)
.filter(([id]) => id.trim().length > 0);
if (normalizedFeatureInfos.length > 0) {
return normalizedFeatureInfos;
}
}
const rawDeviceIds =
params.device_ids ??
params.deviceId ??
params.device_id ??
params.id ??
params.ids;
const deviceIds = Array.isArray(rawDeviceIds)
? rawDeviceIds.map((id) => String(id))
: typeof rawDeviceIds === "string"
? rawDeviceIds
.split(",")
.map((id) => id.trim())
.filter(Boolean)
: [];
return deviceIds.map((id) => [id, "scada"]);
};
const resolveTimeRange = (params: Record<string, unknown>) => ({
startTime:
(params.start_time as string | undefined) ??
(params.startTime as string | undefined) ??
(params.from as string | undefined) ??
(params.start as string | undefined),
endTime:
(params.end_time as string | undefined) ??
(params.endTime as string | undefined) ??
(params.to as string | undefined) ??
(params.end as string | undefined),
});
const compactNames = (names: string[]) => {
if (!names.length) return "";
return names.length > 3
? `${names.slice(0, 3).join(", ")}${names.length}`
: names.join(", ");
};
const buildLocateArtifact = (
tool: string,
params: Record<string, unknown>,
): { artifact: Omit<AgentArtifact, "id" | "params" | "tool">; action: ChatToolAction | null } => {
const ids = normalizeIds(params);
const rawType = params.feature_type;
const featureType =
typeof rawType === "string" ? rawType.trim().toLowerCase() : "";
const config = tool === "locate_features"
? FEATURE_TYPE_MAP[featureType]
: LOCATE_TOOL_CONFIG[tool];
return {
artifact: {
kind: "map",
title: config ? `地图定位${config.label}` : "地图定位",
description: compactNames(ids),
},
action: config
? {
type: "locate_features",
ids,
layer: config.layer,
geometryKind: config.geometryKind,
}
: null,
};
};
const buildToolAction = (
tool: string,
params: Record<string, unknown>,
): { action: ChatToolAction | null; kind: AgentArtifactKind; title: string; description?: string } => {
if (tool === "show_chart") {
return {
action: null,
kind: "chart",
title: (params.title as string | undefined) ?? "生成图表",
description: "已生成可视化图表",
};
}
if (tool === "locate_features" || LOCATE_TOOL_CONFIG[tool]) {
const locate = buildLocateArtifact(tool, params);
return {
action: locate.action,
kind: locate.artifact.kind,
title: locate.artifact.title,
description: locate.artifact.description,
};
}
if (tool === "view_history") {
const featureInfos = (params.feature_infos as [string, string][] | undefined) ?? [];
const { startTime, endTime } = resolveTimeRange(params);
return {
action: {
type: "view_history",
featureInfos,
dataType:
(params.data_type as "realtime" | "scheme" | "none" | undefined) ??
"realtime",
startTime,
endTime,
},
kind: "panel",
title: "打开计算结果曲线",
description: compactNames(featureInfos.map(([id]) => id)),
};
}
if (tool === "view_scada") {
const featureInfos = resolveScadaFeatureInfos(params);
const { startTime, endTime } = resolveTimeRange(params);
return {
action: {
type: "view_scada",
featureInfos,
startTime,
endTime,
},
kind: "panel",
title: "打开 SCADA 数据面板",
description: compactNames(featureInfos.map(([id]) => id)),
};
}
return {
action: null,
kind: "tool",
title: tool || "工具调用",
description: "Agent 已执行工具动作",
};
};
export const useAgentToolActions = () => {
const dispatchToolAction = useChatToolStore((s) => s.dispatch);
return useCallback(
(event: ToolCallEvent, options: HandleToolCallOptions) => {
const { action, kind, title, description } = buildToolAction(
event.tool,
event.params,
);
options.appendArtifact(options.assistantMessageId, {
id: `${event.tool}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`,
tool: event.tool,
kind,
title,
description,
params: event.params,
});
if (action) {
dispatchToolAction(action);
}
},
[dispatchToolAction],
);
};