diff --git a/src/routes/chat.ts b/src/routes/chat.ts index 6856500..90568ff 100644 --- a/src/routes/chat.ts +++ b/src/routes/chat.ts @@ -129,6 +129,26 @@ const createInitialStreamingMessages = (existingMessages: unknown[], userContent ]; }; +const countFrontendUserMessages = (messages: unknown[]) => + messages.filter( + (message) => isObjectRecord(message) && message.role === "user", + ).length; + +const pruneBranchGroupsForMessageIndex = ( + branchGroups: unknown[], + messageIndex: number | undefined, +) => { + if (messageIndex === undefined) { + return branchGroups; + } + return branchGroups.filter( + (group) => + !isObjectRecord(group) || + typeof group.parentCount !== "number" || + group.parentCount < messageIndex, + ); +}; + const upsertBackendProgress = ( progress: unknown, payload: Record, @@ -952,15 +972,24 @@ export const buildChatRouter = ( projectKey: requestContext.projectKey, sessionId: requestContext.clientSessionId, }; - const recentTurns = await sessionTranscriptStore.getRecentTurns(historyContext, 8); const initialSessionState = await sessionUiStateStore.read( toSessionUiStateContext(activeSessionRecord), ); const persistedMessages = initialSessionState?.messages ?? []; + const isRegenerate = + parsed.data.regenerate_from_message_index !== undefined; const baseMessages = - parsed.data.regenerate_from_message_index !== undefined + isRegenerate ? persistedMessages.slice(0, parsed.data.regenerate_from_message_index) : persistedMessages; + const targetUserOrdinal = isRegenerate + ? countFrontendUserMessages( + persistedMessages.slice( + 0, + (parsed.data.regenerate_from_message_index ?? 0) + 1, + ), + ) + : undefined; if (activeRuns.get(activeSessionRecord.sessionId)?.status === "running") { res.status(409).json({ message: "session is already streaming", @@ -968,6 +997,15 @@ export const buildChatRouter = ( }); return; } + if (isRegenerate) { + await sessionTranscriptStore.truncateThread( + historyContext, + parsed.data.regenerate_from_message_index ?? 0, + ); + } + const recentTurns = isRegenerate + ? [] + : await sessionTranscriptStore.getRecentTurns(historyContext, 8); logger.info( { @@ -976,6 +1014,7 @@ export const buildChatRouter = ( created: created || sessionCreated, model: parsed.data.model, approvalMode: parsed.data.approval_mode, + regenerateFromMessageIndex: parsed.data.regenerate_from_message_index, traceId: requestContext.traceId, projectId: requestContext.projectId, }, @@ -997,7 +1036,10 @@ export const buildChatRouter = ( baseMessages, parsed.data.message, ); - const branchGroups = initialSessionState?.branchGroups ?? []; + const branchGroups = pruneBranchGroupsForMessageIndex( + initialSessionState?.branchGroups ?? [], + parsed.data.regenerate_from_message_index, + ); const activeRun: ActiveRun = { clientSessionId, controller: abortController, @@ -1128,8 +1170,13 @@ export const buildChatRouter = ( }; try { - if (parsed.data.regenerate_from_message_index !== undefined) { - await runtime.revertLastUserMessage(binding.sessionId); + if (isRegenerate) { + if (!targetUserOrdinal || targetUserOrdinal < 1) { + throw new Error("target user message not found for regeneration"); + } + await runtime.revertToUserMessage(binding.sessionId, { + userOrdinal: targetUserOrdinal, + }); } const preparedMessage = await buildPromptWithLearningContext( diff --git a/src/runtime/opencode.ts b/src/runtime/opencode.ts index 7da7e3a..15ae31c 100644 --- a/src/runtime/opencode.ts +++ b/src/runtime/opencode.ts @@ -33,6 +33,16 @@ type RuntimeModelOverride = { export type PermissionReply = "once" | "always" | "reject"; +type RuntimeMessage = { + info: { + id: string; + role: string; + }; +}; + +const getRuntimeMessageRole = (message: RuntimeMessage) => message.info.role; +const getRuntimeMessageId = (message: RuntimeMessage) => message.info.id; + export class OpencodeRuntimeAdapter { private clientPromise: Promise | null = null; private closeServer: (() => void) | null = null; @@ -108,17 +118,45 @@ export class OpencodeRuntimeAdapter { return response.data; } - async revertLastUserMessage(sessionId: string) { - const messages = await this.messages(sessionId, 40); - const lastUserMessage = [...messages] - .reverse() - .find((message) => message.info.role === "user"); + async removeMessage(sessionId: string, messageId: string) { + const client = await this.ensureClient(); + const response = await client.session.deleteMessage({ + sessionID: sessionId, + messageID: messageId, + }); + return response.data; + } - if (!lastUserMessage) { - throw new Error("no user message found to revert"); + async revertToUserMessage(sessionId: string, options: { userOrdinal: number }) { + const messages = await this.messages(sessionId, 80); + const userMessages = messages.filter( + (message) => getRuntimeMessageRole(message) === "user", + ); + const targetUserMessage = userMessages[options.userOrdinal - 1]; + + if (!targetUserMessage) { + throw new Error("target user message not found to revert"); } - return this.revertMessage(sessionId, lastUserMessage.info.id); + const targetMessageId = getRuntimeMessageId(targetUserMessage); + const targetIndex = messages.findIndex( + (message) => getRuntimeMessageId(message) === targetMessageId, + ); + const messagesToRemove = targetIndex >= 0 ? messages.slice(targetIndex) : [targetUserMessage]; + + await this.revertMessage(sessionId, targetMessageId); + + for (const message of messagesToRemove.reverse()) { + const messageId = getRuntimeMessageId(message); + try { + await this.removeMessage(sessionId, messageId); + } catch (error) { + logger.warn( + { err: error, sessionId, messageId }, + "failed to remove reverted opencode message", + ); + } + } } async abortSession(sessionId: string) { diff --git a/src/sessions/transcriptStore.ts b/src/sessions/transcriptStore.ts index bbe3f14..09e2237 100644 --- a/src/sessions/transcriptStore.ts +++ b/src/sessions/transcriptStore.ts @@ -147,6 +147,29 @@ export class SessionTranscriptStore { return nextTranscript; } + async truncateThread( + context: SessionTranscriptContext, + keepMessageCount: number, + ) { + const key = this.filePath(context); + return this.serializeWrite(key, async () => { + const transcript = await this.readTranscript(context); + if (!transcript) { + return null; + } + + const nextTranscript: SessionTranscriptRecord = { + ...transcript, + clientSessionId: context.clientSessionId ?? transcript.clientSessionId, + sessionId: context.sessionId, + turns: projectTurnsForFork(transcript.turns, keepMessageCount), + updatedAt: new Date().toISOString(), + }; + await atomicWriteJson(key, nextTranscript); + return nextTranscript; + }); + } + async search( context: Pick, query: string,