新增模拟数据源支持,重构爆管定位逻辑
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -16,6 +18,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
class BurstLocationRequest(BaseModel):
|
class BurstLocationRequest(BaseModel):
|
||||||
network: str
|
network: str
|
||||||
|
data_source: Literal["monitoring", "simulation"] = "monitoring"
|
||||||
pressure_scada_ids: list[str] | None = None
|
pressure_scada_ids: list[str] | None = None
|
||||||
burst_pressure: dict[str, float] | list[dict[str, Any]] | None = None
|
burst_pressure: dict[str, float] | list[dict[str, Any]] | None = None
|
||||||
normal_pressure: dict[str, float] | list[dict[str, Any]] | None = None
|
normal_pressure: dict[str, float] | list[dict[str, Any]] | None = None
|
||||||
@@ -31,6 +34,8 @@ class BurstLocationRequest(BaseModel):
|
|||||||
scada_normal_end: datetime | None = None
|
scada_normal_end: datetime | None = None
|
||||||
use_scada_flow: bool = False
|
use_scada_flow: bool = False
|
||||||
scheme_name: str | None = None
|
scheme_name: str | None = None
|
||||||
|
simulation_scheme_name: str | None = None
|
||||||
|
simulation_scheme_type: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/locate/")
|
@router.post("/locate/")
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from typing import List
|
|||||||
from fastapi.logger import logger
|
from fastapi.logger import logger
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import psycopg
|
import psycopg
|
||||||
|
from psycopg import sql
|
||||||
|
from psycopg.rows import dict_row
|
||||||
import time
|
import time
|
||||||
from app.infra.db.timescaledb.schemas.scheme import SchemeRepository
|
from app.infra.db.timescaledb.schemas.scheme import SchemeRepository
|
||||||
from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository
|
from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository
|
||||||
@@ -170,3 +172,159 @@ class InternalQueries:
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def query_realtime_simulation_by_ids_timerange(
|
||||||
|
element_ids: List[str],
|
||||||
|
start_time: str | datetime,
|
||||||
|
end_time: str | datetime,
|
||||||
|
element_type: str,
|
||||||
|
field: str,
|
||||||
|
db_name: str = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> dict[str, list[dict]]:
|
||||||
|
"""查询实时模拟结果,返回 {id: [{time, value}, ...]}。"""
|
||||||
|
return InternalQueries._query_simulation_by_ids_timerange(
|
||||||
|
schema_name="realtime",
|
||||||
|
element_ids=element_ids,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
element_type=element_type,
|
||||||
|
field=field,
|
||||||
|
db_name=db_name,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def query_scheme_simulation_by_ids_timerange(
|
||||||
|
element_ids: List[str],
|
||||||
|
start_time: str | datetime,
|
||||||
|
end_time: str | datetime,
|
||||||
|
element_type: str,
|
||||||
|
field: str,
|
||||||
|
scheme_type: str,
|
||||||
|
scheme_name: str,
|
||||||
|
db_name: str = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> dict[str, list[dict]]:
|
||||||
|
"""查询方案模拟结果,返回 {id: [{time, value}, ...]}。"""
|
||||||
|
return InternalQueries._query_simulation_by_ids_timerange(
|
||||||
|
schema_name="scheme",
|
||||||
|
element_ids=element_ids,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
element_type=element_type,
|
||||||
|
field=field,
|
||||||
|
db_name=db_name,
|
||||||
|
max_retries=max_retries,
|
||||||
|
scheme_type=scheme_type,
|
||||||
|
scheme_name=scheme_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _query_simulation_by_ids_timerange(
|
||||||
|
*,
|
||||||
|
schema_name: str,
|
||||||
|
element_ids: List[str],
|
||||||
|
start_time: str | datetime,
|
||||||
|
end_time: str | datetime,
|
||||||
|
element_type: str,
|
||||||
|
field: str,
|
||||||
|
db_name: str = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
scheme_type: str | None = None,
|
||||||
|
scheme_name: str | None = None,
|
||||||
|
) -> dict[str, list[dict]]:
|
||||||
|
if not element_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
start_dt = (
|
||||||
|
datetime.fromisoformat(start_time)
|
||||||
|
if isinstance(start_time, str)
|
||||||
|
else start_time
|
||||||
|
)
|
||||||
|
end_dt = (
|
||||||
|
datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time
|
||||||
|
)
|
||||||
|
table_name, valid_fields = InternalQueries._resolve_simulation_table(element_type)
|
||||||
|
if field not in valid_fields:
|
||||||
|
raise ValueError(f"Invalid field for {element_type}: {field}")
|
||||||
|
if schema_name not in {"realtime", "scheme"}:
|
||||||
|
raise ValueError(f"Unsupported schema_name: {schema_name}")
|
||||||
|
if schema_name == "scheme" and (not scheme_type or not scheme_name):
|
||||||
|
raise ValueError("scheme 查询必须提供 scheme_type 和 scheme_name。")
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
conn_string = (
|
||||||
|
timescaledb_info.get_pgconn_string(db_name=db_name)
|
||||||
|
if db_name
|
||||||
|
else timescaledb_info.get_pgconn_string()
|
||||||
|
)
|
||||||
|
with psycopg.Connection.connect(conn_string) as conn:
|
||||||
|
with conn.cursor(row_factory=dict_row) as cur:
|
||||||
|
if schema_name == "scheme":
|
||||||
|
query = sql.SQL(
|
||||||
|
"SELECT id, time, {} FROM {}.{} "
|
||||||
|
"WHERE scheme_type = %s AND scheme_name = %s "
|
||||||
|
"AND time >= %s AND time <= %s AND id = ANY(%s)"
|
||||||
|
).format(
|
||||||
|
sql.Identifier(field),
|
||||||
|
sql.Identifier(schema_name),
|
||||||
|
sql.Identifier(table_name),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
query,
|
||||||
|
(
|
||||||
|
scheme_type,
|
||||||
|
scheme_name,
|
||||||
|
start_dt,
|
||||||
|
end_dt,
|
||||||
|
element_ids,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = sql.SQL(
|
||||||
|
"SELECT id, time, {} FROM {}.{} "
|
||||||
|
"WHERE time >= %s AND time <= %s AND id = ANY(%s)"
|
||||||
|
).format(
|
||||||
|
sql.Identifier(field),
|
||||||
|
sql.Identifier(schema_name),
|
||||||
|
sql.Identifier(table_name),
|
||||||
|
)
|
||||||
|
cur.execute(query, (start_dt, end_dt, element_ids))
|
||||||
|
rows = cur.fetchall()
|
||||||
|
result: dict[str, list[dict]] = {
|
||||||
|
element_id: [] for element_id in element_ids
|
||||||
|
}
|
||||||
|
for row in rows:
|
||||||
|
result.setdefault(row["id"], []).append(
|
||||||
|
{"time": row["time"].isoformat(), "value": row[field]}
|
||||||
|
)
|
||||||
|
for element_id in result:
|
||||||
|
result[element_id].sort(key=lambda item: item["time"])
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_simulation_table(element_type: str) -> tuple[str, set[str]]:
|
||||||
|
normalized_type = element_type.lower()
|
||||||
|
if normalized_type == "node":
|
||||||
|
return "node_simulation", {"actual_demand", "total_head", "pressure", "quality"}
|
||||||
|
if normalized_type == "link":
|
||||||
|
return "link_simulation", {
|
||||||
|
"flow",
|
||||||
|
"friction",
|
||||||
|
"headloss",
|
||||||
|
"quality",
|
||||||
|
"reaction",
|
||||||
|
"setting",
|
||||||
|
"status",
|
||||||
|
"velocity",
|
||||||
|
}
|
||||||
|
raise ValueError(f"Unsupported element_type: {element_type}")
|
||||||
|
|||||||
+309
-43
@@ -18,6 +18,8 @@ from app.services.tjnetwork import dump_inp, get_all_scada_info
|
|||||||
|
|
||||||
SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]]
|
SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]]
|
||||||
FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"}
|
FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"}
|
||||||
|
SIMULATION_DATA_SOURCES = {"monitoring", "simulation"}
|
||||||
|
DEFAULT_SIMULATION_SCHEME_TYPE = "burst_analysis"
|
||||||
|
|
||||||
|
|
||||||
def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
|
def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
|
||||||
@@ -44,6 +46,7 @@ def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
|
|||||||
def run_burst_location_by_network(
|
def run_burst_location_by_network(
|
||||||
*,
|
*,
|
||||||
network: str,
|
network: str,
|
||||||
|
data_source: str = "monitoring",
|
||||||
burst_leakage: float,
|
burst_leakage: float,
|
||||||
pressure_scada_ids: list[str] | None = None,
|
pressure_scada_ids: list[str] | None = None,
|
||||||
burst_pressure: SeriesInput | None = None,
|
burst_pressure: SeriesInput | None = None,
|
||||||
@@ -59,10 +62,18 @@ def run_burst_location_by_network(
|
|||||||
scada_normal_end: datetime | str | None = None,
|
scada_normal_end: datetime | str | None = None,
|
||||||
use_scada_flow: bool = False,
|
use_scada_flow: bool = False,
|
||||||
scheme_name: str | None = None,
|
scheme_name: str | None = None,
|
||||||
|
simulation_scheme_name: str | None = None,
|
||||||
|
simulation_scheme_type: str | None = None,
|
||||||
username: str = "admin",
|
username: str = "admin",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if not network:
|
if not network:
|
||||||
raise ValueError("network is required.")
|
raise ValueError("network is required.")
|
||||||
|
normalized_data_source = _normalize_data_source(
|
||||||
|
data_source, simulation_scheme_name=simulation_scheme_name
|
||||||
|
)
|
||||||
|
resolved_simulation_scheme_type = (
|
||||||
|
simulation_scheme_type or DEFAULT_SIMULATION_SCHEME_TYPE
|
||||||
|
)
|
||||||
|
|
||||||
selected_pressure_ids = (
|
selected_pressure_ids = (
|
||||||
_dedupe_ids(pressure_scada_ids)
|
_dedupe_ids(pressure_scada_ids)
|
||||||
@@ -93,26 +104,60 @@ def run_burst_location_by_network(
|
|||||||
scada_normal_start=scada_normal_start,
|
scada_normal_start=scada_normal_start,
|
||||||
scada_normal_end=scada_normal_end,
|
scada_normal_end=scada_normal_end,
|
||||||
)
|
)
|
||||||
burst_pressure_series, burst_pressure_samples = _build_observed_series_from_scada(
|
if normalized_data_source == "simulation":
|
||||||
network=network,
|
if not simulation_scheme_name:
|
||||||
sensor_ids=selected_pressure_ids,
|
raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
|
||||||
start_dt=burst_start_dt,
|
burst_pressure_series, burst_pressure_samples = (
|
||||||
end_dt=burst_end_dt,
|
_build_observed_series_from_simulation(
|
||||||
data_type="pressure",
|
network=network,
|
||||||
series_name="burst_pressure",
|
sensor_ids=selected_pressure_ids,
|
||||||
)
|
start_dt=burst_start_dt,
|
||||||
(
|
end_dt=burst_end_dt,
|
||||||
normal_pressure_series,
|
data_type="pressure",
|
||||||
normal_pressure_samples,
|
series_name="burst_pressure",
|
||||||
) = _build_observed_series_from_scada(
|
simulation_source="scheme",
|
||||||
network=network,
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
sensor_ids=selected_pressure_ids,
|
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||||
start_dt=normal_start_dt,
|
)
|
||||||
end_dt=normal_end_dt,
|
)
|
||||||
data_type="pressure",
|
(
|
||||||
series_name="normal_pressure",
|
normal_pressure_series,
|
||||||
)
|
normal_pressure_samples,
|
||||||
observed_source = "backend_timerange"
|
) = _build_observed_series_from_simulation(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=selected_pressure_ids,
|
||||||
|
start_dt=normal_start_dt,
|
||||||
|
end_dt=normal_end_dt,
|
||||||
|
data_type="pressure",
|
||||||
|
series_name="normal_pressure",
|
||||||
|
simulation_source="scheme",
|
||||||
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
|
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||||
|
)
|
||||||
|
observed_source = "simulation_scheme_timerange"
|
||||||
|
else:
|
||||||
|
burst_pressure_series, burst_pressure_samples = (
|
||||||
|
_build_observed_series_from_scada(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=selected_pressure_ids,
|
||||||
|
start_dt=burst_start_dt,
|
||||||
|
end_dt=burst_end_dt,
|
||||||
|
data_type="pressure",
|
||||||
|
series_name="burst_pressure",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
normal_pressure_series,
|
||||||
|
normal_pressure_samples,
|
||||||
|
) = _build_observed_series_from_scada(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=selected_pressure_ids,
|
||||||
|
start_dt=normal_start_dt,
|
||||||
|
end_dt=normal_end_dt,
|
||||||
|
data_type="pressure",
|
||||||
|
series_name="normal_pressure",
|
||||||
|
)
|
||||||
|
observed_source = "backend_timerange"
|
||||||
else:
|
else:
|
||||||
if burst_pressure is None or normal_pressure is None:
|
if burst_pressure is None or normal_pressure is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -139,22 +184,52 @@ def run_burst_location_by_network(
|
|||||||
)
|
)
|
||||||
if not selected_flow_ids:
|
if not selected_flow_ids:
|
||||||
raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。")
|
raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。")
|
||||||
burst_flow_series, burst_flow_samples = _build_observed_series_from_scada(
|
if normalized_data_source == "simulation":
|
||||||
network=network,
|
if not simulation_scheme_name:
|
||||||
sensor_ids=selected_flow_ids,
|
raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
|
||||||
start_dt=burst_start_dt,
|
burst_flow_series, burst_flow_samples = (
|
||||||
end_dt=burst_end_dt,
|
_build_observed_series_from_simulation(
|
||||||
data_type="flow",
|
network=network,
|
||||||
series_name="burst_flow",
|
sensor_ids=selected_flow_ids,
|
||||||
)
|
start_dt=burst_start_dt,
|
||||||
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada(
|
end_dt=burst_end_dt,
|
||||||
network=network,
|
data_type="flow",
|
||||||
sensor_ids=selected_flow_ids,
|
series_name="burst_flow",
|
||||||
start_dt=normal_start_dt,
|
simulation_source="scheme",
|
||||||
end_dt=normal_end_dt,
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
data_type="flow",
|
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||||
series_name="normal_flow",
|
)
|
||||||
)
|
)
|
||||||
|
normal_flow_series, normal_flow_samples = (
|
||||||
|
_build_observed_series_from_simulation(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=selected_flow_ids,
|
||||||
|
start_dt=normal_start_dt,
|
||||||
|
end_dt=normal_end_dt,
|
||||||
|
data_type="flow",
|
||||||
|
series_name="normal_flow",
|
||||||
|
simulation_source="scheme",
|
||||||
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
|
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
burst_flow_series, burst_flow_samples = _build_observed_series_from_scada(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=selected_flow_ids,
|
||||||
|
start_dt=burst_start_dt,
|
||||||
|
end_dt=burst_end_dt,
|
||||||
|
data_type="flow",
|
||||||
|
series_name="burst_flow",
|
||||||
|
)
|
||||||
|
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=selected_flow_ids,
|
||||||
|
start_dt=normal_start_dt,
|
||||||
|
end_dt=normal_end_dt,
|
||||||
|
data_type="flow",
|
||||||
|
series_name="normal_flow",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if flow_scada_ids is not None:
|
if flow_scada_ids is not None:
|
||||||
selected_flow_ids = _dedupe_ids(flow_scada_ids)
|
selected_flow_ids = _dedupe_ids(flow_scada_ids)
|
||||||
@@ -186,6 +261,7 @@ def run_burst_location_by_network(
|
|||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
**result,
|
**result,
|
||||||
"network": network,
|
"network": network,
|
||||||
|
"data_source": normalized_data_source,
|
||||||
"pressure_scada_ids": selected_pressure_ids,
|
"pressure_scada_ids": selected_pressure_ids,
|
||||||
"flow_scada_ids": selected_flow_ids or [],
|
"flow_scada_ids": selected_flow_ids or [],
|
||||||
"observed_source": observed_source,
|
"observed_source": observed_source,
|
||||||
@@ -202,6 +278,11 @@ def run_burst_location_by_network(
|
|||||||
"normal_start": normal_start_dt.isoformat(),
|
"normal_start": normal_start_dt.isoformat(),
|
||||||
"normal_end": normal_end_dt.isoformat(),
|
"normal_end": normal_end_dt.isoformat(),
|
||||||
}
|
}
|
||||||
|
if normalized_data_source == "simulation":
|
||||||
|
payload["simulation_scheme"] = {
|
||||||
|
"name": simulation_scheme_name,
|
||||||
|
"type": resolved_simulation_scheme_type,
|
||||||
|
}
|
||||||
if scheme_name:
|
if scheme_name:
|
||||||
_store_burst_scheme(
|
_store_burst_scheme(
|
||||||
network=network,
|
network=network,
|
||||||
@@ -335,8 +416,174 @@ def _build_observed_series_from_scada(
|
|||||||
return pd.Series(values, dtype=float), min(sample_counts)
|
return pd.Series(values, dtype=float), min(sample_counts)
|
||||||
|
|
||||||
|
|
||||||
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
|
def _build_observed_series_from_simulation(
|
||||||
mapping: dict[str, str] = {}
|
*,
|
||||||
|
network: str,
|
||||||
|
sensor_ids: list[str],
|
||||||
|
start_dt: datetime,
|
||||||
|
end_dt: datetime,
|
||||||
|
data_type: str,
|
||||||
|
series_name: str,
|
||||||
|
simulation_source: str,
|
||||||
|
simulation_scheme_name: str | None,
|
||||||
|
simulation_scheme_type: str,
|
||||||
|
) -> tuple[pd.Series, int]:
|
||||||
|
sensor_metadata = _build_sensor_metadata(network=network, data_type=data_type)
|
||||||
|
missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata]
|
||||||
|
if missing_ids:
|
||||||
|
preview = ", ".join(missing_ids[:10])
|
||||||
|
raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}")
|
||||||
|
|
||||||
|
simulation_data = _query_simulation_data_by_sensor_ids(
|
||||||
|
network=network,
|
||||||
|
sensor_ids=sensor_ids,
|
||||||
|
sensor_metadata=sensor_metadata,
|
||||||
|
start_dt=start_dt,
|
||||||
|
end_dt=end_dt,
|
||||||
|
data_type=data_type,
|
||||||
|
simulation_source=simulation_source,
|
||||||
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
|
simulation_scheme_type=simulation_scheme_type,
|
||||||
|
)
|
||||||
|
values: dict[str, float] = {}
|
||||||
|
sample_counts: list[int] = []
|
||||||
|
for sensor_id in sensor_ids:
|
||||||
|
records = simulation_data.get(sensor_id, [])
|
||||||
|
numeric_values = [
|
||||||
|
float(item["value"])
|
||||||
|
for item in records
|
||||||
|
if item.get("value") is not None
|
||||||
|
]
|
||||||
|
if not numeric_values:
|
||||||
|
raise ValueError(f"{series_name} 在时间窗内无有效模拟数据: {sensor_id}")
|
||||||
|
values[sensor_id] = float(sum(numeric_values) / len(numeric_values))
|
||||||
|
sample_counts.append(len(numeric_values))
|
||||||
|
|
||||||
|
return pd.Series(values, dtype=float), min(sample_counts)
|
||||||
|
|
||||||
|
|
||||||
|
def _query_simulation_data_by_sensor_ids(
|
||||||
|
*,
|
||||||
|
network: str,
|
||||||
|
sensor_ids: list[str],
|
||||||
|
sensor_metadata: dict[str, dict[str, str]],
|
||||||
|
start_dt: datetime,
|
||||||
|
end_dt: datetime,
|
||||||
|
data_type: str,
|
||||||
|
simulation_source: str,
|
||||||
|
simulation_scheme_name: str | None,
|
||||||
|
simulation_scheme_type: str,
|
||||||
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
|
if simulation_source not in {"scheme", "realtime"}:
|
||||||
|
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
|
||||||
|
|
||||||
|
result: dict[str, list[dict[str, Any]]] = {sensor_id: [] for sensor_id in sensor_ids}
|
||||||
|
if data_type == "pressure":
|
||||||
|
result.update(
|
||||||
|
_query_simulation_values(
|
||||||
|
network=network,
|
||||||
|
element_ids=sensor_ids,
|
||||||
|
element_type="node",
|
||||||
|
field="pressure",
|
||||||
|
start_dt=start_dt,
|
||||||
|
end_dt=end_dt,
|
||||||
|
simulation_source=simulation_source,
|
||||||
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
|
simulation_scheme_type=simulation_scheme_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if data_type != "flow":
|
||||||
|
raise ValueError(f"Unsupported data_type: {data_type}")
|
||||||
|
|
||||||
|
link_ids: list[str] = []
|
||||||
|
demand_ids: list[str] = []
|
||||||
|
unsupported_ids: list[str] = []
|
||||||
|
for sensor_id in sensor_ids:
|
||||||
|
scada_type = sensor_metadata[sensor_id]["scada_type"]
|
||||||
|
if scada_type in {"pipe_flow", "flow"}:
|
||||||
|
link_ids.append(sensor_id)
|
||||||
|
elif scada_type == "demand":
|
||||||
|
demand_ids.append(sensor_id)
|
||||||
|
else:
|
||||||
|
unsupported_ids.append(f"{sensor_id}({scada_type})")
|
||||||
|
if unsupported_ids:
|
||||||
|
preview = ", ".join(unsupported_ids[:10])
|
||||||
|
raise ValueError(f"flow 模拟数据暂不支持以下 SCADA 类型: {preview}")
|
||||||
|
|
||||||
|
if link_ids:
|
||||||
|
result.update(
|
||||||
|
_query_simulation_values(
|
||||||
|
network=network,
|
||||||
|
element_ids=link_ids,
|
||||||
|
element_type="link",
|
||||||
|
field="flow",
|
||||||
|
start_dt=start_dt,
|
||||||
|
end_dt=end_dt,
|
||||||
|
simulation_source=simulation_source,
|
||||||
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
|
simulation_scheme_type=simulation_scheme_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if demand_ids:
|
||||||
|
result.update(
|
||||||
|
_query_simulation_values(
|
||||||
|
network=network,
|
||||||
|
element_ids=demand_ids,
|
||||||
|
element_type="node",
|
||||||
|
field="actual_demand",
|
||||||
|
start_dt=start_dt,
|
||||||
|
end_dt=end_dt,
|
||||||
|
simulation_source=simulation_source,
|
||||||
|
simulation_scheme_name=simulation_scheme_name,
|
||||||
|
simulation_scheme_type=simulation_scheme_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _query_simulation_values(
|
||||||
|
*,
|
||||||
|
network: str,
|
||||||
|
element_ids: list[str],
|
||||||
|
element_type: str,
|
||||||
|
field: str,
|
||||||
|
start_dt: datetime,
|
||||||
|
end_dt: datetime,
|
||||||
|
simulation_source: str,
|
||||||
|
simulation_scheme_name: str | None,
|
||||||
|
simulation_scheme_type: str,
|
||||||
|
) -> dict[str, list[dict[str, Any]]]:
|
||||||
|
if not element_ids:
|
||||||
|
return {}
|
||||||
|
if simulation_source == "scheme":
|
||||||
|
if not simulation_scheme_name:
|
||||||
|
raise ValueError("读取方案模拟数据时必须提供 simulation_scheme_name。")
|
||||||
|
return InternalQueries.query_scheme_simulation_by_ids_timerange(
|
||||||
|
db_name=network,
|
||||||
|
scheme_type=simulation_scheme_type,
|
||||||
|
scheme_name=simulation_scheme_name,
|
||||||
|
element_ids=element_ids,
|
||||||
|
start_time=start_dt.isoformat(),
|
||||||
|
end_time=end_dt.isoformat(),
|
||||||
|
element_type=element_type,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
if simulation_source == "realtime":
|
||||||
|
return InternalQueries.query_realtime_simulation_by_ids_timerange(
|
||||||
|
db_name=network,
|
||||||
|
element_ids=element_ids,
|
||||||
|
start_time=start_dt.isoformat(),
|
||||||
|
end_time=end_dt.isoformat(),
|
||||||
|
element_type=element_type,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sensor_metadata(network: str, data_type: str) -> dict[str, dict[str, str]]:
|
||||||
|
metadata: dict[str, dict[str, str]] = {}
|
||||||
for item in get_all_scada_info(network):
|
for item in get_all_scada_info(network):
|
||||||
scada_type = str(item.get("type", "")).lower()
|
scada_type = str(item.get("type", "")).lower()
|
||||||
if data_type == "pressure":
|
if data_type == "pressure":
|
||||||
@@ -347,16 +594,35 @@ def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported data_type: {data_type}")
|
raise ValueError(f"Unsupported data_type: {data_type}")
|
||||||
node_id = item.get("associated_element_id")
|
element_id = item.get("associated_element_id")
|
||||||
query_id = item.get("api_query_id")
|
query_id = item.get("api_query_id")
|
||||||
if (
|
if (
|
||||||
isinstance(node_id, str)
|
isinstance(element_id, str)
|
||||||
and node_id
|
and element_id
|
||||||
and isinstance(query_id, str)
|
and isinstance(query_id, str)
|
||||||
and query_id
|
and query_id
|
||||||
):
|
):
|
||||||
mapping[node_id] = query_id
|
metadata[element_id] = {"query_id": query_id, "scada_type": scada_type}
|
||||||
return mapping
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
|
||||||
|
metadata = _build_sensor_metadata(network=network, data_type=data_type)
|
||||||
|
return {
|
||||||
|
element_id: item["query_id"] for element_id, item in metadata.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_data_source(
|
||||||
|
data_source: str | None, simulation_scheme_name: str | None = None
|
||||||
|
) -> str:
|
||||||
|
normalized = str(data_source or "").strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
return "simulation" if simulation_scheme_name else "monitoring"
|
||||||
|
if normalized not in SIMULATION_DATA_SOURCES:
|
||||||
|
allowed_sources = ", ".join(sorted(SIMULATION_DATA_SOURCES))
|
||||||
|
raise ValueError(f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
def _get_sensor_nodes(network: str, data_type: str) -> list[str]:
|
def _get_sensor_nodes(network: str, data_type: str) -> list[str]:
|
||||||
|
|||||||
@@ -0,0 +1,210 @@
|
|||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _load_burst_location_module():
|
||||||
|
module_path = (
|
||||||
|
Path(__file__).resolve().parents[2] / "app" / "services" / "burst_location.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def ensure_package(name: str) -> types.ModuleType:
|
||||||
|
module = sys.modules.get(name)
|
||||||
|
if module is None:
|
||||||
|
module = types.ModuleType(name)
|
||||||
|
module.__path__ = []
|
||||||
|
sys.modules[name] = module
|
||||||
|
return module
|
||||||
|
|
||||||
|
for package_name in [
|
||||||
|
"app",
|
||||||
|
"app.algorithms",
|
||||||
|
"app.infra",
|
||||||
|
"app.infra.db",
|
||||||
|
"app.infra.db.timescaledb",
|
||||||
|
"app.services",
|
||||||
|
]:
|
||||||
|
ensure_package(package_name)
|
||||||
|
|
||||||
|
algorithms_module = types.ModuleType("app.algorithms.burst_location")
|
||||||
|
algorithms_module.run_burst_location = lambda **kwargs: {}
|
||||||
|
sys.modules["app.algorithms.burst_location"] = algorithms_module
|
||||||
|
|
||||||
|
internal_queries_module = types.ModuleType(
|
||||||
|
"app.infra.db.timescaledb.internal_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyInternalQueries:
|
||||||
|
@staticmethod
|
||||||
|
def query_scada_by_ids_timerange(**kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def query_scheme_simulation_by_ids_timerange(**kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def query_realtime_simulation_by_ids_timerange(**kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
internal_queries_module.InternalQueries = DummyInternalQueries
|
||||||
|
sys.modules["app.infra.db.timescaledb.internal_queries"] = internal_queries_module
|
||||||
|
|
||||||
|
scheme_management_module = types.ModuleType("app.services.scheme_management")
|
||||||
|
scheme_management_module.query_burst_location_scheme_detail = lambda *args, **kwargs: {}
|
||||||
|
scheme_management_module.query_burst_location_schemes = lambda *args, **kwargs: []
|
||||||
|
scheme_management_module.scheme_name_exists = lambda *args, **kwargs: False
|
||||||
|
scheme_management_module.store_scheme_info = lambda *args, **kwargs: None
|
||||||
|
sys.modules["app.services.scheme_management"] = scheme_management_module
|
||||||
|
|
||||||
|
tjnetwork_module = types.ModuleType("app.services.tjnetwork")
|
||||||
|
tjnetwork_module.dump_inp = lambda *args, **kwargs: None
|
||||||
|
tjnetwork_module.get_all_scada_info = lambda *args, **kwargs: []
|
||||||
|
sys.modules["app.services.tjnetwork"] = tjnetwork_module
|
||||||
|
|
||||||
|
module_name = "tests_burst_location_under_test"
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec and spec.loader
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, tmp_path):
|
||||||
|
module = _load_burst_location_module()
|
||||||
|
captured = {}
|
||||||
|
scheme_calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
module,
|
||||||
|
"get_all_scada_info",
|
||||||
|
lambda network: [
|
||||||
|
{
|
||||||
|
"type": "pressure",
|
||||||
|
"associated_element_id": "J1",
|
||||||
|
"api_query_id": "pressure-query",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "pipe_flow",
|
||||||
|
"associated_element_id": "P1",
|
||||||
|
"api_query_id": "pipe-flow-query",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "demand",
|
||||||
|
"associated_element_id": "J2",
|
||||||
|
"api_query_id": "demand-query",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
|
||||||
|
|
||||||
|
def fake_run_burst_location(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return {
|
||||||
|
"located_pipe": "Pipe-001",
|
||||||
|
"simulation_times": 3,
|
||||||
|
"similarity_mode": "combined",
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(module, "run_burst_location", fake_run_burst_location)
|
||||||
|
|
||||||
|
def _build_series(start_time: str, values: list[float]) -> list[dict]:
|
||||||
|
base_time = datetime.fromisoformat(start_time)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"time": (base_time + timedelta(minutes=15 * index)).isoformat(),
|
||||||
|
"value": value,
|
||||||
|
}
|
||||||
|
for index, value in enumerate(values)
|
||||||
|
]
|
||||||
|
|
||||||
|
def fake_scheme_query(**kwargs):
|
||||||
|
scheme_calls.append(kwargs)
|
||||||
|
if kwargs["element_type"] == "node" and kwargs["field"] == "pressure":
|
||||||
|
start_hour = datetime.fromisoformat(kwargs["start_time"]).hour
|
||||||
|
values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0]
|
||||||
|
return {"J1": _build_series(kwargs["start_time"], values)}
|
||||||
|
if kwargs["element_type"] == "link" and kwargs["field"] == "flow":
|
||||||
|
start_hour = datetime.fromisoformat(kwargs["start_time"]).hour
|
||||||
|
values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0]
|
||||||
|
return {"P1": _build_series(kwargs["start_time"], values)}
|
||||||
|
if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand":
|
||||||
|
start_hour = datetime.fromisoformat(kwargs["start_time"]).hour
|
||||||
|
values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0]
|
||||||
|
return {"J2": _build_series(kwargs["start_time"], values)}
|
||||||
|
raise AssertionError(f"Unexpected scheme query: {kwargs}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
module.InternalQueries,
|
||||||
|
"query_scheme_simulation_by_ids_timerange",
|
||||||
|
staticmethod(fake_scheme_query),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
module.InternalQueries,
|
||||||
|
"query_realtime_simulation_by_ids_timerange",
|
||||||
|
staticmethod(lambda **kwargs: pytest.fail(f"Unexpected realtime query: {kwargs}")),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = module.run_burst_location_by_network(
|
||||||
|
network="tjwater",
|
||||||
|
data_source="simulation",
|
||||||
|
simulation_scheme_name="BurstSchemeA",
|
||||||
|
simulation_scheme_type="burst_analysis",
|
||||||
|
burst_leakage=10.0,
|
||||||
|
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0),
|
||||||
|
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0),
|
||||||
|
scada_normal_start=datetime(2025, 1, 1, 6, 0, 0),
|
||||||
|
scada_normal_end=datetime(2025, 1, 1, 7, 0, 0),
|
||||||
|
use_scada_flow=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["observed_source"] == "simulation_scheme_timerange"
|
||||||
|
assert result["simulation_scheme"] == {
|
||||||
|
"name": "BurstSchemeA",
|
||||||
|
"type": "burst_analysis",
|
||||||
|
}
|
||||||
|
assert result["pressure_samples"] == {"burst": 4, "normal": 4}
|
||||||
|
assert result["flow_samples"] == {"burst": 4, "normal": 4}
|
||||||
|
assert list(captured["burst_pressure"].index) == ["J1"]
|
||||||
|
assert captured["burst_pressure"]["J1"] == pytest.approx(15.0)
|
||||||
|
assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)
|
||||||
|
assert captured["burst_flow"]["J2"] == pytest.approx(6.0)
|
||||||
|
assert captured["burst_flow"]["P1"] == pytest.approx(8.0)
|
||||||
|
assert captured["normal_flow"]["J2"] == pytest.approx(4.0)
|
||||||
|
assert captured["normal_flow"]["P1"] == pytest.approx(5.0)
|
||||||
|
assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls)
|
||||||
|
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls)
|
||||||
|
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls)
|
||||||
|
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path):
|
||||||
|
module = _load_burst_location_module()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
module,
|
||||||
|
"get_all_scada_info",
|
||||||
|
lambda network: [
|
||||||
|
{
|
||||||
|
"type": "pressure",
|
||||||
|
"associated_element_id": "J1",
|
||||||
|
"api_query_id": "pressure-query",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
|
||||||
|
monkeypatch.setattr(module, "run_burst_location", lambda **kwargs: {})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="simulation_scheme_name"):
|
||||||
|
module.run_burst_location_by_network(
|
||||||
|
network="tjwater",
|
||||||
|
data_source="simulation",
|
||||||
|
burst_leakage=1.0,
|
||||||
|
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0),
|
||||||
|
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0),
|
||||||
|
scada_normal_start=datetime(2025, 1, 1, 6, 0, 0),
|
||||||
|
scada_normal_end=datetime(2025, 1, 1, 7, 0, 0),
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user