添加流式 Copilot 请求处理及审计中间件优化
This commit is contained in:
@@ -61,12 +61,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
"/api/v1/openproject/",
|
||||
"/openproject/",
|
||||
"/api/v1/copilot/chat/",
|
||||
"/api/v1/copilot/chat/stream",
|
||||
}
|
||||
EXCLUDED_PATH_PREFIXES = (
|
||||
"/api/v1/copilot/chat/",
|
||||
"/copilot/chat/",
|
||||
)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 提取开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
|
||||
if self._is_excluded_path(request.url.path):
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# 1. 预判是否需要读取Body (针对写操作)
|
||||
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
||||
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
||||
@@ -75,13 +87,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
if should_capture_body:
|
||||
try:
|
||||
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
||||
original_receive = request._receive
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_data = json.loads(body.decode())
|
||||
|
||||
# 重新构造请求以供后续使用
|
||||
# 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
|
||||
body_sent = False
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
nonlocal body_sent
|
||||
if not body_sent:
|
||||
body_sent = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": body,
|
||||
"more_body": False,
|
||||
}
|
||||
return await original_receive()
|
||||
|
||||
request._receive = receive
|
||||
except Exception as e:
|
||||
@@ -91,7 +114,7 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 决定是否审计
|
||||
if request.url.path in self.EXCLUDED_PATHS:
|
||||
if self._is_excluded_path(request.url.path):
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
@@ -151,6 +174,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return response
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
if path in self.EXCLUDED_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
|
||||
|
||||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||
project_header = request.headers.get("X-Project-Id")
|
||||
if not project_header:
|
||||
|
||||
+45
-32
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
@@ -9,7 +10,8 @@ from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from copilot import CopilotClient, PermissionHandler
|
||||
|
||||
|
||||
@@ -23,11 +25,21 @@ class SessionHolder:
|
||||
last_used_at: float
|
||||
|
||||
|
||||
app = FastAPI(title="TJWater Copilot Python Sidecar")
|
||||
app = FastAPI(title="TJWater Copilot Sidecar")
|
||||
client: Optional[CopilotClient] = None
|
||||
sessions: dict[str, SessionHolder] = {}
|
||||
session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800"))
|
||||
model = os.getenv("COPILOT_MODEL", "gpt-5.3-codex")
|
||||
model = os.getenv("COPILOT_MODEL", "gpt-4.1")
|
||||
logger = logging.getLogger("copilot_sidecar")
|
||||
|
||||
|
||||
class ChatStreamRequest(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=10000)
|
||||
conversation_id: Optional[str] = Field(
|
||||
default=None, alias="conversationId", max_length=128
|
||||
)
|
||||
user_id: Optional[str] = Field(default=None, alias="userId", max_length=128)
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -43,8 +55,8 @@ async def shutdown_event() -> None:
|
||||
for holder in sessions.values():
|
||||
try:
|
||||
await holder.session.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to disconnect session during shutdown: %s", exc)
|
||||
sessions.clear()
|
||||
await client.stop()
|
||||
|
||||
@@ -62,8 +74,8 @@ async def _cleanup_sessions() -> None:
|
||||
continue
|
||||
try:
|
||||
await holder.session.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to disconnect expired session %s: %s", sid, exc)
|
||||
|
||||
|
||||
async def _get_or_create_session(conversation_id: str):
|
||||
@@ -76,11 +88,9 @@ async def _get_or_create_session(conversation_id: str):
|
||||
raise RuntimeError("Copilot client is not initialized")
|
||||
|
||||
session = await client.create_session(
|
||||
{
|
||||
"model": model,
|
||||
"streaming": True,
|
||||
"on_permission_request": PermissionHandler.approve_all,
|
||||
}
|
||||
model=model,
|
||||
streaming=True,
|
||||
on_permission_request=PermissionHandler.approve_all,
|
||||
)
|
||||
sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time())
|
||||
return session
|
||||
@@ -92,42 +102,40 @@ async def health() -> dict[str, Any]:
|
||||
|
||||
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream(request: Request):
|
||||
payload = await request.json()
|
||||
message = payload.get("message")
|
||||
conversation_id = payload.get("conversationId")
|
||||
if not isinstance(message, str) or not message.strip():
|
||||
return JSONResponse(status_code=400, content={"message": "message is required"})
|
||||
|
||||
async def chat_stream(payload: ChatStreamRequest, request: Request):
|
||||
conv_id = (
|
||||
conversation_id.strip()
|
||||
if isinstance(conversation_id, str) and conversation_id.strip()
|
||||
payload.conversation_id.strip()
|
||||
if isinstance(payload.conversation_id, str) and payload.conversation_id.strip()
|
||||
else f"conv-{uuid.uuid4().hex[:10]}"
|
||||
)
|
||||
message = payload.message.strip()
|
||||
|
||||
async def event_generator():
|
||||
session = None
|
||||
queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue()
|
||||
done = asyncio.Event()
|
||||
error_emitted = False
|
||||
saw_message_delta = False
|
||||
|
||||
def on_event(event):
|
||||
nonlocal error_emitted
|
||||
nonlocal saw_message_delta
|
||||
event_type = getattr(event.type, "value", str(event.type))
|
||||
data = getattr(event, "data", None)
|
||||
if event_type == "assistant.message_delta":
|
||||
content = getattr(data, "delta_content", "") or ""
|
||||
if content:
|
||||
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
|
||||
elif event_type == "assistant.message":
|
||||
saw_message_delta = True
|
||||
queue.put_nowait(
|
||||
("token", {"conversationId": conv_id, "content": content})
|
||||
)
|
||||
elif event_type == "assistant.message" and not saw_message_delta:
|
||||
content = getattr(data, "content", "") or ""
|
||||
if content:
|
||||
queue.put_nowait(("token", {"conversationId": conv_id, "content": content}))
|
||||
queue.put_nowait(
|
||||
("token", {"conversationId": conv_id, "content": content})
|
||||
)
|
||||
elif event_type == "session.idle":
|
||||
queue.put_nowait(("done", {"conversationId": conv_id}))
|
||||
done.set()
|
||||
elif event_type == "error":
|
||||
error_emitted = True
|
||||
queue.put_nowait(
|
||||
(
|
||||
"error",
|
||||
@@ -144,22 +152,27 @@ async def chat_stream(request: Request):
|
||||
session = await _get_or_create_session(conv_id)
|
||||
unsubscribe = session.on(on_event)
|
||||
try:
|
||||
await session.send({"prompt": message})
|
||||
await session.send(message)
|
||||
while not done.is_set() or not queue.empty():
|
||||
if await request.is_disconnected():
|
||||
logger.info(
|
||||
"Client disconnected during stream: conversation=%s",
|
||||
conv_id,
|
||||
)
|
||||
return
|
||||
try:
|
||||
event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2)
|
||||
event_name, event_data = await asyncio.wait_for(
|
||||
queue.get(), timeout=0.2
|
||||
)
|
||||
yield _sse(event_name, event_data)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if not error_emitted:
|
||||
yield _sse("done", {"conversationId": conv_id})
|
||||
finally:
|
||||
unsubscribe()
|
||||
if conv_id in sessions:
|
||||
sessions[conv_id].last_used_at = time.time()
|
||||
except Exception as exc:
|
||||
logger.exception("Copilot generation failed for %s: %s", conv_id, exc)
|
||||
yield _sse(
|
||||
"error",
|
||||
{
|
||||
|
||||
+71
-9
@@ -1,21 +1,83 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from urllib.parse import urlparse
|
||||
import uvicorn
|
||||
|
||||
# 将项目根目录添加到 python 路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
|
||||
_SIDECAR_PROCESS: subprocess.Popen | None = None
|
||||
|
||||
|
||||
def _parse_sidecar_target() -> tuple[str, int]:
|
||||
sidecar_url = os.getenv("COPILOT_SIDECAR_URL", "http://127.0.0.1:8787").strip()
|
||||
parsed = urlparse(sidecar_url)
|
||||
host = parsed.hostname or "127.0.0.1"
|
||||
port = parsed.port or 8787
|
||||
return host, port
|
||||
|
||||
|
||||
def _stop_sidecar() -> None:
|
||||
global _SIDECAR_PROCESS
|
||||
proc = _SIDECAR_PROCESS
|
||||
if proc is None:
|
||||
return
|
||||
if proc.poll() is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=3)
|
||||
_SIDECAR_PROCESS = None
|
||||
|
||||
|
||||
def _start_sidecar_if_needed() -> None:
|
||||
global _SIDECAR_PROCESS
|
||||
sidecar_dir = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "copilot-sidecar")
|
||||
)
|
||||
|
||||
host, port = _parse_sidecar_target()
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"server:app",
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--log-level",
|
||||
os.getenv("COPILOT_SIDECAR_LOG_LEVEL", "warning"),
|
||||
]
|
||||
_SIDECAR_PROCESS = subprocess.Popen(cmd, cwd=sidecar_dir)
|
||||
print(f"[run_server] sidecar started at {host}:{port}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Windows 设置事件循环策略
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
# 用 uvicorn.run 支持 workers 参数
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
# workers=2, # 这里可以设置多进程
|
||||
loop="asyncio",
|
||||
)
|
||||
atexit.register(_stop_sidecar)
|
||||
signal.signal(signal.SIGTERM, lambda *_: _stop_sidecar())
|
||||
signal.signal(signal.SIGINT, lambda *_: _stop_sidecar())
|
||||
|
||||
_start_sidecar_if_needed()
|
||||
try:
|
||||
# 用 uvicorn.run 支持 workers 参数
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
# workers=2, # 这里可以设置多进程
|
||||
loop="asyncio",
|
||||
)
|
||||
finally:
|
||||
_stop_sidecar()
|
||||
|
||||
Reference in New Issue
Block a user