重构爆管定位请求,移除不必要的时间参数
This commit is contained in:
@@ -30,8 +30,6 @@ class BurstLocationRequest(BaseModel):
|
||||
basic_pressure: float = 10.0
|
||||
scada_burst_start: datetime | None = None
|
||||
scada_burst_end: datetime | None = None
|
||||
scada_normal_start: datetime | None = None
|
||||
scada_normal_end: datetime | None = None
|
||||
use_scada_flow: bool = False
|
||||
scheme_name: str | None = None
|
||||
simulation_scheme_name: str | None = None
|
||||
|
||||
@@ -46,6 +46,7 @@ def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
|
||||
def run_burst_location_by_network(
|
||||
*,
|
||||
network: str,
|
||||
username: str,
|
||||
data_source: str = "monitoring",
|
||||
burst_leakage: float,
|
||||
pressure_scada_ids: list[str] | None = None,
|
||||
@@ -58,13 +59,10 @@ def run_burst_location_by_network(
|
||||
basic_pressure: float = 10.0,
|
||||
scada_burst_start: datetime | str | None = None,
|
||||
scada_burst_end: datetime | str | None = None,
|
||||
scada_normal_start: datetime | str | None = None,
|
||||
scada_normal_end: datetime | str | None = None,
|
||||
use_scada_flow: bool = False,
|
||||
scheme_name: str | None = None,
|
||||
simulation_scheme_name: str | None = None,
|
||||
simulation_scheme_type: str | None = None,
|
||||
username: str = "admin",
|
||||
) -> dict[str, Any]:
|
||||
if not network:
|
||||
raise ValueError("network is required.")
|
||||
@@ -88,37 +86,29 @@ def run_burst_location_by_network(
|
||||
for value in [
|
||||
scada_burst_start,
|
||||
scada_burst_end,
|
||||
scada_normal_start,
|
||||
scada_normal_end,
|
||||
]
|
||||
)
|
||||
if use_scada_pressure:
|
||||
(
|
||||
burst_start_dt,
|
||||
burst_end_dt,
|
||||
normal_start_dt,
|
||||
normal_end_dt,
|
||||
) = _validate_scada_windows(
|
||||
burst_start_dt, burst_end_dt = _validate_scada_windows(
|
||||
scada_burst_start=scada_burst_start,
|
||||
scada_burst_end=scada_burst_end,
|
||||
scada_normal_start=scada_normal_start,
|
||||
scada_normal_end=scada_normal_end,
|
||||
)
|
||||
if normalized_data_source == "simulation":
|
||||
if not simulation_scheme_name:
|
||||
raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
|
||||
burst_pressure_series, burst_pressure_samples = (
|
||||
_build_observed_series_from_simulation(
|
||||
network=network,
|
||||
sensor_ids=selected_pressure_ids,
|
||||
start_dt=burst_start_dt,
|
||||
end_dt=burst_end_dt,
|
||||
data_type="pressure",
|
||||
series_name="burst_pressure",
|
||||
simulation_source="scheme",
|
||||
simulation_scheme_name=simulation_scheme_name,
|
||||
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||
)
|
||||
(
|
||||
burst_pressure_series,
|
||||
burst_pressure_samples,
|
||||
) = _build_observed_series_from_simulation(
|
||||
network=network,
|
||||
sensor_ids=selected_pressure_ids,
|
||||
start_dt=burst_start_dt,
|
||||
end_dt=burst_end_dt,
|
||||
data_type="pressure",
|
||||
series_name="burst_pressure",
|
||||
simulation_source="scheme",
|
||||
simulation_scheme_name=simulation_scheme_name,
|
||||
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||
)
|
||||
(
|
||||
normal_pressure_series,
|
||||
@@ -126,12 +116,12 @@ def run_burst_location_by_network(
|
||||
) = _build_observed_series_from_simulation(
|
||||
network=network,
|
||||
sensor_ids=selected_pressure_ids,
|
||||
start_dt=normal_start_dt,
|
||||
end_dt=normal_end_dt,
|
||||
start_dt=burst_start_dt,
|
||||
end_dt=burst_end_dt,
|
||||
data_type="pressure",
|
||||
series_name="normal_pressure",
|
||||
simulation_source="realtime",
|
||||
simulation_scheme_name=simulation_scheme_name,
|
||||
simulation_scheme_name=None,
|
||||
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||
)
|
||||
observed_source = "simulation_scheme_burst_realtime_normal_timerange"
|
||||
@@ -149,15 +139,18 @@ def run_burst_location_by_network(
|
||||
(
|
||||
normal_pressure_series,
|
||||
normal_pressure_samples,
|
||||
) = _build_observed_series_from_scada(
|
||||
) = _build_observed_series_from_simulation(
|
||||
network=network,
|
||||
sensor_ids=selected_pressure_ids,
|
||||
start_dt=normal_start_dt,
|
||||
end_dt=normal_end_dt,
|
||||
start_dt=burst_start_dt,
|
||||
end_dt=burst_end_dt,
|
||||
data_type="pressure",
|
||||
series_name="normal_pressure",
|
||||
simulation_source="realtime",
|
||||
simulation_scheme_name=None,
|
||||
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||
)
|
||||
observed_source = "backend_timerange"
|
||||
observed_source = "scada_burst_realtime_normal_timerange"
|
||||
else:
|
||||
if burst_pressure is None or normal_pressure is None:
|
||||
raise ValueError(
|
||||
@@ -168,7 +161,7 @@ def run_burst_location_by_network(
|
||||
burst_pressure_samples = 1
|
||||
normal_pressure_samples = 1
|
||||
observed_source = "request_payload"
|
||||
burst_start_dt = burst_end_dt = normal_start_dt = normal_end_dt = None
|
||||
burst_start_dt = burst_end_dt = None
|
||||
|
||||
selected_flow_ids: list[str] | None = None
|
||||
burst_flow_series: pd.Series | None = None
|
||||
@@ -204,12 +197,12 @@ def run_burst_location_by_network(
|
||||
_build_observed_series_from_simulation(
|
||||
network=network,
|
||||
sensor_ids=selected_flow_ids,
|
||||
start_dt=normal_start_dt,
|
||||
end_dt=normal_end_dt,
|
||||
start_dt=burst_start_dt,
|
||||
end_dt=burst_end_dt,
|
||||
data_type="flow",
|
||||
series_name="normal_flow",
|
||||
simulation_source="realtime",
|
||||
simulation_scheme_name=simulation_scheme_name,
|
||||
simulation_scheme_name=None,
|
||||
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||
)
|
||||
)
|
||||
@@ -222,19 +215,26 @@ def run_burst_location_by_network(
|
||||
data_type="flow",
|
||||
series_name="burst_flow",
|
||||
)
|
||||
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada(
|
||||
network=network,
|
||||
sensor_ids=selected_flow_ids,
|
||||
start_dt=normal_start_dt,
|
||||
end_dt=normal_end_dt,
|
||||
data_type="flow",
|
||||
series_name="normal_flow",
|
||||
normal_flow_series, normal_flow_samples = (
|
||||
_build_observed_series_from_simulation(
|
||||
network=network,
|
||||
sensor_ids=selected_flow_ids,
|
||||
start_dt=burst_start_dt,
|
||||
end_dt=burst_end_dt,
|
||||
data_type="flow",
|
||||
series_name="normal_flow",
|
||||
simulation_source="realtime",
|
||||
simulation_scheme_name=None,
|
||||
simulation_scheme_type=resolved_simulation_scheme_type,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if flow_scada_ids is not None:
|
||||
selected_flow_ids = _dedupe_ids(flow_scada_ids)
|
||||
burst_flow_series = (
|
||||
_normalize_series(burst_flow, "burst_flow") if burst_flow is not None else None
|
||||
_normalize_series(burst_flow, "burst_flow")
|
||||
if burst_flow is not None
|
||||
else None
|
||||
)
|
||||
normal_flow_series = (
|
||||
_normalize_series(normal_flow, "normal_flow")
|
||||
@@ -278,8 +278,6 @@ def run_burst_location_by_network(
|
||||
payload["scada_window"] = {
|
||||
"burst_start": burst_start_dt.isoformat(),
|
||||
"burst_end": burst_end_dt.isoformat(),
|
||||
"normal_start": normal_start_dt.isoformat(),
|
||||
"normal_end": normal_end_dt.isoformat(),
|
||||
}
|
||||
if normalized_data_source == "simulation":
|
||||
payload["simulation_scheme"] = {
|
||||
@@ -303,7 +301,9 @@ def list_burst_location_schemes(
|
||||
network: str, query_date: datetime | str | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
parsed_date = _to_datetime(query_date).date() if query_date is not None else None
|
||||
return query_burst_location_schemes(name=network, network=network, query_date=parsed_date)
|
||||
return query_burst_location_schemes(
|
||||
name=network, network=network, query_date=parsed_date
|
||||
)
|
||||
|
||||
|
||||
def get_burst_location_scheme_detail(network: str, scheme_name: str) -> dict[str, Any]:
|
||||
@@ -359,25 +359,19 @@ def _validate_scada_windows(
|
||||
*,
|
||||
scada_burst_start: datetime | str | None,
|
||||
scada_burst_end: datetime | str | None,
|
||||
scada_normal_start: datetime | str | None,
|
||||
scada_normal_end: datetime | str | None,
|
||||
) -> tuple[datetime, datetime, datetime, datetime]:
|
||||
values = [scada_burst_start, scada_burst_end, scada_normal_start, scada_normal_end]
|
||||
) -> tuple[datetime, datetime]:
|
||||
values = [scada_burst_start, scada_burst_end]
|
||||
if any(v is None for v in values):
|
||||
raise ValueError(
|
||||
"使用后端 SCADA 查询时,必须同时提供 scada_burst_start/scada_burst_end/scada_normal_start/scada_normal_end。"
|
||||
"使用后端 SCADA 查询时,必须同时提供 scada_burst_start/scada_burst_end。"
|
||||
)
|
||||
burst_start_dt = _to_datetime(scada_burst_start)
|
||||
burst_end_dt = _to_datetime(scada_burst_end)
|
||||
normal_start_dt = _to_datetime(scada_normal_start)
|
||||
normal_end_dt = _to_datetime(scada_normal_end)
|
||||
if burst_start_dt >= burst_end_dt:
|
||||
raise ValueError("爆管时段 SCADA 时间窗非法:scada_burst_start 必须早于 scada_burst_end。")
|
||||
if normal_start_dt >= normal_end_dt:
|
||||
raise ValueError(
|
||||
"正常时段 SCADA 时间窗非法:scada_normal_start 必须早于 scada_normal_end。"
|
||||
"爆管时段 SCADA 时间窗非法:scada_burst_start 必须早于 scada_burst_end。"
|
||||
)
|
||||
return burst_start_dt, burst_end_dt, normal_start_dt, normal_end_dt
|
||||
return burst_start_dt, burst_end_dt
|
||||
|
||||
|
||||
def _build_observed_series_from_scada(
|
||||
@@ -390,7 +384,9 @@ def _build_observed_series_from_scada(
|
||||
series_name: str,
|
||||
) -> tuple[pd.Series, int]:
|
||||
scada_mapping = _build_scada_mapping(network=network, data_type=data_type)
|
||||
missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in scada_mapping]
|
||||
missing_ids = [
|
||||
sensor_id for sensor_id in sensor_ids if sensor_id not in scada_mapping
|
||||
]
|
||||
if missing_ids:
|
||||
preview = ", ".join(missing_ids[:10])
|
||||
raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}")
|
||||
@@ -407,9 +403,7 @@ def _build_observed_series_from_scada(
|
||||
for sensor_id, query_id in zip(sensor_ids, query_ids):
|
||||
records = scada_data.get(query_id, [])
|
||||
numeric_values = [
|
||||
float(item["value"])
|
||||
for item in records
|
||||
if item.get("value") is not None
|
||||
float(item["value"]) for item in records if item.get("value") is not None
|
||||
]
|
||||
if not numeric_values:
|
||||
raise ValueError(f"{series_name} 在时间窗内无有效数据: {sensor_id}")
|
||||
@@ -432,7 +426,9 @@ def _build_observed_series_from_simulation(
|
||||
simulation_scheme_type: str,
|
||||
) -> tuple[pd.Series, int]:
|
||||
sensor_metadata = _build_sensor_metadata(network=network, data_type=data_type)
|
||||
missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata]
|
||||
missing_ids = [
|
||||
sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata
|
||||
]
|
||||
if missing_ids:
|
||||
preview = ", ".join(missing_ids[:10])
|
||||
raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}")
|
||||
@@ -453,9 +449,7 @@ def _build_observed_series_from_simulation(
|
||||
for sensor_id in sensor_ids:
|
||||
records = simulation_data.get(sensor_id, [])
|
||||
numeric_values = [
|
||||
float(item["value"])
|
||||
for item in records
|
||||
if item.get("value") is not None
|
||||
float(item["value"]) for item in records if item.get("value") is not None
|
||||
]
|
||||
if not numeric_values:
|
||||
raise ValueError(f"{series_name} 在时间窗内无有效模拟数据: {sensor_id}")
|
||||
@@ -480,7 +474,9 @@ def _query_simulation_data_by_sensor_ids(
|
||||
if simulation_source not in {"scheme", "realtime"}:
|
||||
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
|
||||
|
||||
result: dict[str, list[dict[str, Any]]] = {sensor_id: [] for sensor_id in sensor_ids}
|
||||
result: dict[str, list[dict[str, Any]]] = {
|
||||
sensor_id: [] for sensor_id in sensor_ids
|
||||
}
|
||||
if data_type == "pressure":
|
||||
result.update(
|
||||
_query_simulation_values(
|
||||
@@ -611,9 +607,7 @@ def _build_sensor_metadata(network: str, data_type: str) -> dict[str, dict[str,
|
||||
|
||||
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
|
||||
metadata = _build_sensor_metadata(network=network, data_type=data_type)
|
||||
return {
|
||||
element_id: item["query_id"] for element_id, item in metadata.items()
|
||||
}
|
||||
return {element_id: item["query_id"] for element_id, item in metadata.items()}
|
||||
|
||||
|
||||
def _normalize_data_source(
|
||||
@@ -624,7 +618,9 @@ def _normalize_data_source(
|
||||
return "simulation" if simulation_scheme_name else "monitoring"
|
||||
if normalized not in SIMULATION_DATA_SOURCES:
|
||||
allowed_sources = ", ".join(sorted(SIMULATION_DATA_SOURCES))
|
||||
raise ValueError(f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}")
|
||||
raise ValueError(
|
||||
f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ DEFAULT_N_WORKERS = max(1, min((os.cpu_count() or 1) - 1, 4))
|
||||
|
||||
def run_leakage_identification(
|
||||
network: str,
|
||||
username: str,
|
||||
observed_pressure_data: (
|
||||
str | pd.DataFrame | dict[str, list[Any]] | list[dict[str, Any]] | None
|
||||
) = None,
|
||||
@@ -47,7 +48,6 @@ def run_leakage_identification(
|
||||
scada_end: datetime | str | None = None,
|
||||
sensor_nodes: list[str] | None = None,
|
||||
scheme_name: str | None = None,
|
||||
username: str = "admin",
|
||||
) -> dict[str, Any]:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
inp_path = _prepare_leakage_inp(network)
|
||||
|
||||
@@ -74,7 +74,7 @@ def _load_burst_location_module():
|
||||
return module
|
||||
|
||||
|
||||
def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, tmp_path):
|
||||
def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkeypatch, tmp_path):
|
||||
module = _load_burst_location_module()
|
||||
captured = {}
|
||||
scheme_calls = []
|
||||
@@ -162,14 +162,13 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch,
|
||||
|
||||
result = module.run_burst_location_by_network(
|
||||
network="tjwater",
|
||||
username="testuser",
|
||||
data_source="simulation",
|
||||
simulation_scheme_name="BurstSchemeA",
|
||||
simulation_scheme_type="burst_analysis",
|
||||
burst_leakage=10.0,
|
||||
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0),
|
||||
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0),
|
||||
scada_normal_start=datetime(2025, 1, 1, 6, 0, 0),
|
||||
scada_normal_end=datetime(2025, 1, 1, 7, 0, 0),
|
||||
use_scada_flow=True,
|
||||
)
|
||||
|
||||
@@ -193,9 +192,15 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch,
|
||||
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls)
|
||||
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls)
|
||||
assert len(realtime_calls) == 3
|
||||
assert all(datetime.fromisoformat(call["start_time"]).hour == 8 for call in realtime_calls)
|
||||
assert all(datetime.fromisoformat(call["end_time"]).hour == 9 for call in realtime_calls)
|
||||
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls)
|
||||
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls)
|
||||
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls)
|
||||
assert result["scada_window"] == {
|
||||
"burst_start": "2025-01-01T08:00:00",
|
||||
"burst_end": "2025-01-01T09:00:00",
|
||||
}
|
||||
|
||||
|
||||
def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path):
|
||||
@@ -217,10 +222,9 @@ def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_pat
|
||||
with pytest.raises(ValueError, match="simulation_scheme_name"):
|
||||
module.run_burst_location_by_network(
|
||||
network="tjwater",
|
||||
username="testuser",
|
||||
data_source="simulation",
|
||||
burst_leakage=1.0,
|
||||
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0),
|
||||
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0),
|
||||
scada_normal_start=datetime(2025, 1, 1, 6, 0, 0),
|
||||
scada_normal_end=datetime(2025, 1, 1, 7, 0, 0),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user