添加流式 Copilot 请求处理及审计中间件优化
This commit is contained in:
+45
-32
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
@@ -9,7 +10,8 @@ from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from copilot import CopilotClient, PermissionHandler
|
||||
|
||||
|
||||
@@ -23,11 +25,21 @@ class SessionHolder:
|
||||
last_used_at: float
|
||||
|
||||
|
||||
app = FastAPI(title="TJWater Copilot Python Sidecar")
|
||||
app = FastAPI(title="TJWater Copilot Sidecar")
|
||||
client: Optional[CopilotClient] = None
|
||||
sessions: dict[str, SessionHolder] = {}
|
||||
session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800"))
|
||||
model = os.getenv("COPILOT_MODEL", "gpt-5.3-codex")
|
||||
model = os.getenv("COPILOT_MODEL", "gpt-4.1")
|
||||
logger = logging.getLogger("copilot_sidecar")
|
||||
|
||||
|
||||
class ChatStreamRequest(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=10000)
|
||||
conversation_id: Optional[str] = Field(
|
||||
default=None, alias="conversationId", max_length=128
|
||||
)
|
||||
user_id: Optional[str] = Field(default=None, alias="userId", max_length=128)
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -43,8 +55,8 @@ async def shutdown_event() -> None:
|
||||
for holder in sessions.values():
|
||||
try:
|
||||
await holder.session.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to disconnect session during shutdown: %s", exc)
|
||||
sessions.clear()
|
||||
await client.stop()
|
||||
|
||||
@@ -62,8 +74,8 @@ async def _cleanup_sessions() -> None:
|
||||
continue
|
||||
try:
|
||||
await holder.session.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to disconnect expired session %s: %s", sid, exc)
|
||||
|
||||
|
||||
async def _get_or_create_session(conversation_id: str):
|
||||
@@ -76,11 +88,9 @@ async def _get_or_create_session(conversation_id: str):
|
||||
raise RuntimeError("Copilot client is not initialized")
|
||||
|
||||
session = await client.create_session(
|
||||
{
|
||||
"model": model,
|
||||
"streaming": True,
|
||||
"on_permission_request": PermissionHandler.approve_all,
|
||||
}
|
||||
model=model,
|
||||
streaming=True,
|
||||
on_permission_request=PermissionHandler.approve_all,
|
||||
)
|
||||
sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time())
|
||||
return session
|
||||
@@ -92,42 +102,40 @@ async def health() -> dict[str, Any]:
|
||||
|
||||
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream(request: Request):
|
||||
payload = await request.json()
|
||||
message = payload.get("message")
|
||||
conversation_id = payload.get("conversationId")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
return JSONResponse(status_code=400, content={"message": "message is required"})
|
||||
|
||||
async def chat_stream(payload: ChatStreamRequest, request: Request):
|
||||
conv_id = (
|
||||
conversation_id.strip()
|
||||
if isinstance(conversation_id, str) and conversation_id.strip()
|
||||
payload.conversation_id.strip()
|
||||
if isinstance(payload.conversation_id, str) and payload.conversation_id.strip()
|
||||
else f"conv-{uuid.uuid4().hex[:10]}"
|
||||
)
|
||||
message = payload.message.strip()
|
||||
|
||||
async def event_generator():
|
||||
session = None
|
||||
queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue()
|
||||
done = asyncio.Event()
|
||||
error_emitted = False
|
||||
saw_message_delta = False
|
||||
|
||||
def on_event(event):
|
||||
nonlocal error_emitted
|
||||
nonlocal saw_message_delta
|
||||
event_type = getattr(event.type, "value", str(event.type))
|
||||
data = getattr(event, "data", None)
|
||||
if event_type == "assistant.message_delta":
|
||||
content = getattr(data, "delta_content", "") or ""
|
||||
if content:
|
||||
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
|
||||
elif event_type == "assistant.message":
|
||||
saw_message_delta = True
|
||||
queue.put_nowait(
|
||||
("token", {"conversationId": conv_id, "content": content})
|
||||
)
|
||||
elif event_type == "assistant.message" and not saw_message_delta:
|
||||
content = getattr(data, "content", "") or ""
|
||||
if content:
|
||||
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
|
||||
queue.put_nowait(
|
||||
("token", {"conversationId": conv_id, "content": content})
|
||||
)
|
||||
elif event_type == "session.idle":
|
||||
queue.put_nowait(("done", {"conversationId": conv_id}))
|
||||
done.set()
|
||||
elif event_type == "error":
|
||||
error_emitted = True
|
||||
queue.put_nowait(
|
||||
(
|
||||
"error",
|
||||
@@ -144,22 +152,27 @@ async def chat_stream(request: Request):
|
||||
session = await _get_or_create_session(conv_id)
|
||||
unsubscribe = session.on(on_event)
|
||||
try:
|
||||
await session.send({"prompt": message})
|
||||
await session.send(message)
|
||||
while not done.is_set() or not queue.empty():
|
||||
if await request.is_disconnected():
|
||||
logger.info(
|
||||
"Client disconnected during stream: conversation=%s",
|
||||
conv_id,
|
||||
)
|
||||
return
|
||||
try:
|
||||
event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2)
|
||||
event_name, event_data = await asyncio.wait_for(
|
||||
queue.get(), timeout=0.2
|
||||
)
|
||||
yield _sse(event_name, event_data)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if not error_emitted:
|
||||
yield _sse("done", {"conversationId": conv_id})
|
||||
finally:
|
||||
unsubscribe()
|
||||
if conv_id in sessions:
|
||||
sessions[conv_id].last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
logger.exception("Copilot generation failed for %s: %s", conv_id, exc)
|
||||
yield _sse(
|
||||
"error",
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user