新增 scheme 表下的字段 scheme_type scheme_name

This commit is contained in:
JIANG
2025-12-05 18:27:58 +08:00
parent 4231243b96
commit 4fbdea435b
5 changed files with 324 additions and 111 deletions

View File

@@ -1,6 +1,7 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from psycopg import AsyncConnection, sql
import globals
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
@@ -18,13 +19,14 @@ class SchemeRepository:
async with conn.cursor() as cur:
async with cur.copy(
"COPY scheme.link_simulation (time, scheme, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
@@ -40,33 +42,39 @@ class SchemeRepository:
@staticmethod
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s",
(scheme, start_time, end_time, link_id),
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
@@ -87,18 +95,21 @@ class SchemeRepository:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s"
"SELECT {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time, link_id))
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, link_id)
)
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def get_links_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
@@ -118,11 +129,11 @@ class SchemeRepository:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s"
"SELECT {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time))
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
row = await cur.fetchone()
return row[field] if row else None
@@ -130,7 +141,8 @@ class SchemeRepository:
async def update_link_field(
conn: AsyncConnection,
time: datetime,
scheme: str,
scheme_type: str,
scheme_name: str,
link_id: str,
field: str,
value: Any,
@@ -149,20 +161,24 @@ class SchemeRepository:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s"
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme, link_id))
await cur.execute(query, (value, time, scheme_type, scheme_name, link_id))
@staticmethod
async def delete_links_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
"DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- Node Simulation ---
@@ -174,13 +190,14 @@ class SchemeRepository:
async with conn.cursor() as cur:
async with cur.copy(
"COPY scheme.node_simulation (time, scheme, id, actual_demand, total_head, pressure, quality) FROM STDIN"
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
@@ -192,33 +209,39 @@ class SchemeRepository:
@staticmethod
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s",
(scheme, start_time, end_time, node_id),
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
@@ -230,18 +253,21 @@ class SchemeRepository:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s"
"SELECT {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time, node_id))
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, node_id)
)
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def get_nodes_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
@@ -252,11 +278,11 @@ class SchemeRepository:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s"
"SELECT {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time))
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
row = await cur.fetchone()
return row[field] if row else None
@@ -264,7 +290,8 @@ class SchemeRepository:
async def update_node_field(
conn: AsyncConnection,
time: datetime,
scheme: str,
scheme_type: str,
scheme_name: str,
node_id: str,
field: str,
value: Any,
@@ -274,20 +301,24 @@ class SchemeRepository:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s"
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme, node_id))
await cur.execute(query, (value, time, scheme_type, scheme_name, node_id))
@staticmethod
async def delete_nodes_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
"DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- 复合查询 ---
@@ -295,17 +326,20 @@ class SchemeRepository:
@staticmethod
async def store_scheme_simulation_result(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB.
Args:
conn: Database connection
scheme: Scheme name
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
@@ -313,9 +347,11 @@ class SchemeRepository:
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith('Z'):
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(result_start_time.replace("Z", "+00:00"))
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
@@ -327,39 +363,56 @@ class SchemeRepository:
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_data.append(
{
"time": simulation_time,
"scheme": scheme,
"id": node_result.get("id"),
"actual_demand": node_result.get("actual_demand"),
"total_head": node_result.get("total_head"),
"pressure": node_result.get("pressure"),
"quality": node_result.get("quality"),
}
)
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_data.append(
{
"time": simulation_time,
"scheme": scheme,
"id": link_result.get("id"),
"flow": link_result.get("flow"),
"friction": link_result.get("friction"),
"headloss": link_result.get("headloss"),
"quality": link_result.get("quality"),
"reaction": link_result.get("reaction"),
"setting": link_result.get("setting"),
"status": link_result.get("status"),
"velocity": link_result.get("velocity"),
}
)
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
@@ -371,7 +424,8 @@ class SchemeRepository:
@staticmethod
async def query_all_record_by_scheme_time_property(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
@@ -381,7 +435,8 @@ class SchemeRepository:
Args:
conn: Database connection
scheme: Scheme name
scheme_type: Scheme type
scheme_name: Scheme name
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
@@ -391,7 +446,7 @@ class SchemeRepository:
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith('Z'):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
@@ -412,11 +467,11 @@ class SchemeRepository:
# Query based on type
if type.lower() == "node":
return await SchemeRepository.get_nodes_field_by_scheme_and_time_range(
conn, scheme, start_time, end_time, property
conn, scheme_type, scheme_name, start_time, end_time, property
)
elif type.lower() == "link":
return await SchemeRepository.get_links_field_by_scheme_and_time_range(
conn, scheme, start_time, end_time, property
conn, scheme_type, scheme_name, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
@@ -424,7 +479,8 @@ class SchemeRepository:
@staticmethod
async def query_scheme_simulation_result_by_ID_time(
conn: AsyncConnection,
scheme: str,
scheme_type: str,
scheme_name: str,
ID: str,
type: str,
query_time: str,
@@ -434,7 +490,8 @@ class SchemeRepository:
Args:
conn: Database connection
scheme: Scheme name
scheme_type: Scheme type
scheme_name: Scheme name
ID: The ID of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
@@ -444,7 +501,7 @@ class SchemeRepository:
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith('Z'):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
@@ -465,11 +522,11 @@ class SchemeRepository:
# Query based on type
if type.lower() == "node":
return await SchemeRepository.get_node_by_scheme_and_time_range(
conn, scheme, start_time, end_time, ID
conn, scheme_type, scheme_name, start_time, end_time, ID
)
elif type.lower() == "link":
return await SchemeRepository.get_link_by_scheme_and_time_range(
conn, scheme, start_time, end_time, ID
conn, scheme_type, scheme_name, start_time, end_time, ID
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")