fix: regenerate from target turn
This commit is contained in:
+52
-5
@@ -129,6 +129,26 @@ const createInitialStreamingMessages = (existingMessages: unknown[], userContent
|
|||||||
];
|
];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const countFrontendUserMessages = (messages: unknown[]) =>
|
||||||
|
messages.filter(
|
||||||
|
(message) => isObjectRecord(message) && message.role === "user",
|
||||||
|
).length;
|
||||||
|
|
||||||
|
const pruneBranchGroupsForMessageIndex = (
|
||||||
|
branchGroups: unknown[],
|
||||||
|
messageIndex: number | undefined,
|
||||||
|
) => {
|
||||||
|
if (messageIndex === undefined) {
|
||||||
|
return branchGroups;
|
||||||
|
}
|
||||||
|
return branchGroups.filter(
|
||||||
|
(group) =>
|
||||||
|
!isObjectRecord(group) ||
|
||||||
|
typeof group.parentCount !== "number" ||
|
||||||
|
group.parentCount < messageIndex,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
const upsertBackendProgress = (
|
const upsertBackendProgress = (
|
||||||
progress: unknown,
|
progress: unknown,
|
||||||
payload: Record<string, unknown>,
|
payload: Record<string, unknown>,
|
||||||
@@ -952,15 +972,24 @@ export const buildChatRouter = (
|
|||||||
projectKey: requestContext.projectKey,
|
projectKey: requestContext.projectKey,
|
||||||
sessionId: requestContext.clientSessionId,
|
sessionId: requestContext.clientSessionId,
|
||||||
};
|
};
|
||||||
const recentTurns = await sessionTranscriptStore.getRecentTurns(historyContext, 8);
|
|
||||||
const initialSessionState = await sessionUiStateStore.read(
|
const initialSessionState = await sessionUiStateStore.read(
|
||||||
toSessionUiStateContext(activeSessionRecord),
|
toSessionUiStateContext(activeSessionRecord),
|
||||||
);
|
);
|
||||||
const persistedMessages = initialSessionState?.messages ?? [];
|
const persistedMessages = initialSessionState?.messages ?? [];
|
||||||
|
const isRegenerate =
|
||||||
|
parsed.data.regenerate_from_message_index !== undefined;
|
||||||
const baseMessages =
|
const baseMessages =
|
||||||
parsed.data.regenerate_from_message_index !== undefined
|
isRegenerate
|
||||||
? persistedMessages.slice(0, parsed.data.regenerate_from_message_index)
|
? persistedMessages.slice(0, parsed.data.regenerate_from_message_index)
|
||||||
: persistedMessages;
|
: persistedMessages;
|
||||||
|
const targetUserOrdinal = isRegenerate
|
||||||
|
? countFrontendUserMessages(
|
||||||
|
persistedMessages.slice(
|
||||||
|
0,
|
||||||
|
(parsed.data.regenerate_from_message_index ?? 0) + 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
: undefined;
|
||||||
if (activeRuns.get(activeSessionRecord.sessionId)?.status === "running") {
|
if (activeRuns.get(activeSessionRecord.sessionId)?.status === "running") {
|
||||||
res.status(409).json({
|
res.status(409).json({
|
||||||
message: "session is already streaming",
|
message: "session is already streaming",
|
||||||
@@ -968,6 +997,15 @@ export const buildChatRouter = (
|
|||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (isRegenerate) {
|
||||||
|
await sessionTranscriptStore.truncateThread(
|
||||||
|
historyContext,
|
||||||
|
parsed.data.regenerate_from_message_index ?? 0,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const recentTurns = isRegenerate
|
||||||
|
? []
|
||||||
|
: await sessionTranscriptStore.getRecentTurns(historyContext, 8);
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
{
|
{
|
||||||
@@ -976,6 +1014,7 @@ export const buildChatRouter = (
|
|||||||
created: created || sessionCreated,
|
created: created || sessionCreated,
|
||||||
model: parsed.data.model,
|
model: parsed.data.model,
|
||||||
approvalMode: parsed.data.approval_mode,
|
approvalMode: parsed.data.approval_mode,
|
||||||
|
regenerateFromMessageIndex: parsed.data.regenerate_from_message_index,
|
||||||
traceId: requestContext.traceId,
|
traceId: requestContext.traceId,
|
||||||
projectId: requestContext.projectId,
|
projectId: requestContext.projectId,
|
||||||
},
|
},
|
||||||
@@ -997,7 +1036,10 @@ export const buildChatRouter = (
|
|||||||
baseMessages,
|
baseMessages,
|
||||||
parsed.data.message,
|
parsed.data.message,
|
||||||
);
|
);
|
||||||
const branchGroups = initialSessionState?.branchGroups ?? [];
|
const branchGroups = pruneBranchGroupsForMessageIndex(
|
||||||
|
initialSessionState?.branchGroups ?? [],
|
||||||
|
parsed.data.regenerate_from_message_index,
|
||||||
|
);
|
||||||
const activeRun: ActiveRun = {
|
const activeRun: ActiveRun = {
|
||||||
clientSessionId,
|
clientSessionId,
|
||||||
controller: abortController,
|
controller: abortController,
|
||||||
@@ -1128,8 +1170,13 @@ export const buildChatRouter = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (parsed.data.regenerate_from_message_index !== undefined) {
|
if (isRegenerate) {
|
||||||
await runtime.revertLastUserMessage(binding.sessionId);
|
if (!targetUserOrdinal || targetUserOrdinal < 1) {
|
||||||
|
throw new Error("target user message not found for regeneration");
|
||||||
|
}
|
||||||
|
await runtime.revertToUserMessage(binding.sessionId, {
|
||||||
|
userOrdinal: targetUserOrdinal,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const preparedMessage = await buildPromptWithLearningContext(
|
const preparedMessage = await buildPromptWithLearningContext(
|
||||||
|
|||||||
+46
-8
@@ -33,6 +33,16 @@ type RuntimeModelOverride = {
|
|||||||
|
|
||||||
export type PermissionReply = "once" | "always" | "reject";
|
export type PermissionReply = "once" | "always" | "reject";
|
||||||
|
|
||||||
|
type RuntimeMessage = {
|
||||||
|
info: {
|
||||||
|
id: string;
|
||||||
|
role: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const getRuntimeMessageRole = (message: RuntimeMessage) => message.info.role;
|
||||||
|
const getRuntimeMessageId = (message: RuntimeMessage) => message.info.id;
|
||||||
|
|
||||||
export class OpencodeRuntimeAdapter {
|
export class OpencodeRuntimeAdapter {
|
||||||
private clientPromise: Promise<OpencodeClient> | null = null;
|
private clientPromise: Promise<OpencodeClient> | null = null;
|
||||||
private closeServer: (() => void) | null = null;
|
private closeServer: (() => void) | null = null;
|
||||||
@@ -108,17 +118,45 @@ export class OpencodeRuntimeAdapter {
|
|||||||
return response.data;
|
return response.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
async revertLastUserMessage(sessionId: string) {
|
async removeMessage(sessionId: string, messageId: string) {
|
||||||
const messages = await this.messages(sessionId, 40);
|
const client = await this.ensureClient();
|
||||||
const lastUserMessage = [...messages]
|
const response = await client.session.deleteMessage({
|
||||||
.reverse()
|
sessionID: sessionId,
|
||||||
.find((message) => message.info.role === "user");
|
messageID: messageId,
|
||||||
|
});
|
||||||
|
return response.data;
|
||||||
|
}
|
||||||
|
|
||||||
if (!lastUserMessage) {
|
async revertToUserMessage(sessionId: string, options: { userOrdinal: number }) {
|
||||||
throw new Error("no user message found to revert");
|
const messages = await this.messages(sessionId, 80);
|
||||||
|
const userMessages = messages.filter(
|
||||||
|
(message) => getRuntimeMessageRole(message) === "user",
|
||||||
|
);
|
||||||
|
const targetUserMessage = userMessages[options.userOrdinal - 1];
|
||||||
|
|
||||||
|
if (!targetUserMessage) {
|
||||||
|
throw new Error("target user message not found to revert");
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.revertMessage(sessionId, lastUserMessage.info.id);
|
const targetMessageId = getRuntimeMessageId(targetUserMessage);
|
||||||
|
const targetIndex = messages.findIndex(
|
||||||
|
(message) => getRuntimeMessageId(message) === targetMessageId,
|
||||||
|
);
|
||||||
|
const messagesToRemove = targetIndex >= 0 ? messages.slice(targetIndex) : [targetUserMessage];
|
||||||
|
|
||||||
|
await this.revertMessage(sessionId, targetMessageId);
|
||||||
|
|
||||||
|
for (const message of messagesToRemove.reverse()) {
|
||||||
|
const messageId = getRuntimeMessageId(message);
|
||||||
|
try {
|
||||||
|
await this.removeMessage(sessionId, messageId);
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(
|
||||||
|
{ err: error, sessionId, messageId },
|
||||||
|
"failed to remove reverted opencode message",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async abortSession(sessionId: string) {
|
async abortSession(sessionId: string) {
|
||||||
|
|||||||
@@ -147,6 +147,29 @@ export class SessionTranscriptStore {
|
|||||||
return nextTranscript;
|
return nextTranscript;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async truncateThread(
|
||||||
|
context: SessionTranscriptContext,
|
||||||
|
keepMessageCount: number,
|
||||||
|
) {
|
||||||
|
const key = this.filePath(context);
|
||||||
|
return this.serializeWrite(key, async () => {
|
||||||
|
const transcript = await this.readTranscript(context);
|
||||||
|
if (!transcript) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nextTranscript: SessionTranscriptRecord = {
|
||||||
|
...transcript,
|
||||||
|
clientSessionId: context.clientSessionId ?? transcript.clientSessionId,
|
||||||
|
sessionId: context.sessionId,
|
||||||
|
turns: projectTurnsForFork(transcript.turns, keepMessageCount),
|
||||||
|
updatedAt: new Date().toISOString(),
|
||||||
|
};
|
||||||
|
await atomicWriteJson(key, nextTranscript);
|
||||||
|
return nextTranscript;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
async search(
|
async search(
|
||||||
context: Pick<SessionTranscriptContext, "actorKey" | "projectKey">,
|
context: Pick<SessionTranscriptContext, "actorKey" | "projectKey">,
|
||||||
query: string,
|
query: string,
|
||||||
|
|||||||
Reference in New Issue
Block a user