diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 74b0b6c..1fda378 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -8,7 +8,7 @@ This is a FastAPI-based water network management system (供水管网智能管 # activate the server environment conda activate server # Install dependencies -pip install -r requirements.txt + # Start the server (default: http://0.0.0.0:8000 with 2 workers) python scripts/run_server.py diff --git a/app/api/v1/endpoints/leakage.py b/app/api/v1/endpoints/leakage.py index e8ae06c..5c6b62e 100644 --- a/app/api/v1/endpoints/leakage.py +++ b/app/api/v1/endpoints/leakage.py @@ -1,9 +1,11 @@ from typing import Any from datetime import datetime -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel +from app.auth.dependencies import get_current_user +from app.domain.schemas.user import UserInDB from app.services.leakage_identifier import ( get_leakage_identify_scheme_detail, list_leakage_identify_schemes, @@ -15,13 +17,15 @@ router = APIRouter() class LeakageIdentifyRequest(BaseModel): network: str - observed_pressure_data: str | dict[str, list[Any]] | list[dict[str, Any]] | None = None + observed_pressure_data: str | dict[str, list[Any]] | list[dict[str, Any]] | None = ( + None + ) start_time: float = 0 duration: float = 24 timestep: float = 5 q_sum: float = 0.2 q_sum_unit: str = "m3/s" - output_dir: str = "Results" + output_dir: str = "db_inp" pop_size: int = 50 max_gen: int = 100 output_flow_unit: str = "m3/s" @@ -30,13 +34,16 @@ class LeakageIdentifyRequest(BaseModel): scada_end: datetime | None = None sensor_nodes: list[str] | None = None scheme_name: str | None = None - username: str = "admin" @router.post("/identify/") -async def identify_leakage(data: LeakageIdentifyRequest) -> dict[str, Any]: +async def identify_leakage( + data: LeakageIdentifyRequest, current_user: UserInDB = Depends(get_current_user) +) -> dict[str, Any]: try: - return run_leakage_identification(**data.dict()) + return run_leakage_identification( + **data.model_dump(), username=current_user.username + ) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) @@ -52,10 +59,10 @@ async def query_leakage_schemes( @router.get("/schemes/{scheme_name}") -async def query_leakage_scheme_detail( - network: str, scheme_name: str -) -> dict[str, Any]: +async def query_leakage_scheme_detail(network: str, scheme_name: str) -> dict[str, Any]: try: - return get_leakage_identify_scheme_detail(network=network, scheme_name=scheme_name) + return get_leakage_identify_scheme_detail( + network=network, scheme_name=scheme_name + ) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) diff --git a/app/services/leakage_identifier.py b/app/services/leakage_identifier.py index ca19988..9dfa6da 100644 --- a/app/services/leakage_identifier.py +++ b/app/services/leakage_identifier.py @@ -6,6 +6,7 @@ from typing import Any import numpy as np import pandas as pd +import wntr from app.algorithms.leakage_identifier import LeakageIdentifier from app.infra.db.influxdb import api as influxdb_api @@ -34,7 +35,7 @@ def run_leakage_identification( timestep: float = 5, q_sum: float = 0.2, q_sum_unit: str = "m3/s", - output_dir: str = "Results", + output_dir: str = "db_inp", pop_size: int = 50, max_gen: int = 100, output_flow_unit: str = "m3/s", @@ -46,8 +47,7 @@ def run_leakage_identification( username: str = "admin", ) -> dict[str, Any]: os.makedirs(output_dir, exist_ok=True) - inp_path = os.path.join(output_dir, f"{network}.leakage.inp") - dump_inp(network, inp_path, "2") + inp_path = _prepare_leakage_inp(network) selected_sensor_nodes = ( list(dict.fromkeys([node for node in (sensor_nodes or []) if node])) @@ -500,3 +500,26 @@ def _to_datetime(value: datetime | str) -> datetime: if isinstance(value, datetime): return value return datetime.fromisoformat(value) + + +def _prepare_leakage_inp(network: str) -> str: + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + db_inp_dir = os.path.join(project_root, "db_inp") + os.makedirs(db_inp_dir, exist_ok=True) + inp_path = os.path.join(db_inp_dir, f"{network}.leakage.inp") + if _is_valid_inp_file(inp_path): + return inp_path + dump_inp(network, inp_path, "2") + if not _is_valid_inp_file(inp_path): + raise ValueError(f"漏损识别 INP 文件无效: {inp_path}") + return inp_path + + +def _is_valid_inp_file(inp_path: str) -> bool: + if not os.path.isfile(inp_path) or os.path.getsize(inp_path) <= 0: + return False + try: + wntr.network.WaterNetworkModel(inp_path) + return True + except Exception: + return False diff --git a/tests/api/test_leakage_endpoints.py b/tests/api/test_leakage_endpoints.py index 7a0cc10..bab360d 100644 --- a/tests/api/test_leakage_endpoints.py +++ b/tests/api/test_leakage_endpoints.py @@ -1,5 +1,6 @@ from fastapi import FastAPI from fastapi.testclient import TestClient +from types import SimpleNamespace from app.api.v1.endpoints import leakage as leakage_endpoint @@ -7,12 +8,16 @@ from app.api.v1.endpoints import leakage as leakage_endpoint def _build_client() -> TestClient: app = FastAPI() app.include_router(leakage_endpoint.router, prefix="/api/v1/leakage") + app.dependency_overrides[leakage_endpoint.get_current_user] = lambda: SimpleNamespace( + username="tester" + ) return TestClient(app) def test_identify_leakage_success(monkeypatch): def fake_run_leakage_identification(**kwargs): assert kwargs["network"] == "demo" + assert kwargs["username"] == "tester" return {"rows": [], "area_count": 0} monkeypatch.setattr(