diff --git a/src/routes/chat.ts b/src/routes/chat.ts index 08164ae..74cb94e 100644 --- a/src/routes/chat.ts +++ b/src/routes/chat.ts @@ -53,10 +53,139 @@ const sessionStateSchema = z.object({ branch_groups: z.array(z.unknown()).default([]), }); +type RunStatus = "running" | "completed" | "error" | "aborted"; + +type StreamSubscriber = { + write: (event: string, data: Record) => void; + close: () => void; +}; + +type ActiveRun = { + clientSessionId: string; + controller: AbortController; + messages: unknown[]; + status: RunStatus; + subscribers: Set; +}; + +const activeRuns = new Map(); +const lastRunStatuses = new Map(); + const toSessionUiStateContext = (sessionRecord: SessionRecord) => ({ sessionId: sessionRecord.sessionId, }); +const getSessionRunStatus = (sessionId: string) => + activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId); + +const isObjectRecord = (value: unknown): value is Record => + typeof value === "object" && value !== null && !Array.isArray(value); + +const createFrontendMessageId = () => + `msg-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 8)}`; + +const createInitialStreamingMessages = (existingMessages: unknown[], userContent: string) => { + const userMessage = { + id: createFrontendMessageId(), + role: "user", + content: userContent, + }; + return [ + ...existingMessages, + { + ...userMessage, + branchRootId: userMessage.id, + }, + { + id: createFrontendMessageId(), + role: "assistant", + content: "", + progress: [ + { + id: "request-received", + phase: "start", + status: "running", + title: "已收到请求,正在启动 Agent 分析", + detail: "已接收用户消息,正在建立会话并准备进入分析、规划和工具调用阶段。", + startedAt: Date.now(), + elapsedMs: 0, + elapsedSnapshotAt: Date.now(), + }, + ], + }, + ]; +}; + +const upsertBackendProgress = ( + progress: unknown, + payload: Record, +) => { + const next = Array.isArray(progress) ? [...progress] : []; + const id = typeof payload.id === "string" ? payload.id : `progress-${Date.now()}`; + const index = next.findIndex((item) => isObjectRecord(item) && item.id === id); + const nextItem = { + id, + phase: typeof payload.phase === "string" ? payload.phase : "progress", + status: + payload.status === "completed" || payload.status === "error" + ? payload.status + : "running", + title: typeof payload.title === "string" ? payload.title : "正在处理", + detail: typeof payload.detail === "string" ? payload.detail : undefined, + startedAt: typeof payload.started_at === "number" ? payload.started_at : undefined, + endedAt: typeof payload.ended_at === "number" ? payload.ended_at : undefined, + elapsedMs: typeof payload.elapsed_ms === "number" ? payload.elapsed_ms : undefined, + elapsedSnapshotAt: + typeof payload.elapsed_ms === "number" ? Date.now() : undefined, + durationMs: typeof payload.duration_ms === "number" ? payload.duration_ms : undefined, + }; + if (index >= 0) { + next[index] = nextItem; + } else { + next.push(nextItem); + } + return next; +}; + +const completeBackendProgress = (progress: unknown) => + Array.isArray(progress) + ? progress.map((item) => { + if (!isObjectRecord(item) || item.status !== "running") { + return item; + } + const endedAt = Date.now(); + const startedAt = typeof item.startedAt === "number" ? item.startedAt : undefined; + return { + ...item, + status: "completed", + endedAt, + elapsedMs: undefined, + elapsedSnapshotAt: undefined, + durationMs: + typeof item.durationMs === "number" + ? item.durationMs + : startedAt !== undefined + ? Math.max(0, endedAt - startedAt) + : item.elapsedMs, + }; + }) + : progress; + +const updateLastAssistantMessage = ( + messages: unknown[], + updater: (message: Record) => Record, +) => { + for (let index = messages.length - 1; index >= 0; index -= 1) { + const message = messages[index]; + if (isObjectRecord(message) && message.role === "assistant") { + const next = [...messages]; + next[index] = updater(message); + return next; + } + } + return messages; +}; + export const buildChatRouter = ( sessionBridge: ChatSessionBridge, runtime: OpencodeRuntimeAdapter, @@ -124,6 +253,8 @@ export const buildChatRouter = ( updated_at: record.updatedAt, status: record.status, parent_session_id: record.parentSessionId, + is_streaming: activeRuns.get(record.sessionId)?.status === "running", + run_status: getSessionRunStatus(record.sessionId), })), }); }); @@ -167,9 +298,75 @@ export const buildChatRouter = ( messages: state?.messages ?? [], branch_groups: state?.branchGroups ?? [], parent_session_id: sessionRecord.parentSessionId, + is_streaming: activeRuns.get(sessionRecord.sessionId)?.status === "running", + run_status: getSessionRunStatus(sessionRecord.sessionId), }); }); + chatRouter.get("/session/:sessionId/stream", async (req, res) => { + const sessionId = req.params.sessionId?.trim(); + const projectId = req.header("x-project-id") ?? undefined; + const userId = req.header("x-user-id") ?? undefined; + const actorKey = toActorKey(userId); + const projectKey = toProjectKey(projectId); + if (!sessionId) { + res.status(400).json({ message: "session_id is required" }); + return; + } + + const sessionRecord = await sessionMetadataStore.get( + { actorKey, projectId, projectKey, userId }, + sessionId, + ); + if (!sessionRecord) { + res.status(404).json({ message: "session not found" }); + return; + } + + res.status(200); + res.setHeader("Content-Type", "text/event-stream; charset=utf-8"); + res.setHeader("Cache-Control", "no-cache"); + res.setHeader("Connection", "keep-alive"); + res.setHeader("X-Accel-Buffering", "no"); + res.flushHeaders?.(); + + const run = activeRuns.get(sessionRecord.sessionId); + const state = await sessionUiStateStore.read(toSessionUiStateContext(sessionRecord)); + res.write( + toSse("state", { + session_id: sessionRecord.sessionId, + messages: state?.messages ?? run?.messages ?? [], + is_streaming: run?.status === "running", + run_status: getSessionRunStatus(sessionRecord.sessionId) ?? "completed", + }), + ); + + if (!run || run.status !== "running") { + res.end(); + return; + } + + const subscriber: StreamSubscriber = { + write: (event, data) => { + if (!res.writableEnded && !res.destroyed) { + res.write(toSse(event, data)); + } + }, + close: () => { + if (!res.writableEnded && !res.destroyed) { + res.end(); + } + }, + }; + run.subscribers.add(subscriber); + + const cleanup = () => { + run.subscribers.delete(subscriber); + }; + req.on("close", cleanup); + res.on("close", cleanup); + }); + chatRouter.put("/session/:sessionId", async (req, res) => { const sessionId = req.params.sessionId?.trim(); const parsed = sessionStateSchema.safeParse(req.body ?? {}); @@ -304,6 +501,8 @@ export const buildChatRouter = ( clientSessionId: sessionRecord.sessionId, sessionId: sessionRecord.sessionId, }); + activeRuns.delete(sessionRecord.sessionId); + lastRunStatuses.delete(sessionRecord.sessionId); await sessionMetadataStore.remove(sessionRecord); res.status(204).end(); }); @@ -376,8 +575,42 @@ export const buildChatRouter = ( sessionId: sessionRecord.sessionId, }) : null; + const run = activeRuns.get(parsed.data.session_id); + if (run && run.status === "running") { + run.status = "aborted"; + lastRunStatuses.set(parsed.data.session_id, "aborted"); + run.controller.abort(); + run.messages = updateLastAssistantMessage(run.messages, (message) => ({ + ...message, + content: + typeof message.content === "string" && message.content.trim() + ? message.content + : "⚠️ **请求已中断**", + isError: true, + progress: completeBackendProgress(message.progress), + })); + if (sessionRecord) { + const currentState = await sessionUiStateStore.read( + toSessionUiStateContext(sessionRecord), + ); + await sessionUiStateStore.write(toSessionUiStateContext(sessionRecord), { + sessionId: sessionRecord.sessionId, + isTitleManuallyEdited: currentState?.isTitleManuallyEdited ?? false, + messages: run.messages, + branchGroups: currentState?.branchGroups ?? [], + }); + } + for (const subscriber of run.subscribers) { + subscriber.write("error", { + session_id: parsed.data.session_id, + message: "请求已中断", + }); + subscriber.close(); + } + run.subscribers.clear(); + } - if (!binding) { + if (!binding && !run) { res.status(204).end(); return; } @@ -385,7 +618,7 @@ export const buildChatRouter = ( logger.info( { clientSessionId: parsed.data.session_id, - sessionId: binding.sessionId, + sessionId: binding?.sessionId ?? parsed.data.session_id, }, "aborted chat session by client request", ); @@ -545,6 +778,13 @@ export const buildChatRouter = ( const initialSessionState = await sessionUiStateStore.read( toSessionUiStateContext(activeSessionRecord), ); + if (activeRuns.get(activeSessionRecord.sessionId)?.status === "running") { + res.status(409).json({ + message: "session is already streaming", + session_id: activeSessionRecord.sessionId, + }); + return; + } logger.info( { @@ -569,16 +809,95 @@ export const buildChatRouter = ( let streamClosed = false; const abortController = new AbortController(); sessionBridge.registerAbortController(clientSessionId, abortController); + const initialMessages = createInitialStreamingMessages( + initialSessionState?.messages ?? [], + parsed.data.message, + ); + const branchGroups = initialSessionState?.branchGroups ?? []; + const activeRun: ActiveRun = { + clientSessionId, + controller: abortController, + messages: initialMessages, + status: "running", + subscribers: new Set(), + }; + activeRuns.set(clientSessionId, activeRun); + lastRunStatuses.set(clientSessionId, "running"); + await sessionUiStateStore.write(toSessionUiStateContext(activeSessionRecord), { + sessionId: activeSessionRecord.sessionId, + isTitleManuallyEdited: initialSessionState?.isTitleManuallyEdited ?? false, + messages: initialMessages, + branchGroups, + }); + const primarySubscriber: StreamSubscriber = { + write: (event, data) => { + if (!streamClosed && !res.writableEnded && !res.destroyed) { + res.write(toSse(event, data)); + } + }, + close: () => { + if (!res.writableEnded && !res.destroyed) { + res.end(); + } + }, + }; + activeRun.subscribers.add(primarySubscriber); const handleClientClose = () => { - if (streamClosed || abortController.signal.aborted) { - return; - } - abortController.abort(); + streamClosed = true; + activeRun.subscribers.delete(primarySubscriber); }; req.on("close", handleClientClose); res.on("close", handleClientClose); + const publish = async (event: string, data: Record) => { + if (event === "token") { + activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({ + ...message, + content: `${typeof message.content === "string" ? message.content : ""}${typeof data.content === "string" ? data.content : ""}`, + isError: false, + })); + } else if (event === "progress") { + activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({ + ...message, + progress: upsertBackendProgress(message.progress, data), + })); + } else if (event === "done") { + activeRun.status = "completed"; + lastRunStatuses.set(clientSessionId, "completed"); + activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({ + ...message, + content: + typeof message.content === "string" && message.content.trim() + ? message.content + : "Agent 已完成处理,但没有生成文本回答。请查看过程记录,或换个更具体的问题重试。", + progress: completeBackendProgress(message.progress), + })); + } else if (event === "error") { + activeRun.status = activeRun.status === "aborted" ? "aborted" : "error"; + lastRunStatuses.set(clientSessionId, activeRun.status); + activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({ + ...message, + content: + typeof message.content === "string" && message.content.trim() + ? message.content + : `⚠️ **错误:** ${typeof data.message === "string" ? data.message : "unknown error"}`, + isError: true, + progress: completeBackendProgress(message.progress), + })); + } + + await sessionUiStateStore.write(toSessionUiStateContext(activeSessionRecord), { + sessionId: activeSessionRecord.sessionId, + isTitleManuallyEdited: initialSessionState?.isTitleManuallyEdited ?? false, + messages: activeRun.messages, + branchGroups, + }); + for (const subscriber of activeRun.subscribers) { + subscriber.write(event, data); + } + }; + try { const preparedMessage = await buildPromptWithLearningContext( memoryStore, @@ -601,10 +920,9 @@ export const buildChatRouter = ( projectId: requestContext.projectId, signal: abortController.signal, write: (event, data) => { - if (streamClosed || res.writableEnded || res.destroyed) { - return; - } - res.write(toSse(event, data)); + void publish(event, data).catch((error) => { + logger.warn({ err: error, sessionId: clientSessionId }, "failed to publish chat stream event"); + }); }, }); @@ -642,23 +960,32 @@ export const buildChatRouter = ( ? { title: sessionTitle } : {}), }); - if (!streamClosed && !res.writableEnded && !res.destroyed) { - if ( - shouldGenerateTitle && - sessionTitle && - sessionTitle !== existingSessionTitle - ) { - res.write( - toSse("session_title", { - session_id: clientSessionId, - title: sessionTitle, - }), - ); - } + if ( + shouldGenerateTitle && + sessionTitle && + sessionTitle !== existingSessionTitle + ) { + await publish("session_title", { + session_id: clientSessionId, + title: sessionTitle, + }); } } } finally { sessionBridge.finalizeRequest(clientSessionId); + activeRun.status = abortController.signal.aborted + ? activeRun.status === "aborted" + ? "aborted" + : "aborted" + : activeRun.status === "running" + ? "completed" + : activeRun.status; + lastRunStatuses.set(clientSessionId, activeRun.status); + for (const subscriber of activeRun.subscribers) { + subscriber.close(); + } + activeRun.subscribers.clear(); + activeRuns.delete(clientSessionId); streamClosed = true; req.off("close", handleClientClose); res.off("close", handleClientClose);