Files
TJWaterServerBinary/app/infra/db/timescaledb/repositories/realtime.py
T

592 lines
21 KiB
Python

from typing import List, Any, Dict
from datetime import datetime, timedelta
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
from app.services.time_api import parse_utc_time
class RealtimeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
normalized_start_time = parse_utc_time(start_time, field_name="start_time")
normalized_end_time = parse_utc_time(end_time, field_name="end_time")
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(normalized_start_time, normalized_end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, link_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, link_id))
@staticmethod
async def delete_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
normalized_start_time = parse_utc_time(start_time, field_name="start_time")
normalized_end_time = parse_utc_time(end_time, field_name="end_time")
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(normalized_start_time, normalized_end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, node_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
) -> dict:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, node_id))
@staticmethod
async def delete_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_realtime_simulation_result(
conn: AsyncConnection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB.
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
simulation_time = parse_utc_time(
result_start_time, field_name="result_start_time"
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await RealtimeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await RealtimeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_realtime_simulation_result_sync(
conn: Connection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
simulation_time = parse_utc_time(
result_start_time, field_name="result_start_time"
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
RealtimeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
RealtimeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_time_property(
conn: AsyncConnection,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by time and property from TimescaleDB.
Args:
conn: Database connection
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
target_time = parse_utc_time(query_time, field_name="query_time")
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await RealtimeRepository.get_nodes_field_by_time_range(
conn, start_time, end_time, property
)
elif type.lower() == "link":
data = await RealtimeRepository.get_links_field_by_time_range(
conn, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_simulation_result_by_id_time(
conn: AsyncConnection,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
target_time = parse_utc_time(query_time, field_name="query_time")
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await RealtimeRepository.get_node_by_time_range(
conn, start_time, end_time, id
)
elif type.lower() == "link":
return await RealtimeRepository.get_link_by_time_range(
conn, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")