添加流式 Copilot 请求处理及审计中间件优化

This commit is contained in:
2026-03-24 16:01:22 +08:00
parent c184610035
commit 600ddd329c
3 changed files with 147 additions and 44 deletions
+31 -3
View File
@@ -61,12 +61,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
"/api/v1/openproject/", "/api/v1/openproject/",
"/openproject/", "/openproject/",
"/api/v1/copilot/chat/", "/api/v1/copilot/chat/",
"/api/v1/copilot/chat/stream",
} }
EXCLUDED_PATH_PREFIXES = (
"/api/v1/copilot/chat/",
"/copilot/chat/",
)
async def dispatch(self, request: Request, call_next: Callable) -> Response: async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 提取开始时间 # 提取开始时间
start_time = time.time() start_time = time.time()
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
if self._is_excluded_path(request.url.path):
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 1. 预判是否需要读取Body (针对写操作) # 1. 预判是否需要读取Body (针对写操作)
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag # 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
should_capture_body = request.method in ["POST", "PUT", "PATCH"] should_capture_body = request.method in ["POST", "PUT", "PATCH"]
@@ -75,13 +87,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
if should_capture_body: if should_capture_body:
try: try:
# 注意:读取 body 后需要重新设置,避免影响后续处理 # 注意:读取 body 后需要重新设置,避免影响后续处理
original_receive = request._receive
body = await request.body() body = await request.body()
if body: if body:
request_data = json.loads(body.decode()) request_data = json.loads(body.decode())
# 重新构造请求以供后续使用 # 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
body_sent = False
async def receive(): async def receive():
return {"type": "http.request", "body": body} nonlocal body_sent
if not body_sent:
body_sent = True
return {
"type": "http.request",
"body": body,
"more_body": False,
}
return await original_receive()
request._receive = receive request._receive = receive
except Exception as e: except Exception as e:
@@ -91,7 +114,7 @@ class AuditMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
# 3. 决定是否审计 # 3. 决定是否审计
if request.url.path in self.EXCLUDED_PATHS: if self._is_excluded_path(request.url.path):
process_time = time.time() - start_time process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time) response.headers["X-Process-Time"] = str(process_time)
return response return response
@@ -151,6 +174,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response return response
def _is_excluded_path(self, path: str) -> bool:
if path in self.EXCLUDED_PATHS:
return True
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
def _resolve_project_id(self, request: Request) -> UUID | None: def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id") project_header = request.headers.get("X-Project-Id")
if not project_header: if not project_header:
+45 -32
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging
import os import os
import time import time
import uuid import uuid
@@ -9,7 +10,8 @@ from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, ConfigDict
from copilot import CopilotClient, PermissionHandler from copilot import CopilotClient, PermissionHandler
@@ -23,11 +25,21 @@ class SessionHolder:
last_used_at: float last_used_at: float
app = FastAPI(title="TJWater Copilot Python Sidecar") app = FastAPI(title="TJWater Copilot Sidecar")
client: Optional[CopilotClient] = None client: Optional[CopilotClient] = None
sessions: dict[str, SessionHolder] = {} sessions: dict[str, SessionHolder] = {}
session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800")) session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800"))
model = os.getenv("COPILOT_MODEL", "gpt-5.3-codex") model = os.getenv("COPILOT_MODEL", "gpt-4.1")
logger = logging.getLogger("copilot_sidecar")
class ChatStreamRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=10000)
conversation_id: Optional[str] = Field(
default=None, alias="conversationId", max_length=128
)
user_id: Optional[str] = Field(default=None, alias="userId", max_length=128)
model_config = ConfigDict(populate_by_name=True)
@app.on_event("startup") @app.on_event("startup")
@@ -43,8 +55,8 @@ async def shutdown_event() -> None:
for holder in sessions.values(): for holder in sessions.values():
try: try:
await holder.session.disconnect() await holder.session.disconnect()
except Exception: except Exception as exc:
pass logger.warning("Failed to disconnect session during shutdown: %s", exc)
sessions.clear() sessions.clear()
await client.stop() await client.stop()
@@ -62,8 +74,8 @@ async def _cleanup_sessions() -> None:
continue continue
try: try:
await holder.session.disconnect() await holder.session.disconnect()
except Exception: except Exception as exc:
pass logger.warning("Failed to disconnect expired session %s: %s", sid, exc)
async def _get_or_create_session(conversation_id: str): async def _get_or_create_session(conversation_id: str):
@@ -76,11 +88,9 @@ async def _get_or_create_session(conversation_id: str):
raise RuntimeError("Copilot client is not initialized") raise RuntimeError("Copilot client is not initialized")
session = await client.create_session( session = await client.create_session(
{ model=model,
"model": model, streaming=True,
"streaming": True, on_permission_request=PermissionHandler.approve_all,
"on_permission_request": PermissionHandler.approve_all,
}
) )
sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time()) sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time())
return session return session
@@ -92,42 +102,40 @@ async def health() -> dict[str, Any]:
@app.post("/chat/stream") @app.post("/chat/stream")
async def chat_stream(request: Request): async def chat_stream(payload: ChatStreamRequest, request: Request):
payload = await request.json()
message = payload.get("message")
conversation_id = payload.get("conversationId")
if not isinstance(message, str) or not message.strip():
return JSONResponse(status_code=400, content={"message": "message is required"})
conv_id = ( conv_id = (
conversation_id.strip() payload.conversation_id.strip()
if isinstance(conversation_id, str) and conversation_id.strip() if isinstance(payload.conversation_id, str) and payload.conversation_id.strip()
else f"conv-{uuid.uuid4().hex[:10]}" else f"conv-{uuid.uuid4().hex[:10]}"
) )
message = payload.message.strip()
async def event_generator(): async def event_generator():
session = None
queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue() queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue()
done = asyncio.Event() done = asyncio.Event()
error_emitted = False saw_message_delta = False
def on_event(event): def on_event(event):
nonlocal error_emitted nonlocal saw_message_delta
event_type = getattr(event.type, "value", str(event.type)) event_type = getattr(event.type, "value", str(event.type))
data = getattr(event, "data", None) data = getattr(event, "data", None)
if event_type == "assistant.message_delta": if event_type == "assistant.message_delta":
content = getattr(data, "delta_content", "") or "" content = getattr(data, "delta_content", "") or ""
if content: if content:
queue.put_nowait(("token", {"conversationId": conv_id, "content": content})) saw_message_delta = True
elif event_type == "assistant.message": queue.put_nowait(
("token", {"conversationId": conv_id, "content": content})
)
elif event_type == "assistant.message" and not saw_message_delta:
content = getattr(data, "content", "") or "" content = getattr(data, "content", "") or ""
if content: if content:
queue.put_nowait(("token", {"conversationId": conv_id, "content": content})) queue.put_nowait(
("token", {"conversationId": conv_id, "content": content})
)
elif event_type == "session.idle": elif event_type == "session.idle":
queue.put_nowait(("done", {"conversationId": conv_id})) queue.put_nowait(("done", {"conversationId": conv_id}))
done.set() done.set()
elif event_type == "error": elif event_type == "error":
error_emitted = True
queue.put_nowait( queue.put_nowait(
( (
"error", "error",
@@ -144,22 +152,27 @@ async def chat_stream(request: Request):
session = await _get_or_create_session(conv_id) session = await _get_or_create_session(conv_id)
unsubscribe = session.on(on_event) unsubscribe = session.on(on_event)
try: try:
await session.send({"prompt": message}) await session.send(message)
while not done.is_set() or not queue.empty(): while not done.is_set() or not queue.empty():
if await request.is_disconnected(): if await request.is_disconnected():
logger.info(
"Client disconnected during stream: conversation=%s",
conv_id,
)
return return
try: try:
event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2) event_name, event_data = await asyncio.wait_for(
queue.get(), timeout=0.2
)
yield _sse(event_name, event_data) yield _sse(event_name, event_data)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
if not error_emitted:
yield _sse("done", {"conversationId": conv_id})
finally: finally:
unsubscribe() unsubscribe()
if conv_id in sessions: if conv_id in sessions:
sessions[conv_id].last_used_at = time.time() sessions[conv_id].last_used_at = time.time()
except Exception as exc: except Exception as exc:
logger.exception("Copilot generation failed for %s: %s", conv_id, exc)
yield _sse( yield _sse(
"error", "error",
{ {
+71 -9
View File
@@ -1,21 +1,83 @@
import asyncio import asyncio
import sys import atexit
import os import os
import signal
import subprocess
import sys
from urllib.parse import urlparse
import uvicorn import uvicorn
# 将项目根目录添加到 python 路径 # 将项目根目录添加到 python 路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
_SIDECAR_PROCESS: subprocess.Popen | None = None
def _parse_sidecar_target() -> tuple[str, int]:
sidecar_url = os.getenv("COPILOT_SIDECAR_URL", "http://127.0.0.1:8787").strip()
parsed = urlparse(sidecar_url)
host = parsed.hostname or "127.0.0.1"
port = parsed.port or 8787
return host, port
def _stop_sidecar() -> None:
global _SIDECAR_PROCESS
proc = _SIDECAR_PROCESS
if proc is None:
return
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=3)
_SIDECAR_PROCESS = None
def _start_sidecar_if_needed() -> None:
global _SIDECAR_PROCESS
sidecar_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "copilot-sidecar")
)
host, port = _parse_sidecar_target()
cmd = [
sys.executable,
"-m",
"uvicorn",
"server:app",
"--host",
host,
"--port",
str(port),
"--log-level",
os.getenv("COPILOT_SIDECAR_LOG_LEVEL", "warning"),
]
_SIDECAR_PROCESS = subprocess.Popen(cmd, cwd=sidecar_dir)
print(f"[run_server] sidecar started at {host}:{port}.")
if __name__ == "__main__": if __name__ == "__main__":
# Windows 设置事件循环策略 # Windows 设置事件循环策略
if sys.platform == "win32": if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# 用 uvicorn.run 支持 workers 参数 atexit.register(_stop_sidecar)
uvicorn.run( signal.signal(signal.SIGTERM, lambda *_: _stop_sidecar())
"app.main:app", signal.signal(signal.SIGINT, lambda *_: _stop_sidecar())
host="0.0.0.0",
port=8000, _start_sidecar_if_needed()
# workers=2, # 这里可以设置多进程 try:
loop="asyncio", # 用 uvicorn.run 支持 workers 参数
) uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
# workers=2, # 这里可以设置多进程
loop="asyncio",
)
finally:
_stop_sidecar()