450 lines
17 KiB
JavaScript
450 lines
17 KiB
JavaScript
import { Router } from "express";
|
|
import { z } from "zod";
|
|
import { logger } from "../logger.js";
|
|
const payloadSchema = z.object({
|
|
message: z.string().min(1).max(10000),
|
|
session_id: z.string().max(128).optional(),
|
|
});
|
|
const abortPayloadSchema = z.object({
|
|
session_id: z.string().max(128),
|
|
});
|
|
const forkPayloadSchema = z.object({
|
|
session_id: z.string().max(128).optional(),
|
|
keep_message_count: z.coerce.number().int().min(0),
|
|
});
|
|
export const buildChatRouter = (sessionBridge, runtime) => {
|
|
const chatRouter = Router();
|
|
chatRouter.post("/abort", async (req, res) => {
|
|
const parsed = abortPayloadSchema.safeParse(req.body);
|
|
if (!parsed.success) {
|
|
res.status(400).json({
|
|
message: "invalid request payload",
|
|
detail: parsed.error.flatten(),
|
|
});
|
|
return;
|
|
}
|
|
try {
|
|
const authHeader = req.header("authorization");
|
|
const accessToken = authHeader?.startsWith("Bearer ")
|
|
? authHeader.slice("Bearer ".length)
|
|
: authHeader;
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const traceId = req.header("x-trace-id") ?? undefined;
|
|
const binding = await sessionBridge.abort({
|
|
clientSessionId: parsed.data.session_id,
|
|
accessToken,
|
|
projectId,
|
|
traceId,
|
|
});
|
|
if (!binding) {
|
|
res.status(204).end();
|
|
return;
|
|
}
|
|
logger.info({
|
|
clientSessionId: parsed.data.session_id,
|
|
sessionId: binding.sessionId,
|
|
traceId,
|
|
projectId,
|
|
}, "aborted chat session by client request");
|
|
res.status(202).json({
|
|
session_id: parsed.data.session_id,
|
|
aborted: true,
|
|
});
|
|
}
|
|
catch (error) {
|
|
const detail = error instanceof Error ? error.message : String(error);
|
|
logger.error({ err: error }, "chat abort failed");
|
|
res.status(500).json({
|
|
message: "chat abort failed",
|
|
detail,
|
|
});
|
|
}
|
|
});
|
|
chatRouter.post("/fork", async (req, res) => {
|
|
const parsed = forkPayloadSchema.safeParse(req.body);
|
|
if (!parsed.success) {
|
|
res.status(400).json({
|
|
message: "invalid request payload",
|
|
detail: parsed.error.flatten(),
|
|
});
|
|
return;
|
|
}
|
|
try {
|
|
const authHeader = req.header("authorization");
|
|
const accessToken = authHeader?.startsWith("Bearer ")
|
|
? authHeader.slice("Bearer ".length)
|
|
: authHeader;
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const traceId = req.header("x-trace-id") ?? undefined;
|
|
const { binding, requestContext } = await sessionBridge.fork({
|
|
clientSessionId: parsed.data.session_id,
|
|
accessToken,
|
|
projectId,
|
|
traceId,
|
|
keepMessageCount: parsed.data.keep_message_count,
|
|
});
|
|
logger.info({
|
|
sourceClientSessionId: parsed.data.session_id,
|
|
clientSessionId: requestContext.clientSessionId,
|
|
sessionId: binding.sessionId,
|
|
traceId: requestContext.traceId,
|
|
projectId: requestContext.projectId,
|
|
keepMessageCount: parsed.data.keep_message_count,
|
|
}, "forked chat session");
|
|
res.status(200).json({
|
|
session_id: requestContext.clientSessionId,
|
|
});
|
|
}
|
|
catch (error) {
|
|
const detail = error instanceof Error ? error.message : String(error);
|
|
logger.error({ err: error }, "chat fork failed");
|
|
res.status(500).json({
|
|
message: "chat fork failed",
|
|
detail,
|
|
});
|
|
}
|
|
});
|
|
chatRouter.post("/stream", async (req, res) => {
|
|
const parsed = payloadSchema.safeParse(req.body);
|
|
if (!parsed.success) {
|
|
res.status(400).json({
|
|
message: "invalid request payload",
|
|
detail: parsed.error.flatten(),
|
|
});
|
|
return;
|
|
}
|
|
try {
|
|
const authHeader = req.header("authorization");
|
|
const accessToken = authHeader?.startsWith("Bearer ")
|
|
? authHeader.slice("Bearer ".length)
|
|
: authHeader;
|
|
const projectId = req.header("x-project-id") ?? undefined;
|
|
const traceId = req.header("x-trace-id") ?? undefined;
|
|
const { binding, requestContext, created } = await sessionBridge.resolve({
|
|
clientSessionId: parsed.data.session_id,
|
|
accessToken,
|
|
projectId,
|
|
traceId,
|
|
});
|
|
logger.info({
|
|
clientSessionId: requestContext.clientSessionId,
|
|
sessionId: binding.sessionId,
|
|
created,
|
|
traceId: requestContext.traceId,
|
|
projectId: requestContext.projectId,
|
|
}, "processing chat request");
|
|
res.status(200);
|
|
res.setHeader("Content-Type", "text/event-stream; charset=utf-8");
|
|
res.setHeader("Cache-Control", "no-cache");
|
|
res.setHeader("Connection", "keep-alive");
|
|
res.setHeader("X-Accel-Buffering", "no");
|
|
res.flushHeaders?.();
|
|
const clientSessionId = requestContext.clientSessionId;
|
|
let streamClosed = false;
|
|
const abortController = new AbortController();
|
|
const handleClientClose = () => {
|
|
if (streamClosed || abortController.signal.aborted) {
|
|
return;
|
|
}
|
|
abortController.abort();
|
|
};
|
|
req.on("close", handleClientClose);
|
|
res.on("close", handleClientClose);
|
|
try {
|
|
await streamPromptResponse({
|
|
runtime,
|
|
opencodeSessionId: binding.sessionId,
|
|
clientSessionId,
|
|
message: parsed.data.message,
|
|
signal: abortController.signal,
|
|
write: (event, data) => {
|
|
if (streamClosed || res.writableEnded || res.destroyed) {
|
|
return;
|
|
}
|
|
res.write(toSse(event, data));
|
|
},
|
|
});
|
|
}
|
|
finally {
|
|
streamClosed = true;
|
|
req.off("close", handleClientClose);
|
|
res.off("close", handleClientClose);
|
|
}
|
|
if (!res.writableEnded && !res.destroyed) {
|
|
res.end();
|
|
}
|
|
}
|
|
catch (error) {
|
|
const detail = error instanceof Error ? error.message : String(error);
|
|
logger.error({ err: error }, "chat stream failed");
|
|
res.status(500).json({
|
|
message: "chat stream failed",
|
|
detail,
|
|
});
|
|
}
|
|
});
|
|
return chatRouter;
|
|
};
|
|
const toSse = (event, data) => `event: ${event}\ndata: ${JSON.stringify(data)}\n\n`;
|
|
const getErrorMessage = (error) => error.data?.message ?? error.name;
|
|
const streamPromptResponse = async ({ runtime, opencodeSessionId, clientSessionId, message, signal, write, }) => {
|
|
const eventStream = await runtime.subscribeEvents();
|
|
const iterator = eventStream[Symbol.asyncIterator]();
|
|
const emittedToolParts = new Set();
|
|
const partTypes = new Map();
|
|
const pendingTextDeltas = new Map();
|
|
let emittedText = false;
|
|
let done = false;
|
|
let promptSettled = false;
|
|
let aborted = signal?.aborted ?? false;
|
|
const abortPromise = signal
|
|
? new Promise((resolve) => {
|
|
if (signal.aborted) {
|
|
resolve({ type: "abort" });
|
|
return;
|
|
}
|
|
signal.addEventListener("abort", () => resolve({ type: "abort" }), {
|
|
once: true,
|
|
});
|
|
})
|
|
: null;
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: "request-received",
|
|
phase: "start",
|
|
status: "running",
|
|
title: "已收到请求,正在启动 Agent 分析",
|
|
});
|
|
const promptPromise = runtime
|
|
.prompt(opencodeSessionId, message)
|
|
.then(() => {
|
|
promptSettled = true;
|
|
})
|
|
.catch((error) => {
|
|
promptSettled = true;
|
|
throw error;
|
|
});
|
|
try {
|
|
while (!done) {
|
|
if (signal?.aborted) {
|
|
aborted = true;
|
|
break;
|
|
}
|
|
const nextEvent = iterator
|
|
.next()
|
|
.then((result) => ({ type: "event", result }));
|
|
const nextPrompt = promptSettled
|
|
? null
|
|
: promptPromise.then(() => ({ type: "prompt" }), (error) => ({ type: "prompt-error", error }));
|
|
const next = await Promise.race([
|
|
...(nextPrompt ? [nextEvent, nextPrompt] : [nextEvent]),
|
|
...(abortPromise ? [abortPromise] : []),
|
|
]);
|
|
if (next.type === "abort") {
|
|
aborted = true;
|
|
break;
|
|
}
|
|
if (next.type === "prompt-error") {
|
|
throw next.error;
|
|
}
|
|
if (next.type === "prompt") {
|
|
continue;
|
|
}
|
|
if (next.result.done) {
|
|
break;
|
|
}
|
|
const event = next.result.value;
|
|
if (!isSessionEvent(event, opencodeSessionId)) {
|
|
continue;
|
|
}
|
|
if (event.type === "session.status") {
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: "session-status",
|
|
phase: "session",
|
|
status: event.properties.status.type === "idle" ? "completed" : "running",
|
|
title: event.properties.status.type === "retry"
|
|
? `模型请求重试中:${event.properties.status.message}`
|
|
: event.properties.status.type === "busy"
|
|
? "Agent 正在处理请求"
|
|
: "Agent 已空闲",
|
|
});
|
|
continue;
|
|
}
|
|
if (event.type === "message.part.delta" && event.properties.field === "text") {
|
|
const partType = partTypes.get(event.properties.partID);
|
|
if (partType === "text") {
|
|
emittedText = true;
|
|
write("token", {
|
|
session_id: clientSessionId,
|
|
content: event.properties.delta,
|
|
});
|
|
}
|
|
else if (!partType) {
|
|
const pending = pendingTextDeltas.get(event.properties.partID) ?? [];
|
|
pending.push(event.properties.delta);
|
|
pendingTextDeltas.set(event.properties.partID, pending);
|
|
}
|
|
continue;
|
|
}
|
|
if (event.type === "message.part.updated") {
|
|
const part = event.properties.part;
|
|
partTypes.set(part.id, part.type);
|
|
if (part.type === "text") {
|
|
const pending = pendingTextDeltas.get(part.id) ?? [];
|
|
pendingTextDeltas.delete(part.id);
|
|
for (const content of pending) {
|
|
emittedText = true;
|
|
write("token", {
|
|
session_id: clientSessionId,
|
|
content,
|
|
});
|
|
}
|
|
}
|
|
else if (part.type === "reasoning") {
|
|
pendingTextDeltas.delete(part.id);
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: part.id,
|
|
phase: "planning",
|
|
status: part.time.end ? "completed" : "running",
|
|
title: part.time.end ? "分析规划完成" : "正在规划分析步骤",
|
|
});
|
|
}
|
|
if (part.type === "tool") {
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: part.id,
|
|
phase: "tool",
|
|
status: normalizeToolStatus(part.state.status),
|
|
title: getToolProgressTitle(part.tool, part.state.status),
|
|
detail: part.state.status === "error" ? part.state.error : undefined,
|
|
});
|
|
if (!emittedToolParts.has(part.id)) {
|
|
emittedToolParts.add(part.id);
|
|
write("tool_call", {
|
|
session_id: clientSessionId,
|
|
tool: part.tool,
|
|
params: part.state.input,
|
|
});
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
if (event.type === "todo.updated") {
|
|
const completed = event.properties.todos.filter((todo) => todo.status === "completed").length;
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: "todo-progress",
|
|
phase: "planning",
|
|
status: completed === event.properties.todos.length ? "completed" : "running",
|
|
title: `计划进度 ${completed}/${event.properties.todos.length}`,
|
|
detail: event.properties.todos
|
|
.map((todo) => `${todo.status}: ${todo.content}`)
|
|
.join("\n"),
|
|
});
|
|
continue;
|
|
}
|
|
if (event.type === "session.error") {
|
|
write("error", {
|
|
session_id: clientSessionId,
|
|
message: event.properties.error
|
|
? getErrorMessage(event.properties.error)
|
|
: "opencode session error",
|
|
detail: event.properties.error?.name,
|
|
});
|
|
done = true;
|
|
continue;
|
|
}
|
|
if (event.type === "session.idle") {
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: "session-status",
|
|
phase: "session",
|
|
status: "completed",
|
|
title: "Agent 已完成处理",
|
|
});
|
|
done = true;
|
|
}
|
|
}
|
|
if (aborted) {
|
|
await runtime.abortSession(opencodeSessionId).catch((error) => {
|
|
logger.warn({ sessionId: opencodeSessionId, err: error }, "failed to abort opencode session");
|
|
});
|
|
return;
|
|
}
|
|
await promptPromise;
|
|
if (!emittedText) {
|
|
await emitFallbackMessage(runtime, opencodeSessionId, clientSessionId, write);
|
|
}
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: "request-received",
|
|
phase: "start",
|
|
status: "completed",
|
|
title: "请求处理完成",
|
|
});
|
|
write("progress", {
|
|
session_id: clientSessionId,
|
|
id: "request-completed",
|
|
phase: "complete",
|
|
status: "completed",
|
|
title: "分析完成",
|
|
});
|
|
write("done", { session_id: clientSessionId });
|
|
}
|
|
finally {
|
|
await iterator.return?.(undefined);
|
|
if (!promptSettled) {
|
|
await promptPromise.catch(() => undefined);
|
|
}
|
|
}
|
|
};
|
|
const isSessionEvent = (event, sessionId) => "properties" in event &&
|
|
typeof event.properties === "object" &&
|
|
event.properties !== null &&
|
|
"sessionID" in event.properties &&
|
|
event.properties.sessionID === sessionId;
|
|
const emitFallbackMessage = async (runtime, opencodeSessionId, clientSessionId, write) => {
|
|
const messages = await runtime.messages(opencodeSessionId);
|
|
const assistantMessage = [...messages]
|
|
.reverse()
|
|
.find((message) => message.info.role === "assistant");
|
|
const parts = assistantMessage?.parts ?? [];
|
|
const text = collectTextContent(parts);
|
|
if (text) {
|
|
write("token", {
|
|
session_id: clientSessionId,
|
|
content: text,
|
|
});
|
|
}
|
|
};
|
|
const collectTextContent = (parts) => parts
|
|
.filter((part) => part.type === "text")
|
|
.map((part) => part.text)
|
|
.join("");
|
|
const normalizeToolStatus = (status) => {
|
|
if (status === "completed")
|
|
return "completed";
|
|
if (status === "error")
|
|
return "error";
|
|
return "running";
|
|
};
|
|
const getToolProgressTitle = (tool, status) => {
|
|
const toolName = toolLabels[tool] ?? tool;
|
|
if (status === "completed")
|
|
return `${toolName} 已完成`;
|
|
if (status === "error")
|
|
return `${toolName} 执行失败`;
|
|
if (status === "pending")
|
|
return `准备调用 ${toolName}`;
|
|
return `正在调用 ${toolName}`;
|
|
};
|
|
const toolLabels = {
|
|
dynamic_http_call: "后端数据查询",
|
|
locate_features: "地图定位",
|
|
view_history: "历史数据面板",
|
|
view_scada: "SCADA 面板",
|
|
show_chart: "图表渲染",
|
|
};
|