refactor(chat): centralize session persistence
This commit is contained in:
+51
-77
@@ -36,6 +36,7 @@ import {
|
||||
type ActiveRun,
|
||||
type RunStatus,
|
||||
type StreamSubscriber,
|
||||
appendBackendToolArtifact,
|
||||
cancelBackendTodos,
|
||||
completeBackendProgress,
|
||||
createInitialStreamingMessages,
|
||||
@@ -67,12 +68,6 @@ const forkPayloadSchema = z.object({
|
||||
keep_message_count: z.coerce.number().int().min(0),
|
||||
});
|
||||
|
||||
const sessionStateSchema = z.object({
|
||||
title: z.string().max(120).optional(),
|
||||
is_title_manually_edited: z.boolean().optional(),
|
||||
messages: z.array(z.unknown()).default([]),
|
||||
});
|
||||
|
||||
const activeRuns = new Map<string, ActiveRun>();
|
||||
const lastRunStatuses = new Map<string, RunStatus>();
|
||||
|
||||
@@ -80,6 +75,20 @@ const toSessionUiStateContext = (sessionRecord: SessionRecord) => ({
|
||||
sessionId: sessionRecord.sessionId,
|
||||
});
|
||||
|
||||
export const buildForkedSessionUiState = (
|
||||
sourceState: { messages?: unknown[] } | null | undefined,
|
||||
input: {
|
||||
keepMessageCount: number;
|
||||
targetSessionId: string;
|
||||
},
|
||||
) => ({
|
||||
sessionId: input.targetSessionId,
|
||||
isTitleManuallyEdited: false,
|
||||
messages: Array.isArray(sourceState?.messages)
|
||||
? sourceState.messages.slice(0, input.keepMessageCount)
|
||||
: [],
|
||||
});
|
||||
|
||||
const getSessionRunStatus = (sessionId: string) =>
|
||||
activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId);
|
||||
|
||||
@@ -274,71 +283,6 @@ export const buildChatRouter = (
|
||||
res.on("close", cleanup);
|
||||
});
|
||||
|
||||
chatRouter.put("/session/:sessionId", async (req, res) => {
|
||||
const sessionId = req.params.sessionId?.trim();
|
||||
const parsed = sessionStateSchema.safeParse(req.body ?? {});
|
||||
if (!parsed.success) {
|
||||
res.status(400).json({
|
||||
message: "invalid request payload",
|
||||
detail: parsed.error.flatten(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const projectId = req.header("x-project-id") ?? undefined;
|
||||
const userId = req.header("x-user-id") ?? undefined;
|
||||
const actorKey = toActorKey(userId);
|
||||
const projectKey = toProjectKey(projectId);
|
||||
if (!sessionId) {
|
||||
res.status(400).json({ message: "session_id is required" });
|
||||
return;
|
||||
}
|
||||
|
||||
const { record } = await sessionMetadataStore.ensure({
|
||||
actorKey,
|
||||
projectId,
|
||||
projectKey,
|
||||
sessionId,
|
||||
userId,
|
||||
});
|
||||
const nextRecord = await sessionMetadataStore.touch(record, {
|
||||
...(parsed.data.title ? { title: parsed.data.title } : {}),
|
||||
});
|
||||
await sessionUiStateStore.write(toSessionUiStateContext(nextRecord), {
|
||||
sessionId: nextRecord.sessionId,
|
||||
isTitleManuallyEdited: parsed.data.is_title_manually_edited,
|
||||
messages: parsed.data.messages,
|
||||
});
|
||||
const latestTurn = extractLatestFrontendTurn(parsed.data.messages);
|
||||
if (latestTurn) {
|
||||
void learningOrchestrator.onTurnCompleted({
|
||||
...latestTurn,
|
||||
requestContext: {
|
||||
actorKey,
|
||||
clientSessionId: nextRecord.sessionId,
|
||||
projectId,
|
||||
projectKey,
|
||||
traceId: req.header("x-trace-id") ?? `save-${nextRecord.sessionId}`,
|
||||
userId,
|
||||
},
|
||||
sessionId: nextRecord.sessionId,
|
||||
}).catch((error) => {
|
||||
logger.warn(
|
||||
{ err: error, sessionId: nextRecord.sessionId },
|
||||
"post-save learning failed",
|
||||
);
|
||||
});
|
||||
}
|
||||
res.json({
|
||||
id: nextRecord.sessionId,
|
||||
title: nextRecord.title ?? "新对话",
|
||||
created_at: nextRecord.createdAt,
|
||||
updated_at: nextRecord.updatedAt,
|
||||
status: nextRecord.status,
|
||||
session_id: nextRecord.sessionId,
|
||||
});
|
||||
});
|
||||
|
||||
chatRouter.patch("/session/:sessionId/title", async (req, res) => {
|
||||
const sessionId = req.params.sessionId?.trim();
|
||||
const title =
|
||||
@@ -469,7 +413,7 @@ export const buildChatRouter = (
|
||||
});
|
||||
const nextSessionId = targetSessionRecord.sessionId;
|
||||
|
||||
if (sourceSessionId && parsed.data.keep_message_count > 0) {
|
||||
if (sourceSessionId) {
|
||||
await sessionTranscriptStore.cloneThread(
|
||||
{
|
||||
actorKey,
|
||||
@@ -485,12 +429,24 @@ export const buildChatRouter = (
|
||||
},
|
||||
parsed.data.keep_message_count,
|
||||
);
|
||||
if (sourceSessionRecord?.title) {
|
||||
await sessionMetadataStore.touch(targetSessionRecord, {
|
||||
title: sourceSessionRecord.title,
|
||||
});
|
||||
}
|
||||
}
|
||||
const sourceState = sourceSessionRecord
|
||||
? await sessionUiStateStore.read(toSessionUiStateContext(sourceSessionRecord))
|
||||
: null;
|
||||
const forkTitle = sourceSessionRecord?.title
|
||||
? `${sourceSessionRecord.title} 副本`
|
||||
: "新对话副本";
|
||||
const titledTargetSessionRecord = await sessionMetadataStore.touch(
|
||||
targetSessionRecord,
|
||||
{ title: forkTitle },
|
||||
);
|
||||
await sessionUiStateStore.write(
|
||||
toSessionUiStateContext(titledTargetSessionRecord),
|
||||
buildForkedSessionUiState(sourceState, {
|
||||
keepMessageCount: parsed.data.keep_message_count,
|
||||
targetSessionId: nextSessionId,
|
||||
}),
|
||||
);
|
||||
|
||||
logger.info(
|
||||
{
|
||||
@@ -789,6 +745,11 @@ export const buildChatRouter = (
|
||||
...message,
|
||||
todos: upsertBackendTodoUpdate(message.todos, payload),
|
||||
}));
|
||||
} else if (event === "tool_call") {
|
||||
activeRun.messages = updateLastAssistantMessage(activeRun.messages, (message) => ({
|
||||
...message,
|
||||
artifacts: appendBackendToolArtifact(message.artifacts, data),
|
||||
}));
|
||||
}
|
||||
|
||||
for (const subscriber of activeRun.subscribers) {
|
||||
@@ -876,6 +837,19 @@ export const buildChatRouter = (
|
||||
logger.warn({ err: error, sessionId: clientSessionId }, "failed to persist chat stream state");
|
||||
});
|
||||
}
|
||||
const latestTurn = extractLatestFrontendTurn(activeRun.messages);
|
||||
if (latestTurn) {
|
||||
void learningOrchestrator.onTurnCompleted({
|
||||
...latestTurn,
|
||||
requestContext,
|
||||
sessionId: clientSessionId,
|
||||
}).catch((error) => {
|
||||
logger.warn(
|
||||
{ err: error, sessionId: clientSessionId },
|
||||
"stream-completed learning failed",
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (abortController.signal.aborted) {
|
||||
|
||||
@@ -22,6 +22,15 @@ export type ActiveRun = {
|
||||
subscribers: Set<StreamSubscriber>;
|
||||
};
|
||||
|
||||
type ToolArtifactKind = "chart" | "map" | "panel" | "tool";
|
||||
|
||||
type ToolCallPayload = {
|
||||
session_id?: string;
|
||||
tool?: string;
|
||||
params?: unknown;
|
||||
reason?: string;
|
||||
};
|
||||
|
||||
export const isObjectRecord = (value: unknown): value is Record<string, unknown> =>
|
||||
typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
|
||||
@@ -196,6 +205,59 @@ export const updateLastAssistantQuestion = (
|
||||
};
|
||||
});
|
||||
|
||||
const getToolArtifactKind = (tool: string): ToolArtifactKind => {
|
||||
if (tool === "show_chart" || tool === "chart") return "chart";
|
||||
if (
|
||||
tool === "locate_features" ||
|
||||
tool === "zoom_to_map" ||
|
||||
tool === "render_junctions" ||
|
||||
tool === "apply_layer_style" ||
|
||||
tool.startsWith("locate_")
|
||||
) {
|
||||
return "map";
|
||||
}
|
||||
if (tool === "view_history" || tool === "view_scada") return "panel";
|
||||
return "tool";
|
||||
};
|
||||
|
||||
const getToolArtifactTitle = (tool: string, params: Record<string, unknown>) => {
|
||||
if (typeof params.title === "string" && params.title.trim()) {
|
||||
return params.title.trim();
|
||||
}
|
||||
if (tool === "show_chart" || tool === "chart") return "生成图表";
|
||||
if (tool === "zoom_to_map") return "缩放到地图坐标";
|
||||
if (tool === "render_junctions") return "渲染节点分区";
|
||||
if (tool === "view_history") return "打开计算结果曲线";
|
||||
if (tool === "view_scada") return "打开 SCADA 数据面板";
|
||||
if (tool === "apply_layer_style") return "应用图层样式";
|
||||
if (tool === "locate_features" || tool.startsWith("locate_")) return "地图定位";
|
||||
return tool || "工具调用";
|
||||
};
|
||||
|
||||
export const appendBackendToolArtifact = (
|
||||
artifacts: unknown,
|
||||
payload: ToolCallPayload,
|
||||
) => {
|
||||
const tool = typeof payload.tool === "string" ? payload.tool.trim() : "";
|
||||
if (!tool) {
|
||||
return artifacts;
|
||||
}
|
||||
const params = isObjectRecord(payload.params) ? payload.params : {};
|
||||
const next = Array.isArray(artifacts) ? [...artifacts] : [];
|
||||
next.push({
|
||||
id: `${tool}-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 8)}`,
|
||||
tool,
|
||||
kind: getToolArtifactKind(tool),
|
||||
title: getToolArtifactTitle(tool, params),
|
||||
description:
|
||||
typeof payload.reason === "string" && payload.reason.trim()
|
||||
? payload.reason.trim()
|
||||
: undefined,
|
||||
params,
|
||||
});
|
||||
return next;
|
||||
};
|
||||
|
||||
export const toFrontendPermission = (
|
||||
payload: PermissionRequestPayload,
|
||||
status: "pending" | "approved_once" | "approved_always" | "rejected" | "error" = "pending",
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { describe, expect, it } from "bun:test";
|
||||
|
||||
import {
|
||||
buildForkedSessionUiState,
|
||||
} from "../../src/routes/chat.js";
|
||||
import {
|
||||
buildPromptWithLearningContext,
|
||||
extractLatestFrontendTurn,
|
||||
@@ -199,3 +202,79 @@ describe("extractLatestFrontendTurn", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildForkedSessionUiState", () => {
|
||||
it("copies truncated source messages and preserves tool artifacts", () => {
|
||||
const forked = buildForkedSessionUiState(
|
||||
{
|
||||
messages: [
|
||||
{ role: "user", content: "画压力曲线" },
|
||||
{
|
||||
role: "assistant",
|
||||
content: "已生成图表",
|
||||
artifacts: [
|
||||
{
|
||||
id: "chart-1",
|
||||
tool: "show_chart",
|
||||
kind: "chart",
|
||||
params: { chart_type: "line" },
|
||||
},
|
||||
],
|
||||
},
|
||||
{ role: "user", content: "继续分析" },
|
||||
],
|
||||
},
|
||||
{
|
||||
keepMessageCount: 2,
|
||||
targetSessionId: "forked-session",
|
||||
},
|
||||
);
|
||||
|
||||
expect(forked).toEqual({
|
||||
sessionId: "forked-session",
|
||||
isTitleManuallyEdited: false,
|
||||
messages: [
|
||||
{ role: "user", content: "画压力曲线" },
|
||||
{
|
||||
role: "assistant",
|
||||
content: "已生成图表",
|
||||
artifacts: [
|
||||
{
|
||||
id: "chart-1",
|
||||
tool: "show_chart",
|
||||
kind: "chart",
|
||||
params: { chart_type: "line" },
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it("creates an empty branch state when source UI state is missing or keep count is zero", () => {
|
||||
expect(
|
||||
buildForkedSessionUiState(null, {
|
||||
keepMessageCount: 3,
|
||||
targetSessionId: "forked-without-source",
|
||||
}),
|
||||
).toEqual({
|
||||
sessionId: "forked-without-source",
|
||||
isTitleManuallyEdited: false,
|
||||
messages: [],
|
||||
});
|
||||
|
||||
expect(
|
||||
buildForkedSessionUiState(
|
||||
{ messages: [{ role: "user", content: "不保留" }] },
|
||||
{
|
||||
keepMessageCount: 0,
|
||||
targetSessionId: "forked-empty",
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
sessionId: "forked-empty",
|
||||
isTitleManuallyEdited: false,
|
||||
messages: [],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,10 +1,45 @@
|
||||
import { describe, expect, it } from "bun:test";
|
||||
|
||||
import {
|
||||
appendBackendToolArtifact,
|
||||
cancelBackendTodos,
|
||||
upsertBackendQuestion,
|
||||
} from "../../src/routes/chatUiState.js";
|
||||
|
||||
describe("appendBackendToolArtifact", () => {
|
||||
it("persists show_chart tool calls as chart artifacts", () => {
|
||||
const artifacts = appendBackendToolArtifact([], {
|
||||
session_id: "session-1",
|
||||
tool: "show_chart",
|
||||
reason: "测试折线图渲染",
|
||||
params: {
|
||||
title: "压力曲线",
|
||||
chart_type: "line",
|
||||
x_data: ["00:00", "01:00"],
|
||||
series: [{ name: "P-101", data: [0.42, 0.41] }],
|
||||
},
|
||||
}) as Array<Record<string, unknown>>;
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
expect(artifacts[0]).toMatchObject({
|
||||
tool: "show_chart",
|
||||
kind: "chart",
|
||||
title: "压力曲线",
|
||||
description: "测试折线图渲染",
|
||||
params: {
|
||||
chart_type: "line",
|
||||
x_data: ["00:00", "01:00"],
|
||||
series: [{ name: "P-101", data: [0.42, 0.41] }],
|
||||
},
|
||||
});
|
||||
expect(artifacts[0]).toEqual(
|
||||
expect.objectContaining({
|
||||
id: expect.stringMatching(/^show_chart-/),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("upsertBackendQuestion", () => {
|
||||
it("replaces a tool-call placeholder with the actionable question request", () => {
|
||||
const questions = upsertBackendQuestion(
|
||||
|
||||
Reference in New Issue
Block a user