新增模拟数据源支持,重构爆管定位逻辑

This commit is contained in:
2026-03-07 10:50:25 +08:00
parent bc74e94fbb
commit 7f481ca261
4 changed files with 682 additions and 43 deletions
+5
View File
@@ -1,6 +1,8 @@
from typing import Any
from datetime import datetime
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
@@ -16,6 +18,7 @@ router = APIRouter()
class BurstLocationRequest(BaseModel):
network: str
data_source: Literal["monitoring", "simulation"] = "monitoring"
pressure_scada_ids: list[str] | None = None
burst_pressure: dict[str, float] | list[dict[str, Any]] | None = None
normal_pressure: dict[str, float] | list[dict[str, Any]] | None = None
@@ -31,6 +34,8 @@ class BurstLocationRequest(BaseModel):
scada_normal_end: datetime | None = None
use_scada_flow: bool = False
scheme_name: str | None = None
simulation_scheme_name: str | None = None
simulation_scheme_type: str | None = None
@router.post("/locate/")
@@ -3,6 +3,8 @@ from typing import List
from fastapi.logger import logger
from datetime import datetime, timedelta
import psycopg
from psycopg import sql
from psycopg.rows import dict_row
import time
from app.infra.db.timescaledb.schemas.scheme import SchemeRepository
from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository
@@ -170,3 +172,159 @@ class InternalQueries:
time.sleep(1)
else:
raise
@staticmethod
def query_realtime_simulation_by_ids_timerange(
element_ids: List[str],
start_time: str | datetime,
end_time: str | datetime,
element_type: str,
field: str,
db_name: str = None,
max_retries: int = 3,
) -> dict[str, list[dict]]:
"""查询实时模拟结果,返回 {id: [{time, value}, ...]}。"""
return InternalQueries._query_simulation_by_ids_timerange(
schema_name="realtime",
element_ids=element_ids,
start_time=start_time,
end_time=end_time,
element_type=element_type,
field=field,
db_name=db_name,
max_retries=max_retries,
)
@staticmethod
def query_scheme_simulation_by_ids_timerange(
element_ids: List[str],
start_time: str | datetime,
end_time: str | datetime,
element_type: str,
field: str,
scheme_type: str,
scheme_name: str,
db_name: str = None,
max_retries: int = 3,
) -> dict[str, list[dict]]:
"""查询方案模拟结果,返回 {id: [{time, value}, ...]}。"""
return InternalQueries._query_simulation_by_ids_timerange(
schema_name="scheme",
element_ids=element_ids,
start_time=start_time,
end_time=end_time,
element_type=element_type,
field=field,
db_name=db_name,
max_retries=max_retries,
scheme_type=scheme_type,
scheme_name=scheme_name,
)
@staticmethod
def _query_simulation_by_ids_timerange(
*,
schema_name: str,
element_ids: List[str],
start_time: str | datetime,
end_time: str | datetime,
element_type: str,
field: str,
db_name: str = None,
max_retries: int = 3,
scheme_type: str | None = None,
scheme_name: str | None = None,
) -> dict[str, list[dict]]:
if not element_ids:
return {}
start_dt = (
datetime.fromisoformat(start_time)
if isinstance(start_time, str)
else start_time
)
end_dt = (
datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time
)
table_name, valid_fields = InternalQueries._resolve_simulation_table(element_type)
if field not in valid_fields:
raise ValueError(f"Invalid field for {element_type}: {field}")
if schema_name not in {"realtime", "scheme"}:
raise ValueError(f"Unsupported schema_name: {schema_name}")
if schema_name == "scheme" and (not scheme_type or not scheme_name):
raise ValueError("scheme 查询必须提供 scheme_type 和 scheme_name。")
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
with conn.cursor(row_factory=dict_row) as cur:
if schema_name == "scheme":
query = sql.SQL(
"SELECT id, time, {} FROM {}.{} "
"WHERE scheme_type = %s AND scheme_name = %s "
"AND time >= %s AND time <= %s AND id = ANY(%s)"
).format(
sql.Identifier(field),
sql.Identifier(schema_name),
sql.Identifier(table_name),
)
cur.execute(
query,
(
scheme_type,
scheme_name,
start_dt,
end_dt,
element_ids,
),
)
else:
query = sql.SQL(
"SELECT id, time, {} FROM {}.{} "
"WHERE time >= %s AND time <= %s AND id = ANY(%s)"
).format(
sql.Identifier(field),
sql.Identifier(schema_name),
sql.Identifier(table_name),
)
cur.execute(query, (start_dt, end_dt, element_ids))
rows = cur.fetchall()
result: dict[str, list[dict]] = {
element_id: [] for element_id in element_ids
}
for row in rows:
result.setdefault(row["id"], []).append(
{"time": row["time"].isoformat(), "value": row[field]}
)
for element_id in result:
result[element_id].sort(key=lambda item: item["time"])
return result
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise
@staticmethod
def _resolve_simulation_table(element_type: str) -> tuple[str, set[str]]:
normalized_type = element_type.lower()
if normalized_type == "node":
return "node_simulation", {"actual_demand", "total_head", "pressure", "quality"}
if normalized_type == "link":
return "link_simulation", {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
raise ValueError(f"Unsupported element_type: {element_type}")
+309 -43
View File
@@ -18,6 +18,8 @@ from app.services.tjnetwork import dump_inp, get_all_scada_info
SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]]
FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"}
SIMULATION_DATA_SOURCES = {"monitoring", "simulation"}
DEFAULT_SIMULATION_SCHEME_TYPE = "burst_analysis"
def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
@@ -44,6 +46,7 @@ def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
def run_burst_location_by_network(
*,
network: str,
data_source: str = "monitoring",
burst_leakage: float,
pressure_scada_ids: list[str] | None = None,
burst_pressure: SeriesInput | None = None,
@@ -59,10 +62,18 @@ def run_burst_location_by_network(
scada_normal_end: datetime | str | None = None,
use_scada_flow: bool = False,
scheme_name: str | None = None,
simulation_scheme_name: str | None = None,
simulation_scheme_type: str | None = None,
username: str = "admin",
) -> dict[str, Any]:
if not network:
raise ValueError("network is required.")
normalized_data_source = _normalize_data_source(
data_source, simulation_scheme_name=simulation_scheme_name
)
resolved_simulation_scheme_type = (
simulation_scheme_type or DEFAULT_SIMULATION_SCHEME_TYPE
)
selected_pressure_ids = (
_dedupe_ids(pressure_scada_ids)
@@ -93,26 +104,60 @@ def run_burst_location_by_network(
scada_normal_start=scada_normal_start,
scada_normal_end=scada_normal_end,
)
burst_pressure_series, burst_pressure_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="pressure",
series_name="burst_pressure",
)
(
normal_pressure_series,
normal_pressure_samples,
) = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="pressure",
series_name="normal_pressure",
)
observed_source = "backend_timerange"
if normalized_data_source == "simulation":
if not simulation_scheme_name:
raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
burst_pressure_series, burst_pressure_samples = (
_build_observed_series_from_simulation(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="pressure",
series_name="burst_pressure",
simulation_source="scheme",
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=resolved_simulation_scheme_type,
)
)
(
normal_pressure_series,
normal_pressure_samples,
) = _build_observed_series_from_simulation(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="pressure",
series_name="normal_pressure",
simulation_source="scheme",
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=resolved_simulation_scheme_type,
)
observed_source = "simulation_scheme_timerange"
else:
burst_pressure_series, burst_pressure_samples = (
_build_observed_series_from_scada(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="pressure",
series_name="burst_pressure",
)
)
(
normal_pressure_series,
normal_pressure_samples,
) = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="pressure",
series_name="normal_pressure",
)
observed_source = "backend_timerange"
else:
if burst_pressure is None or normal_pressure is None:
raise ValueError(
@@ -139,22 +184,52 @@ def run_burst_location_by_network(
)
if not selected_flow_ids:
raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。")
burst_flow_series, burst_flow_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_flow_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="flow",
series_name="burst_flow",
)
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_flow_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="flow",
series_name="normal_flow",
)
if normalized_data_source == "simulation":
if not simulation_scheme_name:
raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
burst_flow_series, burst_flow_samples = (
_build_observed_series_from_simulation(
network=network,
sensor_ids=selected_flow_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="flow",
series_name="burst_flow",
simulation_source="scheme",
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=resolved_simulation_scheme_type,
)
)
normal_flow_series, normal_flow_samples = (
_build_observed_series_from_simulation(
network=network,
sensor_ids=selected_flow_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="flow",
series_name="normal_flow",
simulation_source="scheme",
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=resolved_simulation_scheme_type,
)
)
else:
burst_flow_series, burst_flow_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_flow_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="flow",
series_name="burst_flow",
)
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_flow_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="flow",
series_name="normal_flow",
)
else:
if flow_scada_ids is not None:
selected_flow_ids = _dedupe_ids(flow_scada_ids)
@@ -186,6 +261,7 @@ def run_burst_location_by_network(
payload: dict[str, Any] = {
**result,
"network": network,
"data_source": normalized_data_source,
"pressure_scada_ids": selected_pressure_ids,
"flow_scada_ids": selected_flow_ids or [],
"observed_source": observed_source,
@@ -202,6 +278,11 @@ def run_burst_location_by_network(
"normal_start": normal_start_dt.isoformat(),
"normal_end": normal_end_dt.isoformat(),
}
if normalized_data_source == "simulation":
payload["simulation_scheme"] = {
"name": simulation_scheme_name,
"type": resolved_simulation_scheme_type,
}
if scheme_name:
_store_burst_scheme(
network=network,
@@ -335,8 +416,174 @@ def _build_observed_series_from_scada(
return pd.Series(values, dtype=float), min(sample_counts)
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
mapping: dict[str, str] = {}
def _build_observed_series_from_simulation(
*,
network: str,
sensor_ids: list[str],
start_dt: datetime,
end_dt: datetime,
data_type: str,
series_name: str,
simulation_source: str,
simulation_scheme_name: str | None,
simulation_scheme_type: str,
) -> tuple[pd.Series, int]:
sensor_metadata = _build_sensor_metadata(network=network, data_type=data_type)
missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata]
if missing_ids:
preview = ", ".join(missing_ids[:10])
raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}")
simulation_data = _query_simulation_data_by_sensor_ids(
network=network,
sensor_ids=sensor_ids,
sensor_metadata=sensor_metadata,
start_dt=start_dt,
end_dt=end_dt,
data_type=data_type,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
values: dict[str, float] = {}
sample_counts: list[int] = []
for sensor_id in sensor_ids:
records = simulation_data.get(sensor_id, [])
numeric_values = [
float(item["value"])
for item in records
if item.get("value") is not None
]
if not numeric_values:
raise ValueError(f"{series_name} 在时间窗内无有效模拟数据: {sensor_id}")
values[sensor_id] = float(sum(numeric_values) / len(numeric_values))
sample_counts.append(len(numeric_values))
return pd.Series(values, dtype=float), min(sample_counts)
def _query_simulation_data_by_sensor_ids(
*,
network: str,
sensor_ids: list[str],
sensor_metadata: dict[str, dict[str, str]],
start_dt: datetime,
end_dt: datetime,
data_type: str,
simulation_source: str,
simulation_scheme_name: str | None,
simulation_scheme_type: str,
) -> dict[str, list[dict[str, Any]]]:
if simulation_source not in {"scheme", "realtime"}:
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
result: dict[str, list[dict[str, Any]]] = {sensor_id: [] for sensor_id in sensor_ids}
if data_type == "pressure":
result.update(
_query_simulation_values(
network=network,
element_ids=sensor_ids,
element_type="node",
field="pressure",
start_dt=start_dt,
end_dt=end_dt,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
)
return result
if data_type != "flow":
raise ValueError(f"Unsupported data_type: {data_type}")
link_ids: list[str] = []
demand_ids: list[str] = []
unsupported_ids: list[str] = []
for sensor_id in sensor_ids:
scada_type = sensor_metadata[sensor_id]["scada_type"]
if scada_type in {"pipe_flow", "flow"}:
link_ids.append(sensor_id)
elif scada_type == "demand":
demand_ids.append(sensor_id)
else:
unsupported_ids.append(f"{sensor_id}({scada_type})")
if unsupported_ids:
preview = ", ".join(unsupported_ids[:10])
raise ValueError(f"flow 模拟数据暂不支持以下 SCADA 类型: {preview}")
if link_ids:
result.update(
_query_simulation_values(
network=network,
element_ids=link_ids,
element_type="link",
field="flow",
start_dt=start_dt,
end_dt=end_dt,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
)
if demand_ids:
result.update(
_query_simulation_values(
network=network,
element_ids=demand_ids,
element_type="node",
field="actual_demand",
start_dt=start_dt,
end_dt=end_dt,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
)
return result
def _query_simulation_values(
*,
network: str,
element_ids: list[str],
element_type: str,
field: str,
start_dt: datetime,
end_dt: datetime,
simulation_source: str,
simulation_scheme_name: str | None,
simulation_scheme_type: str,
) -> dict[str, list[dict[str, Any]]]:
if not element_ids:
return {}
if simulation_source == "scheme":
if not simulation_scheme_name:
raise ValueError("读取方案模拟数据时必须提供 simulation_scheme_name。")
return InternalQueries.query_scheme_simulation_by_ids_timerange(
db_name=network,
scheme_type=simulation_scheme_type,
scheme_name=simulation_scheme_name,
element_ids=element_ids,
start_time=start_dt.isoformat(),
end_time=end_dt.isoformat(),
element_type=element_type,
field=field,
)
if simulation_source == "realtime":
return InternalQueries.query_realtime_simulation_by_ids_timerange(
db_name=network,
element_ids=element_ids,
start_time=start_dt.isoformat(),
end_time=end_dt.isoformat(),
element_type=element_type,
field=field,
)
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
def _build_sensor_metadata(network: str, data_type: str) -> dict[str, dict[str, str]]:
metadata: dict[str, dict[str, str]] = {}
for item in get_all_scada_info(network):
scada_type = str(item.get("type", "")).lower()
if data_type == "pressure":
@@ -347,16 +594,35 @@ def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
continue
else:
raise ValueError(f"Unsupported data_type: {data_type}")
node_id = item.get("associated_element_id")
element_id = item.get("associated_element_id")
query_id = item.get("api_query_id")
if (
isinstance(node_id, str)
and node_id
isinstance(element_id, str)
and element_id
and isinstance(query_id, str)
and query_id
):
mapping[node_id] = query_id
return mapping
metadata[element_id] = {"query_id": query_id, "scada_type": scada_type}
return metadata
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
metadata = _build_sensor_metadata(network=network, data_type=data_type)
return {
element_id: item["query_id"] for element_id, item in metadata.items()
}
def _normalize_data_source(
data_source: str | None, simulation_scheme_name: str | None = None
) -> str:
normalized = str(data_source or "").strip().lower()
if not normalized:
return "simulation" if simulation_scheme_name else "monitoring"
if normalized not in SIMULATION_DATA_SOURCES:
allowed_sources = ", ".join(sorted(SIMULATION_DATA_SOURCES))
raise ValueError(f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}")
return normalized
def _get_sensor_nodes(network: str, data_type: str) -> list[str]: