重构 Agent 聊天,支持分支管理与消息克隆

This commit is contained in:
2026-04-30 13:05:45 +08:00
parent e5ca9e24aa
commit 36d1a8d6ea
20 changed files with 1722 additions and 586 deletions
+333 -34
View File
@@ -2,15 +2,23 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { streamAgentChat } from "@/lib/chatStream";
import { abortAgentChat, forkAgentChat, streamAgentChat } from "@/lib/chatStream";
import type { StreamEvent } from "@/lib/chatStream";
import type {
AgentArtifact,
BranchGroup,
BranchTransition,
ChatProgress,
Message,
PersistedChatState,
} from "../GlobalChatbox.types";
import { CHAT_STORAGE_KEY, createId, getInitialChatState } from "../GlobalChatbox.utils";
import {
CHAT_STORAGE_KEY,
cloneBranchGroups,
cloneMessages,
createId,
getInitialChatState,
} from "../GlobalChatbox.utils";
type UseAgentChatSessionOptions = {
onToolCall: (
@@ -23,6 +31,14 @@ type UseAgentChatSessionOptions = {
onBeforeSend?: () => void;
};
type PromptRunOptions = {
prompt: string;
sessionIdOverride?: string;
preparedMessages?: Message[];
userMessage?: Message;
assistantMessage?: Message;
};
const upsertProgress = (
progress: ChatProgress[] | undefined,
event: StreamEvent & { type: "progress" },
@@ -49,6 +65,25 @@ const completeRunningProgress = (progress: ChatProgress[] | undefined) =>
item.status === "running" ? { ...item, status: "completed" as const } : item,
);
const createUserMessage = (content: string, branchRootId?: string): Message => {
const id = createId();
return {
id,
role: "user",
content,
branchRootId: branchRootId ?? id,
};
};
const createAssistantMessage = (): Message => ({
id: createId(),
role: "assistant",
content: "",
});
const messagesEqual = (left: Message[], right: Message[]) =>
JSON.stringify(left) === JSON.stringify(right);
export const useAgentChatSession = ({
onToolCall,
onBeforeSend,
@@ -64,16 +99,65 @@ export const useAgentChatSession = ({
const [sessionId, setSessionId] = useState<string | undefined>(
initialChatStateRef.current.sessionId,
);
const [branchGroups, setBranchGroups] = useState<BranchGroup[]>(
initialChatStateRef.current.branchGroups ?? [],
);
const [branchTransition, setBranchTransition] = useState<BranchTransition | null>(null);
const [isStreaming, setIsStreaming] = useState(false);
const abortRef = useRef<AbortController | null>(null);
const sessionIdRef = useRef<string | undefined>(initialChatStateRef.current.sessionId);
const cancelPromiseRef = useRef<Promise<void> | null>(null);
useEffect(() => {
const state: PersistedChatState = { messages, sessionId };
sessionIdRef.current = sessionId;
}, [sessionId]);
useEffect(() => {
const state: PersistedChatState = { messages, sessionId, branchGroups };
try {
window.localStorage.setItem(CHAT_STORAGE_KEY, JSON.stringify(state));
} catch (error) {
console.error("[GlobalChatbox] Failed to persist chat state:", error);
}
}, [branchGroups, messages, sessionId]);
useEffect(() => {
setBranchGroups((prev) => {
let changed = false;
const next = prev.map((group) => {
const rootMessage = messages[group.parentCount];
if (
!rootMessage ||
rootMessage.role !== "user" ||
(rootMessage.branchRootId ?? rootMessage.id) !== group.rootMessageId
) {
return group;
}
const activeBranch = group.branches[group.activeIndex];
if (!activeBranch) {
return group;
}
const nextSuffix = cloneMessages(messages.slice(group.parentCount));
if (
activeBranch.sessionId === sessionId &&
messagesEqual(activeBranch.messages, nextSuffix)
) {
return group;
}
changed = true;
const branches = group.branches.map((branch, index) =>
index === group.activeIndex
? { ...branch, sessionId, messages: nextSuffix }
: branch,
);
return { ...group, branches };
});
return changed ? next : prev;
});
}, [messages, sessionId]);
const appendArtifact = useCallback((messageId: string, artifact: AgentArtifact) => {
@@ -89,21 +173,33 @@ export const useAgentChatSession = ({
);
}, []);
const sendPrompt = useCallback(
async (rawPrompt: string) => {
const runPrompt = useCallback(
async ({
prompt: rawPrompt,
sessionIdOverride,
preparedMessages,
userMessage,
assistantMessage,
}: PromptRunOptions) => {
const prompt = rawPrompt.trim();
if (!prompt || isStreaming) return;
await cancelPromiseRef.current?.catch(() => undefined);
onBeforeSend?.();
setBranchTransition(null);
const nextUserMessage = userMessage ?? createUserMessage(prompt);
const nextAssistantMessage = assistantMessage ?? createAssistantMessage();
const nextMessages =
preparedMessages ??
[...messages, nextUserMessage, nextAssistantMessage];
const userId = createId();
const assistantId = createId();
setIsStreaming(true);
setMessages((prev) => [
...prev,
{ id: userId, role: "user", content: prompt },
{ id: assistantId, role: "assistant", content: "" },
]);
setMessages(cloneMessages(nextMessages));
if (sessionIdOverride !== undefined) {
sessionIdRef.current = sessionIdOverride;
setSessionId(sessionIdOverride);
}
const controller = new AbortController();
abortRef.current = controller;
@@ -111,17 +207,18 @@ export const useAgentChatSession = ({
try {
await streamAgentChat({
message: prompt,
sessionId,
sessionId: sessionIdOverride ?? sessionIdRef.current,
signal: controller.signal,
onEvent: (event) => {
if ("sessionId" in event && !sessionId && event.sessionId) {
if ("sessionId" in event && event.sessionId && event.sessionId !== sessionIdRef.current) {
sessionIdRef.current = event.sessionId;
setSessionId(event.sessionId);
}
if (event.type === "token") {
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
message.id === nextAssistantMessage.id
? {
...message,
content: message.content + event.content,
@@ -133,20 +230,20 @@ export const useAgentChatSession = ({
} else if (event.type === "progress") {
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
message.id === nextAssistantMessage.id
? { ...message, progress: upsertProgress(message.progress, event) }
: message,
),
);
} else if (event.type === "tool_call") {
onToolCall(event, {
assistantMessageId: assistantId,
assistantMessageId: nextAssistantMessage.id,
appendArtifact,
});
} else if (event.type === "done") {
setMessages((prev) =>
prev.map((message) => {
if (message.id !== assistantId) return message;
if (message.id !== nextAssistantMessage.id) return message;
const completedProgress = completeRunningProgress(message.progress);
if (
message.content.trim().length === 0 &&
@@ -166,7 +263,7 @@ export const useAgentChatSession = ({
} else if (event.type === "error") {
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
message.id === nextAssistantMessage.id
? {
...message,
content: message.content || `⚠️ **错误:** ${event.message}`,
@@ -181,23 +278,34 @@ export const useAgentChatSession = ({
},
});
} catch (error) {
if (abortRef.current?.signal.aborted) {
if (controller.signal.aborted) {
setMessages((prev) =>
prev.filter(
(message) =>
!(
message.id === assistantId &&
message.role === "assistant" &&
message.content.trim().length === 0 &&
!(message.artifacts?.length)
),
),
prev
.map((message) =>
message.id === nextAssistantMessage.id
? {
...message,
content: message.content || "⚠️ **请求已中断**",
isError: true,
}
: message,
)
.filter(
(message) =>
!(
message.id === nextAssistantMessage.id &&
message.role === "assistant" &&
message.content.trim().length === 0 &&
!(message.artifacts?.length) &&
!(message.progress?.length)
),
),
);
return;
}
setMessages((prev) =>
prev.map((message) =>
message.id === assistantId
message.id === nextAssistantMessage.id
? {
...message,
content: `⚠️ **错误:** ${String(error)}`,
@@ -213,26 +321,217 @@ export const useAgentChatSession = ({
setIsStreaming(false);
}
},
[appendArtifact, isStreaming, onBeforeSend, onToolCall, sessionId],
[appendArtifact, isStreaming, messages, onBeforeSend, onToolCall],
);
const abort = useCallback(() => {
abortRef.current?.abort();
const controller = abortRef.current;
controller?.abort();
setIsStreaming(false);
const cancelPromise = abortAgentChat(sessionIdRef.current).catch((error) => {
console.error("[GlobalChatbox] Failed to abort agent session:", error);
});
const trackedCancelPromise = cancelPromise.finally(() => {
if (cancelPromiseRef.current === trackedCancelPromise) {
cancelPromiseRef.current = null;
}
});
cancelPromiseRef.current = trackedCancelPromise;
}, []);
const reset = useCallback(() => {
abortRef.current?.abort();
const controller = abortRef.current;
controller?.abort();
const activeSessionId = sessionIdRef.current;
if (activeSessionId) {
const cancelPromise = abortAgentChat(activeSessionId).catch((error) => {
console.error("[GlobalChatbox] Failed to abort agent session during reset:", error);
});
const trackedCancelPromise = cancelPromise.finally(() => {
if (cancelPromiseRef.current === trackedCancelPromise) {
cancelPromiseRef.current = null;
}
});
cancelPromiseRef.current = trackedCancelPromise;
}
setMessages([]);
setBranchGroups([]);
setBranchTransition(null);
setSessionId(undefined);
sessionIdRef.current = undefined;
setIsStreaming(false);
}, []);
const sendPrompt = useCallback(
async (rawPrompt: string) => {
await runPrompt({ prompt: rawPrompt });
},
[runPrompt],
);
const regenerate = useCallback(async () => {
if (isStreaming || messages.length === 0) return;
let lastUserIndex = messages.length - 1;
while (lastUserIndex >= 0 && messages[lastUserIndex].role !== "user") {
lastUserIndex--;
}
if (lastUserIndex < 0) return;
const lastUser = messages[lastUserIndex];
const lastUserContent = lastUser.content;
const nextMessages = cloneMessages(messages.slice(0, lastUserIndex));
const nextUserMessage = createUserMessage(
lastUserContent,
lastUser.branchRootId ?? lastUser.id,
);
const nextAssistantMessage = createAssistantMessage();
setMessages(nextMessages);
await runPrompt({
prompt: lastUserContent,
preparedMessages: [
...nextMessages,
nextUserMessage,
nextAssistantMessage,
],
userMessage: nextUserMessage,
assistantMessage: nextAssistantMessage,
});
}, [isStreaming, messages, runPrompt]);
const editAndResubmit = useCallback(
async (messageId: string, newContent: string) => {
if (isStreaming) return;
const trimmedContent = newContent.trim();
if (!trimmedContent) return;
const messageIndex = messages.findIndex((m) => m.id === messageId);
if (messageIndex < 0 || messages[messageIndex].role !== "user") return;
const originalMessage = messages[messageIndex];
if (trimmedContent === originalMessage.content.trim()) return;
const rootMessageId = originalMessage.branchRootId ?? originalMessage.id;
const currentSessionId = sessionIdRef.current;
const keepMessageCount = messageIndex;
const prefix = cloneMessages(messages.slice(0, messageIndex));
const originalSuffix = cloneMessages(messages.slice(messageIndex));
const forkedSessionId = await forkAgentChat(currentSessionId, keepMessageCount);
const nextUserMessage = createUserMessage(trimmedContent, rootMessageId);
const nextAssistantMessage = createAssistantMessage();
const nextSuffix = [nextUserMessage, nextAssistantMessage];
setBranchGroups((prev) => {
const next = cloneBranchGroups(prev);
const groupIndex = next.findIndex(
(group) =>
group.rootMessageId === rootMessageId && group.parentCount === messageIndex,
);
if (groupIndex >= 0) {
const group = next[groupIndex];
group.branches[group.activeIndex] = {
...group.branches[group.activeIndex],
sessionId: currentSessionId,
messages: originalSuffix,
};
group.branches.push({
id: createId(),
label: `分支 ${group.branches.length + 1}`,
sessionId: forkedSessionId,
messages: cloneMessages(nextSuffix),
});
group.activeIndex = group.branches.length - 1;
} else {
next.push({
id: rootMessageId,
rootMessageId,
parentCount: messageIndex,
activeIndex: 1,
branches: [
{
id: createId(),
label: "分支 1",
sessionId: currentSessionId,
messages: originalSuffix,
},
{
id: createId(),
label: "分支 2",
sessionId: forkedSessionId,
messages: cloneMessages(nextSuffix),
},
],
});
}
return next;
});
sessionIdRef.current = forkedSessionId;
setSessionId(forkedSessionId);
await runPrompt({
prompt: trimmedContent,
sessionIdOverride: forkedSessionId,
preparedMessages: [...prefix, ...nextSuffix],
userMessage: nextUserMessage,
assistantMessage: nextAssistantMessage,
});
},
[isStreaming, messages, runPrompt],
);
const cycleBranch = useCallback(
(rootMessageId: string, direction: -1 | 1) => {
if (isStreaming) return;
setBranchGroups((prev) => {
const next = cloneBranchGroups(prev);
const group = next.find((item) => item.rootMessageId === rootMessageId);
if (!group || group.branches.length < 2) {
return prev;
}
const nextIndex =
(group.activeIndex + direction + group.branches.length) % group.branches.length;
const selectedBranch = group.branches[nextIndex];
group.activeIndex = nextIndex;
const nextMessages = [
...cloneMessages(messages.slice(0, group.parentCount)),
...cloneMessages(selectedBranch.messages),
];
setBranchTransition({
rootMessageId,
parentCount: group.parentCount,
activeBranchId: selectedBranch.id,
nonce: Date.now(),
});
sessionIdRef.current = selectedBranch.sessionId;
setSessionId(selectedBranch.sessionId);
setMessages(nextMessages);
return next;
});
},
[isStreaming, messages],
);
return {
messages,
branchGroups,
branchTransition,
isStreaming,
sessionId,
sendPrompt,
regenerate,
editAndResubmit,
cycleBranch,
abort,
reset,
};
@@ -43,16 +43,45 @@ const LOCATE_TOOL_CONFIG: Record<
locate_tanks: { layer: "geo_tanks", geometryKind: "point", label: "水池" },
};
const LOCATE_ID_PARAM_KEYS = [
"ids",
"id",
"feature_ids",
"feature_id",
"node_ids",
"node_id",
"junction_ids",
"junction_id",
"pipe_ids",
"pipe_id",
"valve_ids",
"valve_id",
"reservoir_ids",
"reservoir_id",
"pump_ids",
"pump_id",
"tank_ids",
"tank_id",
] as const;
const normalizeIds = (params: Record<string, unknown>): string[] => {
const rawIds = params.ids;
if (Array.isArray(rawIds)) {
return rawIds.map((id) => String(id).trim()).filter(Boolean);
}
if (typeof rawIds === "string") {
return rawIds
.split(",")
.map((id) => id.trim())
.filter(Boolean);
for (const key of LOCATE_ID_PARAM_KEYS) {
const rawValue = params[key];
if (Array.isArray(rawValue)) {
const normalized = rawValue.map((id) => String(id).trim()).filter(Boolean);
if (normalized.length > 0) {
return normalized;
}
}
if (typeof rawValue === "string" || typeof rawValue === "number") {
const normalized = String(rawValue)
.split(",")
.map((id) => id.trim())
.filter(Boolean);
if (normalized.length > 0) {
return normalized;
}
}
}
return [];
};