From b23cb6acddf96d02f3946cb686921199e571267f Mon Sep 17 00:00:00 2001 From: Huarch Date: Mon, 8 Jun 2026 18:10:28 +0800 Subject: [PATCH] fix(chat): wire question and todo cards --- src/components/chat/AgentTurn.tsx | 613 ++++++++++++++++++ src/components/chat/AgentWorkspace.test.tsx | 2 + src/components/chat/AgentWorkspace.tsx | 18 +- src/components/chat/GlobalChatbox.tsx | 4 + src/components/chat/GlobalChatbox.types.ts | 7 + .../chat/hooks/useAgentChatSession.test.tsx | 368 +++++++++++ .../chat/hooks/useAgentChatSession.ts | 384 ++++++++++- src/lib/chatStream.test.ts | 100 +++ src/lib/chatStream.ts | 227 +++++++ 9 files changed, 1713 insertions(+), 10 deletions(-) diff --git a/src/components/chat/AgentTurn.tsx b/src/components/chat/AgentTurn.tsx index 9043a8d..7990d0f 100644 --- a/src/components/chat/AgentTurn.tsx +++ b/src/components/chat/AgentTurn.tsx @@ -9,12 +9,15 @@ import { Avatar, Box, Button, + Checkbox, Chip, CircularProgress, Collapse, + FormControlLabel, IconButton, Paper, Stack, + TextField, Tooltip, Typography, alpha, @@ -50,6 +53,9 @@ import BlockRounded from "@mui/icons-material/BlockRounded"; import PushPinRounded from "@mui/icons-material/PushPinRounded"; import KeyboardArrowDownRounded from "@mui/icons-material/KeyboardArrowDownRounded"; import KeyboardArrowUpRounded from "@mui/icons-material/KeyboardArrowUpRounded"; +import AssignmentTurnedInRounded from "@mui/icons-material/AssignmentTurnedInRounded"; +import HelpOutlineRounded from "@mui/icons-material/HelpOutlineRounded"; +import RadioButtonUncheckedRounded from "@mui/icons-material/RadioButtonUncheckedRounded"; import type { PermissionReply } from "@/lib/chatStream"; type AgentTurnProps = { @@ -63,6 +69,8 @@ type AgentTurnProps = { onRegenerate: (messageId: string) => void; onCreateBranch: (messageId: string) => void; onReplyPermission: (requestId: string, reply: PermissionReply) => void; + onReplyQuestion: (requestId: string, answers: string[][]) => void; + onRejectQuestion: (requestId: string) => void; }; const normalizeClipboardText = (value: string) => value.replace(/\s+$/u, ""); @@ -670,6 +678,597 @@ const PermissionRequestGroup = ({ ); }; +const getQuestionStatusLabel = ( + status: NonNullable[number]["status"], +) => { + if (status === "answered") return "已回答"; + if (status === "rejected") return "已跳过"; + if (status === "error") return "提交失败"; + if (status === "submitting") return "提交中"; + return "等待回答"; +}; + +const getQuestionStatusColor = ( + status: NonNullable[number]["status"], + theme: Theme, +) => { + if (status === "answered") return theme.palette.success.main; + if (status === "rejected") return theme.palette.text.secondary; + if (status === "error") return theme.palette.error.main; + return "#0288d1"; +}; + +const QuestionRequestCard = ({ + questionRequest, + onReply, + onReject, +}: { + questionRequest: NonNullable[number]; + onReply: (requestId: string, answers: string[][]) => void; + onReject: (requestId: string) => void; +}) => { + const theme = useTheme(); + const isEditable = + questionRequest.status === "pending" || questionRequest.status === "error"; + const isSubmitting = questionRequest.status === "submitting"; + const statusColor = getQuestionStatusColor(questionRequest.status, theme); + const [selected, setSelected] = React.useState>({}); + const [custom, setCustom] = React.useState>({}); + + const answers = React.useMemo( + () => + questionRequest.questions.map((question, index) => { + const selectedAnswers = selected[index] ?? []; + const customAnswer = custom[index]?.trim(); + return customAnswer ? [...selectedAnswers, customAnswer] : selectedAnswers; + }), + [custom, questionRequest.questions, selected], + ); + + const canSubmit = + isEditable && + questionRequest.questions.length > 0 && + questionRequest.questions.every((question, index) => { + const answer = answers[index] ?? []; + const hasInput = answer.some((item) => item.trim().length > 0); + const canAnswer = question.options.length > 0 || question.custom === true; + return canAnswer && hasInput; + }); + + const answerSummary = (questionRequest.answers ?? []) + .map((answer) => answer.join("、")) + .filter(Boolean) + .join(";"); + + return ( + + + + + + + + 需要补充信息 + + + + + + + {questionRequest.questions.map((question, index) => { + const selectedAnswers = selected[index] ?? []; + const setQuestionAnswers = (nextAnswers: string[]) => { + setSelected((current) => ({ + ...current, + [index]: nextAnswers, + })); + }; + + return ( + + + {question.header || `问题 ${index + 1}`} + + + {question.question} + + + {question.options.length ? ( + + {question.options.map((option) => { + const checked = selectedAnswers.includes(option.label); + if (question.multiple) { + return ( + { + if (event.target.checked) { + setQuestionAnswers([...selectedAnswers, option.label]); + } else { + setQuestionAnswers( + selectedAnswers.filter((item) => item !== option.label), + ); + } + }} + /> + } + label={ + + + {option.label} + + {option.description ? ( + + {option.description} + + ) : null} + + } + sx={{ alignItems: "flex-start", m: 0 }} + /> + ); + } + return ( + + ); + })} + + ) : null} + + {question.custom ? ( + + setCustom((current) => ({ + ...current, + [index]: event.target.value, + })) + } + placeholder="补充说明" + sx={{ mt: 1 }} + /> + ) : null} + + ); + })} + + {questionRequest.status === "answered" ? ( + + 已回答{answerSummary ? `:${answerSummary}` : ""} + + ) : null} + + {questionRequest.status === "rejected" ? ( + + 已跳过 + + ) : null} + + {questionRequest.error ? ( + + {questionRequest.error} + + ) : null} + + + {isEditable || isSubmitting ? ( + + + + + ) : null} + + ); +}; + +const QuestionRequestGroup = ({ + questions, + onReply, + onReject, +}: { + questions: NonNullable; + onReply: (requestId: string, answers: string[][]) => void; + onReject: (requestId: string) => void; +}) => ( + + {questions.map((question) => ( + + ))} + +); + +const TodoPlanCard = ({ + todoUpdate, +}: { + todoUpdate: NonNullable[number]; +}) => { + const theme = useTheme(); + const total = todoUpdate.todos.length; + const completed = todoUpdate.todos.filter((todo) => todo.status === "completed").length; + const running = todoUpdate.todos.find((todo) => todo.status === "in_progress"); + const cancelled = todoUpdate.todos.filter((todo) => todo.status === "cancelled").length; + const isAborted = cancelled > 0 && !running; + const [expanded, setExpanded] = React.useState( + !isAborted && todoUpdate.todos.length <= 3, + ); + React.useEffect(() => { + if (isAborted) { + setExpanded(false); + } + }, [isAborted]); + const visibleTodos = + isAborted && !expanded + ? [] + : expanded || total <= 3 + ? todoUpdate.todos + : [ + ...todoUpdate.todos.slice(0, 3), + ...(running && !todoUpdate.todos.slice(0, 3).some((todo) => todo.id === running.id) + ? [running] + : []), + ]; + + const getTodoVisual = (status: NonNullable[number]["todos"][number]["status"]) => { + if (status === "completed") { + return { icon: , color: theme.palette.success.main, label: "已完成" }; + } + if (status === "in_progress") { + return { icon: , color: "#0288d1", label: "进行中" }; + } + if (status === "cancelled") { + return { icon: , color: theme.palette.text.disabled, label: "已中止" }; + } + return { icon: , color: theme.palette.text.secondary, label: "待处理" }; + }; + + if (total === 0) { + return null; + } + + return ( + + setExpanded((value) => !value)} + onKeyDown={(event) => { + if (event.key === "Enter" || event.key === " ") { + event.preventDefault(); + setExpanded((value) => !value); + } + }} + sx={{ + px: 1.5, + py: 1.15, + cursor: "pointer", + transition: "background-color 0.2s ease", + "&:hover": { bgcolor: alpha("#000", 0.025) }, + }} + > + + + + + + 任务规划 + + + {isAborted + ? `${completed}/${total} 已完成,${cancelled} 项已中止` + : `${completed}/${total} 已完成${running ? ",1 项进行中" : ""}`} + + + + {expanded ? ( + + ) : ( + + )} + + + + {visibleTodos.length ? ( + + {visibleTodos.map((todo, index) => { + const visual = getTodoVisual(todo.status); + return ( + + + {visual.icon} + + + {todo.content} + + + + ); + })} + + ) : null} + + ); +}; + export const AgentTurn = React.memo( ({ message, @@ -682,6 +1281,8 @@ export const AgentTurn = React.memo( onRegenerate, onCreateBranch, onReplyPermission, + onReplyQuestion, + onRejectQuestion, }: AgentTurnProps) => { const theme = useTheme(); const isUser = message.role === "user"; @@ -824,6 +1425,18 @@ export const AgentTurn = React.memo( /> ) : null} + {message.questions?.length ? ( + + ) : null} + + {message.todos?.length ? ( + + ) : null} + { onRegenerate: jest.fn(), onCreateBranch: jest.fn(), onReplyPermission: jest.fn(), + onReplyQuestion: jest.fn(), + onRejectQuestion: jest.fn(), }; beforeEach(() => { diff --git a/src/components/chat/AgentWorkspace.tsx b/src/components/chat/AgentWorkspace.tsx index 37e415f..dc76f07 100644 --- a/src/components/chat/AgentWorkspace.tsx +++ b/src/components/chat/AgentWorkspace.tsx @@ -31,6 +31,8 @@ type AgentWorkspaceProps = { onRegenerate: (messageId: string) => void; onCreateBranch: (messageId: string) => void; onReplyPermission: (requestId: string, reply: PermissionReply) => void; + onReplyQuestion: (requestId: string, answers: string[][]) => void; + onRejectQuestion: (requestId: string) => void; }; type TurnListProps = { @@ -45,6 +47,8 @@ type TurnListProps = { onRegenerate: (messageId: string) => void; onCreateBranch: (messageId: string) => void; onReplyPermission: (requestId: string, reply: PermissionReply) => void; + onReplyQuestion: (requestId: string, answers: string[][]) => void; + onRejectQuestion: (requestId: string) => void; }; const sameMessages = (left: Message[], right: Message[]) => @@ -63,6 +67,8 @@ const TurnListInner = ({ onRegenerate, onCreateBranch, onReplyPermission, + onReplyQuestion, + onRejectQuestion, }: TurnListProps) => { return ( <> @@ -79,6 +85,8 @@ const TurnListInner = ({ onRegenerate={onRegenerate} onCreateBranch={onCreateBranch} onReplyPermission={onReplyPermission} + onReplyQuestion={onReplyQuestion} + onRejectQuestion={onRejectQuestion} /> ))} @@ -98,7 +106,9 @@ const TurnList = React.memo( prevProps.isTtsSupported === nextProps.isTtsSupported && prevProps.onRegenerate === nextProps.onRegenerate && prevProps.onCreateBranch === nextProps.onCreateBranch && - prevProps.onReplyPermission === nextProps.onReplyPermission, + prevProps.onReplyPermission === nextProps.onReplyPermission && + prevProps.onReplyQuestion === nextProps.onReplyQuestion && + prevProps.onRejectQuestion === nextProps.onRejectQuestion, ); TurnList.displayName = "TurnList"; @@ -231,6 +241,8 @@ export const AgentWorkspace = ({ onRegenerate, onCreateBranch, onReplyPermission, + onReplyQuestion, + onRejectQuestion, }: AgentWorkspaceProps) => { const theme = useTheme(); const latestAssistant = [...messages] @@ -278,6 +290,8 @@ export const AgentWorkspace = ({ onRegenerate={onRegenerate} onCreateBranch={onCreateBranch} onReplyPermission={onReplyPermission} + onReplyQuestion={onReplyQuestion} + onRejectQuestion={onRejectQuestion} /> {streamingMessage ? ( @@ -293,6 +307,8 @@ export const AgentWorkspace = ({ onRegenerate={onRegenerate} onCreateBranch={onCreateBranch} onReplyPermission={onReplyPermission} + onReplyQuestion={onReplyQuestion} + onRejectQuestion={onRejectQuestion} /> ) : null} diff --git a/src/components/chat/GlobalChatbox.tsx b/src/components/chat/GlobalChatbox.tsx index d1d1192..b73f62f 100644 --- a/src/components/chat/GlobalChatbox.tsx +++ b/src/components/chat/GlobalChatbox.tsx @@ -75,6 +75,8 @@ export const GlobalChatbox: React.FC = ({ open, onClose }) => { createBranch, abort, replyPermission, + replyQuestion, + rejectQuestion, createSession, renameSession, removeSession, @@ -353,6 +355,8 @@ export const GlobalChatbox: React.FC = ({ open, onClose }) => { onRegenerate={regenerate} onCreateBranch={createBranch} onReplyPermission={replyPermission} + onReplyQuestion={replyQuestion} + onRejectQuestion={rejectQuestion} /> ({ abortAgentChat: jest.fn(async () => undefined), forkAgentChat: jest.fn(async () => "forked-session"), replyAgentPermission: jest.fn(async () => undefined), + replyAgentQuestion: jest.fn(async () => undefined), resumeAgentChatStream: jest.fn(async () => undefined), streamAgentChat: jest.fn(async () => undefined), })); @@ -53,11 +55,13 @@ describe("useAgentChatSession", () => { jest.mocked(abortAgentChat).mockReset(); jest.mocked(forkAgentChat).mockReset(); jest.mocked(replyAgentPermission).mockReset(); + jest.mocked(replyAgentQuestion).mockReset(); jest.mocked(resumeAgentChatStream).mockReset(); jest.mocked(streamAgentChat).mockReset(); jest.mocked(abortAgentChat).mockImplementation(async () => undefined); jest.mocked(forkAgentChat).mockImplementation(async () => "forked-session"); jest.mocked(replyAgentPermission).mockImplementation(async () => undefined); + jest.mocked(replyAgentQuestion).mockImplementation(async () => undefined); jest.mocked(resumeAgentChatStream).mockImplementation(async () => undefined); jest.mocked(streamAgentChat).mockImplementation(async () => undefined); deleteChatSession.mockImplementation(async () => undefined); @@ -333,6 +337,337 @@ describe("useAgentChatSession", () => { ]); }); + it("applies question responses to the message that owns the request", async () => { + listChatSessions.mockResolvedValue([ + { + id: "session-streaming", + title: "运行中", + createdAt: 1, + updatedAt: 2, + isStreaming: true, + }, + ]); + jest.mocked(resumeAgentChatStream).mockImplementationOnce(async ({ onEvent }) => { + onEvent({ + type: "state", + sessionId: "session-loaded", + messages: [ + { id: "u1", role: "user", content: "继续分析" }, + { + id: "a1", + role: "assistant", + content: "需要确认", + questions: [ + { + requestId: "q-1", + sessionId: "session-loaded", + questions: [ + { + header: "范围", + question: "选择范围", + options: [], + custom: true, + }, + ], + createdAt: 123, + status: "pending", + }, + ], + }, + { id: "a2", role: "assistant", content: "后续消息" }, + ], + isStreaming: true, + runStatus: "running", + }); + onEvent({ + type: "question_response", + sessionId: "session-loaded", + requestId: "q-1", + answers: [["城区"]], + rejected: false, + }); + }); + + const { result } = renderHook(() => + useAgentChatSession({ + projectId: "project-1", + onToolCall: jest.fn(), + }), + ); + + await waitFor(() => expect(result.current.isHydrating).toBe(false)); + + expect(result.current.messages[1].questions?.[0]).toEqual( + expect.objectContaining({ + requestId: "q-1", + status: "answered", + answers: [["城区"]], + }), + ); + expect(result.current.messages[2].questions).toBeUndefined(); + }); + + it("deduplicates question requests across assistant messages", async () => { + listChatSessions.mockResolvedValue([ + { + id: "session-streaming", + title: "运行中", + createdAt: 1, + updatedAt: 2, + isStreaming: true, + }, + ]); + jest.mocked(resumeAgentChatStream).mockImplementationOnce(async ({ onEvent }) => { + onEvent({ + type: "state", + sessionId: "session-loaded", + messages: [ + { id: "u1", role: "user", content: "继续分析" }, + { + id: "a1", + role: "assistant", + content: "需要确认", + questions: [ + { + requestId: "question-1", + sessionId: "session-loaded", + questions: [ + { + header: "测试问题", + question: "你觉得这个 question 工具好用吗?", + options: [ + { + label: "非常好用", + description: "交互清晰,选项方便", + }, + ], + }, + ], + tool: { + messageID: "message-1", + callID: "call-1", + }, + createdAt: 123, + status: "pending", + }, + ], + }, + { id: "a2", role: "assistant", content: "后续消息" }, + ], + isStreaming: true, + runStatus: "running", + }); + onEvent({ + type: "question_request", + sessionId: "session-loaded", + requestId: "call-1", + questions: [ + { + header: "测试问题", + question: "你觉得这个 question 工具好用吗?", + options: [ + { + label: "非常好用", + description: "交互清晰,选项方便", + }, + ], + }, + ], + tool: { + messageID: "message-1", + callID: "call-1", + }, + createdAt: 456, + }); + }); + + const { result } = renderHook(() => + useAgentChatSession({ + projectId: "project-1", + onToolCall: jest.fn(), + }), + ); + + await waitFor(() => expect(result.current.isHydrating).toBe(false)); + + const allQuestions = result.current.messages.flatMap( + (message) => message.questions ?? [], + ); + expect(allQuestions).toHaveLength(1); + expect(result.current.messages[1].questions?.[0]).toEqual( + expect.objectContaining({ + requestId: "question-1", + tool: expect.objectContaining({ callID: "call-1" }), + }), + ); + expect(result.current.messages[2].questions).toBeUndefined(); + }); + + it("keeps the actionable question request id when a tool-part duplicate arrives later", async () => { + listChatSessions.mockResolvedValue([ + { + id: "session-streaming", + title: "运行中", + createdAt: 1, + updatedAt: 2, + isStreaming: true, + }, + ]); + jest.mocked(resumeAgentChatStream).mockImplementationOnce(async ({ onEvent }) => { + onEvent({ + type: "state", + sessionId: "session-loaded", + messages: [ + { id: "u1", role: "user", content: "继续分析" }, + { + id: "a1", + role: "assistant", + content: "需要确认", + questions: [ + { + requestId: "question-1", + sessionId: "session-loaded", + questions: [ + { + header: "测试问题", + question: "你觉得这个 question 工具好用吗?", + options: [ + { + label: "非常好用", + description: "交互清晰,选项方便", + }, + ], + }, + ], + tool: { + messageID: "message-1", + callID: "call-1", + }, + createdAt: 123, + status: "pending", + }, + ], + }, + ], + isStreaming: true, + runStatus: "running", + }); + onEvent({ + type: "question_request", + sessionId: "session-loaded", + requestId: "call-1", + questions: [ + { + header: "测试问题", + question: "你觉得这个 question 工具好用吗?", + options: [ + { + label: "非常好用", + description: "交互清晰,选项方便", + }, + ], + }, + ], + tool: { + messageID: "message-1", + callID: "call-1", + }, + createdAt: 456, + }); + }); + + const { result } = renderHook(() => + useAgentChatSession({ + projectId: "project-1", + onToolCall: jest.fn(), + }), + ); + + await waitFor(() => expect(result.current.isHydrating).toBe(false)); + + const allQuestions = result.current.messages.flatMap( + (message) => message.questions ?? [], + ); + expect(allQuestions).toHaveLength(1); + expect(allQuestions[0]).toEqual( + expect.objectContaining({ + requestId: "question-1", + tool: expect.objectContaining({ callID: "call-1" }), + }), + ); + }); + + it("deduplicates persisted duplicate questions from state events", async () => { + listChatSessions.mockResolvedValue([ + { + id: "session-streaming", + title: "运行中", + createdAt: 1, + updatedAt: 2, + isStreaming: true, + }, + ]); + const duplicateQuestion = { + sessionId: "session-loaded", + questions: [ + { + header: "测试问题", + question: "你觉得这个 question 工具好用吗?", + options: [ + { + label: "非常好用", + description: "交互清晰,选项方便", + }, + ], + }, + ], + tool: { + messageID: "message-1", + callID: "call-1", + }, + createdAt: 123, + status: "pending" as const, + }; + jest.mocked(resumeAgentChatStream).mockImplementationOnce(async ({ onEvent }) => { + onEvent({ + type: "state", + sessionId: "session-loaded", + messages: [ + { id: "u1", role: "user", content: "继续分析" }, + { + id: "a1", + role: "assistant", + content: "需要确认", + questions: [{ ...duplicateQuestion, requestId: "question-1" }], + }, + { + id: "a2", + role: "assistant", + content: "后续消息", + questions: [{ ...duplicateQuestion, requestId: "call-1" }], + }, + ], + isStreaming: true, + runStatus: "running", + }); + }); + + const { result } = renderHook(() => + useAgentChatSession({ + projectId: "project-1", + onToolCall: jest.fn(), + }), + ); + + await waitFor(() => expect(result.current.isHydrating).toBe(false)); + + expect( + result.current.messages.flatMap((message) => message.questions ?? []), + ).toHaveLength(1); + expect(result.current.messages[1].questions).toHaveLength(1); + expect(result.current.messages[2].questions).toBeUndefined(); + }); + it("aborts a resumed streaming session through the backend abort endpoint", async () => { listChatSessions.mockResolvedValue([ { @@ -433,6 +768,23 @@ describe("useAgentChatSession", () => { title: "开始分析", startedAt: 1000, } satisfies StreamEvent); + onEvent({ + type: "todo_update", + sessionId: "session-1", + todos: [ + { + id: "todo-1", + content: "分析水位", + status: "in_progress", + }, + { + id: "todo-2", + content: "生成建议", + status: "pending", + }, + ], + createdAt: 1001, + } satisfies StreamEvent); signal?.addEventListener("abort", () => { reject(new Error("aborted")); @@ -474,6 +826,22 @@ describe("useAgentChatSession", () => { endedAt: expect.any(Number), }), ], + todos: [ + expect.objectContaining({ + todos: [ + expect.objectContaining({ + id: "todo-1", + status: "cancelled", + updatedAt: expect.any(Number), + }), + expect.objectContaining({ + id: "todo-2", + status: "cancelled", + updatedAt: expect.any(Number), + }), + ], + }), + ], }), ); expect(abortAgentChat).toHaveBeenCalledWith("session-1"); diff --git a/src/components/chat/hooks/useAgentChatSession.ts b/src/components/chat/hooks/useAgentChatSession.ts index 99dab87..19e0231 100644 --- a/src/components/chat/hooks/useAgentChatSession.ts +++ b/src/components/chat/hooks/useAgentChatSession.ts @@ -5,13 +5,17 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { abortAgentChat, forkAgentChat, + rejectAgentQuestion, replyAgentPermission, + replyAgentQuestion, resumeAgentChatStream, streamAgentChat, } from "@/lib/chatStream"; import type { AgentApprovalMode, AgentModel, + AgentQuestionRequest, + AgentTodoUpdate, PermissionReply, StreamEvent, } from "@/lib/chatStream"; @@ -135,6 +139,20 @@ const completeRunningProgress = (progress: ChatProgress[] | undefined) => }; }); +const cancelRunningTodos = (todos: AgentTodoUpdate[] | undefined) => + todos?.map((todoUpdate) => ({ + ...todoUpdate, + todos: todoUpdate.todos.map((todo) => + todo.status === "pending" || todo.status === "in_progress" + ? { + ...todo, + status: "cancelled" as const, + updatedAt: Date.now(), + } + : todo, + ), + })); + const upsertPermission = ( permissions: AgentPermissionRequest[] | undefined, event: StreamEvent & { type: "permission_request" }, @@ -170,12 +188,192 @@ const toPermissionStatus = (reply: PermissionReply): AgentPermissionRequest["sta return "rejected"; }; +const isActionableQuestionRequest = (question: { + requestId: string; + tool?: AgentQuestionRequest["tool"]; +}) => Boolean(question.requestId && question.requestId !== question.tool?.callID); + +const toQuestionRequest = ( + event: StreamEvent & { type: "question_request" }, + status: AgentQuestionRequest["status"] = "pending", +): AgentQuestionRequest => ({ + requestId: event.requestId, + sessionId: event.sessionId, + questions: event.questions, + tool: event.tool, + createdAt: event.createdAt, + status, +}); + +const getQuestionContentSignature = ( + questions: AgentQuestionRequest["questions"], +) => + JSON.stringify( + questions.map((question) => ({ + header: question.header, + question: question.question, + options: question.options.map((option) => ({ + label: option.label, + description: option.description, + })), + multiple: question.multiple ?? false, + custom: question.custom ?? false, + })), + ); + +const isSameQuestionRequest = ( + question: AgentQuestionRequest, + event: StreamEvent & { type: "question_request" }, +) => { + if (question.requestId === event.requestId) return true; + if (question.tool?.callID && event.tool?.callID) { + return question.tool.callID === event.tool.callID; + } + return ( + question.status === "pending" && + question.sessionId === event.sessionId && + getQuestionContentSignature(question.questions) === + getQuestionContentSignature(event.questions) + ); +}; + +const isSameQuestionPair = ( + left: AgentQuestionRequest, + right: AgentQuestionRequest, +) => { + if (left.requestId === right.requestId) return true; + if (left.tool?.callID && right.tool?.callID) { + return left.tool.callID === right.tool.callID; + } + return ( + left.status === "pending" && + right.status === "pending" && + left.sessionId === right.sessionId && + getQuestionContentSignature(left.questions) === + getQuestionContentSignature(right.questions) + ); +}; + +const dedupeQuestionsAcrossMessages = (messages: Message[]) => { + const seen: AgentQuestionRequest[] = []; + let changed = false; + const nextMessages = messages.map((message) => { + if (!message.questions?.length) { + return message; + } + const nextQuestions = message.questions.filter((question) => { + if (seen.some((existing) => isSameQuestionPair(existing, question))) { + changed = true; + return false; + } + seen.push(question); + return true; + }); + if (nextQuestions.length === message.questions.length) { + return message; + } + return { + ...message, + questions: nextQuestions.length ? nextQuestions : undefined, + }; + }); + return changed ? nextMessages : messages; +}; + +const upsertQuestionAcrossMessages = ( + messages: Message[], + event: StreamEvent & { type: "question_request" }, + assistantMessageId: string, +) => { + let existing: AgentQuestionRequest | undefined; + for (const message of messages) { + const match = message.questions?.find((question) => + isSameQuestionRequest(question, event), + ); + if (match) { + existing = match; + break; + } + } + + const existingStatus: AgentQuestionRequest["status"] | undefined = + existing?.status === "submitting" ? "submitting" : undefined; + const nextQuestion = + existing && + isActionableQuestionRequest(existing) && + !isActionableQuestionRequest(event) + ? { + ...existing, + sessionId: event.sessionId, + questions: event.questions, + tool: event.tool ?? existing.tool, + createdAt: event.createdAt, + status: existingStatus ?? existing.status, + } + : toQuestionRequest(event, existingStatus ?? "pending"); + const targetMessageId = existing + ? messages.find((message) => + message.questions?.some((question) => isSameQuestionRequest(question, event)), + )?.id ?? assistantMessageId + : assistantMessageId; + + return messages.map((message) => { + const filteredQuestions = message.questions?.filter( + (question) => !isSameQuestionRequest(question, event), + ); + if (message.id !== targetMessageId) { + return filteredQuestions?.length === message.questions?.length + ? message + : { + ...message, + questions: filteredQuestions?.length ? filteredQuestions : undefined, + }; + } + + const nextQuestions = [...(filteredQuestions ?? []), nextQuestion]; + return { + ...message, + questions: nextQuestions, + }; + }); +}; + +const applyQuestionResponse = ( + questions: AgentQuestionRequest[] | undefined, + event: StreamEvent & { type: "question_response" }, +) => + (questions ?? []).map((question) => + question.requestId === event.requestId + ? { + ...question, + status: event.rejected ? "rejected" as const : "answered" as const, + answers: event.answers ?? question.answers, + repliedAt: Date.now(), + error: undefined, + } + : question, + ); + +const upsertTodoUpdate = ( + todos: AgentTodoUpdate[] | undefined, + event: StreamEvent & { type: "todo_update" }, +) => [ + { + sessionId: event.sessionId, + messageId: event.messageId, + todos: event.todos, + createdAt: event.createdAt, + }, +]; + const finalizeAssistantMessageAfterAbort = (message: Message): Message => { const completedProgress = completeRunningProgress(message.progress); + const cancelledTodos = cancelRunningTodos(message.todos); const hasVisibleOutput = message.content.trim().length > 0 || Boolean(message.artifacts?.length) || - Boolean(completedProgress?.length); + Boolean(completedProgress?.length) || + Boolean(cancelledTodos?.length); if (!hasVisibleOutput) { return message; @@ -186,6 +384,7 @@ const finalizeAssistantMessageAfterAbort = (message: Message): Message => { content: message.content || "⚠️ **请求已中断**", isError: true, progress: completedProgress, + todos: cancelledTodos, }; }; @@ -291,7 +490,7 @@ export const useAgentChatSession = ({ hydrationNonceRef.current += 1; titleUpdateNonceRef.current += 1; - setMessages(loadedState.messages); + setMessages(dedupeQuestionsAcrossMessages(loadedState.messages)); setSessionTitle(loadedState.title); setIsSessionTitleManuallyEdited(loadedState.isTitleManuallyEdited ?? false); setSessionId(loadedState.sessionId); @@ -401,7 +600,9 @@ export const useAgentChatSession = ({ } if (event.type === "state") { - const nextMessages = cloneMessages(event.messages as Message[]); + const nextMessages = dedupeQuestionsAcrossMessages( + cloneMessages(event.messages as Message[]), + ); messagesRef.current = nextMessages; setMessages(nextMessages); setIsStreaming(event.isStreaming); @@ -502,6 +703,32 @@ export const useAgentChatSession = ({ }; }), ); + } else if (event.type === "question_request") { + setMessages((prev) => + upsertQuestionAcrossMessages(prev, event, assistantMessageId), + ); + } else if (event.type === "question_response") { + setMessages((prev) => + prev.map((message) => + message.questions?.some((question) => question.requestId === event.requestId) + ? { + ...message, + questions: applyQuestionResponse(message.questions, event), + } + : message, + ), + ); + } else if (event.type === "todo_update") { + setMessages((prev) => + prev.map((message) => + message.id === assistantMessageId + ? { + ...message, + todos: upsertTodoUpdate(message.todos, event), + } + : message, + ), + ); } else if (event.type === "done") { setMessages((prev) => prev.map((message) => { @@ -531,6 +758,7 @@ export const useAgentChatSession = ({ content: message.content || `⚠️ **错误:** ${event.message}`, isError: true, progress: completeRunningProgress(message.progress), + todos: cancelRunningTodos(message.todos), } : message, ), @@ -621,11 +849,7 @@ export const useAgentChatSession = ({ prev .map((message) => message.id === nextAssistantMessage.id - ? { - ...message, - content: message.content || "⚠️ **请求已中断**", - isError: true, - } + ? finalizeAssistantMessageAfterAbort(message) : message, ) .filter( @@ -635,7 +859,8 @@ export const useAgentChatSession = ({ message.role === "assistant" && message.content.trim().length === 0 && !(message.artifacts?.length) && - !(message.progress?.length) + !(message.progress?.length) && + !(message.todos?.length) ), ), ); @@ -766,6 +991,145 @@ export const useAgentChatSession = ({ [], ); + const replyQuestion = useCallback( + async (requestId: string, answers: string[][]) => { + const target = messagesRef.current + .flatMap((message) => message.questions ?? []) + .find((question) => question.requestId === requestId); + if (!target || target.status === "submitting") { + return; + } + + setMessages((prev) => + prev.map((message) => + !message.questions?.some((question) => question.requestId === requestId) + ? message + : { + ...message, + questions: message.questions.map((question) => + question.requestId === requestId + ? { ...question, status: "submitting", error: undefined } + : question, + ), + }, + ), + ); + + try { + await replyAgentQuestion(target.sessionId, requestId, answers); + setMessages((prev) => + prev.map((message) => + !message.questions?.some((question) => question.requestId === requestId) + ? message + : { + ...message, + questions: message.questions.map((question) => + question.requestId === requestId + ? { + ...question, + status: "answered", + answers, + repliedAt: Date.now(), + error: undefined, + } + : question, + ), + }, + ), + ); + } catch (error) { + setMessages((prev) => + prev.map((message) => + !message.questions?.some((question) => question.requestId === requestId) + ? message + : { + ...message, + questions: message.questions.map((question) => + question.requestId === requestId + ? { + ...question, + status: "error", + error: error instanceof Error ? error.message : String(error), + } + : question, + ), + }, + ), + ); + } + }, + [], + ); + + const rejectQuestion = useCallback( + async (requestId: string) => { + const target = messagesRef.current + .flatMap((message) => message.questions ?? []) + .find((question) => question.requestId === requestId); + if (!target || target.status === "submitting") { + return; + } + + setMessages((prev) => + prev.map((message) => + !message.questions?.some((question) => question.requestId === requestId) + ? message + : { + ...message, + questions: message.questions.map((question) => + question.requestId === requestId + ? { ...question, status: "submitting", error: undefined } + : question, + ), + }, + ), + ); + + try { + await rejectAgentQuestion(target.sessionId, requestId); + setMessages((prev) => + prev.map((message) => + !message.questions?.some((question) => question.requestId === requestId) + ? message + : { + ...message, + questions: message.questions.map((question) => + question.requestId === requestId + ? { + ...question, + status: "rejected", + repliedAt: Date.now(), + error: undefined, + } + : question, + ), + }, + ), + ); + } catch (error) { + setMessages((prev) => + prev.map((message) => + !message.questions?.some((question) => question.requestId === requestId) + ? message + : { + ...message, + questions: message.questions.map((question) => + question.requestId === requestId + ? { + ...question, + status: "error", + error: error instanceof Error ? error.message : String(error), + } + : question, + ), + }, + ), + ); + } + }, + [], + ); + const createSession = useCallback(() => { if (isHydrating || isStreaming) return; @@ -1009,6 +1373,8 @@ export const useAgentChatSession = ({ createBranch, abort, replyPermission, + replyQuestion, + rejectQuestion, createSession, renameSession, removeSession, diff --git a/src/lib/chatStream.test.ts b/src/lib/chatStream.test.ts index 5477c02..fc71e5c 100644 --- a/src/lib/chatStream.test.ts +++ b/src/lib/chatStream.test.ts @@ -1,7 +1,9 @@ import { abortAgentChat, forkAgentChat, + rejectAgentQuestion, replyAgentPermission, + replyAgentQuestion, type StreamEvent, resumeAgentChatStream, streamAgentChat, @@ -218,6 +220,69 @@ describe("streamAgentChat", () => { ]); }); + it("parses question request, response, and todo update events", async () => { + apiFetch.mockResolvedValue({ + ok: true, + body: makeStream([ + 'event: question_request\ndata: {"session_id":"s1","request_id":"q-1","questions":[{"header":"范围","question":"选择范围","options":[{"label":"城区","description":"中心城区"}],"multiple":false,"custom":true}],"tool":{"message_id":"m1","call_id":"c1"},"created_at":123}\n\n', + 'event: question_response\ndata: {"session_id":"s1","request_id":"q-1","answers":[["城区","补充说明"]]}\n\n', + 'event: todo_update\ndata: {"session_id":"s1","todos":[{"id":"t1","content":"分析水位","status":"in_progress","priority":"high","updated_at":456}],"created_at":456}\n\n', + ]), + }); + + const events: StreamEvent[] = []; + + await streamAgentChat({ + message: "hi", + onEvent: (event) => events.push(event), + }); + + expect(events).toEqual([ + { + type: "question_request", + sessionId: "s1", + requestId: "q-1", + questions: [ + { + header: "范围", + question: "选择范围", + options: [{ label: "城区", description: "中心城区" }], + multiple: false, + custom: true, + }, + ], + tool: { + messageID: "m1", + callID: "c1", + }, + createdAt: 123, + }, + { + type: "question_response", + sessionId: "s1", + requestId: "q-1", + answers: [["城区", "补充说明"]], + rejected: false, + }, + { + type: "todo_update", + sessionId: "s1", + messageId: undefined, + todos: [ + { + id: "t1", + content: "分析水位", + status: "in_progress", + priority: "high", + createdAt: undefined, + updatedAt: 456, + }, + ], + createdAt: 456, + }, + ]); + }); + it("emits error when response is not ok", async () => { apiFetch.mockResolvedValue({ ok: false, @@ -314,6 +379,41 @@ describe("streamAgentChat", () => { ); }); + it("calls question reply and reject endpoints", async () => { + apiFetch.mockResolvedValue({ + ok: true, + status: 202, + text: async () => "", + }); + + await replyAgentQuestion("s1", "q-1", [["城区"]]); + await rejectAgentQuestion("s1", "q-2"); + + expect(apiFetch).toHaveBeenCalledWith( + expect.stringContaining("/api/v1/agent/chat/question/q-1/reply"), + expect.objectContaining({ + method: "POST", + projectHeaderMode: "include", + skipAuthRedirect: true, + body: JSON.stringify({ + session_id: "s1", + answers: [["城区"]], + }), + }), + ); + expect(apiFetch).toHaveBeenCalledWith( + expect.stringContaining("/api/v1/agent/chat/question/q-2/reject"), + expect.objectContaining({ + method: "POST", + projectHeaderMode: "include", + skipAuthRedirect: true, + body: JSON.stringify({ + session_id: "s1", + }), + }), + ); + }); + it("calls fork endpoint and returns new session id", async () => { apiFetch.mockResolvedValue({ ok: true, diff --git a/src/lib/chatStream.ts b/src/lib/chatStream.ts index 88f6948..b5c1e37 100644 --- a/src/lib/chatStream.ts +++ b/src/lib/chatStream.ts @@ -8,6 +8,53 @@ export type AgentModel = export type PermissionReply = "once" | "always" | "reject"; export type AgentApprovalMode = "request" | "always"; +export type AgentQuestionStatus = + | "pending" + | "submitting" + | "answered" + | "rejected" + | "error"; + +export type AgentQuestionRequest = { + requestId: string; + sessionId: string; + questions: Array<{ + header: string; + question: string; + options: Array<{ + label: string; + description: string; + }>; + multiple?: boolean; + custom?: boolean; + }>; + tool?: { + messageID: string; + callID: string; + }; + createdAt: number; + repliedAt?: number; + status: AgentQuestionStatus; + answers?: string[][]; + error?: string; +}; + +export type AgentTodoItem = { + id: string; + content: string; + status: "pending" | "in_progress" | "completed" | "cancelled"; + priority?: "low" | "medium" | "high"; + createdAt?: number; + updatedAt?: number; +}; + +export type AgentTodoUpdate = { + sessionId: string; + messageId?: string; + todos: AgentTodoItem[]; + createdAt: number; +}; + export type StreamEvent = | { type: "state"; @@ -64,6 +111,28 @@ export type StreamEvent = sessionId: string; requestId: string; reply: PermissionReply; + } + | { + type: "question_request"; + sessionId: string; + requestId: string; + questions: AgentQuestionRequest["questions"]; + tool?: AgentQuestionRequest["tool"]; + createdAt: number; + } + | { + type: "question_response"; + sessionId: string; + requestId: string; + answers?: string[][]; + rejected?: boolean; + } + | { + type: "todo_update"; + sessionId: string; + messageId?: string; + todos: AgentTodoItem[]; + createdAt: number; }; type StreamOptions = { @@ -125,6 +194,80 @@ const resolveToolParams = ( return isObjectRecord(params) ? params : {}; }; +const normalizeQuestionList = (value: unknown): AgentQuestionRequest["questions"] => { + if (!Array.isArray(value)) return []; + return value + .filter(isObjectRecord) + .map((question) => ({ + header: typeof question.header === "string" ? question.header : "", + question: typeof question.question === "string" ? question.question : "", + options: Array.isArray(question.options) + ? question.options.filter(isObjectRecord).map((option) => ({ + label: typeof option.label === "string" ? option.label : "", + description: + typeof option.description === "string" ? option.description : "", + })) + : [], + multiple: typeof question.multiple === "boolean" ? question.multiple : undefined, + custom: typeof question.custom === "boolean" ? question.custom : undefined, + })); +}; + +const normalizeAnswers = (value: unknown): string[][] | undefined => { + if (!Array.isArray(value)) return undefined; + return value.map((answer) => + Array.isArray(answer) + ? answer.filter((item): item is string => typeof item === "string") + : [], + ); +}; + +const normalizeQuestionTool = (value: unknown): AgentQuestionRequest["tool"] => { + if (!isObjectRecord(value)) return undefined; + const messageID = + typeof value.messageID === "string" + ? value.messageID + : typeof value.message_id === "string" + ? value.message_id + : undefined; + const callID = + typeof value.callID === "string" + ? value.callID + : typeof value.call_id === "string" + ? value.call_id + : undefined; + return messageID && callID ? { messageID, callID } : undefined; +}; + +const normalizeTodoStatus = (value: unknown): AgentTodoItem["status"] => { + if (value === "in_progress" || value === "completed" || value === "cancelled") { + return value; + } + return "pending"; +}; + +const normalizeTodoPriority = (value: unknown): AgentTodoItem["priority"] => { + if (value === "low" || value === "medium" || value === "high") { + return value; + } + return undefined; +}; + +const normalizeTodos = (value: unknown): AgentTodoItem[] => { + if (!Array.isArray(value)) return []; + return value.filter(isObjectRecord).map((todo, index) => ({ + id: + typeof todo.id === "string" && todo.id.trim() + ? todo.id + : `todo-${index}`, + content: typeof todo.content === "string" ? todo.content : "", + status: normalizeTodoStatus(todo.status), + priority: normalizeTodoPriority(todo.priority), + createdAt: typeof todo.created_at === "number" ? todo.created_at : undefined, + updatedAt: typeof todo.updated_at === "number" ? todo.updated_at : undefined, + })); +}; + const emitParsedStreamEvent = ( event: string, data: string, @@ -158,6 +301,11 @@ const emitParsedStreamEvent = ( always?: unknown; created_at?: number; reply?: PermissionReply; + questions?: unknown; + answers?: unknown; + rejected?: boolean; + message_id?: string; + todos?: unknown; }; if (event === "state") { onEvent({ @@ -244,6 +392,31 @@ const emitParsedStreamEvent = ( requestId: parsed.request_id ?? "", reply: parsed.reply ?? "reject", }); + } else if (event === "question_request") { + onEvent({ + type: "question_request", + sessionId: parsed.session_id ?? "", + requestId: parsed.request_id ?? "", + questions: normalizeQuestionList(parsed.questions), + tool: normalizeQuestionTool(parsed.tool), + createdAt: parsed.created_at ?? Date.now(), + }); + } else if (event === "question_response") { + onEvent({ + type: "question_response", + sessionId: parsed.session_id ?? "", + requestId: parsed.request_id ?? "", + answers: normalizeAnswers(parsed.answers), + rejected: parsed.rejected === true, + }); + } else if (event === "todo_update") { + onEvent({ + type: "todo_update", + sessionId: parsed.session_id ?? "", + messageId: parsed.message_id, + todos: normalizeTodos(parsed.todos), + createdAt: parsed.created_at ?? Date.now(), + }); } } catch { onEvent({ @@ -443,6 +616,60 @@ export const replyAgentPermission = async ( } }; +export const replyAgentQuestion = async ( + sessionId: string, + requestId: string, + answers: string[][], +) => { + const response = await apiFetch( + `${config.AGENT_URL}/api/v1/agent/chat/question/${encodeURIComponent(requestId)}/reply`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + session_id: sessionId, + answers, + }), + projectHeaderMode: "include", + userHeaderMode: "include", + skipAuthRedirect: true, + }, + ); + + if (!response.ok) { + const detail = await response.text(); + throw new Error(detail || `question reply failed: ${response.status}`); + } +}; + +export const rejectAgentQuestion = async ( + sessionId: string, + requestId: string, +) => { + const response = await apiFetch( + `${config.AGENT_URL}/api/v1/agent/chat/question/${encodeURIComponent(requestId)}/reject`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + session_id: sessionId, + }), + projectHeaderMode: "include", + userHeaderMode: "include", + skipAuthRedirect: true, + }, + ); + + if (!response.ok) { + const detail = await response.text(); + throw new Error(detail || `question reject failed: ${response.status}`); + } +}; + export const forkAgentChat = async (sessionId: string | undefined, keepMessageCount: number) => { const response = await apiFetch(`${config.AGENT_URL}/api/v1/agent/chat/fork`, { method: "POST",