添加流式 Copilot 请求处理及审计中间件优化

This commit is contained in:
2026-03-24 16:01:22 +08:00
parent c184610035
commit 600ddd329c
3 changed files with 147 additions and 44 deletions
+31 -3
View File
@@ -61,12 +61,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
"/api/v1/openproject/",
"/openproject/",
"/api/v1/copilot/chat/",
"/api/v1/copilot/chat/stream",
}
EXCLUDED_PATH_PREFIXES = (
"/api/v1/copilot/chat/",
"/copilot/chat/",
)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 提取开始时间
start_time = time.time()
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
if self._is_excluded_path(request.url.path):
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 1. 预判是否需要读取Body (针对写操作)
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
@@ -75,13 +87,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
if should_capture_body:
try:
# 注意:读取 body 后需要重新设置,避免影响后续处理
original_receive = request._receive
body = await request.body()
if body:
request_data = json.loads(body.decode())
# 重新构造请求以供后续使用
# 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
body_sent = False
async def receive():
return {"type": "http.request", "body": body}
nonlocal body_sent
if not body_sent:
body_sent = True
return {
"type": "http.request",
"body": body,
"more_body": False,
}
return await original_receive()
request._receive = receive
except Exception as e:
@@ -91,7 +114,7 @@ class AuditMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
# 3. 决定是否审计
if request.url.path in self.EXCLUDED_PATHS:
if self._is_excluded_path(request.url.path):
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@@ -151,6 +174,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
def _is_excluded_path(self, path: str) -> bool:
if path in self.EXCLUDED_PATHS:
return True
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header: