添加流式 Copilot 请求处理及审计中间件优化

This commit is contained in:
2026-03-24 16:01:22 +08:00
parent c184610035
commit 600ddd329c
3 changed files with 147 additions and 44 deletions
+45 -32
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
import time
import uuid
@@ -9,7 +10,8 @@ from dataclasses import dataclass
from typing import Any, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, ConfigDict
from copilot import CopilotClient, PermissionHandler
@@ -23,11 +25,21 @@ class SessionHolder:
last_used_at: float
app = FastAPI(title="TJWater Copilot Python Sidecar")
app = FastAPI(title="TJWater Copilot Sidecar")
client: Optional[CopilotClient] = None
sessions: dict[str, SessionHolder] = {}
session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800"))
model = os.getenv("COPILOT_MODEL", "gpt-5.3-codex")
model = os.getenv("COPILOT_MODEL", "gpt-4.1")
logger = logging.getLogger("copilot_sidecar")
class ChatStreamRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=10000)
conversation_id: Optional[str] = Field(
default=None, alias="conversationId", max_length=128
)
user_id: Optional[str] = Field(default=None, alias="userId", max_length=128)
model_config = ConfigDict(populate_by_name=True)
@app.on_event("startup")
@@ -43,8 +55,8 @@ async def shutdown_event() -> None:
for holder in sessions.values():
try:
await holder.session.disconnect()
except Exception:
pass
except Exception as exc:
logger.warning("Failed to disconnect session during shutdown: %s", exc)
sessions.clear()
await client.stop()
@@ -62,8 +74,8 @@ async def _cleanup_sessions() -> None:
continue
try:
await holder.session.disconnect()
except Exception:
pass
except Exception as exc:
logger.warning("Failed to disconnect expired session %s: %s", sid, exc)
async def _get_or_create_session(conversation_id: str):
@@ -76,11 +88,9 @@ async def _get_or_create_session(conversation_id: str):
raise RuntimeError("Copilot client is not initialized")
session = await client.create_session(
{
"model": model,
"streaming": True,
"on_permission_request": PermissionHandler.approve_all,
}
model=model,
streaming=True,
on_permission_request=PermissionHandler.approve_all,
)
sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time())
return session
@@ -92,42 +102,40 @@ async def health() -> dict[str, Any]:
@app.post("/chat/stream")
async def chat_stream(request: Request):
payload = await request.json()
message = payload.get("message")
conversation_id = payload.get("conversationId")
if not isinstance(message, str) or not message.strip():
return JSONResponse(status_code=400, content={"message": "message is required"})
async def chat_stream(payload: ChatStreamRequest, request: Request):
conv_id = (
conversation_id.strip()
if isinstance(conversation_id, str) and conversation_id.strip()
payload.conversation_id.strip()
if isinstance(payload.conversation_id, str) and payload.conversation_id.strip()
else f"conv-{uuid.uuid4().hex[:10]}"
)
message = payload.message.strip()
async def event_generator():
session = None
queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue()
done = asyncio.Event()
error_emitted = False
saw_message_delta = False
def on_event(event):
nonlocal error_emitted
nonlocal saw_message_delta
event_type = getattr(event.type, "value", str(event.type))
data = getattr(event, "data", None)
if event_type == "assistant.message_delta":
content = getattr(data, "delta_content", "") or ""
if content:
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
elif event_type == "assistant.message":
saw_message_delta = True
queue.put_nowait(
("token", {"conversationId": conv_id, "content": content})
)
elif event_type == "assistant.message" and not saw_message_delta:
content = getattr(data, "content", "") or ""
if content:
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
queue.put_nowait(
("token", {"conversationId": conv_id, "content": content})
)
elif event_type == "session.idle":
queue.put_nowait(("done", {"conversationId": conv_id}))
done.set()
elif event_type == "error":
error_emitted = True
queue.put_nowait(
(
"error",
@@ -144,22 +152,27 @@ async def chat_stream(request: Request):
session = await _get_or_create_session(conv_id)
unsubscribe = session.on(on_event)
try:
await session.send({"prompt": message})
await session.send(message)
while not done.is_set() or not queue.empty():
if await request.is_disconnected():
logger.info(
"Client disconnected during stream: conversation=%s",
conv_id,
)
return
try:
event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2)
event_name, event_data = await asyncio.wait_for(
queue.get(), timeout=0.2
)
yield _sse(event_name, event_data)
except asyncio.TimeoutError:
continue
if not error_emitted:
yield _sse("done", {"conversationId": conv_id})
finally:
unsubscribe()
if conv_id in sessions:
sessions[conv_id].last_used_at = time.time()
except Exception as exc:
logger.exception("Copilot generation failed for %s: %s", conv_id, exc)
yield _sse(
"error",
{