416 lines
14 KiB
Python
416 lines
14 KiB
Python
from typing import List, Any, Dict
|
|
from datetime import datetime, timedelta
|
|
from psycopg import AsyncConnection, sql
|
|
|
|
|
|
class RealtimeRepository:
|
|
|
|
# --- Link Simulation ---
|
|
|
|
@staticmethod
|
|
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
|
|
"""Batch insert for realtime.link_simulation using COPY for performance."""
|
|
if not data:
|
|
return
|
|
|
|
async with conn.cursor() as cur:
|
|
async with cur.copy(
|
|
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
|
|
) as copy:
|
|
for item in data:
|
|
await copy.write_row(
|
|
(
|
|
item["time"],
|
|
item["id"],
|
|
item.get("flow"),
|
|
item.get("friction"),
|
|
item.get("headloss"),
|
|
item.get("quality"),
|
|
item.get("reaction"),
|
|
item.get("setting"),
|
|
item.get("status"),
|
|
item.get("velocity"),
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
async def get_link_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
|
|
) -> List[dict]:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(
|
|
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
|
|
(start_time, end_time, link_id),
|
|
)
|
|
return await cur.fetchall()
|
|
|
|
@staticmethod
|
|
async def get_links_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
|
) -> List[dict]:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(
|
|
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
|
|
(start_time, end_time),
|
|
)
|
|
return await cur.fetchall()
|
|
|
|
@staticmethod
|
|
async def get_link_field_by_time_range(
|
|
conn: AsyncConnection,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
link_id: str,
|
|
field: str,
|
|
) -> Any:
|
|
# Validate field name to prevent SQL injection
|
|
valid_fields = {
|
|
"flow",
|
|
"friction",
|
|
"headloss",
|
|
"quality",
|
|
"reaction",
|
|
"setting",
|
|
"status",
|
|
"velocity",
|
|
}
|
|
if field not in valid_fields:
|
|
raise ValueError(f"Invalid field: {field}")
|
|
|
|
query = sql.SQL(
|
|
"SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
|
|
).format(sql.Identifier(field))
|
|
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, (start_time, end_time, link_id))
|
|
row = await cur.fetchone()
|
|
return row[field] if row else None
|
|
|
|
@staticmethod
|
|
async def get_links_field_by_time_range(
|
|
conn: AsyncConnection,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
link_id: str,
|
|
field: str,
|
|
) -> Any:
|
|
# Validate field name to prevent SQL injection
|
|
valid_fields = {
|
|
"flow",
|
|
"friction",
|
|
"headloss",
|
|
"quality",
|
|
"reaction",
|
|
"setting",
|
|
"status",
|
|
"velocity",
|
|
}
|
|
if field not in valid_fields:
|
|
raise ValueError(f"Invalid field: {field}")
|
|
|
|
query = sql.SQL(
|
|
"SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
|
|
).format(sql.Identifier(field))
|
|
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, (start_time, end_time))
|
|
row = await cur.fetchone()
|
|
return row[field] if row else None
|
|
|
|
@staticmethod
|
|
async def update_link_field(
|
|
conn: AsyncConnection,
|
|
time: datetime,
|
|
link_id: str,
|
|
field: str,
|
|
value: Any,
|
|
):
|
|
valid_fields = {
|
|
"flow",
|
|
"friction",
|
|
"headloss",
|
|
"quality",
|
|
"reaction",
|
|
"setting",
|
|
"status",
|
|
"velocity",
|
|
}
|
|
if field not in valid_fields:
|
|
raise ValueError(f"Invalid field: {field}")
|
|
|
|
query = sql.SQL(
|
|
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
|
|
).format(sql.Identifier(field))
|
|
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, (value, time, link_id))
|
|
|
|
@staticmethod
|
|
async def delete_links_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
|
):
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(
|
|
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
|
|
(start_time, end_time),
|
|
)
|
|
|
|
# --- Node Simulation ---
|
|
|
|
@staticmethod
|
|
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
|
|
if not data:
|
|
return
|
|
|
|
async with conn.cursor() as cur:
|
|
async with cur.copy(
|
|
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
|
|
) as copy:
|
|
for item in data:
|
|
await copy.write_row(
|
|
(
|
|
item["time"],
|
|
item["id"],
|
|
item.get("actual_demand"),
|
|
item.get("total_head"),
|
|
item.get("pressure"),
|
|
item.get("quality"),
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
async def get_node_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
|
|
) -> List[dict]:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(
|
|
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
|
|
(start_time, end_time, node_id),
|
|
)
|
|
return await cur.fetchall()
|
|
|
|
@staticmethod
|
|
async def get_nodes_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
|
) -> List[dict]:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(
|
|
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
|
|
(start_time, end_time),
|
|
)
|
|
return await cur.fetchall()
|
|
|
|
@staticmethod
|
|
async def get_node_field_by_time_range(
|
|
conn: AsyncConnection,
|
|
start_time: datetime,
|
|
end_time: datetime,
|
|
node_id: str,
|
|
field: str,
|
|
) -> Any:
|
|
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
|
if field not in valid_fields:
|
|
raise ValueError(f"Invalid field: {field}")
|
|
|
|
query = sql.SQL(
|
|
"SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
|
|
).format(sql.Identifier(field))
|
|
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, (start_time, end_time, node_id))
|
|
row = await cur.fetchone()
|
|
return row[field] if row else None
|
|
|
|
@staticmethod
|
|
async def get_nodes_field_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
|
|
) -> Any:
|
|
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
|
if field not in valid_fields:
|
|
raise ValueError(f"Invalid field: {field}")
|
|
|
|
query = sql.SQL(
|
|
"SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
|
|
).format(sql.Identifier(field))
|
|
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, (start_time, end_time))
|
|
row = await cur.fetchone()
|
|
return row[field] if row else None
|
|
|
|
@staticmethod
|
|
async def update_node_field(
|
|
conn: AsyncConnection,
|
|
time: datetime,
|
|
node_id: str,
|
|
field: str,
|
|
value: Any,
|
|
):
|
|
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
|
if field not in valid_fields:
|
|
raise ValueError(f"Invalid field: {field}")
|
|
|
|
query = sql.SQL(
|
|
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
|
|
).format(sql.Identifier(field))
|
|
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, (value, time, node_id))
|
|
|
|
@staticmethod
|
|
async def delete_nodes_by_time_range(
|
|
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
|
):
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(
|
|
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
|
|
(start_time, end_time),
|
|
)
|
|
|
|
# --- 复合查询 ---
|
|
|
|
@staticmethod
|
|
async def store_realtime_simulation_result(
|
|
conn: AsyncConnection,
|
|
node_result_list: List[Dict[str, any]],
|
|
link_result_list: List[Dict[str, any]],
|
|
result_start_time: str,
|
|
):
|
|
"""
|
|
Store realtime simulation results to TimescaleDB.
|
|
|
|
Args:
|
|
conn: Database connection
|
|
node_result_list: List of node simulation results
|
|
link_result_list: List of link simulation results
|
|
result_start_time: Start time for the results (ISO format string)
|
|
"""
|
|
# Convert result_start_time string to datetime if needed
|
|
if isinstance(result_start_time, str):
|
|
simulation_time = datetime.fromisoformat(
|
|
result_start_time.replace("Z", "+00:00")
|
|
)
|
|
else:
|
|
simulation_time = result_start_time
|
|
|
|
# Prepare node data for batch insert
|
|
node_data = []
|
|
for node_result in node_result_list:
|
|
node_data.append(
|
|
{
|
|
"time": simulation_time,
|
|
"id": node_result.get("id"),
|
|
"actual_demand": node_result.get("actual_demand"),
|
|
"total_head": node_result.get("total_head"),
|
|
"pressure": node_result.get("pressure"),
|
|
"quality": node_result.get("quality"),
|
|
}
|
|
)
|
|
|
|
# Prepare link data for batch insert
|
|
link_data = []
|
|
for link_result in link_result_list:
|
|
link_data.append(
|
|
{
|
|
"time": simulation_time,
|
|
"id": link_result.get("id"),
|
|
"flow": link_result.get("flow"),
|
|
"friction": link_result.get("friction"),
|
|
"headloss": link_result.get("headloss"),
|
|
"quality": link_result.get("quality"),
|
|
"reaction": link_result.get("reaction"),
|
|
"setting": link_result.get("setting"),
|
|
"status": link_result.get("status"),
|
|
"velocity": link_result.get("velocity"),
|
|
}
|
|
)
|
|
|
|
# Insert data using batch methods
|
|
if node_data:
|
|
await RealtimeRepository.insert_nodes_batch(conn, node_data)
|
|
|
|
if link_data:
|
|
await RealtimeRepository.insert_links_batch(conn, link_data)
|
|
|
|
@staticmethod
|
|
async def query_all_record_by_time_property(
|
|
conn: AsyncConnection,
|
|
query_time: str,
|
|
type: str,
|
|
property: str,
|
|
) -> list:
|
|
"""
|
|
Query all records by time and property from TimescaleDB.
|
|
|
|
Args:
|
|
conn: Database connection
|
|
query_time: Time to query (ISO format string)
|
|
type: Type of data ("node" or "link")
|
|
property: Property/field to query
|
|
|
|
Returns:
|
|
List of records matching the criteria
|
|
"""
|
|
# Convert query_time string to datetime
|
|
if isinstance(query_time, str):
|
|
target_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
|
|
else:
|
|
target_time = query_time
|
|
|
|
# Create time range: query_time ± 1 second
|
|
start_time = target_time - timedelta(seconds=1)
|
|
end_time = target_time + timedelta(seconds=1)
|
|
|
|
# Query based on type
|
|
if type.lower() == "node":
|
|
return await RealtimeRepository.get_nodes_field_by_time_range(
|
|
conn, start_time, end_time, property
|
|
)
|
|
elif type.lower() == "link":
|
|
return await RealtimeRepository.get_links_field_by_time_range(
|
|
conn, start_time, end_time, property
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
|
|
|
|
@staticmethod
|
|
async def query_simulation_result_by_ID_time(
|
|
conn: AsyncConnection,
|
|
ID: str,
|
|
type: str,
|
|
query_time: str,
|
|
) -> list[dict]:
|
|
"""
|
|
Query simulation results by ID and time from TimescaleDB.
|
|
|
|
Args:
|
|
conn: Database connection
|
|
ID: The ID of the node or link
|
|
type: Type of data ("node" or "link")
|
|
query_time: Time to query (ISO format string)
|
|
|
|
Returns:
|
|
List of records matching the criteria
|
|
"""
|
|
# Convert query_time string to datetime
|
|
if isinstance(query_time, str):
|
|
target_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
|
|
else:
|
|
target_time = query_time
|
|
|
|
# Create time range: query_time ± 1 second
|
|
start_time = target_time - timedelta(seconds=1)
|
|
end_time = target_time + timedelta(seconds=1)
|
|
|
|
# Query based on type
|
|
if type.lower() == "node":
|
|
return await RealtimeRepository.get_node_by_time_range(
|
|
conn, start_time, end_time, ID
|
|
)
|
|
elif type.lower() == "link":
|
|
return await RealtimeRepository.get_link_by_time_range(
|
|
conn, start_time, end_time, ID
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
|