From 600ddd329c7ab865bbe3c50e4c52b054de14d216 Mon Sep 17 00:00:00 2001 From: Jiang Date: Tue, 24 Mar 2026 16:01:22 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=81=E5=BC=8F=20Copilot?= =?UTF-8?q?=20=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=E5=8F=8A=E5=AE=A1?= =?UTF-8?q?=E8=AE=A1=E4=B8=AD=E9=97=B4=E4=BB=B6=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/infra/audit/middleware.py | 34 +++++++++++++-- copilot-sidecar/server.py | 77 +++++++++++++++++++-------------- scripts/run_server.py | 80 +++++++++++++++++++++++++++++++---- 3 files changed, 147 insertions(+), 44 deletions(-) diff --git a/app/infra/audit/middleware.py b/app/infra/audit/middleware.py index 657002f..544b29b 100644 --- a/app/infra/audit/middleware.py +++ b/app/infra/audit/middleware.py @@ -61,12 +61,24 @@ class AuditMiddleware(BaseHTTPMiddleware): "/api/v1/openproject/", "/openproject/", "/api/v1/copilot/chat/", + "/api/v1/copilot/chat/stream", } + EXCLUDED_PATH_PREFIXES = ( + "/api/v1/copilot/chat/", + "/copilot/chat/", + ) async def dispatch(self, request: Request, call_next: Callable) -> Response: # 提取开始时间 start_time = time.time() + # 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期 + if self._is_excluded_path(request.url.path): + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + # 1. 预判是否需要读取Body (针对写操作) # 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag should_capture_body = request.method in ["POST", "PUT", "PATCH"] @@ -75,13 +87,24 @@ class AuditMiddleware(BaseHTTPMiddleware): if should_capture_body: try: # 注意:读取 body 后需要重新设置,避免影响后续处理 + original_receive = request._receive body = await request.body() if body: request_data = json.loads(body.decode()) - # 重新构造请求以供后续使用 + # 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive + body_sent = False + async def receive(): - return {"type": "http.request", "body": body} + nonlocal body_sent + if not body_sent: + body_sent = True + return { + "type": "http.request", + "body": body, + "more_body": False, + } + return await original_receive() request._receive = receive except Exception as e: @@ -91,7 +114,7 @@ class AuditMiddleware(BaseHTTPMiddleware): response = await call_next(request) # 3. 决定是否审计 - if request.url.path in self.EXCLUDED_PATHS: + if self._is_excluded_path(request.url.path): process_time = time.time() - start_time response.headers["X-Process-Time"] = str(process_time) return response @@ -151,6 +174,11 @@ class AuditMiddleware(BaseHTTPMiddleware): return response + def _is_excluded_path(self, path: str) -> bool: + if path in self.EXCLUDED_PATHS: + return True + return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES) + def _resolve_project_id(self, request: Request) -> UUID | None: project_header = request.headers.get("X-Project-Id") if not project_header: diff --git a/copilot-sidecar/server.py b/copilot-sidecar/server.py index ed44fe6..1930c69 100644 --- a/copilot-sidecar/server.py +++ b/copilot-sidecar/server.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import json +import logging import os import time import uuid @@ -9,7 +10,8 @@ from dataclasses import dataclass from typing import Any, Optional from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field, ConfigDict from copilot import CopilotClient, PermissionHandler @@ -23,11 +25,21 @@ class SessionHolder: last_used_at: float -app = FastAPI(title="TJWater Copilot Python Sidecar") +app = FastAPI(title="TJWater Copilot Sidecar") client: Optional[CopilotClient] = None sessions: dict[str, SessionHolder] = {} session_ttl_seconds = int(os.getenv("COPILOT_SESSION_TTL_SECONDS", "1800")) -model = os.getenv("COPILOT_MODEL", "gpt-5.3-codex") +model = os.getenv("COPILOT_MODEL", "gpt-4.1") +logger = logging.getLogger("copilot_sidecar") + + +class ChatStreamRequest(BaseModel): + message: str = Field(..., min_length=1, max_length=10000) + conversation_id: Optional[str] = Field( + default=None, alias="conversationId", max_length=128 + ) + user_id: Optional[str] = Field(default=None, alias="userId", max_length=128) + model_config = ConfigDict(populate_by_name=True) @app.on_event("startup") @@ -43,8 +55,8 @@ async def shutdown_event() -> None: for holder in sessions.values(): try: await holder.session.disconnect() - except Exception: - pass + except Exception as exc: + logger.warning("Failed to disconnect session during shutdown: %s", exc) sessions.clear() await client.stop() @@ -62,8 +74,8 @@ async def _cleanup_sessions() -> None: continue try: await holder.session.disconnect() - except Exception: - pass + except Exception as exc: + logger.warning("Failed to disconnect expired session %s: %s", sid, exc) async def _get_or_create_session(conversation_id: str): @@ -76,11 +88,9 @@ async def _get_or_create_session(conversation_id: str): raise RuntimeError("Copilot client is not initialized") session = await client.create_session( - { - "model": model, - "streaming": True, - "on_permission_request": PermissionHandler.approve_all, - } + model=model, + streaming=True, + on_permission_request=PermissionHandler.approve_all, ) sessions[conversation_id] = SessionHolder(session=session, last_used_at=time.time()) return session @@ -92,42 +102,40 @@ async def health() -> dict[str, Any]: @app.post("/chat/stream") -async def chat_stream(request: Request): - payload = await request.json() - message = payload.get("message") - conversation_id = payload.get("conversationId") - if not isinstance(message, str) or not message.strip(): - return JSONResponse(status_code=400, content={"message": "message is required"}) - +async def chat_stream(payload: ChatStreamRequest, request: Request): conv_id = ( - conversation_id.strip() - if isinstance(conversation_id, str) and conversation_id.strip() + payload.conversation_id.strip() + if isinstance(payload.conversation_id, str) and payload.conversation_id.strip() else f"conv-{uuid.uuid4().hex[:10]}" ) + message = payload.message.strip() async def event_generator(): - session = None queue: asyncio.Queue[tuple[str, dict[str, Any]]] = asyncio.Queue() done = asyncio.Event() - error_emitted = False + saw_message_delta = False def on_event(event): - nonlocal error_emitted + nonlocal saw_message_delta event_type = getattr(event.type, "value", str(event.type)) data = getattr(event, "data", None) if event_type == "assistant.message_delta": content = getattr(data, "delta_content", "") or "" if content: - queue.put_nowait(("token", {"conversationId": conv_id, "content": content})) - elif event_type == "assistant.message": + saw_message_delta = True + queue.put_nowait( + ("token", {"conversationId": conv_id, "content": content}) + ) + elif event_type == "assistant.message" and not saw_message_delta: content = getattr(data, "content", "") or "" if content: - queue.put_nowait(("token", {"conversationId": conv_id, "content": content})) + queue.put_nowait( + ("token", {"conversationId": conv_id, "content": content}) + ) elif event_type == "session.idle": queue.put_nowait(("done", {"conversationId": conv_id})) done.set() elif event_type == "error": - error_emitted = True queue.put_nowait( ( "error", @@ -144,22 +152,27 @@ async def chat_stream(request: Request): session = await _get_or_create_session(conv_id) unsubscribe = session.on(on_event) try: - await session.send({"prompt": message}) + await session.send(message) while not done.is_set() or not queue.empty(): if await request.is_disconnected(): + logger.info( + "Client disconnected during stream: conversation=%s", + conv_id, + ) return try: - event_name, event_data = await asyncio.wait_for(queue.get(), timeout=0.2) + event_name, event_data = await asyncio.wait_for( + queue.get(), timeout=0.2 + ) yield _sse(event_name, event_data) except asyncio.TimeoutError: continue - if not error_emitted: - yield _sse("done", {"conversationId": conv_id}) finally: unsubscribe() if conv_id in sessions: sessions[conv_id].last_used_at = time.time() except Exception as exc: + logger.exception("Copilot generation failed for %s: %s", conv_id, exc) yield _sse( "error", { diff --git a/scripts/run_server.py b/scripts/run_server.py index 57f2c67..5b50447 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -1,21 +1,83 @@ import asyncio -import sys +import atexit import os +import signal +import subprocess +import sys +from urllib.parse import urlparse import uvicorn # 将项目根目录添加到 python 路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +_SIDECAR_PROCESS: subprocess.Popen | None = None + + +def _parse_sidecar_target() -> tuple[str, int]: + sidecar_url = os.getenv("COPILOT_SIDECAR_URL", "http://127.0.0.1:8787").strip() + parsed = urlparse(sidecar_url) + host = parsed.hostname or "127.0.0.1" + port = parsed.port or 8787 + return host, port + + +def _stop_sidecar() -> None: + global _SIDECAR_PROCESS + proc = _SIDECAR_PROCESS + if proc is None: + return + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout=3) + _SIDECAR_PROCESS = None + + +def _start_sidecar_if_needed() -> None: + global _SIDECAR_PROCESS + sidecar_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "copilot-sidecar") + ) + + host, port = _parse_sidecar_target() + cmd = [ + sys.executable, + "-m", + "uvicorn", + "server:app", + "--host", + host, + "--port", + str(port), + "--log-level", + os.getenv("COPILOT_SIDECAR_LOG_LEVEL", "warning"), + ] + _SIDECAR_PROCESS = subprocess.Popen(cmd, cwd=sidecar_dir) + print(f"[run_server] sidecar started at {host}:{port}.") + + if __name__ == "__main__": # Windows 设置事件循环策略 if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - # 用 uvicorn.run 支持 workers 参数 - uvicorn.run( - "app.main:app", - host="0.0.0.0", - port=8000, - # workers=2, # 这里可以设置多进程 - loop="asyncio", - ) + atexit.register(_stop_sidecar) + signal.signal(signal.SIGTERM, lambda *_: _stop_sidecar()) + signal.signal(signal.SIGINT, lambda *_: _stop_sidecar()) + + _start_sidecar_if_needed() + try: + # 用 uvicorn.run 支持 workers 参数 + uvicorn.run( + "app.main:app", + host="0.0.0.0", + port=8000, + # workers=2, # 这里可以设置多进程 + loop="asyncio", + ) + finally: + _stop_sidecar()