cb298f2099
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
359 lines
11 KiB
TypeScript
359 lines
11 KiB
TypeScript
import { Router } from "express";
|
|
import { z } from "zod";
|
|
|
|
import { type LearningOrchestrator } from "../learning/orchestrator.js";
|
|
import { logger } from "../logger.js";
|
|
import { MemoryStore } from "../memory/store.js";
|
|
import { type ResultReferenceResolver } from "../results/resolver.js";
|
|
import { RESULT_REFERENCE_KIND } from "../results/store.js";
|
|
import { type OpencodeRuntimeAdapter } from "../runtime/opencode.js";
|
|
import { type ChatSessionBridge } from "../chat/sessionBridge.js";
|
|
import { toActorKey } from "../utils/fileStore.js";
|
|
import {
|
|
buildPromptWithLearningContext,
|
|
generateSessionTitle,
|
|
getConversationTurnStats,
|
|
} from "./chatSession.js";
|
|
import {
|
|
collectTextContent,
|
|
streamPromptResponse,
|
|
supportedModels,
|
|
type SupportedModel,
|
|
} from "./chatStream.js";
|
|
|
|
const payloadSchema = z.object({
|
|
message: z.string().min(1).max(10000),
|
|
session_id: z.string().max(128).optional(),
|
|
model: z.enum(supportedModels).optional(),
|
|
});
|
|
|
|
const abortPayloadSchema = z.object({
|
|
session_id: z.string().max(128),
|
|
});
|
|
|
|
const forkPayloadSchema = z.object({
|
|
session_id: z.string().max(128).optional(),
|
|
keep_message_count: z.coerce.number().int().min(0),
|
|
});
|
|
|
|
export const buildChatRouter = (
|
|
sessionBridge: ChatSessionBridge,
|
|
runtime: OpencodeRuntimeAdapter,
|
|
memoryStore: MemoryStore,
|
|
learningOrchestrator: LearningOrchestrator,
|
|
resultReferenceResolver: ResultReferenceResolver,
|
|
) => {
|
|
const chatRouter = Router();
|
|
|
|
chatRouter.get("/render-ref/:renderRef", async (req, res) => {
|
|
const renderRef = req.params.renderRef?.trim();
|
|
const userId = req.header("x-user-id")?.trim();
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const clientSessionId =
|
|
typeof req.query.session_id === "string"
|
|
? req.query.session_id.trim()
|
|
: undefined;
|
|
|
|
if (!userId) {
|
|
res.status(400).json({
|
|
message: "x-user-id is required",
|
|
});
|
|
return;
|
|
}
|
|
|
|
if (!renderRef) {
|
|
res.status(400).json({
|
|
message: "render_ref is required",
|
|
});
|
|
return;
|
|
}
|
|
|
|
const result = await resultReferenceResolver.getFullAuthorized(
|
|
renderRef,
|
|
{
|
|
actorKey: toActorKey(userId),
|
|
clientSessionId,
|
|
projectId,
|
|
},
|
|
{
|
|
expectedKind: RESULT_REFERENCE_KIND.renderJunctionsPayload,
|
|
},
|
|
);
|
|
|
|
if (!result) {
|
|
res.status(404).json({ message: "render_ref not found" });
|
|
return;
|
|
}
|
|
|
|
res.json(result);
|
|
});
|
|
|
|
chatRouter.post("/abort", async (req, res) => {
|
|
const parsed = abortPayloadSchema.safeParse(req.body);
|
|
if (!parsed.success) {
|
|
res.status(400).json({
|
|
message: "invalid request payload",
|
|
detail: parsed.error.flatten(),
|
|
});
|
|
return;
|
|
}
|
|
|
|
try {
|
|
const authHeader = req.header("authorization");
|
|
const accessToken = authHeader?.startsWith("Bearer ")
|
|
? authHeader.slice("Bearer ".length)
|
|
: authHeader;
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const traceId = req.header("x-trace-id") ?? undefined;
|
|
const userId = req.header("x-user-id") ?? undefined;
|
|
|
|
const binding = await sessionBridge.abort({
|
|
clientSessionId: parsed.data.session_id,
|
|
accessToken,
|
|
projectId,
|
|
traceId,
|
|
userId,
|
|
});
|
|
|
|
if (!binding) {
|
|
res.status(204).end();
|
|
return;
|
|
}
|
|
|
|
logger.info(
|
|
{
|
|
clientSessionId: parsed.data.session_id,
|
|
sessionId: binding.sessionId,
|
|
traceId,
|
|
projectId,
|
|
},
|
|
"aborted chat session by client request",
|
|
);
|
|
res.status(202).json({
|
|
session_id: parsed.data.session_id,
|
|
aborted: true,
|
|
});
|
|
} catch (error) {
|
|
const detail = error instanceof Error ? error.message : String(error);
|
|
logger.error({ err: error }, "chat abort failed");
|
|
res.status(500).json({
|
|
message: "chat abort failed",
|
|
detail,
|
|
});
|
|
}
|
|
});
|
|
|
|
chatRouter.post("/fork", async (req, res) => {
|
|
const parsed = forkPayloadSchema.safeParse(req.body);
|
|
if (!parsed.success) {
|
|
res.status(400).json({
|
|
message: "invalid request payload",
|
|
detail: parsed.error.flatten(),
|
|
});
|
|
return;
|
|
}
|
|
|
|
try {
|
|
const authHeader = req.header("authorization");
|
|
const accessToken = authHeader?.startsWith("Bearer ")
|
|
? authHeader.slice("Bearer ".length)
|
|
: authHeader;
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const traceId = req.header("x-trace-id") ?? undefined;
|
|
const userId = req.header("x-user-id") ?? undefined;
|
|
|
|
const { binding, requestContext } = await sessionBridge.fork({
|
|
clientSessionId: parsed.data.session_id,
|
|
accessToken,
|
|
projectId,
|
|
traceId,
|
|
keepMessageCount: parsed.data.keep_message_count,
|
|
userId,
|
|
});
|
|
|
|
logger.info(
|
|
{
|
|
sourceClientSessionId: parsed.data.session_id,
|
|
clientSessionId: requestContext.clientSessionId,
|
|
sessionId: binding.sessionId,
|
|
traceId: requestContext.traceId,
|
|
projectId: requestContext.projectId,
|
|
keepMessageCount: parsed.data.keep_message_count,
|
|
},
|
|
"forked chat session",
|
|
);
|
|
|
|
res.status(200).json({
|
|
session_id: requestContext.clientSessionId,
|
|
});
|
|
} catch (error) {
|
|
const detail = error instanceof Error ? error.message : String(error);
|
|
logger.error({ err: error }, "chat fork failed");
|
|
res.status(500).json({
|
|
message: "chat fork failed",
|
|
detail,
|
|
});
|
|
}
|
|
});
|
|
|
|
chatRouter.post("/stream", async (req, res) => {
|
|
const parsed = payloadSchema.safeParse(req.body);
|
|
if (!parsed.success) {
|
|
res.status(400).json({
|
|
message: "invalid request payload",
|
|
detail: parsed.error.flatten(),
|
|
});
|
|
return;
|
|
}
|
|
|
|
try {
|
|
const authHeader = req.header("authorization");
|
|
const accessToken = authHeader?.startsWith("Bearer ")
|
|
? authHeader.slice("Bearer ".length)
|
|
: authHeader;
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const traceId = req.header("x-trace-id") ?? undefined;
|
|
const userId = req.header("x-user-id") ?? undefined;
|
|
|
|
const { binding, requestContext, created } = await sessionBridge.resolve({
|
|
clientSessionId: parsed.data.session_id,
|
|
accessToken,
|
|
projectId,
|
|
traceId,
|
|
userId,
|
|
});
|
|
|
|
logger.info(
|
|
{
|
|
clientSessionId: requestContext.clientSessionId,
|
|
sessionId: binding.sessionId,
|
|
created,
|
|
model: parsed.data.model,
|
|
traceId: requestContext.traceId,
|
|
projectId: requestContext.projectId,
|
|
},
|
|
"processing chat request",
|
|
);
|
|
|
|
res.status(200);
|
|
res.setHeader("Content-Type", "text/event-stream; charset=utf-8");
|
|
res.setHeader("Cache-Control", "no-cache");
|
|
res.setHeader("Connection", "keep-alive");
|
|
res.setHeader("X-Accel-Buffering", "no");
|
|
res.flushHeaders?.();
|
|
|
|
const clientSessionId = requestContext.clientSessionId;
|
|
let streamClosed = false;
|
|
const abortController = new AbortController();
|
|
const handleClientClose = () => {
|
|
if (streamClosed || abortController.signal.aborted) {
|
|
return;
|
|
}
|
|
abortController.abort();
|
|
};
|
|
|
|
req.on("close", handleClientClose);
|
|
res.on("close", handleClientClose);
|
|
|
|
try {
|
|
const preparedMessage = await buildPromptWithLearningContext(
|
|
memoryStore,
|
|
requestContext.actorKey,
|
|
requestContext.projectKey,
|
|
parsed.data.message,
|
|
);
|
|
const streamResult = await streamPromptResponse({
|
|
runtime,
|
|
opencodeSessionId: binding.sessionId,
|
|
clientSessionId,
|
|
message: preparedMessage,
|
|
model: parsed.data.model,
|
|
traceId: requestContext.traceId,
|
|
projectId: requestContext.projectId,
|
|
signal: abortController.signal,
|
|
write: (event, data) => {
|
|
if (streamClosed || res.writableEnded || res.destroyed) {
|
|
return;
|
|
}
|
|
res.write(toSse(event, data));
|
|
},
|
|
});
|
|
|
|
if (!streamResult.aborted && !streamResult.failed) {
|
|
const messages = await runtime.messages(binding.sessionId, 60);
|
|
const assistantMessage = [...messages]
|
|
.reverse()
|
|
.find((message) => message.info.role === "assistant");
|
|
const assistantText = collectTextContent(assistantMessage?.parts ?? []);
|
|
const existingSessionTitle = sessionBridge.getSessionTitle(binding.sessionId);
|
|
let sessionTitle = existingSessionTitle;
|
|
const { userMessageCount, assistantMessageCount } =
|
|
await getConversationTurnStats(runtime, binding.sessionId);
|
|
const shouldGenerateTitle =
|
|
userMessageCount <= 3 &&
|
|
assistantMessageCount >= 1;
|
|
if (shouldGenerateTitle) {
|
|
sessionTitle = await generateSessionTitle(runtime, {
|
|
sessionId: binding.sessionId,
|
|
latestUserMessage: parsed.data.message,
|
|
fallbackTitle: existingSessionTitle,
|
|
});
|
|
if (sessionTitle !== existingSessionTitle) {
|
|
sessionBridge.setSessionTitle(binding.sessionId, sessionTitle);
|
|
}
|
|
}
|
|
if (!streamClosed && !res.writableEnded && !res.destroyed) {
|
|
if (
|
|
shouldGenerateTitle &&
|
|
sessionTitle &&
|
|
sessionTitle !== existingSessionTitle
|
|
) {
|
|
res.write(
|
|
toSse("session_title", {
|
|
session_id: clientSessionId,
|
|
title: sessionTitle,
|
|
}),
|
|
);
|
|
}
|
|
}
|
|
if (assistantText) {
|
|
void learningOrchestrator.onTurnCompleted({
|
|
assistantMessage: assistantText,
|
|
model: parsed.data.model,
|
|
requestContext,
|
|
sessionId: binding.sessionId,
|
|
toolCallCount: streamResult.toolCallCount,
|
|
userMessage: parsed.data.message,
|
|
}).catch((error) => {
|
|
logger.warn(
|
|
{ err: error, sessionId: binding.sessionId },
|
|
"post-turn learning failed",
|
|
);
|
|
});
|
|
}
|
|
}
|
|
} finally {
|
|
streamClosed = true;
|
|
req.off("close", handleClientClose);
|
|
res.off("close", handleClientClose);
|
|
}
|
|
|
|
if (!res.writableEnded && !res.destroyed) {
|
|
res.end();
|
|
}
|
|
} catch (error) {
|
|
const detail = error instanceof Error ? error.message : String(error);
|
|
logger.error({ err: error }, "chat stream failed");
|
|
res.status(500).json({
|
|
message: "chat stream failed",
|
|
detail,
|
|
});
|
|
}
|
|
});
|
|
|
|
return chatRouter;
|
|
};
|
|
|
|
const toSse = (event: string, data: Record<string, unknown>) =>
|
|
`event: ${event}\ndata: ${JSON.stringify(data)}\n\n`;
|