From 801f611ce57634b9e8af9acbb06fbe300fcf6710 Mon Sep 17 00:00:00 2001 From: Huarch Date: Mon, 8 Jun 2026 19:33:13 +0800 Subject: [PATCH] fix(chat): restore forked context --- src/routes/chat.ts | 64 ++++++++++++-------------------- src/routes/chatSession.ts | 5 +++ src/routes/chatUiState.ts | 20 ---------- tests/routes/chatSession.test.ts | 16 ++++++++ 4 files changed, 44 insertions(+), 61 deletions(-) diff --git a/src/routes/chat.ts b/src/routes/chat.ts index abb6032..f978b10 100644 --- a/src/routes/chat.ts +++ b/src/routes/chat.ts @@ -19,6 +19,7 @@ import { extractLatestFrontendTurn, generateSessionTitle, shouldGenerateSessionTitle, + shouldRestoreConversationForRuntime, } from "./chatSession.js"; import { registerChatAuxiliaryRoutes } from "./chatAuxiliaryRoutes.js"; import { registerChatInteractionRoutes } from "./chatInteractionRoutes.js"; @@ -37,10 +38,8 @@ import { type StreamSubscriber, cancelBackendTodos, completeBackendProgress, - countFrontendUserMessages, createInitialStreamingMessages, isObjectRecord, - pruneBranchGroupsForMessageIndex, toFrontendPermission, toPermissionStatus, updateLastAssistantMessage, @@ -56,7 +55,6 @@ const payloadSchema = z.object({ session_id: z.string().max(128).optional(), model: z.enum(supportedModels).optional(), approval_mode: z.enum(["request", "always"]).optional().default("request"), - regenerate_from_message_index: z.coerce.number().int().min(0).optional(), }); const createSessionPayloadSchema = z.object({ @@ -86,6 +84,17 @@ const toSessionUiStateContext = (sessionRecord: SessionRecord) => ({ const getSessionRunStatus = (sessionId: string) => activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId); +const runtimeHasConversation = async ( + runtime: OpencodeRuntimeAdapter, + sessionId: string, +) => { + const messages = await runtime.messages(sessionId, 1); + return messages.some( + (message) => + message.info.role === "user" || message.info.role === "assistant", + ); +}; + export const buildChatRouter = ( sessionBridge: ChatSessionBridge, runtime: OpencodeRuntimeAdapter, @@ -555,6 +564,13 @@ export const buildChatRouter = ( userId, }); const activeSessionRecord = await sessionMetadataStore.touch(ensuredSessionRecord); + const hasRuntimeConversation = hadExistingRuntimeSession + ? await runtimeHasConversation(runtime, binding.sessionId) + : false; + const shouldRestoreConversation = shouldRestoreConversationForRuntime({ + hadExistingSessionRecord: hadExistingRuntimeSession, + runtimeHasConversation: hasRuntimeConversation, + }); const historyContext = { actorKey: requestContext.actorKey, clientSessionId: requestContext.clientSessionId, @@ -565,20 +581,7 @@ export const buildChatRouter = ( toSessionUiStateContext(activeSessionRecord), ); const persistedMessages = initialSessionState?.messages ?? []; - const isRegenerate = - parsed.data.regenerate_from_message_index !== undefined; - const baseMessages = - isRegenerate - ? persistedMessages.slice(0, parsed.data.regenerate_from_message_index) - : persistedMessages; - const targetUserOrdinal = isRegenerate - ? countFrontendUserMessages( - persistedMessages.slice( - 0, - (parsed.data.regenerate_from_message_index ?? 0) + 1, - ), - ) - : undefined; + const baseMessages = persistedMessages; if (activeRuns.get(activeSessionRecord.sessionId)?.status === "running") { res.status(409).json({ message: "session is already streaming", @@ -586,15 +589,7 @@ export const buildChatRouter = ( }); return; } - if (isRegenerate) { - await sessionTranscriptStore.truncateThread( - historyContext, - parsed.data.regenerate_from_message_index ?? 0, - ); - } - const recentTurns = isRegenerate - ? [] - : await sessionTranscriptStore.getRecentTurns(historyContext, 8); + const recentTurns = await sessionTranscriptStore.getRecentTurns(historyContext, 8); logger.info( { @@ -603,7 +598,6 @@ export const buildChatRouter = ( created: created || sessionCreated, model: parsed.data.model, approvalMode: parsed.data.approval_mode, - regenerateFromMessageIndex: parsed.data.regenerate_from_message_index, traceId: requestContext.traceId, projectId: requestContext.projectId, }, @@ -625,10 +619,7 @@ export const buildChatRouter = ( baseMessages, parsed.data.message, ); - const branchGroups = pruneBranchGroupsForMessageIndex( - initialSessionState?.branchGroups ?? [], - parsed.data.regenerate_from_message_index, - ); + const branchGroups = initialSessionState?.branchGroups ?? []; const activeRun: ActiveRun = { clientSessionId, controller: abortController, @@ -815,15 +806,6 @@ export const buildChatRouter = ( }; try { - if (isRegenerate) { - if (!targetUserOrdinal || targetUserOrdinal < 1) { - throw new Error("target user message not found for regeneration"); - } - await runtime.revertToUserMessage(binding.sessionId, { - userOrdinal: targetUserOrdinal, - }); - } - const preparedMessage = await buildPromptWithLearningContext( memoryStore, requestContext.actorKey, @@ -832,7 +814,7 @@ export const buildChatRouter = ( recentTurns, persistedMessages: baseMessages, message: parsed.data.message, - restoreConversation: !hadExistingRuntimeSession, + restoreConversation: shouldRestoreConversation, }, ); const streamResult = await streamPromptResponse({ diff --git a/src/routes/chatSession.ts b/src/routes/chatSession.ts index 12eecb8..d6f0aa6 100644 --- a/src/routes/chatSession.ts +++ b/src/routes/chatSession.ts @@ -212,6 +212,11 @@ export const buildPromptWithLearningContext = async ( .join("\n\n"); }; +export const shouldRestoreConversationForRuntime = (options: { + hadExistingSessionRecord: boolean; + runtimeHasConversation: boolean; +}) => !options.hadExistingSessionRecord || !options.runtimeHasConversation; + const buildRestoredConversationContext = (recentTurns: SessionTurnRecord[]) => { const formattedTurns = recentTurns .slice(-RESTORE_TURN_LIMIT) diff --git a/src/routes/chatUiState.ts b/src/routes/chatUiState.ts index d8e7fe3..38218d9 100644 --- a/src/routes/chatUiState.ts +++ b/src/routes/chatUiState.ts @@ -63,26 +63,6 @@ export const createInitialStreamingMessages = ( ]; }; -export const countFrontendUserMessages = (messages: unknown[]) => - messages.filter( - (message) => isObjectRecord(message) && message.role === "user", - ).length; - -export const pruneBranchGroupsForMessageIndex = ( - branchGroups: unknown[], - messageIndex: number | undefined, -) => { - if (messageIndex === undefined) { - return branchGroups; - } - return branchGroups.filter( - (group) => - !isObjectRecord(group) || - typeof group.parentCount !== "number" || - group.parentCount < messageIndex, - ); -}; - export const upsertBackendProgress = ( progress: unknown, payload: Record, diff --git a/tests/routes/chatSession.test.ts b/tests/routes/chatSession.test.ts index fecdf1c..c8ae543 100644 --- a/tests/routes/chatSession.test.ts +++ b/tests/routes/chatSession.test.ts @@ -4,6 +4,7 @@ import { buildPromptWithLearningContext, extractLatestFrontendTurn, generateSessionTitle, + shouldRestoreConversationForRuntime, shouldGenerateSessionTitle, } from "../../src/routes/chatSession.js"; import { type SessionTurnRecord } from "../../src/sessions/transcriptStore.js"; @@ -161,6 +162,21 @@ describe("buildPromptWithLearningContext", () => { expect(prompt).not.toContain("[Previous conversation context]"); expect(prompt).toBe("基于刚才结果继续分析"); }); + + it("restores copied fork context when metadata exists but runtime has no conversation", () => { + expect( + shouldRestoreConversationForRuntime({ + hadExistingSessionRecord: true, + runtimeHasConversation: false, + }), + ).toBe(true); + expect( + shouldRestoreConversationForRuntime({ + hadExistingSessionRecord: true, + runtimeHasConversation: true, + }), + ).toBe(false); + }); }); describe("extractLatestFrontendTurn", () => {