新增会话标题管理功能,优化会话标题生成

This commit is contained in:
2026-04-30 15:05:21 +08:00
parent e6b10cd603
commit c806f03d51
2 changed files with 94 additions and 0 deletions
+14
View File
@@ -11,6 +11,7 @@ export type ChatRequestContext = SessionContext & {
export class ChatSessionBridge {
// 这里额外保存 session -> 用户上下文,供工具桥在服务端代发真实后端请求时复用。
private readonly sessionContexts = new Map<string, ChatRequestContext>();
private readonly sessionTitles = new Map<string, string>();
constructor(
private readonly registry: SessionRegistry,
@@ -70,6 +71,18 @@ export class ChatSessionBridge {
return this.sessionContexts.get(sessionId) ?? null;
}
getSessionTitle(sessionId: string) {
return this.sessionTitles.get(sessionId);
}
setSessionTitle(sessionId: string, title: string) {
const normalized = title.trim();
if (!normalized) {
return;
}
this.sessionTitles.set(sessionId, normalized);
}
async abort(context: {
clientSessionId?: string;
accessToken?: string;
@@ -167,6 +180,7 @@ export class ChatSessionBridge {
const expiredSessionIds = this.registry.evictExpired();
for (const sessionId of expiredSessionIds) {
this.sessionContexts.delete(sessionId);
this.sessionTitles.delete(sessionId);
// 这里用 abort 做轻量清理;即使失败,也不阻断本地过期回收。
void this.runtime.abortSession(sessionId).catch((error) => {
logger.debug({ sessionId, err: error }, "ignoring failed abort for expired session");
+80
View File
@@ -174,6 +174,12 @@ export const buildChatRouter = (
res.flushHeaders?.();
const clientSessionId = requestContext.clientSessionId;
const existingSessionTitle = sessionBridge.getSessionTitle(binding.sessionId);
const sessionTitle = existingSessionTitle
?? (await generateSessionTitle(runtime, parsed.data.message));
if (!existingSessionTitle) {
sessionBridge.setSessionTitle(binding.sessionId, sessionTitle);
}
let streamClosed = false;
const abortController = new AbortController();
const handleClientClose = () => {
@@ -187,6 +193,12 @@ export const buildChatRouter = (
res.on("close", handleClientClose);
try {
res.write(
toSse("session_title", {
session_id: clientSessionId,
title: sessionTitle,
}),
);
await streamPromptResponse({
runtime,
opencodeSessionId: binding.sessionId,
@@ -556,6 +568,74 @@ const getToolProgressTitle = (tool: string, status: string) => {
return `正在调用 ${toolName}`;
};
const buildSessionTitle = (message: string) => {
const normalized = message.replace(/\s+/g, " ").trim();
if (!normalized) {
return "新对话";
}
return normalized.length > 24 ? `${normalized.slice(0, 24)}...` : normalized;
};
const TITLE_PROMPT_TIMEOUT_MS = 2500;
const generateSessionTitle = async (
runtime: OpencodeRuntimeAdapter,
userMessage: string,
) => {
const fallback = buildSessionTitle(userMessage);
const normalized = userMessage.replace(/\s+/g, " ").trim();
if (!normalized) {
return fallback;
}
const titleSession = await runtime.createSession(`title-${Date.now().toString(36)}`);
const request = runtime
.prompt(
titleSession.id,
[
"你是会话标题生成器。",
"请根据用户问题生成一个 8-16 字中文标题。",
"要求:简洁、可读、避免标点、不要引号、不要解释。",
"只输出标题本身。",
`用户问题:${normalized}`,
].join("\n"),
)
.then(async () => {
const messages = await runtime.messages(titleSession.id, 20);
const assistantMessage = [...messages]
.reverse()
.find((message) => message.info.role === "assistant");
const title = collectTextContent(assistantMessage?.parts ?? []);
return normalizeGeneratedTitle(title, fallback);
});
const timeout = new Promise<string>((resolve) => {
setTimeout(() => resolve(fallback), TITLE_PROMPT_TIMEOUT_MS);
});
try {
return await Promise.race([request, timeout]);
} catch (error) {
logger.warn({ err: error }, "failed to generate session title, using fallback");
return fallback;
} finally {
await runtime.abortSession(titleSession.id).catch((error) => {
logger.debug({ sessionId: titleSession.id, err: error }, "failed to cleanup title session");
});
}
};
const normalizeGeneratedTitle = (rawTitle: string, fallback: string) => {
const normalized = rawTitle
.replace(/\s+/g, " ")
.replace(/["'“”‘’`]/g, "")
.trim();
if (!normalized) {
return fallback;
}
return normalized.length > 24 ? `${normalized.slice(0, 24)}...` : normalized;
};
const toolLabels: Record<string, string> = {
dynamic_http_call: "后端数据查询",
locate_features: "地图定位",