重构 Agent 聊天,支持分支管理与消息克隆
This commit is contained in:
@@ -2,15 +2,23 @@
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
import { streamAgentChat } from "@/lib/chatStream";
|
||||
import { abortAgentChat, forkAgentChat, streamAgentChat } from "@/lib/chatStream";
|
||||
import type { StreamEvent } from "@/lib/chatStream";
|
||||
import type {
|
||||
AgentArtifact,
|
||||
BranchGroup,
|
||||
BranchTransition,
|
||||
ChatProgress,
|
||||
Message,
|
||||
PersistedChatState,
|
||||
} from "../GlobalChatbox.types";
|
||||
import { CHAT_STORAGE_KEY, createId, getInitialChatState } from "../GlobalChatbox.utils";
|
||||
import {
|
||||
CHAT_STORAGE_KEY,
|
||||
cloneBranchGroups,
|
||||
cloneMessages,
|
||||
createId,
|
||||
getInitialChatState,
|
||||
} from "../GlobalChatbox.utils";
|
||||
|
||||
type UseAgentChatSessionOptions = {
|
||||
onToolCall: (
|
||||
@@ -23,6 +31,14 @@ type UseAgentChatSessionOptions = {
|
||||
onBeforeSend?: () => void;
|
||||
};
|
||||
|
||||
type PromptRunOptions = {
|
||||
prompt: string;
|
||||
sessionIdOverride?: string;
|
||||
preparedMessages?: Message[];
|
||||
userMessage?: Message;
|
||||
assistantMessage?: Message;
|
||||
};
|
||||
|
||||
const upsertProgress = (
|
||||
progress: ChatProgress[] | undefined,
|
||||
event: StreamEvent & { type: "progress" },
|
||||
@@ -49,6 +65,25 @@ const completeRunningProgress = (progress: ChatProgress[] | undefined) =>
|
||||
item.status === "running" ? { ...item, status: "completed" as const } : item,
|
||||
);
|
||||
|
||||
const createUserMessage = (content: string, branchRootId?: string): Message => {
|
||||
const id = createId();
|
||||
return {
|
||||
id,
|
||||
role: "user",
|
||||
content,
|
||||
branchRootId: branchRootId ?? id,
|
||||
};
|
||||
};
|
||||
|
||||
const createAssistantMessage = (): Message => ({
|
||||
id: createId(),
|
||||
role: "assistant",
|
||||
content: "",
|
||||
});
|
||||
|
||||
const messagesEqual = (left: Message[], right: Message[]) =>
|
||||
JSON.stringify(left) === JSON.stringify(right);
|
||||
|
||||
export const useAgentChatSession = ({
|
||||
onToolCall,
|
||||
onBeforeSend,
|
||||
@@ -64,16 +99,65 @@ export const useAgentChatSession = ({
|
||||
const [sessionId, setSessionId] = useState<string | undefined>(
|
||||
initialChatStateRef.current.sessionId,
|
||||
);
|
||||
const [branchGroups, setBranchGroups] = useState<BranchGroup[]>(
|
||||
initialChatStateRef.current.branchGroups ?? [],
|
||||
);
|
||||
const [branchTransition, setBranchTransition] = useState<BranchTransition | null>(null);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const sessionIdRef = useRef<string | undefined>(initialChatStateRef.current.sessionId);
|
||||
const cancelPromiseRef = useRef<Promise<void> | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const state: PersistedChatState = { messages, sessionId };
|
||||
sessionIdRef.current = sessionId;
|
||||
}, [sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
const state: PersistedChatState = { messages, sessionId, branchGroups };
|
||||
try {
|
||||
window.localStorage.setItem(CHAT_STORAGE_KEY, JSON.stringify(state));
|
||||
} catch (error) {
|
||||
console.error("[GlobalChatbox] Failed to persist chat state:", error);
|
||||
}
|
||||
}, [branchGroups, messages, sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
setBranchGroups((prev) => {
|
||||
let changed = false;
|
||||
const next = prev.map((group) => {
|
||||
const rootMessage = messages[group.parentCount];
|
||||
if (
|
||||
!rootMessage ||
|
||||
rootMessage.role !== "user" ||
|
||||
(rootMessage.branchRootId ?? rootMessage.id) !== group.rootMessageId
|
||||
) {
|
||||
return group;
|
||||
}
|
||||
|
||||
const activeBranch = group.branches[group.activeIndex];
|
||||
if (!activeBranch) {
|
||||
return group;
|
||||
}
|
||||
|
||||
const nextSuffix = cloneMessages(messages.slice(group.parentCount));
|
||||
if (
|
||||
activeBranch.sessionId === sessionId &&
|
||||
messagesEqual(activeBranch.messages, nextSuffix)
|
||||
) {
|
||||
return group;
|
||||
}
|
||||
|
||||
changed = true;
|
||||
const branches = group.branches.map((branch, index) =>
|
||||
index === group.activeIndex
|
||||
? { ...branch, sessionId, messages: nextSuffix }
|
||||
: branch,
|
||||
);
|
||||
return { ...group, branches };
|
||||
});
|
||||
|
||||
return changed ? next : prev;
|
||||
});
|
||||
}, [messages, sessionId]);
|
||||
|
||||
const appendArtifact = useCallback((messageId: string, artifact: AgentArtifact) => {
|
||||
@@ -89,21 +173,33 @@ export const useAgentChatSession = ({
|
||||
);
|
||||
}, []);
|
||||
|
||||
const sendPrompt = useCallback(
|
||||
async (rawPrompt: string) => {
|
||||
const runPrompt = useCallback(
|
||||
async ({
|
||||
prompt: rawPrompt,
|
||||
sessionIdOverride,
|
||||
preparedMessages,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
}: PromptRunOptions) => {
|
||||
const prompt = rawPrompt.trim();
|
||||
if (!prompt || isStreaming) return;
|
||||
|
||||
await cancelPromiseRef.current?.catch(() => undefined);
|
||||
onBeforeSend?.();
|
||||
setBranchTransition(null);
|
||||
|
||||
const nextUserMessage = userMessage ?? createUserMessage(prompt);
|
||||
const nextAssistantMessage = assistantMessage ?? createAssistantMessage();
|
||||
const nextMessages =
|
||||
preparedMessages ??
|
||||
[...messages, nextUserMessage, nextAssistantMessage];
|
||||
|
||||
const userId = createId();
|
||||
const assistantId = createId();
|
||||
setIsStreaming(true);
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{ id: userId, role: "user", content: prompt },
|
||||
{ id: assistantId, role: "assistant", content: "" },
|
||||
]);
|
||||
setMessages(cloneMessages(nextMessages));
|
||||
if (sessionIdOverride !== undefined) {
|
||||
sessionIdRef.current = sessionIdOverride;
|
||||
setSessionId(sessionIdOverride);
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
@@ -111,17 +207,18 @@ export const useAgentChatSession = ({
|
||||
try {
|
||||
await streamAgentChat({
|
||||
message: prompt,
|
||||
sessionId,
|
||||
sessionId: sessionIdOverride ?? sessionIdRef.current,
|
||||
signal: controller.signal,
|
||||
onEvent: (event) => {
|
||||
if ("sessionId" in event && !sessionId && event.sessionId) {
|
||||
if ("sessionId" in event && event.sessionId && event.sessionId !== sessionIdRef.current) {
|
||||
sessionIdRef.current = event.sessionId;
|
||||
setSessionId(event.sessionId);
|
||||
}
|
||||
|
||||
if (event.type === "token") {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.id === assistantId
|
||||
message.id === nextAssistantMessage.id
|
||||
? {
|
||||
...message,
|
||||
content: message.content + event.content,
|
||||
@@ -133,20 +230,20 @@ export const useAgentChatSession = ({
|
||||
} else if (event.type === "progress") {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.id === assistantId
|
||||
message.id === nextAssistantMessage.id
|
||||
? { ...message, progress: upsertProgress(message.progress, event) }
|
||||
: message,
|
||||
),
|
||||
);
|
||||
} else if (event.type === "tool_call") {
|
||||
onToolCall(event, {
|
||||
assistantMessageId: assistantId,
|
||||
assistantMessageId: nextAssistantMessage.id,
|
||||
appendArtifact,
|
||||
});
|
||||
} else if (event.type === "done") {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) => {
|
||||
if (message.id !== assistantId) return message;
|
||||
if (message.id !== nextAssistantMessage.id) return message;
|
||||
const completedProgress = completeRunningProgress(message.progress);
|
||||
if (
|
||||
message.content.trim().length === 0 &&
|
||||
@@ -166,7 +263,7 @@ export const useAgentChatSession = ({
|
||||
} else if (event.type === "error") {
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.id === assistantId
|
||||
message.id === nextAssistantMessage.id
|
||||
? {
|
||||
...message,
|
||||
content: message.content || `⚠️ **错误:** ${event.message}`,
|
||||
@@ -181,23 +278,34 @@ export const useAgentChatSession = ({
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
if (abortRef.current?.signal.aborted) {
|
||||
if (controller.signal.aborted) {
|
||||
setMessages((prev) =>
|
||||
prev.filter(
|
||||
(message) =>
|
||||
!(
|
||||
message.id === assistantId &&
|
||||
message.role === "assistant" &&
|
||||
message.content.trim().length === 0 &&
|
||||
!(message.artifacts?.length)
|
||||
),
|
||||
),
|
||||
prev
|
||||
.map((message) =>
|
||||
message.id === nextAssistantMessage.id
|
||||
? {
|
||||
...message,
|
||||
content: message.content || "⚠️ **请求已中断**",
|
||||
isError: true,
|
||||
}
|
||||
: message,
|
||||
)
|
||||
.filter(
|
||||
(message) =>
|
||||
!(
|
||||
message.id === nextAssistantMessage.id &&
|
||||
message.role === "assistant" &&
|
||||
message.content.trim().length === 0 &&
|
||||
!(message.artifacts?.length) &&
|
||||
!(message.progress?.length)
|
||||
),
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
setMessages((prev) =>
|
||||
prev.map((message) =>
|
||||
message.id === assistantId
|
||||
message.id === nextAssistantMessage.id
|
||||
? {
|
||||
...message,
|
||||
content: `⚠️ **错误:** ${String(error)}`,
|
||||
@@ -213,26 +321,217 @@ export const useAgentChatSession = ({
|
||||
setIsStreaming(false);
|
||||
}
|
||||
},
|
||||
[appendArtifact, isStreaming, onBeforeSend, onToolCall, sessionId],
|
||||
[appendArtifact, isStreaming, messages, onBeforeSend, onToolCall],
|
||||
);
|
||||
|
||||
const abort = useCallback(() => {
|
||||
abortRef.current?.abort();
|
||||
const controller = abortRef.current;
|
||||
controller?.abort();
|
||||
setIsStreaming(false);
|
||||
|
||||
const cancelPromise = abortAgentChat(sessionIdRef.current).catch((error) => {
|
||||
console.error("[GlobalChatbox] Failed to abort agent session:", error);
|
||||
});
|
||||
const trackedCancelPromise = cancelPromise.finally(() => {
|
||||
if (cancelPromiseRef.current === trackedCancelPromise) {
|
||||
cancelPromiseRef.current = null;
|
||||
}
|
||||
});
|
||||
cancelPromiseRef.current = trackedCancelPromise;
|
||||
}, []);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
abortRef.current?.abort();
|
||||
const controller = abortRef.current;
|
||||
controller?.abort();
|
||||
const activeSessionId = sessionIdRef.current;
|
||||
if (activeSessionId) {
|
||||
const cancelPromise = abortAgentChat(activeSessionId).catch((error) => {
|
||||
console.error("[GlobalChatbox] Failed to abort agent session during reset:", error);
|
||||
});
|
||||
const trackedCancelPromise = cancelPromise.finally(() => {
|
||||
if (cancelPromiseRef.current === trackedCancelPromise) {
|
||||
cancelPromiseRef.current = null;
|
||||
}
|
||||
});
|
||||
cancelPromiseRef.current = trackedCancelPromise;
|
||||
}
|
||||
setMessages([]);
|
||||
setBranchGroups([]);
|
||||
setBranchTransition(null);
|
||||
setSessionId(undefined);
|
||||
sessionIdRef.current = undefined;
|
||||
setIsStreaming(false);
|
||||
}, []);
|
||||
|
||||
const sendPrompt = useCallback(
|
||||
async (rawPrompt: string) => {
|
||||
await runPrompt({ prompt: rawPrompt });
|
||||
},
|
||||
[runPrompt],
|
||||
);
|
||||
|
||||
const regenerate = useCallback(async () => {
|
||||
if (isStreaming || messages.length === 0) return;
|
||||
|
||||
let lastUserIndex = messages.length - 1;
|
||||
while (lastUserIndex >= 0 && messages[lastUserIndex].role !== "user") {
|
||||
lastUserIndex--;
|
||||
}
|
||||
|
||||
if (lastUserIndex < 0) return;
|
||||
|
||||
const lastUser = messages[lastUserIndex];
|
||||
const lastUserContent = lastUser.content;
|
||||
const nextMessages = cloneMessages(messages.slice(0, lastUserIndex));
|
||||
const nextUserMessage = createUserMessage(
|
||||
lastUserContent,
|
||||
lastUser.branchRootId ?? lastUser.id,
|
||||
);
|
||||
const nextAssistantMessage = createAssistantMessage();
|
||||
|
||||
setMessages(nextMessages);
|
||||
await runPrompt({
|
||||
prompt: lastUserContent,
|
||||
preparedMessages: [
|
||||
...nextMessages,
|
||||
nextUserMessage,
|
||||
nextAssistantMessage,
|
||||
],
|
||||
userMessage: nextUserMessage,
|
||||
assistantMessage: nextAssistantMessage,
|
||||
});
|
||||
}, [isStreaming, messages, runPrompt]);
|
||||
|
||||
const editAndResubmit = useCallback(
|
||||
async (messageId: string, newContent: string) => {
|
||||
if (isStreaming) return;
|
||||
|
||||
const trimmedContent = newContent.trim();
|
||||
if (!trimmedContent) return;
|
||||
|
||||
const messageIndex = messages.findIndex((m) => m.id === messageId);
|
||||
if (messageIndex < 0 || messages[messageIndex].role !== "user") return;
|
||||
|
||||
const originalMessage = messages[messageIndex];
|
||||
if (trimmedContent === originalMessage.content.trim()) return;
|
||||
|
||||
const rootMessageId = originalMessage.branchRootId ?? originalMessage.id;
|
||||
const currentSessionId = sessionIdRef.current;
|
||||
const keepMessageCount = messageIndex;
|
||||
const prefix = cloneMessages(messages.slice(0, messageIndex));
|
||||
const originalSuffix = cloneMessages(messages.slice(messageIndex));
|
||||
const forkedSessionId = await forkAgentChat(currentSessionId, keepMessageCount);
|
||||
|
||||
const nextUserMessage = createUserMessage(trimmedContent, rootMessageId);
|
||||
const nextAssistantMessage = createAssistantMessage();
|
||||
const nextSuffix = [nextUserMessage, nextAssistantMessage];
|
||||
|
||||
setBranchGroups((prev) => {
|
||||
const next = cloneBranchGroups(prev);
|
||||
const groupIndex = next.findIndex(
|
||||
(group) =>
|
||||
group.rootMessageId === rootMessageId && group.parentCount === messageIndex,
|
||||
);
|
||||
|
||||
if (groupIndex >= 0) {
|
||||
const group = next[groupIndex];
|
||||
group.branches[group.activeIndex] = {
|
||||
...group.branches[group.activeIndex],
|
||||
sessionId: currentSessionId,
|
||||
messages: originalSuffix,
|
||||
};
|
||||
group.branches.push({
|
||||
id: createId(),
|
||||
label: `分支 ${group.branches.length + 1}`,
|
||||
sessionId: forkedSessionId,
|
||||
messages: cloneMessages(nextSuffix),
|
||||
});
|
||||
group.activeIndex = group.branches.length - 1;
|
||||
} else {
|
||||
next.push({
|
||||
id: rootMessageId,
|
||||
rootMessageId,
|
||||
parentCount: messageIndex,
|
||||
activeIndex: 1,
|
||||
branches: [
|
||||
{
|
||||
id: createId(),
|
||||
label: "分支 1",
|
||||
sessionId: currentSessionId,
|
||||
messages: originalSuffix,
|
||||
},
|
||||
{
|
||||
id: createId(),
|
||||
label: "分支 2",
|
||||
sessionId: forkedSessionId,
|
||||
messages: cloneMessages(nextSuffix),
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
|
||||
return next;
|
||||
});
|
||||
|
||||
sessionIdRef.current = forkedSessionId;
|
||||
setSessionId(forkedSessionId);
|
||||
await runPrompt({
|
||||
prompt: trimmedContent,
|
||||
sessionIdOverride: forkedSessionId,
|
||||
preparedMessages: [...prefix, ...nextSuffix],
|
||||
userMessage: nextUserMessage,
|
||||
assistantMessage: nextAssistantMessage,
|
||||
});
|
||||
},
|
||||
[isStreaming, messages, runPrompt],
|
||||
);
|
||||
|
||||
const cycleBranch = useCallback(
|
||||
(rootMessageId: string, direction: -1 | 1) => {
|
||||
if (isStreaming) return;
|
||||
|
||||
setBranchGroups((prev) => {
|
||||
const next = cloneBranchGroups(prev);
|
||||
const group = next.find((item) => item.rootMessageId === rootMessageId);
|
||||
if (!group || group.branches.length < 2) {
|
||||
return prev;
|
||||
}
|
||||
|
||||
const nextIndex =
|
||||
(group.activeIndex + direction + group.branches.length) % group.branches.length;
|
||||
const selectedBranch = group.branches[nextIndex];
|
||||
group.activeIndex = nextIndex;
|
||||
|
||||
const nextMessages = [
|
||||
...cloneMessages(messages.slice(0, group.parentCount)),
|
||||
...cloneMessages(selectedBranch.messages),
|
||||
];
|
||||
setBranchTransition({
|
||||
rootMessageId,
|
||||
parentCount: group.parentCount,
|
||||
activeBranchId: selectedBranch.id,
|
||||
nonce: Date.now(),
|
||||
});
|
||||
sessionIdRef.current = selectedBranch.sessionId;
|
||||
setSessionId(selectedBranch.sessionId);
|
||||
setMessages(nextMessages);
|
||||
|
||||
return next;
|
||||
});
|
||||
},
|
||||
[isStreaming, messages],
|
||||
);
|
||||
|
||||
return {
|
||||
messages,
|
||||
branchGroups,
|
||||
branchTransition,
|
||||
isStreaming,
|
||||
sessionId,
|
||||
sendPrompt,
|
||||
regenerate,
|
||||
editAndResubmit,
|
||||
cycleBranch,
|
||||
abort,
|
||||
reset,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user