Files
TJWaterAgent/dist/routes/chat.js
T

450 lines
17 KiB
JavaScript

import { Router } from "express";
import { z } from "zod";
import { logger } from "../logger.js";
const payloadSchema = z.object({
message: z.string().min(1).max(10000),
session_id: z.string().max(128).optional(),
});
const abortPayloadSchema = z.object({
session_id: z.string().max(128),
});
const forkPayloadSchema = z.object({
session_id: z.string().max(128).optional(),
keep_message_count: z.coerce.number().int().min(0),
});
export const buildChatRouter = (sessionBridge, runtime) => {
const chatRouter = Router();
chatRouter.post("/abort", async (req, res) => {
const parsed = abortPayloadSchema.safeParse(req.body);
if (!parsed.success) {
res.status(400).json({
message: "invalid request payload",
detail: parsed.error.flatten(),
});
return;
}
try {
const authHeader = req.header("authorization");
const accessToken = authHeader?.startsWith("Bearer ")
? authHeader.slice("Bearer ".length)
: authHeader;
const projectId = req.header("x-project-id") ?? undefined;
const traceId = req.header("x-trace-id") ?? undefined;
const binding = await sessionBridge.abort({
clientSessionId: parsed.data.session_id,
accessToken,
projectId,
traceId,
});
if (!binding) {
res.status(204).end();
return;
}
logger.info({
clientSessionId: parsed.data.session_id,
sessionId: binding.sessionId,
traceId,
projectId,
}, "aborted chat session by client request");
res.status(202).json({
session_id: parsed.data.session_id,
aborted: true,
});
}
catch (error) {
const detail = error instanceof Error ? error.message : String(error);
logger.error({ err: error }, "chat abort failed");
res.status(500).json({
message: "chat abort failed",
detail,
});
}
});
chatRouter.post("/fork", async (req, res) => {
const parsed = forkPayloadSchema.safeParse(req.body);
if (!parsed.success) {
res.status(400).json({
message: "invalid request payload",
detail: parsed.error.flatten(),
});
return;
}
try {
const authHeader = req.header("authorization");
const accessToken = authHeader?.startsWith("Bearer ")
? authHeader.slice("Bearer ".length)
: authHeader;
const projectId = req.header("x-project-id") ?? undefined;
const traceId = req.header("x-trace-id") ?? undefined;
const { binding, requestContext } = await sessionBridge.fork({
clientSessionId: parsed.data.session_id,
accessToken,
projectId,
traceId,
keepMessageCount: parsed.data.keep_message_count,
});
logger.info({
sourceClientSessionId: parsed.data.session_id,
clientSessionId: requestContext.clientSessionId,
sessionId: binding.sessionId,
traceId: requestContext.traceId,
projectId: requestContext.projectId,
keepMessageCount: parsed.data.keep_message_count,
}, "forked chat session");
res.status(200).json({
session_id: requestContext.clientSessionId,
});
}
catch (error) {
const detail = error instanceof Error ? error.message : String(error);
logger.error({ err: error }, "chat fork failed");
res.status(500).json({
message: "chat fork failed",
detail,
});
}
});
chatRouter.post("/stream", async (req, res) => {
const parsed = payloadSchema.safeParse(req.body);
if (!parsed.success) {
res.status(400).json({
message: "invalid request payload",
detail: parsed.error.flatten(),
});
return;
}
try {
const authHeader = req.header("authorization");
const accessToken = authHeader?.startsWith("Bearer ")
? authHeader.slice("Bearer ".length)
: authHeader;
const projectId = req.header("x-project-id") ?? undefined;
const traceId = req.header("x-trace-id") ?? undefined;
const { binding, requestContext, created } = await sessionBridge.resolve({
clientSessionId: parsed.data.session_id,
accessToken,
projectId,
traceId,
});
logger.info({
clientSessionId: requestContext.clientSessionId,
sessionId: binding.sessionId,
created,
traceId: requestContext.traceId,
projectId: requestContext.projectId,
}, "processing chat request");
res.status(200);
res.setHeader("Content-Type", "text/event-stream; charset=utf-8");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
res.setHeader("X-Accel-Buffering", "no");
res.flushHeaders?.();
const clientSessionId = requestContext.clientSessionId;
let streamClosed = false;
const abortController = new AbortController();
const handleClientClose = () => {
if (streamClosed || abortController.signal.aborted) {
return;
}
abortController.abort();
};
req.on("close", handleClientClose);
res.on("close", handleClientClose);
try {
await streamPromptResponse({
runtime,
opencodeSessionId: binding.sessionId,
clientSessionId,
message: parsed.data.message,
signal: abortController.signal,
write: (event, data) => {
if (streamClosed || res.writableEnded || res.destroyed) {
return;
}
res.write(toSse(event, data));
},
});
}
finally {
streamClosed = true;
req.off("close", handleClientClose);
res.off("close", handleClientClose);
}
if (!res.writableEnded && !res.destroyed) {
res.end();
}
}
catch (error) {
const detail = error instanceof Error ? error.message : String(error);
logger.error({ err: error }, "chat stream failed");
res.status(500).json({
message: "chat stream failed",
detail,
});
}
});
return chatRouter;
};
const toSse = (event, data) => `event: ${event}\ndata: ${JSON.stringify(data)}\n\n`;
const getErrorMessage = (error) => error.data?.message ?? error.name;
const streamPromptResponse = async ({ runtime, opencodeSessionId, clientSessionId, message, signal, write, }) => {
const eventStream = await runtime.subscribeEvents();
const iterator = eventStream[Symbol.asyncIterator]();
const emittedToolParts = new Set();
const partTypes = new Map();
const pendingTextDeltas = new Map();
let emittedText = false;
let done = false;
let promptSettled = false;
let aborted = signal?.aborted ?? false;
const abortPromise = signal
? new Promise((resolve) => {
if (signal.aborted) {
resolve({ type: "abort" });
return;
}
signal.addEventListener("abort", () => resolve({ type: "abort" }), {
once: true,
});
})
: null;
write("progress", {
session_id: clientSessionId,
id: "request-received",
phase: "start",
status: "running",
title: "已收到请求,正在启动 Agent 分析",
});
const promptPromise = runtime
.prompt(opencodeSessionId, message)
.then(() => {
promptSettled = true;
})
.catch((error) => {
promptSettled = true;
throw error;
});
try {
while (!done) {
if (signal?.aborted) {
aborted = true;
break;
}
const nextEvent = iterator
.next()
.then((result) => ({ type: "event", result }));
const nextPrompt = promptSettled
? null
: promptPromise.then(() => ({ type: "prompt" }), (error) => ({ type: "prompt-error", error }));
const next = await Promise.race([
...(nextPrompt ? [nextEvent, nextPrompt] : [nextEvent]),
...(abortPromise ? [abortPromise] : []),
]);
if (next.type === "abort") {
aborted = true;
break;
}
if (next.type === "prompt-error") {
throw next.error;
}
if (next.type === "prompt") {
continue;
}
if (next.result.done) {
break;
}
const event = next.result.value;
if (!isSessionEvent(event, opencodeSessionId)) {
continue;
}
if (event.type === "session.status") {
write("progress", {
session_id: clientSessionId,
id: "session-status",
phase: "session",
status: event.properties.status.type === "idle" ? "completed" : "running",
title: event.properties.status.type === "retry"
? `模型请求重试中:${event.properties.status.message}`
: event.properties.status.type === "busy"
? "Agent 正在处理请求"
: "Agent 已空闲",
});
continue;
}
if (event.type === "message.part.delta" && event.properties.field === "text") {
const partType = partTypes.get(event.properties.partID);
if (partType === "text") {
emittedText = true;
write("token", {
session_id: clientSessionId,
content: event.properties.delta,
});
}
else if (!partType) {
const pending = pendingTextDeltas.get(event.properties.partID) ?? [];
pending.push(event.properties.delta);
pendingTextDeltas.set(event.properties.partID, pending);
}
continue;
}
if (event.type === "message.part.updated") {
const part = event.properties.part;
partTypes.set(part.id, part.type);
if (part.type === "text") {
const pending = pendingTextDeltas.get(part.id) ?? [];
pendingTextDeltas.delete(part.id);
for (const content of pending) {
emittedText = true;
write("token", {
session_id: clientSessionId,
content,
});
}
}
else if (part.type === "reasoning") {
pendingTextDeltas.delete(part.id);
write("progress", {
session_id: clientSessionId,
id: part.id,
phase: "planning",
status: part.time.end ? "completed" : "running",
title: part.time.end ? "分析规划完成" : "正在规划分析步骤",
});
}
if (part.type === "tool") {
write("progress", {
session_id: clientSessionId,
id: part.id,
phase: "tool",
status: normalizeToolStatus(part.state.status),
title: getToolProgressTitle(part.tool, part.state.status),
detail: part.state.status === "error" ? part.state.error : undefined,
});
if (!emittedToolParts.has(part.id)) {
emittedToolParts.add(part.id);
write("tool_call", {
session_id: clientSessionId,
tool: part.tool,
params: part.state.input,
});
}
}
continue;
}
if (event.type === "todo.updated") {
const completed = event.properties.todos.filter((todo) => todo.status === "completed").length;
write("progress", {
session_id: clientSessionId,
id: "todo-progress",
phase: "planning",
status: completed === event.properties.todos.length ? "completed" : "running",
title: `计划进度 ${completed}/${event.properties.todos.length}`,
detail: event.properties.todos
.map((todo) => `${todo.status}: ${todo.content}`)
.join("\n"),
});
continue;
}
if (event.type === "session.error") {
write("error", {
session_id: clientSessionId,
message: event.properties.error
? getErrorMessage(event.properties.error)
: "opencode session error",
detail: event.properties.error?.name,
});
done = true;
continue;
}
if (event.type === "session.idle") {
write("progress", {
session_id: clientSessionId,
id: "session-status",
phase: "session",
status: "completed",
title: "Agent 已完成处理",
});
done = true;
}
}
if (aborted) {
await runtime.abortSession(opencodeSessionId).catch((error) => {
logger.warn({ sessionId: opencodeSessionId, err: error }, "failed to abort opencode session");
});
return;
}
await promptPromise;
if (!emittedText) {
await emitFallbackMessage(runtime, opencodeSessionId, clientSessionId, write);
}
write("progress", {
session_id: clientSessionId,
id: "request-received",
phase: "start",
status: "completed",
title: "请求处理完成",
});
write("progress", {
session_id: clientSessionId,
id: "request-completed",
phase: "complete",
status: "completed",
title: "分析完成",
});
write("done", { session_id: clientSessionId });
}
finally {
await iterator.return?.(undefined);
if (!promptSettled) {
await promptPromise.catch(() => undefined);
}
}
};
const isSessionEvent = (event, sessionId) => "properties" in event &&
typeof event.properties === "object" &&
event.properties !== null &&
"sessionID" in event.properties &&
event.properties.sessionID === sessionId;
const emitFallbackMessage = async (runtime, opencodeSessionId, clientSessionId, write) => {
const messages = await runtime.messages(opencodeSessionId);
const assistantMessage = [...messages]
.reverse()
.find((message) => message.info.role === "assistant");
const parts = assistantMessage?.parts ?? [];
const text = collectTextContent(parts);
if (text) {
write("token", {
session_id: clientSessionId,
content: text,
});
}
};
const collectTextContent = (parts) => parts
.filter((part) => part.type === "text")
.map((part) => part.text)
.join("");
const normalizeToolStatus = (status) => {
if (status === "completed")
return "completed";
if (status === "error")
return "error";
return "running";
};
const getToolProgressTitle = (tool, status) => {
const toolName = toolLabels[tool] ?? tool;
if (status === "completed")
return `${toolName} 已完成`;
if (status === "error")
return `${toolName} 执行失败`;
if (status === "pending")
return `准备调用 ${toolName}`;
return `正在调用 ${toolName}`;
};
const toolLabels = {
dynamic_http_call: "后端数据查询",
locate_features: "地图定位",
view_history: "历史数据面板",
view_scada: "SCADA 面板",
show_chart: "图表渲染",
};