from typing import List, Any, Optional from datetime import datetime from psycopg import AsyncConnection, sql class SchemeRepository: # --- Link Simulation --- @staticmethod async def insert_links_batch(conn: AsyncConnection, data: List[dict]): """Batch insert for scheme.link_simulation using COPY for performance.""" if not data: return async with conn.cursor() as cur: async with cur.copy( "COPY scheme.link_simulation (time, scheme, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" ) as copy: for item in data: await copy.write_row( ( item["time"], item["scheme"], item["id"], item.get("flow"), item.get("friction"), item.get("headloss"), item.get("quality"), item.get("reaction"), item.get("setting"), item.get("status"), item.get("velocity"), ) ) @staticmethod async def get_link_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime, link_id: str, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s", (scheme, start_time, end_time, link_id), ) return await cur.fetchall() @staticmethod async def get_links_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s", (scheme, start_time, end_time), ) return await cur.fetchall() @staticmethod async def get_link_field_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime, link_id: str, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme, start_time, end_time, link_id)) row = await cur.fetchone() return row[field] if row else None @staticmethod async def get_links_field_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme, start_time, end_time)) row = await cur.fetchone() return row[field] if row else None @staticmethod async def update_link_field( conn: AsyncConnection, time: datetime, scheme: str, link_id: str, field: str, value: Any, ): valid_fields = { "flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity", } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (value, time, scheme, link_id)) @staticmethod async def delete_links_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime ): async with conn.cursor() as cur: await cur.execute( "DELETE FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s", (scheme, start_time, end_time), ) # --- Node Simulation --- @staticmethod async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]): if not data: return async with conn.cursor() as cur: async with cur.copy( "COPY scheme.node_simulation (time, scheme, id, actual_demand, total_head, pressure, quality) FROM STDIN" ) as copy: for item in data: await copy.write_row( ( item["time"], item["scheme"], item["id"], item.get("actual_demand"), item.get("total_head"), item.get("pressure"), item.get("quality"), ) ) @staticmethod async def get_node_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime, node_id: str, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s", (scheme, start_time, end_time, node_id), ) return await cur.fetchall() @staticmethod async def get_nodes_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s", (scheme, start_time, end_time), ) return await cur.fetchall() @staticmethod async def get_node_field_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime, node_id: str, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme, start_time, end_time, node_id)) row = await cur.fetchone() return row[field] if row else None @staticmethod async def get_nodes_field_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime, field: str, ) -> Any: # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme, start_time, end_time)) row = await cur.fetchone() return row[field] if row else None @staticmethod async def update_node_field( conn: AsyncConnection, time: datetime, scheme: str, node_id: str, field: str, value: Any, ): valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (value, time, scheme, node_id)) @staticmethod async def delete_nodes_by_scheme_and_time_range( conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime ): async with conn.cursor() as cur: await cur.execute( "DELETE FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s", (scheme, start_time, end_time), )