from __future__ import annotations from datetime import datetime from typing import Any import pandas as pd from app.algorithms.burst_detection.burst_detector import BurstDetector from app.infra.db.timescaledb.internal_queries import InternalQueries from app.services.scheme_management import ( query_burst_detection_scheme_detail, query_burst_detection_schemes, scheme_name_exists, store_scheme_info, ) from app.services.tjnetwork import get_all_scada_info def run_burst_detection( *, network: str, username: str, observed_pressure_data: ( pd.DataFrame | dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] | None ) = None, points_per_day: int = 1440, mu: int = 100, iforest_params: dict[str, Any] | None = None, scada_start: datetime | str | None = None, scada_end: datetime | str | None = None, sensor_nodes: list[str] | None = None, scheme_name: str | None = None, ) -> dict[str, Any]: """ 运行爆管侦测服务入口。 调用方式二选一: - 直接传 `observed_pressure_data` - 或传 `scada_start/scada_end` 让后端自动查询 SCADA 压力数据 `observed_pressure_data` 支持格式: - `pd.DataFrame` 行表示时间点,列表示传感器;列名应为传感器/节点 ID。 - `dict[str, list[Any]]` 键为传感器/节点 ID,值为按时间顺序排列的压力序列。 例如:`{"J1": [101.2, 101.0], "J2": [99.8, 99.7]}`。 - `list[dict[str, Any]]` 每个元素代表一个时间点的多传感器观测。 例如:`[{"J1": 101.2, "J2": 99.8}, {"J1": 101.0, "J2": 99.7}]`。 - `list[list[Any]]` 二维数组式 JSON,格式为 `(时间点数, 传感器数)`。 这是最接近原始 `burst_detector` 示例代码的调用方式。 数据约束: - 统一要求“行=时间点,列=传感器”。 - 总样本点数必须能被 `points_per_day` 整除。 - 至少要有 2 天数据,即 `sample_count >= 2 * points_per_day`。 - 若传入 `sensor_nodes`,输入数据必须包含这些列;SCADA 模式下也会只按这些节点取数。 """ if not network: raise ValueError("network is required.") selected_sensor_nodes = ( list(dict.fromkeys([node for node in (sensor_nodes or []) if node])) if sensor_nodes else None ) use_scada_source = scada_start is not None or scada_end is not None if use_scada_source: scada_sensor_nodes = ( selected_sensor_nodes if selected_sensor_nodes is not None else _get_pressure_sensor_nodes(network) ) observed_df = _build_observed_pressure_from_scada( network=network, sensor_nodes=scada_sensor_nodes, scada_start=scada_start, scada_end=scada_end, ) observed_input: pd.DataFrame | dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] = observed_df observed_source = "backend_timerange" else: if observed_pressure_data is None: raise ValueError( "未提供 observed_pressure_data,且未提供 scada_start/scada_end。" ) observed_input = observed_pressure_data observed_source = "request_payload" detector = BurstDetector( mu=mu, points_per_day=points_per_day, iforest_params=iforest_params, ) result_df = detector.run_detection( observed_input, sensor_nodes=selected_sensor_nodes, ) resolved_sensor_nodes = list(result_df.attrs.get("sensor_nodes", [])) rows = _serialize_result_rows(result_df) payload: dict[str, Any] = { "network": network, "sensor_nodes": resolved_sensor_nodes, "observed_source": observed_source, "sample_count": int(result_df.attrs.get("sample_count", 0)), "points_per_day": int(result_df.attrs.get("points_per_day", points_per_day)), "day_count": int(result_df.attrs.get("day_count", len(result_df))), "rows": rows, "summary": _build_detection_summary(result_df), } if use_scada_source: payload["scada_window"] = { "start": _to_datetime(scada_start).isoformat(), "end": _to_datetime(scada_end).isoformat(), } if scheme_name: _store_burst_detection_scheme( network=network, scheme_name=scheme_name, username=username, payload=payload, mu=mu, points_per_day=points_per_day, iforest_params=detector.iforest_params, ) payload["scheme_name"] = scheme_name return payload def list_burst_detection_schemes( network: str, query_date: datetime | str | None = None, ) -> list[dict[str, Any]]: parsed_date = _to_datetime(query_date).date() if query_date is not None else None return query_burst_detection_schemes( name=network, network=network, query_date=parsed_date, ) def get_burst_detection_scheme_detail(network: str, scheme_name: str) -> dict[str, Any]: result = query_burst_detection_scheme_detail(network, scheme_name) if not result: raise ValueError(f"未找到爆管侦测方案: {scheme_name}") return result def _store_burst_detection_scheme( *, network: str, scheme_name: str, username: str, payload: dict[str, Any], mu: int, points_per_day: int, iforest_params: dict[str, Any], ) -> None: if scheme_name_exists(network, scheme_name): raise ValueError(f"方案名称已存在: {scheme_name}") now_iso = datetime.now().isoformat() scheme_detail = { "network": network, "sensor_nodes": payload.get("sensor_nodes", []), "observed_source": payload.get("observed_source"), "scada_window": payload.get("scada_window"), "algorithm_params": { "mu": mu, "points_per_day": points_per_day, "iforest_params": iforest_params, }, "result_summary": payload.get("summary", {}), "result_payload": payload, } store_scheme_info( name=network, scheme_name=scheme_name, scheme_type="burst_detection", username=username, scheme_start_time=now_iso, scheme_detail=scheme_detail, ) def _serialize_result_rows(result_df: pd.DataFrame) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for row in result_df.to_dict(orient="records"): rows.append( { "Day": int(row["Day"]), "Score": float(row["Score"]), "Prediction": int(row["Prediction"]), "IsBurst": bool(row["IsBurst"]), } ) return rows def _build_detection_summary(result_df: pd.DataFrame) -> dict[str, Any]: rows = _serialize_result_rows(result_df) if not rows: raise ValueError("爆管侦测结果为空。") score_series = result_df["Score"] most_anomalous_index = int(score_series.idxmin()) latest_row = rows[-1] anomaly_days = [row["Day"] for row in rows if row["IsBurst"]] return { "burst_detected": bool(latest_row["IsBurst"]), "latest_day": latest_row, "most_anomalous_day": int(result_df.iloc[most_anomalous_index]["Day"]), "anomaly_days": anomaly_days, "anomaly_day_count": len(anomaly_days), "latest_sensor_rankings": _build_latest_sensor_rankings(result_df), } def _build_latest_sensor_rankings(result_df: pd.DataFrame) -> list[dict[str, Any]]: feature_matrix = result_df.attrs.get("high_freq_features") sensor_nodes = list(result_df.attrs.get("sensor_nodes", [])) if feature_matrix is None or len(sensor_nodes) == 0: return [] latest_values = feature_matrix[-1] ranking = sorted( zip(sensor_nodes, latest_values, strict=False), key=lambda item: item[1], ) return [ { "sensor_node": sensor_id, "latest_high_frequency_value": float(value), } for sensor_id, value in ranking[: min(10, len(ranking))] ] def _get_pressure_sensor_nodes(network: str) -> list[str]: sensor_nodes: list[str] = [] for item in get_all_scada_info(network): if str(item.get("type", "")).lower() != "pressure": continue node_id = item.get("associated_element_id") if isinstance(node_id, str) and node_id: sensor_nodes.append(node_id) sensor_nodes = list(dict.fromkeys(sensor_nodes)) if not sensor_nodes: raise ValueError("未找到压力传感器对应节点(scada_info.type=pressure)。") return sensor_nodes def _build_observed_pressure_from_scada( *, network: str, sensor_nodes: list[str], scada_start: datetime | str | None, scada_end: datetime | str | None, ) -> pd.DataFrame: if scada_start is None or scada_end is None: raise ValueError("使用后端 SCADA 查询时必须同时提供 scada_start 与 scada_end。") start_dt = _to_datetime(scada_start) end_dt = _to_datetime(scada_end) if start_dt >= end_dt: raise ValueError("SCADA 时间窗非法:scada_start 必须早于 scada_end。") node_query_id: dict[str, str] = {} for item in get_all_scada_info(network): if str(item.get("type", "")).lower() != "pressure": continue node_id = item.get("associated_element_id") query_id = item.get("api_query_id") if ( isinstance(node_id, str) and node_id and isinstance(query_id, str) and query_id ): node_query_id[node_id] = query_id missing_nodes = [node_id for node_id in sensor_nodes if node_id not in node_query_id] if missing_nodes: preview = ", ".join(missing_nodes[:10]) raise ValueError(f"未找到可用于压力观测的 SCADA api_query_id: {preview}") query_ids = [node_query_id[node_id] for node_id in sensor_nodes] scada_data = InternalQueries.query_scada_by_ids_timerange( db_name=network, device_ids=query_ids, start_time=start_dt.isoformat(), end_time=end_dt.isoformat(), ) available_lengths = [ len(scada_data.get(query_id, [])) for query_id in query_ids if len(scada_data.get(query_id, [])) > 0 ] if not available_lengths: raise ValueError("指定时间窗内未查询到压力 SCADA 数据。") min_len = min(available_lengths) observation_df = pd.DataFrame() for node_id in sensor_nodes: query_id = node_query_id[node_id] records = scada_data.get(query_id, [])[:min_len] if len(records) < min_len: continue observation_df[node_id] = [float(item["value"]) for item in records] if observation_df.empty: raise ValueError("SCADA 压力数据无法构建观测矩阵。") return observation_df def _to_datetime(value: datetime | str) -> datetime: if isinstance(value, datetime): return value return datetime.fromisoformat(value)