fix(chat): restore forked context
This commit is contained in:
+23
-41
@@ -19,6 +19,7 @@ import {
|
||||
extractLatestFrontendTurn,
|
||||
generateSessionTitle,
|
||||
shouldGenerateSessionTitle,
|
||||
shouldRestoreConversationForRuntime,
|
||||
} from "./chatSession.js";
|
||||
import { registerChatAuxiliaryRoutes } from "./chatAuxiliaryRoutes.js";
|
||||
import { registerChatInteractionRoutes } from "./chatInteractionRoutes.js";
|
||||
@@ -37,10 +38,8 @@ import {
|
||||
type StreamSubscriber,
|
||||
cancelBackendTodos,
|
||||
completeBackendProgress,
|
||||
countFrontendUserMessages,
|
||||
createInitialStreamingMessages,
|
||||
isObjectRecord,
|
||||
pruneBranchGroupsForMessageIndex,
|
||||
toFrontendPermission,
|
||||
toPermissionStatus,
|
||||
updateLastAssistantMessage,
|
||||
@@ -56,7 +55,6 @@ const payloadSchema = z.object({
|
||||
session_id: z.string().max(128).optional(),
|
||||
model: z.enum(supportedModels).optional(),
|
||||
approval_mode: z.enum(["request", "always"]).optional().default("request"),
|
||||
regenerate_from_message_index: z.coerce.number().int().min(0).optional(),
|
||||
});
|
||||
|
||||
const createSessionPayloadSchema = z.object({
|
||||
@@ -86,6 +84,17 @@ const toSessionUiStateContext = (sessionRecord: SessionRecord) => ({
|
||||
const getSessionRunStatus = (sessionId: string) =>
|
||||
activeRuns.get(sessionId)?.status ?? lastRunStatuses.get(sessionId);
|
||||
|
||||
const runtimeHasConversation = async (
|
||||
runtime: OpencodeRuntimeAdapter,
|
||||
sessionId: string,
|
||||
) => {
|
||||
const messages = await runtime.messages(sessionId, 1);
|
||||
return messages.some(
|
||||
(message) =>
|
||||
message.info.role === "user" || message.info.role === "assistant",
|
||||
);
|
||||
};
|
||||
|
||||
export const buildChatRouter = (
|
||||
sessionBridge: ChatSessionBridge,
|
||||
runtime: OpencodeRuntimeAdapter,
|
||||
@@ -555,6 +564,13 @@ export const buildChatRouter = (
|
||||
userId,
|
||||
});
|
||||
const activeSessionRecord = await sessionMetadataStore.touch(ensuredSessionRecord);
|
||||
const hasRuntimeConversation = hadExistingRuntimeSession
|
||||
? await runtimeHasConversation(runtime, binding.sessionId)
|
||||
: false;
|
||||
const shouldRestoreConversation = shouldRestoreConversationForRuntime({
|
||||
hadExistingSessionRecord: hadExistingRuntimeSession,
|
||||
runtimeHasConversation: hasRuntimeConversation,
|
||||
});
|
||||
const historyContext = {
|
||||
actorKey: requestContext.actorKey,
|
||||
clientSessionId: requestContext.clientSessionId,
|
||||
@@ -565,20 +581,7 @@ export const buildChatRouter = (
|
||||
toSessionUiStateContext(activeSessionRecord),
|
||||
);
|
||||
const persistedMessages = initialSessionState?.messages ?? [];
|
||||
const isRegenerate =
|
||||
parsed.data.regenerate_from_message_index !== undefined;
|
||||
const baseMessages =
|
||||
isRegenerate
|
||||
? persistedMessages.slice(0, parsed.data.regenerate_from_message_index)
|
||||
: persistedMessages;
|
||||
const targetUserOrdinal = isRegenerate
|
||||
? countFrontendUserMessages(
|
||||
persistedMessages.slice(
|
||||
0,
|
||||
(parsed.data.regenerate_from_message_index ?? 0) + 1,
|
||||
),
|
||||
)
|
||||
: undefined;
|
||||
const baseMessages = persistedMessages;
|
||||
if (activeRuns.get(activeSessionRecord.sessionId)?.status === "running") {
|
||||
res.status(409).json({
|
||||
message: "session is already streaming",
|
||||
@@ -586,15 +589,7 @@ export const buildChatRouter = (
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (isRegenerate) {
|
||||
await sessionTranscriptStore.truncateThread(
|
||||
historyContext,
|
||||
parsed.data.regenerate_from_message_index ?? 0,
|
||||
);
|
||||
}
|
||||
const recentTurns = isRegenerate
|
||||
? []
|
||||
: await sessionTranscriptStore.getRecentTurns(historyContext, 8);
|
||||
const recentTurns = await sessionTranscriptStore.getRecentTurns(historyContext, 8);
|
||||
|
||||
logger.info(
|
||||
{
|
||||
@@ -603,7 +598,6 @@ export const buildChatRouter = (
|
||||
created: created || sessionCreated,
|
||||
model: parsed.data.model,
|
||||
approvalMode: parsed.data.approval_mode,
|
||||
regenerateFromMessageIndex: parsed.data.regenerate_from_message_index,
|
||||
traceId: requestContext.traceId,
|
||||
projectId: requestContext.projectId,
|
||||
},
|
||||
@@ -625,10 +619,7 @@ export const buildChatRouter = (
|
||||
baseMessages,
|
||||
parsed.data.message,
|
||||
);
|
||||
const branchGroups = pruneBranchGroupsForMessageIndex(
|
||||
initialSessionState?.branchGroups ?? [],
|
||||
parsed.data.regenerate_from_message_index,
|
||||
);
|
||||
const branchGroups = initialSessionState?.branchGroups ?? [];
|
||||
const activeRun: ActiveRun = {
|
||||
clientSessionId,
|
||||
controller: abortController,
|
||||
@@ -815,15 +806,6 @@ export const buildChatRouter = (
|
||||
};
|
||||
|
||||
try {
|
||||
if (isRegenerate) {
|
||||
if (!targetUserOrdinal || targetUserOrdinal < 1) {
|
||||
throw new Error("target user message not found for regeneration");
|
||||
}
|
||||
await runtime.revertToUserMessage(binding.sessionId, {
|
||||
userOrdinal: targetUserOrdinal,
|
||||
});
|
||||
}
|
||||
|
||||
const preparedMessage = await buildPromptWithLearningContext(
|
||||
memoryStore,
|
||||
requestContext.actorKey,
|
||||
@@ -832,7 +814,7 @@ export const buildChatRouter = (
|
||||
recentTurns,
|
||||
persistedMessages: baseMessages,
|
||||
message: parsed.data.message,
|
||||
restoreConversation: !hadExistingRuntimeSession,
|
||||
restoreConversation: shouldRestoreConversation,
|
||||
},
|
||||
);
|
||||
const streamResult = await streamPromptResponse({
|
||||
|
||||
@@ -212,6 +212,11 @@ export const buildPromptWithLearningContext = async (
|
||||
.join("\n\n");
|
||||
};
|
||||
|
||||
export const shouldRestoreConversationForRuntime = (options: {
|
||||
hadExistingSessionRecord: boolean;
|
||||
runtimeHasConversation: boolean;
|
||||
}) => !options.hadExistingSessionRecord || !options.runtimeHasConversation;
|
||||
|
||||
const buildRestoredConversationContext = (recentTurns: SessionTurnRecord[]) => {
|
||||
const formattedTurns = recentTurns
|
||||
.slice(-RESTORE_TURN_LIMIT)
|
||||
|
||||
@@ -63,26 +63,6 @@ export const createInitialStreamingMessages = (
|
||||
];
|
||||
};
|
||||
|
||||
export const countFrontendUserMessages = (messages: unknown[]) =>
|
||||
messages.filter(
|
||||
(message) => isObjectRecord(message) && message.role === "user",
|
||||
).length;
|
||||
|
||||
export const pruneBranchGroupsForMessageIndex = (
|
||||
branchGroups: unknown[],
|
||||
messageIndex: number | undefined,
|
||||
) => {
|
||||
if (messageIndex === undefined) {
|
||||
return branchGroups;
|
||||
}
|
||||
return branchGroups.filter(
|
||||
(group) =>
|
||||
!isObjectRecord(group) ||
|
||||
typeof group.parentCount !== "number" ||
|
||||
group.parentCount < messageIndex,
|
||||
);
|
||||
};
|
||||
|
||||
export const upsertBackendProgress = (
|
||||
progress: unknown,
|
||||
payload: Record<string, unknown>,
|
||||
|
||||
Reference in New Issue
Block a user