From fa2c28c1c0af32c553c717926fafe8a5d5adad6e Mon Sep 17 00:00:00 2001 From: Huarch Date: Wed, 10 Jun 2026 19:29:42 +0800 Subject: [PATCH] refactor(chat): centralize session persistence --- src/routes/chat.ts | 128 ++++++++++++------------------- src/routes/chatUiState.ts | 62 +++++++++++++++ tests/routes/chatSession.test.ts | 79 +++++++++++++++++++ tests/routes/chatUiState.test.ts | 35 +++++++++ 4 files changed, 227 insertions(+), 77 deletions(-) diff --git a/src/routes/chat.ts b/src/routes/chat.ts index 7f2f1f8..a70d016 100644 --- a/src/routes/chat.ts +++ b/src/routes/chat.ts @@ -36,6 +36,7 @@ import { type ActiveRun, type RunStatus, type StreamSubscriber, + appendBackendToolArtifact, cancelBackendTodos, completeBackendProgress, createInitialStreamingMessages, @@ -67,12 +68,6 @@ const forkPayloadSchema = z.object({ keep_message_count: z.coerce.number().int().min(0), }); -const sessionStateSchema = z.object({ - title: z.string().max(120).optional(), - is_title_manually_edited: z.boolean().optional(), - messages: z.array(z.unknown()).default([]), -}); - const activeRuns = new Map(); const lastRunStatuses = new Map(); @@ -80,6 +75,20 @@ const toSessionUiStateContext = (sessionRecord: SessionRecord) => ({ sessionId: sessionRecord.sessionId, }); +export const buildForkedSessionUiState = ( + sourceState: { messages?: unknown[] } | null | undefined, + input: { + keepMessageCount: number; + targetSessionId: string; + }, +) => ({ + sessionId: input.targetSessionId, + isTitleManuallyEdited: false, + messages: Array.isArray(sourceState?.messages) + ? sourceState.messages.slice(0, input.keepMessageCount) + : [], +}); + const getSessionRunStatus = (sessionId: string) => activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId); @@ -274,71 +283,6 @@ export const buildChatRouter = ( res.on("close", cleanup); }); - chatRouter.put("/session/:sessionId", async (req, res) => { - const sessionId = req.params.sessionId?.trim(); - const parsed = sessionStateSchema.safeParse(req.body ?? {}); - if (!parsed.success) { - res.status(400).json({ - message: "invalid request payload", - detail: parsed.error.flatten(), - }); - return; - } - - const projectId = req.header("x-project-id") ?? undefined; - const userId = req.header("x-user-id") ?? undefined; - const actorKey = toActorKey(userId); - const projectKey = toProjectKey(projectId); - if (!sessionId) { - res.status(400).json({ message: "session_id is required" }); - return; - } - - const { record } = await sessionMetadataStore.ensure({ - actorKey, - projectId, - projectKey, - sessionId, - userId, - }); - const nextRecord = await sessionMetadataStore.touch(record, { - ...(parsed.data.title ? { title: parsed.data.title } : {}), - }); - await sessionUiStateStore.write(toSessionUiStateContext(nextRecord), { - sessionId: nextRecord.sessionId, - isTitleManuallyEdited: parsed.data.is_title_manually_edited, - messages: parsed.data.messages, - }); - const latestTurn = extractLatestFrontendTurn(parsed.data.messages); - if (latestTurn) { - void learningOrchestrator.onTurnCompleted({ - ...latestTurn, - requestContext: { - actorKey, - clientSessionId: nextRecord.sessionId, - projectId, - projectKey, - traceId: req.header("x-trace-id") ?? `save-${nextRecord.sessionId}`, - userId, - }, - sessionId: nextRecord.sessionId, - }).catch((error) => { - logger.warn( - { err: error, sessionId: nextRecord.sessionId }, - "post-save learning failed", - ); - }); - } - res.json({ - id: nextRecord.sessionId, - title: nextRecord.title ?? "新对话", - created_at: nextRecord.createdAt, - updated_at: nextRecord.updatedAt, - status: nextRecord.status, - session_id: nextRecord.sessionId, - }); - }); - chatRouter.patch("/session/:sessionId/title", async (req, res) => { const sessionId = req.params.sessionId?.trim(); const title = @@ -469,7 +413,7 @@ export const buildChatRouter = ( }); const nextSessionId = targetSessionRecord.sessionId; - if (sourceSessionId && parsed.data.keep_message_count > 0) { + if (sourceSessionId) { await sessionTranscriptStore.cloneThread( { actorKey, @@ -485,12 +429,24 @@ export const buildChatRouter = ( }, parsed.data.keep_message_count, ); - if (sourceSessionRecord?.title) { - await sessionMetadataStore.touch(targetSessionRecord, { - title: sourceSessionRecord.title, - }); - } } + const sourceState = sourceSessionRecord + ? await sessionUiStateStore.read(toSessionUiStateContext(sourceSessionRecord)) + : null; + const forkTitle = sourceSessionRecord?.title + ? `${sourceSessionRecord.title} 副本` + : "新对话副本"; + const titledTargetSessionRecord = await sessionMetadataStore.touch( + targetSessionRecord, + { title: forkTitle }, + ); + await sessionUiStateStore.write( + toSessionUiStateContext(titledTargetSessionRecord), + buildForkedSessionUiState(sourceState, { + keepMessageCount: parsed.data.keep_message_count, + targetSessionId: nextSessionId, + }), + ); logger.info( { @@ -789,6 +745,11 @@ export const buildChatRouter = ( ...message, todos: upsertBackendTodoUpdate(message.todos, payload), })); + } else if (event === "tool_call") { + activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({ + ...message, + artifacts: appendBackendToolArtifact(message.artifacts, data), + })); } for (const subscriber of activeRun.subscribers) { @@ -876,6 +837,19 @@ export const buildChatRouter = ( logger.warn({ err: error, sessionId: clientSessionId }, "failed to persist chat stream state"); }); } + const latestTurn = extractLatestFrontendTurn(activeRun.messages); + if (latestTurn) { + void learningOrchestrator.onTurnCompleted({ + ...latestTurn, + requestContext, + sessionId: clientSessionId, + }).catch((error) => { + logger.warn( + { err: error, sessionId: clientSessionId }, + "stream-completed learning failed", + ); + }); + } } } finally { if (abortController.signal.aborted) { diff --git a/src/routes/chatUiState.ts b/src/routes/chatUiState.ts index d7f89c1..30c9135 100644 --- a/src/routes/chatUiState.ts +++ b/src/routes/chatUiState.ts @@ -22,6 +22,15 @@ export type ActiveRun = { subscribers: Set; }; +type ToolArtifactKind = "chart" | "map" | "panel" | "tool"; + +type ToolCallPayload = { + session_id?: string; + tool?: string; + params?: unknown; + reason?: string; +}; + export const isObjectRecord = (value: unknown): value is Record => typeof value === "object" && value !== null && !Array.isArray(value); @@ -196,6 +205,59 @@ export const updateLastAssistantQuestion = ( }; }); +const getToolArtifactKind = (tool: string): ToolArtifactKind => { + if (tool === "show_chart" || tool === "chart") return "chart"; + if ( + tool === "locate_features" || + tool === "zoom_to_map" || + tool === "render_junctions" || + tool === "apply_layer_style" || + tool.startsWith("locate_") + ) { + return "map"; + } + if (tool === "view_history" || tool === "view_scada") return "panel"; + return "tool"; +}; + +const getToolArtifactTitle = (tool: string, params: Record) => { + if (typeof params.title === "string" && params.title.trim()) { + return params.title.trim(); + } + if (tool === "show_chart" || tool === "chart") return "生成图表"; + if (tool === "zoom_to_map") return "缩放到地图坐标"; + if (tool === "render_junctions") return "渲染节点分区"; + if (tool === "view_history") return "打开计算结果曲线"; + if (tool === "view_scada") return "打开 SCADA 数据面板"; + if (tool === "apply_layer_style") return "应用图层样式"; + if (tool === "locate_features" || tool.startsWith("locate_")) return "地图定位"; + return tool || "工具调用"; +}; + +export const appendBackendToolArtifact = ( + artifacts: unknown, + payload: ToolCallPayload, +) => { + const tool = typeof payload.tool === "string" ? payload.tool.trim() : ""; + if (!tool) { + return artifacts; + } + const params = isObjectRecord(payload.params) ? payload.params : {}; + const next = Array.isArray(artifacts) ? [...artifacts] : []; + next.push({ + id: `${tool}-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 8)}`, + tool, + kind: getToolArtifactKind(tool), + title: getToolArtifactTitle(tool, params), + description: + typeof payload.reason === "string" && payload.reason.trim() + ? payload.reason.trim() + : undefined, + params, + }); + return next; +}; + export const toFrontendPermission = ( payload: PermissionRequestPayload, status: "pending" | "approved_once" | "approved_always" | "rejected" | "error" = "pending", diff --git a/tests/routes/chatSession.test.ts b/tests/routes/chatSession.test.ts index c8ae543..19e5611 100644 --- a/tests/routes/chatSession.test.ts +++ b/tests/routes/chatSession.test.ts @@ -1,5 +1,8 @@ import { describe, expect, it } from "bun:test"; +import { + buildForkedSessionUiState, +} from "../../src/routes/chat.js"; import { buildPromptWithLearningContext, extractLatestFrontendTurn, @@ -199,3 +202,79 @@ describe("extractLatestFrontendTurn", () => { }); }); }); + +describe("buildForkedSessionUiState", () => { + it("copies truncated source messages and preserves tool artifacts", () => { + const forked = buildForkedSessionUiState( + { + messages: [ + { role: "user", content: "画压力曲线" }, + { + role: "assistant", + content: "已生成图表", + artifacts: [ + { + id: "chart-1", + tool: "show_chart", + kind: "chart", + params: { chart_type: "line" }, + }, + ], + }, + { role: "user", content: "继续分析" }, + ], + }, + { + keepMessageCount: 2, + targetSessionId: "forked-session", + }, + ); + + expect(forked).toEqual({ + sessionId: "forked-session", + isTitleManuallyEdited: false, + messages: [ + { role: "user", content: "画压力曲线" }, + { + role: "assistant", + content: "已生成图表", + artifacts: [ + { + id: "chart-1", + tool: "show_chart", + kind: "chart", + params: { chart_type: "line" }, + }, + ], + }, + ], + }); + }); + + it("creates an empty branch state when source UI state is missing or keep count is zero", () => { + expect( + buildForkedSessionUiState(null, { + keepMessageCount: 3, + targetSessionId: "forked-without-source", + }), + ).toEqual({ + sessionId: "forked-without-source", + isTitleManuallyEdited: false, + messages: [], + }); + + expect( + buildForkedSessionUiState( + { messages: [{ role: "user", content: "不保留" }] }, + { + keepMessageCount: 0, + targetSessionId: "forked-empty", + }, + ), + ).toEqual({ + sessionId: "forked-empty", + isTitleManuallyEdited: false, + messages: [], + }); + }); +}); diff --git a/tests/routes/chatUiState.test.ts b/tests/routes/chatUiState.test.ts index d07877e..81b444c 100644 --- a/tests/routes/chatUiState.test.ts +++ b/tests/routes/chatUiState.test.ts @@ -1,10 +1,45 @@ import { describe, expect, it } from "bun:test"; import { + appendBackendToolArtifact, cancelBackendTodos, upsertBackendQuestion, } from "../../src/routes/chatUiState.js"; +describe("appendBackendToolArtifact", () => { + it("persists show_chart tool calls as chart artifacts", () => { + const artifacts = appendBackendToolArtifact([], { + session_id: "session-1", + tool: "show_chart", + reason: "测试折线图渲染", + params: { + title: "压力曲线", + chart_type: "line", + x_data: ["00:00", "01:00"], + series: [{ name: "P-101", data: [0.42, 0.41] }], + }, + }) as Array>; + + expect(artifacts).toHaveLength(1); + expect(artifacts[0]).toMatchObject({ + tool: "show_chart", + kind: "chart", + title: "压力曲线", + description: "测试折线图渲染", + params: { + chart_type: "line", + x_data: ["00:00", "01:00"], + series: [{ name: "P-101", data: [0.42, 0.41] }], + }, + }); + expect(artifacts[0]).toEqual( + expect.objectContaining({ + id: expect.stringMatching(/^show_chart-/), + }), + ); + }); +}); + describe("upsertBackendQuestion", () => { it("replaces a tool-call placeholder with the actionable question request", () => { const questions = upsertBackendQuestion(