281 lines
9.8 KiB
TypeScript
281 lines
9.8 KiB
TypeScript
import { randomUUID } from "node:crypto";
|
|
|
|
import { logger } from "../logger.js";
|
|
import { type OpencodeRuntimeAdapter } from "../runtime/opencode.js";
|
|
import { type SessionBinding, type SessionContext, SessionRegistry } from "../session/registry.js";
|
|
import { ToolSessionContextStore } from "../session/toolContextStore.js";
|
|
import { toActorKey, toProjectKey } from "../utils/fileStore.js";
|
|
|
|
export type ChatRequestContext = SessionContext & {
|
|
actorKey: string;
|
|
projectKey: string;
|
|
traceId: string;
|
|
};
|
|
|
|
export class ChatSessionBridge {
|
|
// 这里额外保存 session -> 用户上下文,供工具桥在服务端代发真实后端请求时复用。
|
|
private readonly sessionContexts = new Map<string, ChatRequestContext>();
|
|
private readonly sessionTitles = new Map<string, string>();
|
|
private readonly toolContextStore = new ToolSessionContextStore();
|
|
|
|
constructor(
|
|
private readonly registry: SessionRegistry,
|
|
private readonly runtime: OpencodeRuntimeAdapter,
|
|
) {}
|
|
|
|
async resolve(context: {
|
|
clientSessionId?: string;
|
|
accessToken?: string;
|
|
projectId?: string;
|
|
traceId?: string;
|
|
userId?: string;
|
|
}): Promise<{
|
|
binding: SessionBinding;
|
|
requestContext: ChatRequestContext;
|
|
created: boolean;
|
|
}> {
|
|
const requestContext: ChatRequestContext = {
|
|
clientSessionId:
|
|
context.clientSessionId?.trim() || `agent-${randomUUID().slice(0, 12)}`,
|
|
accessToken: context.accessToken,
|
|
actorKey: toActorKey(context.userId),
|
|
projectId: context.projectId,
|
|
projectKey: toProjectKey(context.projectId),
|
|
traceId: context.traceId?.trim() || `trace-${randomUUID().slice(0, 12)}`,
|
|
userId: context.userId?.trim(),
|
|
};
|
|
|
|
this.cleanupExpired();
|
|
|
|
const current = this.registry.get(requestContext);
|
|
if (current) {
|
|
this.sessionContexts.set(current.sessionId, requestContext);
|
|
await this.toolContextStore.write({
|
|
actorKey: requestContext.actorKey,
|
|
allowLearningWrite: true,
|
|
clientSessionId: requestContext.clientSessionId,
|
|
learningMode: "interactive",
|
|
projectId: requestContext.projectId,
|
|
projectKey: requestContext.projectKey,
|
|
sessionId: current.sessionId,
|
|
traceId: requestContext.traceId,
|
|
});
|
|
try {
|
|
// 只有 opencode 侧 session 仍存在时,才复用本地映射。
|
|
await this.runtime.getSession(current.sessionId);
|
|
return { binding: current, requestContext, created: false };
|
|
} catch (error) {
|
|
logger.warn(
|
|
{
|
|
clientSessionId: requestContext.clientSessionId,
|
|
sessionId: current.sessionId,
|
|
err: error,
|
|
},
|
|
"existing opencode session lookup failed, creating a new session",
|
|
);
|
|
}
|
|
}
|
|
|
|
const session = await this.runtime.createSession(requestContext.clientSessionId);
|
|
const binding = this.registry.upsert(requestContext, session.id);
|
|
this.sessionContexts.set(binding.sessionId, requestContext);
|
|
await this.toolContextStore.write({
|
|
actorKey: requestContext.actorKey,
|
|
allowLearningWrite: true,
|
|
clientSessionId: requestContext.clientSessionId,
|
|
learningMode: "interactive",
|
|
projectId: requestContext.projectId,
|
|
projectKey: requestContext.projectKey,
|
|
sessionId: binding.sessionId,
|
|
traceId: requestContext.traceId,
|
|
});
|
|
return { binding, requestContext, created: true };
|
|
}
|
|
|
|
count(): number {
|
|
return this.registry.count();
|
|
}
|
|
|
|
getSessionContext(sessionId: string) {
|
|
return this.sessionContexts.get(sessionId) ?? null;
|
|
}
|
|
|
|
getSessionTitle(sessionId: string) {
|
|
return this.sessionTitles.get(sessionId);
|
|
}
|
|
|
|
setSessionTitle(sessionId: string, title: string) {
|
|
const normalized = title.trim();
|
|
if (!normalized) {
|
|
return;
|
|
}
|
|
this.sessionTitles.set(sessionId, normalized);
|
|
}
|
|
|
|
cloneSessionTitle(sourceSessionId: string, targetSessionId: string) {
|
|
const existingTitle = this.sessionTitles.get(sourceSessionId);
|
|
if (!existingTitle) {
|
|
return;
|
|
}
|
|
this.sessionTitles.set(targetSessionId, existingTitle);
|
|
}
|
|
|
|
async abort(context: {
|
|
clientSessionId?: string;
|
|
accessToken?: string;
|
|
projectId?: string;
|
|
traceId?: string;
|
|
userId?: string;
|
|
}): Promise<SessionBinding | null> {
|
|
const clientSessionId = context.clientSessionId?.trim();
|
|
if (!clientSessionId) {
|
|
return null;
|
|
}
|
|
|
|
const requestContext: ChatRequestContext = {
|
|
clientSessionId,
|
|
accessToken: context.accessToken,
|
|
actorKey: toActorKey(context.userId),
|
|
projectId: context.projectId,
|
|
projectKey: toProjectKey(context.projectId),
|
|
traceId: context.traceId?.trim() || `trace-${randomUUID().slice(0, 12)}`,
|
|
userId: context.userId?.trim(),
|
|
};
|
|
|
|
this.cleanupExpired();
|
|
|
|
const binding = this.registry.get(requestContext);
|
|
if (!binding) {
|
|
return null;
|
|
}
|
|
|
|
this.sessionContexts.set(binding.sessionId, requestContext);
|
|
await this.toolContextStore.write({
|
|
actorKey: requestContext.actorKey,
|
|
allowLearningWrite: true,
|
|
clientSessionId: requestContext.clientSessionId,
|
|
learningMode: "interactive",
|
|
projectId: requestContext.projectId,
|
|
projectKey: requestContext.projectKey,
|
|
sessionId: binding.sessionId,
|
|
traceId: requestContext.traceId,
|
|
});
|
|
await this.runtime.abortSession(binding.sessionId);
|
|
return binding;
|
|
}
|
|
|
|
async fork(context: {
|
|
clientSessionId?: string;
|
|
accessToken?: string;
|
|
projectId?: string;
|
|
traceId?: string;
|
|
keepMessageCount: number;
|
|
userId?: string;
|
|
}): Promise<{
|
|
binding: SessionBinding;
|
|
requestContext: ChatRequestContext;
|
|
created: boolean;
|
|
}> {
|
|
const currentClientSessionId = context.clientSessionId?.trim();
|
|
const nextRequestContext: ChatRequestContext = {
|
|
clientSessionId: `agent-${randomUUID().slice(0, 12)}`,
|
|
accessToken: context.accessToken,
|
|
actorKey: toActorKey(context.userId),
|
|
projectId: context.projectId,
|
|
projectKey: toProjectKey(context.projectId),
|
|
traceId: context.traceId?.trim() || `trace-${randomUUID().slice(0, 12)}`,
|
|
userId: context.userId?.trim(),
|
|
};
|
|
|
|
this.cleanupExpired();
|
|
|
|
if (!currentClientSessionId || context.keepMessageCount <= 0) {
|
|
const session = await this.runtime.createSession(nextRequestContext.clientSessionId);
|
|
const binding = this.registry.upsert(nextRequestContext, session.id);
|
|
this.sessionContexts.set(binding.sessionId, nextRequestContext);
|
|
await this.toolContextStore.write({
|
|
actorKey: nextRequestContext.actorKey,
|
|
allowLearningWrite: true,
|
|
clientSessionId: nextRequestContext.clientSessionId,
|
|
learningMode: "interactive",
|
|
projectId: nextRequestContext.projectId,
|
|
projectKey: nextRequestContext.projectKey,
|
|
sessionId: binding.sessionId,
|
|
traceId: nextRequestContext.traceId,
|
|
});
|
|
return { binding, requestContext: nextRequestContext, created: true };
|
|
}
|
|
|
|
const currentContext: ChatRequestContext = {
|
|
clientSessionId: currentClientSessionId,
|
|
accessToken: context.accessToken,
|
|
actorKey: toActorKey(context.userId),
|
|
projectId: context.projectId,
|
|
projectKey: toProjectKey(context.projectId),
|
|
traceId: nextRequestContext.traceId,
|
|
userId: context.userId?.trim(),
|
|
};
|
|
|
|
const current = this.registry.get(currentContext);
|
|
if (!current) {
|
|
const session = await this.runtime.createSession(nextRequestContext.clientSessionId);
|
|
const binding = this.registry.upsert(nextRequestContext, session.id);
|
|
this.sessionContexts.set(binding.sessionId, nextRequestContext);
|
|
await this.toolContextStore.write({
|
|
actorKey: nextRequestContext.actorKey,
|
|
allowLearningWrite: true,
|
|
clientSessionId: nextRequestContext.clientSessionId,
|
|
learningMode: "interactive",
|
|
projectId: nextRequestContext.projectId,
|
|
projectKey: nextRequestContext.projectKey,
|
|
sessionId: binding.sessionId,
|
|
traceId: nextRequestContext.traceId,
|
|
});
|
|
return { binding, requestContext: nextRequestContext, created: true };
|
|
}
|
|
|
|
await this.runtime.getSession(current.sessionId);
|
|
const messages = await this.runtime.messages(
|
|
current.sessionId,
|
|
Math.max(100, context.keepMessageCount + 20),
|
|
);
|
|
const chatMessages = messages.filter(
|
|
(message) => message.info.role === "user" || message.info.role === "assistant",
|
|
);
|
|
const keepMessage = chatMessages[context.keepMessageCount - 1];
|
|
|
|
if (!keepMessage) {
|
|
throw new Error(`fork keep point not found for message count ${context.keepMessageCount}`);
|
|
}
|
|
|
|
const session = await this.runtime.forkSession(current.sessionId, keepMessage.info.id);
|
|
const binding = this.registry.upsert(nextRequestContext, session.id);
|
|
this.sessionContexts.set(binding.sessionId, nextRequestContext);
|
|
await this.toolContextStore.write({
|
|
actorKey: nextRequestContext.actorKey,
|
|
allowLearningWrite: true,
|
|
clientSessionId: nextRequestContext.clientSessionId,
|
|
learningMode: "interactive",
|
|
projectId: nextRequestContext.projectId,
|
|
projectKey: nextRequestContext.projectKey,
|
|
sessionId: binding.sessionId,
|
|
traceId: nextRequestContext.traceId,
|
|
});
|
|
this.cloneSessionTitle(current.sessionId, binding.sessionId);
|
|
return { binding, requestContext: nextRequestContext, created: true };
|
|
}
|
|
|
|
cleanupExpired(): void {
|
|
const expiredSessionIds = this.registry.evictExpired();
|
|
for (const sessionId of expiredSessionIds) {
|
|
this.sessionContexts.delete(sessionId);
|
|
this.sessionTitles.delete(sessionId);
|
|
void this.toolContextStore.remove(sessionId);
|
|
// 这里用 abort 做轻量清理;即使失败,也不阻断本地过期回收。
|
|
void this.runtime.abortSession(sessionId).catch((error) => {
|
|
logger.debug({ sessionId, err: error }, "ignoring failed abort for expired session");
|
|
});
|
|
}
|
|
}
|
|
}
|