添加 Copilot 聊天流式响应功能及相关配置
This commit is contained in:
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from copilot import CopilotClient, PermissionHandler
|
||||
|
||||
|
||||
def _sse(event: str, data: dict[str, Any]) -> str:
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionHolder:
|
||||
session: Any
|
||||
last_used_at: float
|
||||
|
||||
|
||||
app = FastAPI(title="TJWater Copilot Python Sidecar")
|
||||
client: Optional[CopilotClient] = None
|
||||
sessions: dict[str, SessionHolder] = {}
|
||||
session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800"))
|
||||
model = os.getenv("COPILOT_MODEL", "gpt-5.1-codex")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
global client
|
||||
client = CopilotClient()
|
||||
await client.start()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
if client is not None:
|
||||
for holder in sessions.values():
|
||||
try:
|
||||
await holder.session.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
sessions.clear()
|
||||
await client.stop()
|
||||
|
||||
|
||||
async def _cleanup_sessions() -> None:
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid
|
||||
for sid, holder in sessions.items()
|
||||
if now - holder.last_used_at > session_ttl_seconds
|
||||
]
|
||||
for sid in expired:
|
||||
holder = sessions.pop(sid, None)
|
||||
if holder is None:
|
||||
continue
|
||||
try:
|
||||
await holder.session.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _get_or_create_session(conversation_id: str):
|
||||
await _cleanup_sessions()
|
||||
if conversation_id in sessions:
|
||||
sessions[conversation_id].last_used_at = time.time()
|
||||
return sessions[conversation_id].session
|
||||
|
||||
if client is None:
|
||||
raise RuntimeError("Copilot client is not initialized")
|
||||
|
||||
session = await client.create_session(
|
||||
{
|
||||
"model": model,
|
||||
"streaming": True,
|
||||
"on_permission_request": PermissionHandler.approve_all,
|
||||
}
|
||||
)
|
||||
sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time())
|
||||
return session
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, Any]:
|
||||
return {"ok": True, "model": model, "sessions": len(sessions)}
|
||||
|
||||
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream(request: Request):
|
||||
payload = await request.json()
|
||||
message = payload.get("message")
|
||||
conversation_id = payload.get("conversationId")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
return JSONResponse(status_code=400, content={"message": "message is required"})
|
||||
|
||||
conv_id = (
|
||||
conversation_id.strip()
|
||||
if isinstance(conversation_id, str) and conversation_id.strip()
|
||||
else f"conv-{uuid.uuid4().hex[:10]}"
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
session = None
|
||||
queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue()
|
||||
done = asyncio.Event()
|
||||
error_emitted = False
|
||||
|
||||
def on_event(event):
|
||||
nonlocal error_emitted
|
||||
event_type = getattr(event.type, "value", str(event.type))
|
||||
data = getattr(event, "data", None)
|
||||
if event_type == "assistant.message_delta":
|
||||
content = getattr(data, "delta_content", "") or ""
|
||||
if content:
|
||||
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
|
||||
elif event_type == "assistant.message":
|
||||
content = getattr(data, "content", "") or ""
|
||||
if content:
|
||||
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
|
||||
elif event_type == "session.idle":
|
||||
queue.put_nowait(("done", {"conversationId": conv_id}))
|
||||
done.set()
|
||||
elif event_type == "error":
|
||||
error_emitted = True
|
||||
queue.put_nowait(
|
||||
(
|
||||
"error",
|
||||
{
|
||||
"conversationId": conv_id,
|
||||
"message": "copilot session error",
|
||||
"detail": str(data),
|
||||
},
|
||||
)
|
||||
)
|
||||
done.set()
|
||||
|
||||
try:
|
||||
session = await _get_or_create_session(conv_id)
|
||||
unsubscribe = session.on(on_event)
|
||||
try:
|
||||
await session.send({"prompt": message})
|
||||
while not done.is_set() or not queue.empty():
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
try:
|
||||
event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2)
|
||||
yield _sse(event_name, event_data)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if not error_emitted:
|
||||
yield _sse("done", {"conversationId": conv_id})
|
||||
finally:
|
||||
unsubscribe()
|
||||
if conv_id in sessions:
|
||||
sessions[conv_id].last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
yield _sse(
|
||||
"error",
|
||||
{
|
||||
"conversationId": conv_id,
|
||||
"message": "copilot generation failed",
|
||||
"detail": str(exc),
|
||||
},
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user