添加流式 Copilot 请求处理及审计中间件优化

This commit is contained in:
2026-03-24 16:01:22 +08:00
parent c184610035
commit 600ddd329c
3 changed files with 147 additions and 44 deletions
+31 -3
View File
@@ -61,12 +61,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
"/api/v1/openproject/",
"/openproject/",
"/api/v1/copilot/chat/",
"/api/v1/copilot/chat/stream",
}
EXCLUDED_PATH_PREFIXES = (
"/api/v1/copilot/chat/",
"/copilot/chat/",
)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 提取开始时间
start_time = time.time()
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
if self._is_excluded_path(request.url.path):
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 1. 预判是否需要读取Body (针对写操作)
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
@@ -75,13 +87,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
if should_capture_body:
try:
# 注意:读取 body 后需要重新设置,避免影响后续处理
original_receive = request._receive
body = await request.body()
if body:
request_data = json.loads(body.decode())
# 重新构造请求以供后续使用
# 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
body_sent = False
async def receive():
return {"type": "http.request", "body": body}
nonlocal body_sent
if not body_sent:
body_sent = True
return {
"type": "http.request",
"body": body,
"more_body": False,
}
return await original_receive()
request._receive = receive
except Exception as e:
@@ -91,7 +114,7 @@ class AuditMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
# 3. 决定是否审计
if request.url.path in self.EXCLUDED_PATHS:
if self._is_excluded_path(request.url.path):
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@@ -151,6 +174,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
def _is_excluded_path(self, path: str) -> bool:
if path in self.EXCLUDED_PATHS:
return True
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header:
+45 -32
View File
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
import time
import uuid
@@ -9,7 +10,8 @@ from dataclasses import dataclass
from typing import Any, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, ConfigDict
from copilot import CopilotClient, PermissionHandler
@@ -23,11 +25,21 @@ class SessionHolder:
last_used_at: float
app = FastAPI(title="TJWater Copilot Python Sidecar")
app = FastAPI(title="TJWater Copilot Sidecar")
client: Optional[CopilotClient] = None
sessions: dict[str, SessionHolder] = {}
session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800"))
model = os.getenv("COPILOT_MODEL", "gpt-5.3-codex")
model = os.getenv("COPILOT_MODEL", "gpt-4.1")
logger = logging.getLogger("copilot_sidecar")
class ChatStreamRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=10000)
conversation_id: Optional[str] = Field(
default=None, alias="conversationId", max_length=128
)
user_id: Optional[str] = Field(default=None, alias="userId", max_length=128)
model_config = ConfigDict(populate_by_name=True)
@app.on_event("startup")
@@ -43,8 +55,8 @@ async def shutdown_event() -> None:
for holder in sessions.values():
try:
await holder.session.disconnect()
except Exception:
pass
except Exception as exc:
logger.warning("Failed to disconnect session during shutdown: %s", exc)
sessions.clear()
await client.stop()
@@ -62,8 +74,8 @@ async def _cleanup_sessions() -> None:
continue
try:
await holder.session.disconnect()
except Exception:
pass
except Exception as exc:
logger.warning("Failed to disconnect expired session %s: %s", sid, exc)
async def _get_or_create_session(conversation_id: str):
@@ -76,11 +88,9 @@ async def _get_or_create_session(conversation_id: str):
raise RuntimeError("Copilot client is not initialized")
session = await client.create_session(
{
"model": model,
"streaming": True,
"on_permission_request": PermissionHandler.approve_all,
}
model=model,
streaming=True,
on_permission_request=PermissionHandler.approve_all,
)
sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time())
return session
@@ -92,42 +102,40 @@ async def health() -> dict[str, Any]:
@app.post("/chat/stream")
async def chat_stream(request: Request):
payload = await request.json()
message = payload.get("message")
conversation_id = payload.get("conversationId")
if not isinstance(message, str) or not message.strip():
return JSONResponse(status_code=400, content={"message": "message is required"})
async def chat_stream(payload: ChatStreamRequest, request: Request):
conv_id = (
conversation_id.strip()
if isinstance(conversation_id, str) and conversation_id.strip()
payload.conversation_id.strip()
if isinstance(payload.conversation_id, str) and payload.conversation_id.strip()
else f"conv-{uuid.uuid4().hex[:10]}"
)
message = payload.message.strip()
async def event_generator():
session = None
queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue()
done = asyncio.Event()
error_emitted = False
saw_message_delta = False
def on_event(event):
nonlocal error_emitted
nonlocal saw_message_delta
event_type = getattr(event.type, "value", str(event.type))
data = getattr(event, "data", None)
if event_type == "assistant.message_delta":
content = getattr(data, "delta_content", "") or ""
if content:
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
elif event_type == "assistant.message":
saw_message_delta = True
queue.put_nowait(
("token", {"conversationId": conv_id, "content": content})
)
elif event_type == "assistant.message" and not saw_message_delta:
content = getattr(data, "content", "") or ""
if content:
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
queue.put_nowait(
("token", {"conversationId": conv_id, "content": content})
)
elif event_type == "session.idle":
queue.put_nowait(("done", {"conversationId": conv_id}))
done.set()
elif event_type == "error":
error_emitted = True
queue.put_nowait(
(
"error",
@@ -144,22 +152,27 @@ async def chat_stream(request: Request):
session = await _get_or_create_session(conv_id)
unsubscribe = session.on(on_event)
try:
await session.send({"prompt": message})
await session.send(message)
while not done.is_set() or not queue.empty():
if await request.is_disconnected():
logger.info(
"Client disconnected during stream: conversation=%s",
conv_id,
)
return
try:
event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2)
event_name, event_data = await asyncio.wait_for(
queue.get(), timeout=0.2
)
yield _sse(event_name, event_data)
except asyncio.TimeoutError:
continue
if not error_emitted:
yield _sse("done", {"conversationId": conv_id})
finally:
unsubscribe()
if conv_id in sessions:
sessions[conv_id].last_used_at = time.time()
except Exception as exc:
logger.exception("Copilot generation failed for %s: %s", conv_id, exc)
yield _sse(
"error",
{
+71 -9
View File
@@ -1,21 +1,83 @@
import asyncio
import sys
import atexit
import os
import signal
import subprocess
import sys
from urllib.parse import urlparse
import uvicorn
# 将项目根目录添加到 python 路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
_SIDECAR_PROCESS: subprocess.Popen | None = None
def _parse_sidecar_target() -> tuple[str, int]:
sidecar_url = os.getenv("COPILOT_SIDECAR_URL", "http://127.0.0.1:8787").strip()
parsed = urlparse(sidecar_url)
host = parsed.hostname or "127.0.0.1"
port = parsed.port or 8787
return host, port
def _stop_sidecar() -> None:
global _SIDECAR_PROCESS
proc = _SIDECAR_PROCESS
if proc is None:
return
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=3)
_SIDECAR_PROCESS = None
def _start_sidecar_if_needed() -> None:
global _SIDECAR_PROCESS
sidecar_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "copilot-sidecar")
)
host, port = _parse_sidecar_target()
cmd = [
sys.executable,
"-m",
"uvicorn",
"server:app",
"--host",
host,
"--port",
str(port),
"--log-level",
os.getenv("COPILOT_SIDECAR_LOG_LEVEL", "warning"),
]
_SIDECAR_PROCESS = subprocess.Popen(cmd, cwd=sidecar_dir)
print(f"[run_server] sidecar started at {host}:{port}.")
if __name__ == "__main__":
# Windows 设置事件循环策略
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
# 用 uvicorn.run 支持 workers 参数
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
# workers=2, # 这里可以设置多进程
loop="asyncio",
)
atexit.register(_stop_sidecar)
signal.signal(signal.SIGTERM, lambda *_: _stop_sidecar())
signal.signal(signal.SIGINT, lambda *_: _stop_sidecar())
_start_sidecar_if_needed()
try:
# 用 uvicorn.run 支持 workers 参数
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
# workers=2, # 这里可以设置多进程
loop="asyncio",
)
finally:
_stop_sidecar()