diff --git a/src/chat/sessionBridge.ts b/src/chat/sessionBridge.ts index beb8dbd..686a09d 100644 --- a/src/chat/sessionBridge.ts +++ b/src/chat/sessionBridge.ts @@ -11,6 +11,7 @@ export type ChatRequestContext = SessionContext & { export class ChatSessionBridge { // 这里额外保存 session -> 用户上下文,供工具桥在服务端代发真实后端请求时复用。 private readonly sessionContexts = new Map(); + private readonly sessionTitles = new Map(); constructor( private readonly registry: SessionRegistry, @@ -70,6 +71,18 @@ export class ChatSessionBridge { return this.sessionContexts.get(sessionId) ?? null; } + getSessionTitle(sessionId: string) { + return this.sessionTitles.get(sessionId); + } + + setSessionTitle(sessionId: string, title: string) { + const normalized = title.trim(); + if (!normalized) { + return; + } + this.sessionTitles.set(sessionId, normalized); + } + async abort(context: { clientSessionId?: string; accessToken?: string; @@ -167,6 +180,7 @@ export class ChatSessionBridge { const expiredSessionIds = this.registry.evictExpired(); for (const sessionId of expiredSessionIds) { this.sessionContexts.delete(sessionId); + this.sessionTitles.delete(sessionId); // 这里用 abort 做轻量清理;即使失败,也不阻断本地过期回收。 void this.runtime.abortSession(sessionId).catch((error) => { logger.debug({ sessionId, err: error }, "ignoring failed abort for expired session"); diff --git a/src/routes/chat.ts b/src/routes/chat.ts index 347995b..71111d9 100644 --- a/src/routes/chat.ts +++ b/src/routes/chat.ts @@ -174,6 +174,12 @@ export const buildChatRouter = ( res.flushHeaders?.(); const clientSessionId = requestContext.clientSessionId; + const existingSessionTitle = sessionBridge.getSessionTitle(binding.sessionId); + const sessionTitle = existingSessionTitle + ?? (await generateSessionTitle(runtime, parsed.data.message)); + if (!existingSessionTitle) { + sessionBridge.setSessionTitle(binding.sessionId, sessionTitle); + } let streamClosed = false; const abortController = new AbortController(); const handleClientClose = () => { @@ -187,6 +193,12 @@ export const buildChatRouter = ( res.on("close", handleClientClose); try { + res.write( + toSse("session_title", { + session_id: clientSessionId, + title: sessionTitle, + }), + ); await streamPromptResponse({ runtime, opencodeSessionId: binding.sessionId, @@ -556,6 +568,74 @@ const getToolProgressTitle = (tool: string, status: string) => { return `正在调用 ${toolName}`; }; +const buildSessionTitle = (message: string) => { + const normalized = message.replace(/\s+/g, " ").trim(); + if (!normalized) { + return "新对话"; + } + return normalized.length > 24 ? `${normalized.slice(0, 24)}...` : normalized; +}; + +const TITLE_PROMPT_TIMEOUT_MS = 2500; + +const generateSessionTitle = async ( + runtime: OpencodeRuntimeAdapter, + userMessage: string, +) => { + const fallback = buildSessionTitle(userMessage); + const normalized = userMessage.replace(/\s+/g, " ").trim(); + if (!normalized) { + return fallback; + } + + const titleSession = await runtime.createSession(`title-${Date.now().toString(36)}`); + const request = runtime + .prompt( + titleSession.id, + [ + "你是会话标题生成器。", + "请根据用户问题生成一个 8-16 字中文标题。", + "要求:简洁、可读、避免标点、不要引号、不要解释。", + "只输出标题本身。", + `用户问题:${normalized}`, + ].join("\n"), + ) + .then(async () => { + const messages = await runtime.messages(titleSession.id, 20); + const assistantMessage = [...messages] + .reverse() + .find((message) => message.info.role === "assistant"); + const title = collectTextContent(assistantMessage?.parts ?? []); + return normalizeGeneratedTitle(title, fallback); + }); + + const timeout = new Promise((resolve) => { + setTimeout(() => resolve(fallback), TITLE_PROMPT_TIMEOUT_MS); + }); + + try { + return await Promise.race([request, timeout]); + } catch (error) { + logger.warn({ err: error }, "failed to generate session title, using fallback"); + return fallback; + } finally { + await runtime.abortSession(titleSession.id).catch((error) => { + logger.debug({ sessionId: titleSession.id, err: error }, "failed to cleanup title session"); + }); + } +}; + +const normalizeGeneratedTitle = (rawTitle: string, fallback: string) => { + const normalized = rawTitle + .replace(/\s+/g, " ") + .replace(/["'“”‘’`]/g, "") + .trim(); + if (!normalized) { + return fallback; + } + return normalized.length > 24 ? `${normalized.slice(0, 24)}...` : normalized; +}; + const toolLabels: Record = { dynamic_http_call: "后端数据查询", locate_features: "地图定位",