from typing import List, Any, Dict from datetime import datetime, timedelta, timezone from psycopg import AsyncConnection, Connection, sql import globals # 定义UTC+8时区 UTC_8 = timezone(timedelta(hours=8)) class SchemeRepository: # --- Link Simulation --- @staticmethod async def insert_links_batch(conn: AsyncConnection, data: List[dict]): """Batch insert for scheme.link_simulation using INSERT for performance.""" if not data: return query = """ INSERT INTO scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (time, scheme_type, scheme_name, id) DO UPDATE SET flow = EXCLUDED.flow, friction = EXCLUDED.friction, headloss = EXCLUDED.headloss, quality = EXCLUDED.quality, reaction = EXCLUDED.reaction, setting = EXCLUDED.setting, status = EXCLUDED.status, velocity = EXCLUDED.velocity """ async with conn.cursor() as cur: await cur.executemany(query, [ (item["time"], item["scheme_type"], item["scheme_name"], item["id"], item.get("flow"), item.get("friction"), item.get("headloss"), item.get("quality"), item.get("reaction"), item.get("setting"), item.get("status"), item.get("velocity")) for item in data ]) @staticmethod def insert_links_batch_sync(conn: Connection, data: List[dict]): """Batch insert for scheme.link_simulation using INSERT for performance (sync version).""" if not data: return query = """ INSERT INTO scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (time, scheme_type, scheme_name, id) DO UPDATE SET flow = EXCLUDED.flow, friction = EXCLUDED.friction, headloss = EXCLUDED.headloss, quality = EXCLUDED.quality, reaction = EXCLUDED.reaction, setting = EXCLUDED.setting, status = EXCLUDED.status, velocity = EXCLUDED.velocity """ with conn.cursor() as cur: cur.executemany(query, [ (item["time"], item["scheme_type"], item["scheme_name"], item["id"], item.get("flow"), item.get("friction"), item.get("headloss"), item.get("quality"), item.get("reaction"), item.get("setting"), item.get("status"), item.get("velocity")) for item in data ]) @staticmethod async def get_link_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, link_id: str, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s", (scheme_type, scheme_name, start_time, end_time, link_id), ) return await cur.fetchall() @staticmethod async def get_links_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s", (scheme_type, scheme_name, start_time, end_time), ) return await cur.fetchall() @staticmethod async def get_link_field_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, link_id: str, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute( query, (scheme_type, scheme_name, start_time, end_time, link_id) ) row = await cur.fetchone() return row[field] if row else None @staticmethod async def get_links_field_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme_type, scheme_name, start_time, end_time)) row = await cur.fetchone() return row[field] if row else None @staticmethod async def update_link_field( conn: AsyncConnection, time: datetime, scheme_type: str, scheme_name: str, link_id: str, field: str, value: Any, ): valid_fields = { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (value, time, scheme_type, scheme_name, link_id)) @staticmethod async def delete_links_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, ): async with conn.cursor() as cur: await cur.execute( "DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s", (scheme_type, scheme_name, start_time, end_time), ) # --- Node Simulation --- @staticmethod async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]): if not data: return query = """ INSERT INTO scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (time, scheme_type, scheme_name, id) DO UPDATE SET actual_demand = EXCLUDED.actual_demand, total_head = EXCLUDED.total_head, pressure = EXCLUDED.pressure, quality = EXCLUDED.quality """ async with conn.cursor() as cur: await cur.executemany(query, [ (item["time"], item["scheme_type"], item["scheme_name"], item["id"], item.get("actual_demand"), item.get("total_head"), item.get("pressure"), item.get("quality")) for item in data ]) @staticmethod def insert_nodes_batch_sync(conn: Connection, data: List[dict]): if not data: return query = """ INSERT INTO scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (time, scheme_type, scheme_name, id) DO UPDATE SET actual_demand = EXCLUDED.actual_demand, total_head = EXCLUDED.total_head, pressure = EXCLUDED.pressure, quality = EXCLUDED.quality """ with conn.cursor() as cur: cur.executemany(query, [ (item["time"], item["scheme_type"], item["scheme_name"], item["id"], item.get("actual_demand"), item.get("total_head"), item.get("pressure"), item.get("quality")) for item in data ]) @staticmethod async def get_node_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, node_id: str, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s", (scheme_type, scheme_name, start_time, end_time, node_id), ) return await cur.fetchall() @staticmethod async def get_nodes_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s", (scheme_type, scheme_name, start_time, end_time), ) return await cur.fetchall() @staticmethod async def get_node_field_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, node_id: str, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute( query, (scheme_type, scheme_name, start_time, end_time, node_id) ) row = await cur.fetchone() return row[field] if row else None @staticmethod async def get_nodes_field_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme_type, scheme_name, start_time, end_time)) row = await cur.fetchone() return row[field] if row else None @staticmethod async def update_node_field( conn: AsyncConnection, time: datetime, scheme_type: str, scheme_name: str, node_id: str, field: str, value: Any, ): valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (value, time, scheme_type, scheme_name, node_id)) @staticmethod async def delete_nodes_by_scheme_and_time_range( conn: AsyncConnection, scheme_type: str, scheme_name: str, start_time: datetime, end_time: datetime, ): async with conn.cursor() as cur: await cur.execute( "DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s", (scheme_type, scheme_name, start_time, end_time), ) # --- 复合查询 --- @staticmethod async def store_scheme_simulation_result( conn: AsyncConnection, scheme_type: str, scheme_name: str, node_result_list: List[Dict[str, any]], link_result_list: List[Dict[str, any]], result_start_time: str, num_periods: int = 1, ): """ Store scheme simulation results to TimescaleDB. Args: conn: Database connection scheme_type: Scheme type scheme_name: Scheme name node_result_list: List of node simulation results link_result_list: List of link simulation results result_start_time: Start time for the results (ISO format string) """ # Convert result_start_time string to datetime if needed if isinstance(result_start_time, str): # 如果是ISO格式字符串,解析并转换为UTC+8 if result_start_time.endswith("Z"): # UTC时间,转换为UTC+8 utc_time = datetime.fromisoformat( result_start_time.replace("Z", "+00:00") ) simulation_time = utc_time.astimezone(UTC_8) else: # 假设已经是UTC+8时间 simulation_time = datetime.fromisoformat(result_start_time) if simulation_time.tzinfo is None: simulation_time = simulation_time.replace(tzinfo=UTC_8) else: simulation_time = result_start_time if simulation_time.tzinfo is None: simulation_time = simulation_time.replace(tzinfo=UTC_8) timestep_parts = globals.hydraulic_timestep.split(":") timestep = timedelta( hours=int(timestep_parts[0]), minutes=int(timestep_parts[1]), seconds=int(timestep_parts[2]), ) # Prepare node data for batch insert node_data = [] for node_result in node_result_list: node_id = node_result.get("node") for period_index in range(num_periods): current_time = simulation_time + (timestep * period_index) data = node_result.get("result", [])[period_index] node_data.append( { "time": current_time, "scheme_type": scheme_type, "scheme_name": scheme_name, "id": node_id, "actual_demand": data.get("demand"), "total_head": data.get("head"), "pressure": data.get("pressure"), "quality": data.get("quality"), } ) # Prepare link data for batch insert link_data = [] for link_result in link_result_list: link_id = link_result.get("link") for period_index in range(num_periods): current_time = simulation_time + (timestep * period_index) data = link_result.get("result", [])[period_index] link_data.append( { "time": current_time, "scheme_type": scheme_type, "scheme_name": scheme_name, "id": link_id, "flow": data.get("flow"), "friction": data.get("friction"), "headloss": data.get("headloss"), "quality": data.get("quality"), "reaction": data.get("reaction"), "setting": data.get("setting"), "status": data.get("status"), "velocity": data.get("velocity"), } ) # Insert data using batch methods if node_data: await SchemeRepository.insert_nodes_batch(conn, node_data) if link_data: await SchemeRepository.insert_links_batch(conn, link_data) @staticmethod def store_scheme_simulation_result_sync( conn: Connection, scheme_type: str, scheme_name: str, node_result_list: List[Dict[str, any]], link_result_list: List[Dict[str, any]], result_start_time: str, num_periods: int = 1, ): """ Store scheme simulation results to TimescaleDB (sync version). Args: conn: Database connection scheme_type: Scheme type scheme_name: Scheme name node_result_list: List of node simulation results link_result_list: List of link simulation results result_start_time: Start time for the results (ISO format string) """ # Convert result_start_time string to datetime if needed if isinstance(result_start_time, str): # 如果是ISO格式字符串,解析并转换为UTC+8 if result_start_time.endswith("Z"): # UTC时间,转换为UTC+8 utc_time = datetime.fromisoformat( result_start_time.replace("Z", "+00:00") ) simulation_time = utc_time.astimezone(UTC_8) else: # 假设已经是UTC+8时间 simulation_time = datetime.fromisoformat(result_start_time) if simulation_time.tzinfo is None: simulation_time = simulation_time.replace(tzinfo=UTC_8) else: simulation_time = result_start_time if simulation_time.tzinfo is None: simulation_time = simulation_time.replace(tzinfo=UTC_8) timestep_parts = globals.hydraulic_timestep.split(":") timestep = timedelta( hours=int(timestep_parts[0]), minutes=int(timestep_parts[1]), seconds=int(timestep_parts[2]), ) # Prepare node data for batch insert node_data = [] for node_result in node_result_list: node_id = node_result.get("node") for period_index in range(num_periods): current_time = simulation_time + (timestep * period_index) data = node_result.get("result", [])[period_index] node_data.append( { "time": current_time, "scheme_type": scheme_type, "scheme_name": scheme_name, "id": node_id, "actual_demand": data.get("demand"), "total_head": data.get("head"), "pressure": data.get("pressure"), "quality": data.get("quality"), } ) # Prepare link data for batch insert link_data = [] for link_result in link_result_list: link_id = link_result.get("link") for period_index in range(num_periods): current_time = simulation_time + (timestep * period_index) data = link_result.get("result", [])[period_index] link_data.append( { "time": current_time, "scheme_type": scheme_type, "scheme_name": scheme_name, "id": link_id, "flow": data.get("flow"), "friction": data.get("friction"), "headloss": data.get("headloss"), "quality": data.get("quality"), "reaction": data.get("reaction"), "setting": data.get("setting"), "status": data.get("status"), "velocity": data.get("velocity"), } ) # Insert data using batch methods if node_data: SchemeRepository.insert_nodes_batch_sync(conn, node_data) if link_data: SchemeRepository.insert_links_batch_sync(conn, link_data) @staticmethod async def query_all_record_by_scheme_time_property( conn: AsyncConnection, scheme_type: str, scheme_name: str, query_time: str, type: str, property: str, ) -> list: """ Query all records by scheme, time and property from TimescaleDB. Args: conn: Database connection scheme_type: Scheme type scheme_name: Scheme name query_time: Time to query (ISO format string) type: Type of data ("node" or "link") property: Property/field to query Returns: List of records matching the criteria """ # Convert query_time string to datetime if isinstance(query_time, str): if query_time.endswith("Z"): # UTC时间,转换为UTC+8 utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00")) target_time = utc_time.astimezone(UTC_8) else: # 假设已经是UTC+8时间 target_time = datetime.fromisoformat(query_time) if target_time.tzinfo is None: target_time = target_time.replace(tzinfo=UTC_8) else: target_time = query_time if target_time.tzinfo is None: target_time = target_time.replace(tzinfo=UTC_8) # Create time range: query_time ± 1 second start_time = target_time - timedelta(seconds=1) end_time = target_time + timedelta(seconds=1) # Query based on type if type.lower() == "node": return await SchemeRepository.get_nodes_field_by_scheme_and_time_range( conn, scheme_type, scheme_name, start_time, end_time, property ) elif type.lower() == "link": return await SchemeRepository.get_links_field_by_scheme_and_time_range( conn, scheme_type, scheme_name, start_time, end_time, property ) else: raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'") @staticmethod async def query_scheme_simulation_result_by_ID_time( conn: AsyncConnection, scheme_type: str, scheme_name: str, ID: str, type: str, query_time: str, ) -> list[dict]: """ Query scheme simulation results by ID and time from TimescaleDB. Args: conn: Database connection scheme_type: Scheme type scheme_name: Scheme name ID: The ID of the node or link type: Type of data ("node" or "link") query_time: Time to query (ISO format string) Returns: List of records matching the criteria """ # Convert query_time string to datetime if isinstance(query_time, str): if query_time.endswith("Z"): # UTC时间,转换为UTC+8 utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00")) target_time = utc_time.astimezone(UTC_8) else: # 假设已经是UTC+8时间 target_time = datetime.fromisoformat(query_time) if target_time.tzinfo is None: target_time = target_time.replace(tzinfo=UTC_8) else: target_time = query_time if target_time.tzinfo is None: target_time = target_time.replace(tzinfo=UTC_8) # Create time range: query_time ± 1 second start_time = target_time - timedelta(seconds=1) end_time = target_time + timedelta(seconds=1) # Query based on type if type.lower() == "node": return await SchemeRepository.get_node_by_scheme_and_time_range( conn, scheme_type, scheme_name, start_time, end_time, ID ) elif type.lower() == "link": return await SchemeRepository.get_link_by_scheme_and_time_range( conn, scheme_type, scheme_name, start_time, end_time, ID ) else: raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")