from typing import List from fastapi.logger import logger from datetime import datetime, timedelta import psycopg from psycopg import sql from psycopg.rows import dict_row import time from app.infra.db.timescaledb.schemas.scheme import SchemeRepository from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository import app.infra.db.timescaledb.timescaledb_info as timescaledb_info from app.infra.db.timescaledb.schemas.scada import ScadaRepository class InternalStorage: @staticmethod def store_realtime_simulation( node_result_list: List[dict], link_result_list: List[dict], result_start_time: str, db_name: str = None, max_retries: int = 3, ): """存储实时模拟结果""" for attempt in range(max_retries): try: conn_string = ( timescaledb_info.get_pgconn_string(db_name=db_name) if db_name else timescaledb_info.get_pgconn_string() ) with psycopg.Connection.connect(conn_string) as conn: RealtimeRepository.store_realtime_simulation_result_sync( conn, node_result_list, link_result_list, result_start_time ) break # 成功 except Exception as e: logger.error(f"存储尝试 {attempt + 1} 失败: {e}") if attempt < max_retries - 1: time.sleep(1) # 重试前等待 else: raise # 达到最大重试次数后抛出异常 @staticmethod def store_scheme_simulation( scheme_type: str, scheme_name: str, node_result_list: List[dict], link_result_list: List[dict], result_start_time: str, num_periods: int = 1, db_name: str = None, max_retries: int = 3, ): """存储方案模拟结果""" for attempt in range(max_retries): try: conn_string = ( timescaledb_info.get_pgconn_string(db_name=db_name) if db_name else timescaledb_info.get_pgconn_string() ) with psycopg.Connection.connect(conn_string) as conn: SchemeRepository.store_scheme_simulation_result_sync( conn, scheme_type, scheme_name, node_result_list, link_result_list, result_start_time, num_periods, ) break # 成功 except Exception as e: logger.error(f"存储尝试 {attempt + 1} 失败: {e}") if attempt < max_retries - 1: time.sleep(1) # 重试前等待 else: raise # 达到最大重试次数后抛出异常 class InternalQueries: @staticmethod def query_scada_by_ids_time( device_ids: List[str], query_time: str, db_name: str = None, max_retries: int = 3, ) -> dict: """查询指定时间点的 SCADA 数据""" # 解析时间,假设是北京时间 beijing_time = datetime.fromisoformat(query_time) start_time = beijing_time - timedelta(seconds=1) end_time = beijing_time + timedelta(seconds=1) for attempt in range(max_retries): try: conn_string = ( timescaledb_info.get_pgconn_string(db_name=db_name) if db_name else timescaledb_info.get_pgconn_string() ) with psycopg.Connection.connect(conn_string) as conn: rows = ScadaRepository.get_scada_by_ids_time_range_sync( conn, device_ids, start_time, end_time ) # 处理结果,返回每个 device_id 的第一个值 result = {} for device_id in device_ids: device_rows = [ row for row in rows if row["device_id"] == device_id ] if device_rows: result[device_id] = device_rows[0]["monitored_value"] else: result[device_id] = None return result except Exception as e: logger.error(f"查询尝试 {attempt + 1} 失败: {e}") if attempt < max_retries - 1: time.sleep(1) else: raise @staticmethod def query_scada_by_ids_timerange( device_ids: List[str], start_time: str | datetime, end_time: str | datetime, db_name: str = None, max_retries: int = 3, ) -> dict[str, list[dict]]: """查询指定时间窗的 SCADA 数据,返回 {device_id: [{time, value}, ...]}。""" start_dt = ( datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time ) end_dt = ( datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time ) for attempt in range(max_retries): try: conn_string = ( timescaledb_info.get_pgconn_string(db_name=db_name) if db_name else timescaledb_info.get_pgconn_string() ) with psycopg.Connection.connect(conn_string) as conn: rows = ScadaRepository.get_scada_by_ids_time_range_sync( conn, device_ids, start_dt, end_dt ) result: dict[str, list[dict]] = { device_id: [] for device_id in device_ids } for row in rows: device_id = row["device_id"] value = row.get("cleaned_value") if value is None: value = row.get("monitored_value") result.setdefault(device_id, []).append( {"time": row["time"].isoformat(), "value": value} ) for device_id in result: result[device_id].sort(key=lambda item: item["time"]) return result except Exception as e: logger.error(f"查询尝试 {attempt + 1} 失败: {e}") if attempt < max_retries - 1: time.sleep(1) else: raise @staticmethod def query_realtime_simulation_by_ids_timerange( element_ids: List[str], start_time: str | datetime, end_time: str | datetime, element_type: str, field: str, db_name: str = None, max_retries: int = 3, ) -> dict[str, list[dict]]: """查询实时模拟结果,返回 {id: [{time, value}, ...]}。""" return InternalQueries._query_simulation_by_ids_timerange( schema_name="realtime", element_ids=element_ids, start_time=start_time, end_time=end_time, element_type=element_type, field=field, db_name=db_name, max_retries=max_retries, ) @staticmethod def query_scheme_simulation_by_ids_timerange( element_ids: List[str], start_time: str | datetime, end_time: str | datetime, element_type: str, field: str, scheme_type: str, scheme_name: str, db_name: str = None, max_retries: int = 3, ) -> dict[str, list[dict]]: """查询方案模拟结果,返回 {id: [{time, value}, ...]}。""" return InternalQueries._query_simulation_by_ids_timerange( schema_name="scheme", element_ids=element_ids, start_time=start_time, end_time=end_time, element_type=element_type, field=field, db_name=db_name, max_retries=max_retries, scheme_type=scheme_type, scheme_name=scheme_name, ) @staticmethod def _query_simulation_by_ids_timerange( *, schema_name: str, element_ids: List[str], start_time: str | datetime, end_time: str | datetime, element_type: str, field: str, db_name: str = None, max_retries: int = 3, scheme_type: str | None = None, scheme_name: str | None = None, ) -> dict[str, list[dict]]: if not element_ids: return {} start_dt = ( datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time ) end_dt = ( datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time ) table_name, valid_fields = InternalQueries._resolve_simulation_table(element_type) if field not in valid_fields: raise ValueError(f"Invalid field for {element_type}: {field}") if schema_name not in {"realtime", "scheme"}: raise ValueError(f"Unsupported schema_name: {schema_name}") if schema_name == "scheme" and (not scheme_type or not scheme_name): raise ValueError("scheme 查询必须提供 scheme_type 和 scheme_name。") for attempt in range(max_retries): try: conn_string = ( timescaledb_info.get_pgconn_string(db_name=db_name) if db_name else timescaledb_info.get_pgconn_string() ) with psycopg.Connection.connect(conn_string) as conn: with conn.cursor(row_factory=dict_row) as cur: if schema_name == "scheme": query = sql.SQL( "SELECT id, time, {} FROM {}.{} " "WHERE scheme_type = %s AND scheme_name = %s " "AND time >= %s AND time <= %s AND id = ANY(%s)" ).format( sql.Identifier(field), sql.Identifier(schema_name), sql.Identifier(table_name), ) cur.execute( query, ( scheme_type, scheme_name, start_dt, end_dt, element_ids, ), ) else: query = sql.SQL( "SELECT id, time, {} FROM {}.{} " "WHERE time >= %s AND time <= %s AND id = ANY(%s)" ).format( sql.Identifier(field), sql.Identifier(schema_name), sql.Identifier(table_name), ) cur.execute(query, (start_dt, end_dt, element_ids)) rows = cur.fetchall() result: dict[str, list[dict]] = { element_id: [] for element_id in element_ids } for row in rows: result.setdefault(row["id"], []).append( {"time": row["time"].isoformat(), "value": row[field]} ) for element_id in result: result[element_id].sort(key=lambda item: item["time"]) return result except Exception as e: logger.error(f"查询尝试 {attempt + 1} 失败: {e}") if attempt < max_retries - 1: time.sleep(1) else: raise @staticmethod def _resolve_simulation_table(element_type: str) -> tuple[str, set[str]]: normalized_type = element_type.lower() if normalized_type == "node": return "node_simulation", {"actual_demand", "total_head", "pressure", "quality"} if normalized_type == "link": return "link_simulation", { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } raise ValueError(f"Unsupported element_type: {element_type}")