refactor(chat): centralize session persistence

This commit is contained in:
2026-06-10 19:29:42 +08:00
parent cf6cada538
commit fa2c28c1c0
4 changed files with 227 additions and 77 deletions
+51 -77
View File
@@ -36,6 +36,7 @@ import {
type ActiveRun, type ActiveRun,
type RunStatus, type RunStatus,
type StreamSubscriber, type StreamSubscriber,
appendBackendToolArtifact,
cancelBackendTodos, cancelBackendTodos,
completeBackendProgress, completeBackendProgress,
createInitialStreamingMessages, createInitialStreamingMessages,
@@ -67,12 +68,6 @@ const forkPayloadSchema = z.object({
keep_message_count: z.coerce.number().int().min(0), keep_message_count: z.coerce.number().int().min(0),
}); });
const sessionStateSchema = z.object({
title: z.string().max(120).optional(),
is_title_manually_edited: z.boolean().optional(),
messages: z.array(z.unknown()).default([]),
});
const activeRuns = new Map<string, ActiveRun>(); const activeRuns = new Map<string, ActiveRun>();
const lastRunStatuses = new Map<string, RunStatus>(); const lastRunStatuses = new Map<string, RunStatus>();
@@ -80,6 +75,20 @@ const toSessionUiStateContext = (sessionRecord: SessionRecord) => ({
sessionId: sessionRecord.sessionId, sessionId: sessionRecord.sessionId,
}); });
export const buildForkedSessionUiState = (
sourceState: { messages?: unknown[] } | null | undefined,
input: {
keepMessageCount: number;
targetSessionId: string;
},
) => ({
sessionId: input.targetSessionId,
isTitleManuallyEdited: false,
messages: Array.isArray(sourceState?.messages)
? sourceState.messages.slice(0, input.keepMessageCount)
: [],
});
const getSessionRunStatus = (sessionId: string) => const getSessionRunStatus = (sessionId: string) =>
activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId); activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId);
@@ -274,71 +283,6 @@ export const buildChatRouter = (
res.on("close", cleanup); res.on("close", cleanup);
}); });
chatRouter.put("/session/:sessionId", async (req, res) => {
const sessionId = req.params.sessionId?.trim();
const parsed = sessionStateSchema.safeParse(req.body ?? {});
if (!parsed.success) {
res.status(400).json({
message: "invalid request payload",
detail: parsed.error.flatten(),
});
return;
}
const projectId = req.header("x-project-id") ?? undefined;
const userId = req.header("x-user-id") ?? undefined;
const actorKey = toActorKey(userId);
const projectKey = toProjectKey(projectId);
if (!sessionId) {
res.status(400).json({ message: "session_id is required" });
return;
}
const { record } = await sessionMetadataStore.ensure({
actorKey,
projectId,
projectKey,
sessionId,
userId,
});
const nextRecord = await sessionMetadataStore.touch(record, {
...(parsed.data.title ? { title: parsed.data.title } : {}),
});
await sessionUiStateStore.write(toSessionUiStateContext(nextRecord), {
sessionId: nextRecord.sessionId,
isTitleManuallyEdited: parsed.data.is_title_manually_edited,
messages: parsed.data.messages,
});
const latestTurn = extractLatestFrontendTurn(parsed.data.messages);
if (latestTurn) {
void learningOrchestrator.onTurnCompleted({
...latestTurn,
requestContext: {
actorKey,
clientSessionId: nextRecord.sessionId,
projectId,
projectKey,
traceId: req.header("x-trace-id") ?? `save-${nextRecord.sessionId}`,
userId,
},
sessionId: nextRecord.sessionId,
}).catch((error) => {
logger.warn(
{ err: error, sessionId: nextRecord.sessionId },
"post-save learning failed",
);
});
}
res.json({
id: nextRecord.sessionId,
title: nextRecord.title ?? "新对话",
created_at: nextRecord.createdAt,
updated_at: nextRecord.updatedAt,
status: nextRecord.status,
session_id: nextRecord.sessionId,
});
});
chatRouter.patch("/session/:sessionId/title", async (req, res) => { chatRouter.patch("/session/:sessionId/title", async (req, res) => {
const sessionId = req.params.sessionId?.trim(); const sessionId = req.params.sessionId?.trim();
const title = const title =
@@ -469,7 +413,7 @@ export const buildChatRouter = (
}); });
const nextSessionId = targetSessionRecord.sessionId; const nextSessionId = targetSessionRecord.sessionId;
if (sourceSessionId && parsed.data.keep_message_count > 0) { if (sourceSessionId) {
await sessionTranscriptStore.cloneThread( await sessionTranscriptStore.cloneThread(
{ {
actorKey, actorKey,
@@ -485,12 +429,24 @@ export const buildChatRouter = (
}, },
parsed.data.keep_message_count, parsed.data.keep_message_count,
); );
if (sourceSessionRecord?.title) {
await sessionMetadataStore.touch(targetSessionRecord, {
title: sourceSessionRecord.title,
});
}
} }
const sourceState = sourceSessionRecord
? await sessionUiStateStore.read(toSessionUiStateContext(sourceSessionRecord))
: null;
const forkTitle = sourceSessionRecord?.title
? `${sourceSessionRecord.title} 副本`
: "新对话副本";
const titledTargetSessionRecord = await sessionMetadataStore.touch(
targetSessionRecord,
{ title: forkTitle },
);
await sessionUiStateStore.write(
toSessionUiStateContext(titledTargetSessionRecord),
buildForkedSessionUiState(sourceState, {
keepMessageCount: parsed.data.keep_message_count,
targetSessionId: nextSessionId,
}),
);
logger.info( logger.info(
{ {
@@ -789,6 +745,11 @@ export const buildChatRouter = (
...message, ...message,
todos: upsertBackendTodoUpdate(message.todos, payload), todos: upsertBackendTodoUpdate(message.todos, payload),
})); }));
} else if (event === "tool_call") {
activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({
...message,
artifacts: appendBackendToolArtifact(message.artifacts, data),
}));
} }
for (const subscriber of activeRun.subscribers) { for (const subscriber of activeRun.subscribers) {
@@ -876,6 +837,19 @@ export const buildChatRouter = (
logger.warn({ err: error, sessionId: clientSessionId }, "failed to persist chat stream state"); logger.warn({ err: error, sessionId: clientSessionId }, "failed to persist chat stream state");
}); });
} }
const latestTurn = extractLatestFrontendTurn(activeRun.messages);
if (latestTurn) {
void learningOrchestrator.onTurnCompleted({
...latestTurn,
requestContext,
sessionId: clientSessionId,
}).catch((error) => {
logger.warn(
{ err: error, sessionId: clientSessionId },
"stream-completed learning failed",
);
});
}
} }
} finally { } finally {
if (abortController.signal.aborted) { if (abortController.signal.aborted) {
+62
View File
@@ -22,6 +22,15 @@ export type ActiveRun = {
subscribers: Set<StreamSubscriber>; subscribers: Set<StreamSubscriber>;
}; };
type ToolArtifactKind = "chart" | "map" | "panel" | "tool";
type ToolCallPayload = {
session_id?: string;
tool?: string;
params?: unknown;
reason?: string;
};
export const isObjectRecord = (value: unknown): value is Record<string, unknown> => export const isObjectRecord = (value: unknown): value is Record<string, unknown> =>
typeof value === "object" && value !== null && !Array.isArray(value); typeof value === "object" && value !== null && !Array.isArray(value);
@@ -196,6 +205,59 @@ export const updateLastAssistantQuestion = (
}; };
}); });
const getToolArtifactKind = (tool: string): ToolArtifactKind => {
if (tool === "show_chart" || tool === "chart") return "chart";
if (
tool === "locate_features" ||
tool === "zoom_to_map" ||
tool === "render_junctions" ||
tool === "apply_layer_style" ||
tool.startsWith("locate_")
) {
return "map";
}
if (tool === "view_history" || tool === "view_scada") return "panel";
return "tool";
};
const getToolArtifactTitle = (tool: string, params: Record<string, unknown>) => {
if (typeof params.title === "string" && params.title.trim()) {
return params.title.trim();
}
if (tool === "show_chart" || tool === "chart") return "生成图表";
if (tool === "zoom_to_map") return "缩放到地图坐标";
if (tool === "render_junctions") return "渲染节点分区";
if (tool === "view_history") return "打开计算结果曲线";
if (tool === "view_scada") return "打开 SCADA 数据面板";
if (tool === "apply_layer_style") return "应用图层样式";
if (tool === "locate_features" || tool.startsWith("locate_")) return "地图定位";
return tool || "工具调用";
};
export const appendBackendToolArtifact = (
artifacts: unknown,
payload: ToolCallPayload,
) => {
const tool = typeof payload.tool === "string" ? payload.tool.trim() : "";
if (!tool) {
return artifacts;
}
const params = isObjectRecord(payload.params) ? payload.params : {};
const next = Array.isArray(artifacts) ? [...artifacts] : [];
next.push({
id: `${tool}-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 8)}`,
tool,
kind: getToolArtifactKind(tool),
title: getToolArtifactTitle(tool, params),
description:
typeof payload.reason === "string" && payload.reason.trim()
? payload.reason.trim()
: undefined,
params,
});
return next;
};
export const toFrontendPermission = ( export const toFrontendPermission = (
payload: PermissionRequestPayload, payload: PermissionRequestPayload,
status: "pending" | "approved_once" | "approved_always" | "rejected" | "error" = "pending", status: "pending" | "approved_once" | "approved_always" | "rejected" | "error" = "pending",
+79
View File
@@ -1,5 +1,8 @@
import { describe, expect, it } from "bun:test"; import { describe, expect, it } from "bun:test";
import {
buildForkedSessionUiState,
} from "../../src/routes/chat.js";
import { import {
buildPromptWithLearningContext, buildPromptWithLearningContext,
extractLatestFrontendTurn, extractLatestFrontendTurn,
@@ -199,3 +202,79 @@ describe("extractLatestFrontendTurn", () => {
}); });
}); });
}); });
describe("buildForkedSessionUiState", () => {
it("copies truncated source messages and preserves tool artifacts", () => {
const forked = buildForkedSessionUiState(
{
messages: [
{ role: "user", content: "画压力曲线" },
{
role: "assistant",
content: "已生成图表",
artifacts: [
{
id: "chart-1",
tool: "show_chart",
kind: "chart",
params: { chart_type: "line" },
},
],
},
{ role: "user", content: "继续分析" },
],
},
{
keepMessageCount: 2,
targetSessionId: "forked-session",
},
);
expect(forked).toEqual({
sessionId: "forked-session",
isTitleManuallyEdited: false,
messages: [
{ role: "user", content: "画压力曲线" },
{
role: "assistant",
content: "已生成图表",
artifacts: [
{
id: "chart-1",
tool: "show_chart",
kind: "chart",
params: { chart_type: "line" },
},
],
},
],
});
});
it("creates an empty branch state when source UI state is missing or keep count is zero", () => {
expect(
buildForkedSessionUiState(null, {
keepMessageCount: 3,
targetSessionId: "forked-without-source",
}),
).toEqual({
sessionId: "forked-without-source",
isTitleManuallyEdited: false,
messages: [],
});
expect(
buildForkedSessionUiState(
{ messages: [{ role: "user", content: "不保留" }] },
{
keepMessageCount: 0,
targetSessionId: "forked-empty",
},
),
).toEqual({
sessionId: "forked-empty",
isTitleManuallyEdited: false,
messages: [],
});
});
});
+35
View File
@@ -1,10 +1,45 @@
import { describe, expect, it } from "bun:test"; import { describe, expect, it } from "bun:test";
import { import {
appendBackendToolArtifact,
cancelBackendTodos, cancelBackendTodos,
upsertBackendQuestion, upsertBackendQuestion,
} from "../../src/routes/chatUiState.js"; } from "../../src/routes/chatUiState.js";
describe("appendBackendToolArtifact", () => {
it("persists show_chart tool calls as chart artifacts", () => {
const artifacts = appendBackendToolArtifact([], {
session_id: "session-1",
tool: "show_chart",
reason: "测试折线图渲染",
params: {
title: "压力曲线",
chart_type: "line",
x_data: ["00:00", "01:00"],
series: [{ name: "P-101", data: [0.42, 0.41] }],
},
}) as Array<Record<string, unknown>>;
expect(artifacts).toHaveLength(1);
expect(artifacts[0]).toMatchObject({
tool: "show_chart",
kind: "chart",
title: "压力曲线",
description: "测试折线图渲染",
params: {
chart_type: "line",
x_data: ["00:00", "01:00"],
series: [{ name: "P-101", data: [0.42, 0.41] }],
},
});
expect(artifacts[0]).toEqual(
expect.objectContaining({
id: expect.stringMatching(/^show_chart-/),
}),
);
});
});
describe("upsertBackendQuestion", () => { describe("upsertBackendQuestion", () => {
it("replaces a tool-call placeholder with the actionable question request", () => { it("replaces a tool-call placeholder with the actionable question request", () => {
const questions = upsertBackendQuestion( const questions = upsertBackendQuestion(