重构现代化 FastAPI 后端项目框架

This commit is contained in:
2026-01-21 16:50:57 +08:00
parent 9e06e68a15
commit c56f2fd1db
352 changed files with 176 additions and 70 deletions

0
app/infra/db/__init__.py Normal file
View File

View File

4964
app/infra/db/influxdb/api.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
# influxdb数据库连接信息
url = "http://127.0.0.1:8086" # 替换为你的InfluxDB实例地址
token = "kMPX2V5HsbzPpUT2B9HPBu1sTG1Emf-lPlT2UjxYnGAuocpXq_f_0lK4HHs-TbbKyjsZpICkMsyXG_V2D7P7yQ==" # 替换为你的InfluxDB Token
# _ENCODED_TOKEN = "eEdETTVSWnFSSkF1ekFHUy1vdFhVZEMyTkZkWTc1cUpBalJMcUFCNHA1V2NJSUFsSVVwT3BUOF95QTE2QU9IbUpXZXJ3UV8wOGd3Yjg0c3k0MmpuWlE9PQ=="
# token = base64.b64decode(_ENCODED_TOKEN).decode("utf-8")
org = "TJWATERORG" # 替换为你的Organization名称

View File

@@ -0,0 +1,33 @@
from influxdb_client import InfluxDBClient, Point, WriteOptions
from influxdb_client.client.query_api import QueryApi
import influxdb_info
# 配置 InfluxDB 连接
url = influxdb_info.url
token = influxdb_info.token
org = influxdb_info.org
bucket = "SCADA_data"
# 创建 InfluxDB 客户端
client = InfluxDBClient(url=url, token=token, org=org)
# 创建查询 API 对象
query_api = client.query_api()
# 构建查询语句
query = f'''
from(bucket: "{bucket}")
|> range(start: -1h)
'''
# 执行查询
result = query_api.query(query)
print(result)
# 处理查询结果
for table in result:
for record in table.records:
print(f"Time: {record.get_time()}, Value: {record.get_value()}, Measurement: {record.get_measurement()}, Field: {record.get_field()}")
# 关闭客户端连接
client.close()

View File

@@ -0,0 +1 @@
from .router import router

View File

@@ -0,0 +1,108 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import app.native.api.postgresql_info as postgresql_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = postgresql_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: default")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("PostgreSQL connection pool closed.")
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")

View File

@@ -0,0 +1,83 @@
import time
from typing import List, Optional
from fastapi.logger import logger
import postgresql_info
import psycopg
class InternalQueries:
@staticmethod
def get_links_by_property(
fields: Optional[List[str]] = None,
property_conditions: Optional[dict] = None,
db_name: str = None,
max_retries: int = 3,
) -> List[dict]:
"""
查询pg数据库中,pipes 的指定字段记录或根据属性筛选
:param fields: 要查询的字段列表,如 ["id", "diameter", "status"],默认查询所有字段
:param property: 可选的筛选条件字典,如 {"status": "Open"} 或 {"diameter": 300}
:param db_name: 数据库名称
:param max_retries: 最大重试次数
:return: 包含所有记录的列表,每条记录为一个字典
"""
# 如果未指定字段,查询所有字段
if not fields:
fields = [
"id",
"node1",
"node2",
"length",
"diameter",
"roughness",
"minor_loss",
"status",
]
for attempt in range(max_retries):
try:
conn_string = (
postgresql_info.get_pgconn_string(db_name=db_name)
if db_name
else postgresql_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
with conn.cursor() as cur:
# 构建SELECT子句
select_fields = ", ".join(fields)
base_query = f"""
SELECT {select_fields}
FROM public.pipes
"""
# 如果提供了筛选条件构建WHERE子句
if property_conditions:
conditions = []
params = []
for key, value in property_conditions.items():
conditions.append(f"{key} = %s")
params.append(value)
query = base_query + " WHERE " + " AND ".join(conditions)
cur.execute(query, params)
else:
cur.execute(base_query)
records = cur.fetchall()
# 将查询结果转换为字典列表
pipes = []
for record in records:
pipe_dict = {}
for idx, field in enumerate(fields):
pipe_dict[field] = record[idx]
pipes.append(pipe_dict)
return pipes
break # 成功
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
@router.get("/scada-info")
async def get_scada_info_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有SCADA信息
"""
try:
# 使用ScadaRepository查询SCADA信息
scada_data = await ScadaRepository.get_scadas(conn)
return {"success": True, "data": scada_data, "count": len(scada_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}"
)
@router.get("/scheme-list")
async def get_scheme_list_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有方案信息
"""
try:
# 使用SchemeRepository查询方案信息
scheme_data = await SchemeRepository.get_schemes(conn)
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
@router.get("/burst-locate-result")
async def get_burst_locate_result_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有爆管定位结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
burst_data = await SchemeRepository.get_burst_locate_results(conn)
return {"success": True, "data": burst_data, "count": len(burst_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询爆管定位结果时发生错误: {str(e)}"
)
@router.get("/burst-locate-result/{burst_incident}")
async def get_burst_locate_result_by_incident(
burst_incident: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""
根据 burst_incident 查询爆管定位结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
return await SchemeRepository.get_burst_locate_result_by_incident(
conn, burst_incident
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}",
)

View File

@@ -0,0 +1,36 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class ScadaRepository:
@staticmethod
async def get_scadas(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中,scada_info 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表,每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, associated_element_id, transmission_mode, transmission_frequency, reliability
FROM public.scada_info
"""
)
records = await cur.fetchall()
# 将查询结果转换为字典列表(假设 record 是字典)
scada_infos = []
for record in records:
scada_infos.append(
{
"id": record["id"], # 使用字典键
"type": record["type"],
"associated_element_id": record["associated_element_id"],
"transmission_mode": record["transmission_mode"],
"transmission_frequency": record["transmission_frequency"],
"reliability": record["reliability"],
}
)
return scada_infos

View File

@@ -0,0 +1,104 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class SchemeRepository:
@staticmethod
async def get_schemes(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中, scheme_list 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表, 每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail
FROM public.scheme_list
"""
)
records = await cur.fetchall()
scheme_list = []
for record in records:
scheme_list.append(
{
"scheme_id": record["scheme_id"],
"scheme_name": record["scheme_name"],
"scheme_type": record["scheme_type"],
"username": record["username"],
"create_time": record["create_time"],
"scheme_start_time": record["scheme_start_time"],
"scheme_detail": record["scheme_detail"],
}
)
return scheme_list
@staticmethod
async def get_burst_locate_results(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中, burst_locate_result 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表, 每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, burst_incident, leakage, detect_time, locate_result
FROM public.burst_locate_result
"""
)
records = await cur.fetchall()
results = []
for record in records:
results.append(
{
"id": record["id"],
"type": record["type"],
"burst_incident": record["burst_incident"],
"leakage": record["leakage"],
"detect_time": record["detect_time"],
"locate_result": record["locate_result"],
}
)
return results
@staticmethod
async def get_burst_locate_result_by_incident(
conn: AsyncConnection, burst_incident: str
) -> List[dict]:
"""
根据 burst_incident 查询爆管定位结果
:param conn: 异步数据库连接
:param burst_incident: 爆管事件标识
:return: 包含匹配记录的列表
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, burst_incident, leakage, detect_time, locate_result
FROM public.burst_locate_result
WHERE burst_incident = %s
""",
(burst_incident,),
)
records = await cur.fetchall()
results = []
for record in records:
results.append(
{
"id": record["id"],
"type": record["type"],
"burst_incident": record["burst_incident"],
"leakage": record["leakage"],
"detect_time": record["detect_time"],
"locate_result": record["locate_result"],
}
)
return results

View File

@@ -0,0 +1,4 @@
from .router import router
from .database import *
from .timescaledb_info import *
from .composite_queries import CompositeQueries

View File

@@ -0,0 +1,606 @@
import time
from typing import List, Optional, Any, Dict, Tuple
from datetime import datetime, timedelta
from psycopg import AsyncConnection
import pandas as pd
import numpy as np
from api_ex.Fdataclean import clean_flow_data_df_kf
from api_ex.Pdataclean import clean_pressure_data_df_km
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
from postgresql.internal_queries import InternalQueries
from postgresql.scada_info import ScadaRepository as PostgreScadaRepository
from timescaledb.schemas.realtime import RealtimeRepository
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.scada import ScadaRepository
class CompositeQueries:
"""
复合查询类,提供跨表查询功能
"""
@staticmethod
async def get_scada_associated_realtime_simulation_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node 模拟值
根据传入的 SCADA device_ids找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
device_ids: SCADA 设备ID列表
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
for device_id in device_ids:
# 2. 根据 device_id 找到对应的 SCADA 信息
target_scada = None
for scada in scada_infos:
if scada["id"] == device_id:
target_scada = scada
break
if not target_scada:
raise ValueError(f"SCADA device {device_id} not found")
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
element_id = target_scada["associated_element_id"]
scada_type = target_scada["type"]
if scada_type.lower() == "pipe_flow":
# 查询 link 模拟数据
res = await RealtimeRepository.get_link_field_by_time_range(
timescale_conn, start_time, end_time, element_id, "flow"
)
elif scada_type.lower() == "pressure":
# 查询 node 模拟数据
res = await RealtimeRepository.get_node_field_by_time_range(
timescale_conn, start_time, end_time, element_id, "pressure"
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_scada_associated_scheme_simulation_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node scheme 模拟值
根据传入的 SCADA device_ids找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
device_ids: SCADA 设备ID列表
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
for device_id in device_ids:
# 2. 根据 device_id 找到对应的 SCADA 信息
target_scada = None
for scada in scada_infos:
if scada["id"] == device_id:
target_scada = scada
break
if not target_scada:
raise ValueError(f"SCADA device {device_id} not found")
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
element_id = target_scada["associated_element_id"]
scada_type = target_scada["type"]
if scada_type.lower() == "pipe_flow":
# 查询 link 模拟数据
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
element_id,
"flow",
)
elif scada_type.lower() == "pressure":
# 查询 node 模拟数据
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
element_id,
"pressure",
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_realtime_simulation_data(
timescale_conn: AsyncConnection,
featureInfos: List[Tuple[str, str]],
start_time: datetime,
end_time: datetime,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 link/node 模拟值
根据传入的 featureInfos找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
for feature_id, type in featureInfos:
if type.lower() == "pipe":
# 查询 link 模拟数据
res = await RealtimeRepository.get_link_field_by_time_range(
timescale_conn, start_time, end_time, feature_id, "flow"
)
elif type.lower() == "junction":
# 查询 node 模拟数据
res = await RealtimeRepository.get_node_field_by_time_range(
timescale_conn, start_time, end_time, feature_id, "pressure"
)
else:
raise ValueError(f"Unknown type: {type}")
# 添加 scada_id 到每个数据项
for item in res:
item["feature_id"] = feature_id
result[feature_id] = res
return result
@staticmethod
async def get_scheme_simulation_data(
timescale_conn: AsyncConnection,
featureInfos: List[Tuple[str, str]],
start_time: datetime,
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 link/node scheme 模拟值
根据传入的 featureInfos找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
start_time: 开始时间
end_time: 结束时间
scheme_type: 工况类型
scheme_name: 工况名称
Returns:
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
Raises:
ValueError: 当类型无效时
"""
result = {}
for feature_id, type in featureInfos:
if type.lower() == "pipe":
# 查询 link 模拟数据
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
feature_id,
"flow",
)
elif type.lower() == "junction":
# 查询 node 模拟数据
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
feature_id,
"pressure",
)
else:
raise ValueError(f"Unknown type: {type}")
# 添加 feature_id 到每个数据项
for item in res:
item["feature_id"] = feature_id
result[feature_id] = res
return result
@staticmethod
async def get_element_associated_scada_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
element_id: str,
start_time: datetime,
end_time: datetime,
use_cleaned: bool = False,
) -> Optional[Any]:
"""
获取 link/node 关联的 SCADA 监测值
根据传入的 link/node id匹配 SCADA 信息,
如果存在关联的 SCADA device_id获取实际的监测数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
element_id: link 或 node 的 ID
start_time: 开始时间
end_time: 结束时间
use_cleaned: 是否使用清洗后的数据 (True: "cleaned_value", False: "monitored_value")
Returns:
SCADA 监测数据值,如果没有找到则返回 None
Raises:
ValueError: 当元素类型无效时
"""
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
# 2. 根据 element_type 和 element_id 找到关联的 SCADA 设备
associated_scada = None
for scada in scada_infos:
if scada["associated_element_id"] == element_id:
associated_scada = scada
break
if not associated_scada:
# 没有找到关联的 SCADA 设备
return None
# 3. 通过 SCADA device_id 获取监测数据
device_id = associated_scada["id"]
# 根据 use_cleaned 参数选择字段
data_field = "cleaned_value" if use_cleaned else "monitored_value"
# 保证 device_id 以列表形式传递
res = await ScadaRepository.get_scada_field_by_id_time_range(
timescale_conn, [device_id], start_time, end_time, data_field
)
# 将 device_id 替换为 element_id 返回
return {element_id: res.get(device_id, [])}
@staticmethod
async def clean_scada_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> str:
"""
清洗 SCADA 数据
根据 device_ids 查询 monitored_value清洗后更新 cleaned_value
Args:
timescale_conn: TimescaleDB 连接
postgres_conn: PostgreSQL 连接
device_ids: 设备 ID 列表
start_time: 开始时间
end_time: 结束时间
Returns:
"success" 或错误信息
"""
try:
# 获取所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
# 将列表转换为字典,以 device_id 为键
scada_device_info_dict = {info["id"]: info for info in scada_infos}
# 如果 device_ids 为空,则处理所有 SCADA 设备
if not device_ids:
device_ids = list(scada_device_info_dict.keys())
# 批量查询所有设备的数据
data = await ScadaRepository.get_scada_field_by_id_time_range(
timescale_conn, device_ids, start_time, end_time, "monitored_value"
)
if not data:
return "error: fetch none scada data" # 没有数据,直接返回
# 将嵌套字典转换为 DataFrame使用 time 作为索引
# data 格式: {device_id: [{"time": "...", "value": ...}, ...]}
all_records = []
for device_id, records in data.items():
for record in records:
all_records.append(
{
"time": record["time"],
"device_id": device_id,
"value": record["value"],
}
)
if not all_records:
return "error: fetch none scada data" # 没有数据,直接返回
# 创建 DataFrame 并透视,使 device_id 成为列
df_long = pd.DataFrame(all_records)
df = df_long.pivot(index="time", columns="device_id", values="value")
# 根据type分类设备
pressure_ids = [
id
for id in df.columns
if scada_device_info_dict.get(id, {}).get("type") == "pressure"
]
flow_ids = [
id
for id in df.columns
if scada_device_info_dict.get(id, {}).get("type") == "pipe_flow"
]
# 处理pressure数据
if pressure_ids:
pressure_df = df[pressure_ids]
# 重置索引,将 time 变为普通列
pressure_df = pressure_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = pressure_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_pressure_data_df_km(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in pressure_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
# 处理flow数据
if flow_ids:
flow_df = df[flow_ids]
# 重置索引,将 time 变为普通列
flow_df = flow_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = flow_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_flow_data_df_kf(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([flow_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in flow_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
return "success"
except Exception as e:
return f"error: {str(e)}"
@staticmethod
async def predict_pipeline_health(
timescale_conn: AsyncConnection,
network_name: str,
query_time: datetime,
) -> List[Dict[str, Any]]:
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率
Args:
timescale_conn: TimescaleDB 异步连接
db_name: 管网数据库名称
query_time: 查询时间
property_conditions: 可选的管道筛选条件,如 {"diameter": 300}
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
Raises:
ValueError: 当参数无效或数据不足时
FileNotFoundError: 当模型文件未找到时
"""
try:
# 1. 准备时间范围查询时间前后1秒
start_time = query_time - timedelta(seconds=1)
end_time = query_time + timedelta(seconds=1)
# 2. 先查询流速数据velocity获取有数据的管道ID列表
velocity_data = await RealtimeRepository.get_links_field_by_time_range(
timescale_conn, start_time, end_time, "velocity"
)
if not velocity_data:
raise ValueError("未找到流速数据")
# 3. 只查询有流速数据的管道的基本信息
valid_link_ids = list(velocity_data.keys())
# 批量查询这些管道的详细信息
fields = ["id", "diameter", "node1", "node2"]
all_links = InternalQueries.get_links_by_property(
fields=fields,
db_name=network_name,
)
# 转换为字典以快速查找
links_dict = {link["id"]: link for link in all_links}
# 获取所有需要查询的节点ID
node_ids = set()
for link_id in valid_link_ids:
if link_id in links_dict:
link = links_dict[link_id]
node_ids.add(link["node1"])
node_ids.add(link["node2"])
# 4. 批量查询压力数据pressure
pressure_data = await RealtimeRepository.get_nodes_field_by_time_range(
timescale_conn, start_time, end_time, "pressure"
)
# 5. 组合数据结构
materials = []
diameters = []
velocities = []
pressures = []
link_ids = []
for link_id in valid_link_ids:
# 跳过不在管道字典中的ID如泵等其他元素
if link_id not in links_dict:
continue
link = links_dict[link_id]
diameter = link["diameter"]
node1 = link["node1"]
node2 = link["node2"]
# 获取流速数据
velocity_values = velocity_data[link_id]
velocity = velocity_values[-1]["value"] if velocity_values else 0
# 获取node1和node2的压力数据计算平均值
node1_pressure = 0
node2_pressure = 0
if node1 in pressure_data and pressure_data[node1]:
pressure_values = pressure_data[node1]
node1_pressure = (
pressure_values[-1]["value"] if pressure_values else 0
)
if node2 in pressure_data and pressure_data[node2]:
pressure_values = pressure_data[node2]
node2_pressure = (
pressure_values[-1]["value"] if pressure_values else 0
)
# 计算平均压力
avg_pressure = (node1_pressure + node2_pressure) / 2
# 添加到列表
link_ids.append(link_id)
materials.append(7) # 默认材料类型为7可根据实际情况调整
diameters.append(diameter)
velocities.append(velocity)
pressures.append(avg_pressure)
if not link_ids:
raise ValueError("没有找到有效的管道数据用于预测")
# 6. 创建DataFrame
data = pd.DataFrame(
{
"Material": materials,
"Diameter": diameters,
"Flow Velocity": velocities,
"Pressure": pressures,
}
)
# 7. 使用PipelineHealthAnalyzer进行预测
analyzer = PipelineHealthAnalyzer(
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
)
survival_functions = analyzer.predict_survival(data)
# 8. 组合结果
results = []
for i, link_id in enumerate(link_ids):
sf = survival_functions[i]
results.append(
{
"link_id": link_id,
"survival_function": {
"x": sf.x.tolist(), # 时间点(年)
"y": sf.y.tolist(), # 生存概率
},
}
)
return results
except Exception as e:
raise ValueError(f"管道健康预测失败: {str(e)}")

View File

@@ -0,0 +1,115 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import app.infra.db.timescaledb.timescaledb_info as timescaledb_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("TimescaleDB connection pool closed.")
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")

View File

@@ -0,0 +1,122 @@
from typing import List
from fastapi.logger import logger
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.realtime import RealtimeRepository
import timescaledb.timescaledb_info as timescaledb_info
from datetime import datetime, timedelta
from timescaledb.schemas.scada import ScadaRepository
import psycopg
import time
class InternalStorage:
@staticmethod
def store_realtime_simulation(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
db_name: str = None,
max_retries: int = 3,
):
"""存储实时模拟结果"""
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
RealtimeRepository.store_realtime_simulation_result_sync(
conn, node_result_list, link_result_list, result_start_time
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
@staticmethod
def store_scheme_simulation(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
num_periods: int = 1,
db_name: str = None,
max_retries: int = 3,
):
"""存储方案模拟结果"""
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
SchemeRepository.store_scheme_simulation_result_sync(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
num_periods,
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
class InternalQueries:
@staticmethod
def query_scada_by_ids_time(
device_ids: List[str],
query_time: str,
db_name: str = None,
max_retries: int = 3,
) -> dict:
"""查询指定时间点的 SCADA 数据"""
# 解析时间,假设是北京时间
beijing_time = datetime.fromisoformat(query_time)
start_time = beijing_time - timedelta(seconds=1)
end_time = beijing_time + timedelta(seconds=1)
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
rows = ScadaRepository.get_scada_by_ids_time_range_sync(
conn, device_ids, start_time, end_time
)
# 处理结果,返回每个 device_id 的第一个值
result = {}
for device_id in device_ids:
device_rows = [
row for row in rows if row["device_id"] == device_id
]
if device_rows:
result[device_id] = device_rows[0]["monitored_value"]
else:
result[device_id] = None
return result
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise

View File

@@ -0,0 +1,627 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
from .composite_queries import CompositeQueries
from postgresql.database import get_database_instance as get_postgres_database_instance
router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# PostgreSQL 数据库连接依赖函数
async def get_postgres_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
)
):
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_postgres_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# --- Realtime Endpoints ---
@router.post("/realtime/links/batch", status_code=201)
async def insert_realtime_links(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/links")
async def get_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_links_by_time_range(conn, start_time, end_time)
@router.delete("/realtime/links")
async def delete_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_links_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.patch("/realtime/links/{link_id}/field")
async def update_realtime_link_field(
link_id: str,
time: datetime,
field: str,
value: float, # Assuming float for now, could be Any but FastAPI needs type
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/realtime/nodes/batch", status_code=201)
async def insert_realtime_nodes(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/nodes")
async def get_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_nodes_by_time_range(conn, start_time, end_time)
@router.delete("/realtime/nodes")
async def delete_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_nodes_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.post("/realtime/simulation/store", status_code=201)
async def store_realtime_simulation_result(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Store realtime simulation results to TimescaleDB"""
await RealtimeRepository.store_realtime_simulation_result(
conn, node_result_list, link_result_list, result_start_time
)
return {"message": "Simulation results stored successfully"}
@router.get("/realtime/query/by-time-property")
async def query_realtime_records_by_time_property(
query_time: str,
type: str,
property: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query all realtime records by time and property"""
try:
results = await RealtimeRepository.query_all_record_by_time_property(
conn, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/realtime/query/by-id-time")
async def query_realtime_simulation_by_id_time(
id: str,
type: str,
query_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query realtime simulation results by id and time"""
try:
results = await RealtimeRepository.query_simulation_result_by_id_time(
conn, id, type, query_time
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- Scheme Endpoints ---
@router.post("/scheme/links/batch", status_code=201)
async def insert_scheme_links(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/links")
async def get_scheme_links(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await SchemeRepository.get_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
@router.get("/scheme/links/{link_id}/field")
async def get_scheme_link_field(
scheme_type: str,
scheme_name: str,
link_id: str,
start_time: datetime,
end_time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_link_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, link_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/links/{link_id}/field")
async def update_scheme_link_field(
scheme_type: str,
scheme_name: str,
link_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_link_field(
conn, time, scheme_type, scheme_name, link_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/links")
async def delete_scheme_links(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/nodes/batch", status_code=201)
async def insert_scheme_nodes(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/nodes/{node_id}/field")
async def get_scheme_node_field(
scheme_type: str,
scheme_name: str,
node_id: str,
start_time: datetime,
end_time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_node_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, node_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/nodes/{node_id}/field")
async def update_scheme_node_field(
scheme_type: str,
scheme_name: str,
node_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_node_field(
conn, time, scheme_type, scheme_name, node_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/nodes")
async def delete_scheme_nodes(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_nodes_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/simulation/store", status_code=201)
async def store_scheme_simulation_result(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Store scheme simulation results to TimescaleDB"""
await SchemeRepository.store_scheme_simulation_result(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
)
return {"message": "Scheme simulation results stored successfully"}
@router.get("/scheme/query/by-scheme-time-property")
async def query_scheme_records_by_scheme_time_property(
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query all scheme records by scheme, time and property"""
try:
results = await SchemeRepository.query_all_record_by_scheme_time_property(
conn, scheme_type, scheme_name, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/scheme/query/by-id-time")
async def query_scheme_simulation_by_id_time(
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query scheme simulation results by id and time"""
try:
result = await SchemeRepository.query_scheme_simulation_result_by_id_time(
conn, scheme_type, scheme_name, id, type, query_time
)
return {"result": result}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- SCADA Endpoints ---
@router.post("/scada/batch", status_code=201)
async def insert_scada_data(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await ScadaRepository.insert_scada_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scada/by-ids-time-range")
async def get_scada_by_ids_time_range(
start_time: datetime,
end_time: datetime,
device_ids: str,
conn: AsyncConnection = Depends(get_database_connection),
):
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()] if device_ids else []
)
return await ScadaRepository.get_scada_by_ids_time_range(
conn, device_ids_list, start_time, end_time
)
@router.get("/scada/by-ids-field-time-range")
async def get_scada_field_by_ids_time_range(
start_time: datetime,
end_time: datetime,
field: str,
device_ids: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await ScadaRepository.get_scada_field_by_id_time_range(
conn, device_ids_list, start_time, end_time, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scada/{device_id}/field")
async def update_scada_field(
device_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await ScadaRepository.update_scada_field(conn, time, device_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scada/by-id-time-range")
async def delete_scada_data(
device_id: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await ScadaRepository.delete_scada_by_id_time_range(
conn, device_id, start_time, end_time
)
return {"message": "Deleted successfully"}
# --- Composite Query Endpoints ---
@router.get("/composite/scada-simulation")
async def get_scada_associated_simulation_data(
start_time: datetime,
end_time: datetime,
device_ids: str,
scheme_type: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取 SCADA 关联的 link/node 模拟值
根据传入的 SCADA device_ids找到关联的 link/node
并根据对应的 type查询对应的模拟数据
"""
try:
# 手动解析 device_ids 为 List[str],去除空格
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
if scheme_type and scheme_name:
result = await CompositeQueries.get_scada_associated_scheme_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list, # 使用解析后的列表
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = (
await CompositeQueries.get_scada_associated_realtime_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list, # 使用解析后的列表
start_time,
end_time,
)
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-simulation")
async def get_feature_simulation_data(
start_time: datetime,
end_time: datetime,
feature_infos: str = Query(
..., description="特征信息,格式: id1:type1,id2:type2type为pipe或junction"
),
scheme_type: str = Query(None, description="指定方案类型,若为空则查询实时数据"),
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
):
"""
获取 link/node 模拟值
根据传入的 featureInfos找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
feature_infos: 格式为 "element_id1:type1,element_id2:type2"
例如: "P1:pipe,J1:junction"
"""
try:
# 解析 feature_infos 为 List[Tuple[str, str]]
feature_infos_list = []
if feature_infos:
for item in feature_infos.split(","):
item = item.strip()
if ":" in item:
element_id, element_type = item.split(":", 1)
feature_infos_list.append(
(element_id.strip(), element_type.strip())
)
if not feature_infos_list:
raise HTTPException(status_code=400, detail="feature_infos cannot be empty")
if scheme_type and scheme_name:
result = await CompositeQueries.get_scheme_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = await CompositeQueries.get_realtime_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-scada")
async def get_element_associated_scada_data(
element_id: str,
start_time: datetime,
end_time: datetime,
use_cleaned: bool = Query(False, description="是否使用清洗后的数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取 link/node 关联的 SCADA 监测值
根据传入的 link/node id匹配 SCADA 信息,
如果存在关联的 SCADA device_id获取实际的监测数据
"""
try:
result = await CompositeQueries.get_element_associated_scada_data(
timescale_conn, postgres_conn, element_id, start_time, end_time, use_cleaned
)
if result is None:
raise HTTPException(
status_code=404, detail="No associated SCADA data found"
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/composite/clean-scada")
async def clean_scada_data(
device_ids: str,
start_time: datetime = Query(...),
end_time: datetime = Query(...),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
清洗 SCADA 数据
根据 device_ids 查询 monitored_value清洗后更新 cleaned_value
"""
try:
if device_ids == "all":
device_ids_list = []
else:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await CompositeQueries.clean_scada_data(
timescale_conn, postgres_conn, device_ids_list, start_time, end_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/pipeline-health-prediction")
async def predict_pipeline_health(
query_time: datetime = Query(..., description="查询时间"),
network_name: str = Query(..., description="管网数据库名称"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
):
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率
Args:
query_time: 查询时间
db_name: 管网数据库名称
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
"""
try:
return await CompositeQueries.predict_pipeline_health(
timescale_conn, network_name, query_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")

View File

@@ -0,0 +1,647 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class RealtimeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, link_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, link_id))
@staticmethod
async def delete_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, node_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
) -> dict:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, node_id))
@staticmethod
async def delete_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_realtime_simulation_result(
conn: AsyncConnection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB.
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await RealtimeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await RealtimeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_realtime_simulation_result_sync(
conn: Connection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
RealtimeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
RealtimeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_time_property(
conn: AsyncConnection,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by time and property from TimescaleDB.
Args:
conn: Database connection
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await RealtimeRepository.get_nodes_field_by_time_range(
conn, start_time, end_time, property
)
elif type.lower() == "link":
data = await RealtimeRepository.get_links_field_by_time_range(
conn, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_simulation_result_by_id_time(
conn: AsyncConnection,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await RealtimeRepository.get_node_by_time_range(
conn, start_time, end_time, id
)
elif type.lower() == "link":
return await RealtimeRepository.get_link_by_time_range(
conn, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")

View File

@@ -0,0 +1,106 @@
from typing import List, Any
from datetime import datetime
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
class ScadaRepository:
@staticmethod
async def insert_scada_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
async with conn.cursor() as cur:
async with cur.copy(
"COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["device_id"],
item.get("monitored_value"),
item.get("cleaned_value"),
)
)
@staticmethod
async def get_scada_by_ids_time_range(
conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
(device_ids, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
def get_scada_by_ids_time_range_sync(
conn: Connection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[dict]:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
(device_ids, start_time, end_time),
)
return cur.fetchall()
@staticmethod
async def get_scada_field_by_id_time_range(
conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT device_id, time, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, device_ids))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["device_id"]].append({
"time": row["time"].isoformat(),
"value": row[field]
})
return dict(result)
@staticmethod
async def update_scada_field(
conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any
):
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, device_id))
@staticmethod
async def delete_scada_by_id_time_range(
conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s",
(device_id, start_time, end_time),
)

View File

@@ -0,0 +1,760 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
import globals
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class SchemeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, link_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, link_id))
@staticmethod
async def delete_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, node_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, node_id))
@staticmethod
async def delete_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_scheme_simulation_result(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await SchemeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await SchemeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_scheme_simulation_result_sync(
conn: Connection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
SchemeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
SchemeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_scheme_time_property(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by scheme, time and property from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await SchemeRepository.get_nodes_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
elif type.lower() == "link":
data = await SchemeRepository.get_links_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_scheme_simulation_result_by_id_time(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query scheme simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await SchemeRepository.get_node_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
elif type.lower() == "link":
return await SchemeRepository.get_link_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")

View File

@@ -0,0 +1,36 @@
from dotenv import load_dotenv
import os
load_dotenv()
pg_name = os.getenv("TIMESCALEDB_DB_NAME")
pg_host = os.getenv("TIMESCALEDB_DB_HOST")
pg_port = os.getenv("TIMESCALEDB_DB_PORT")
pg_user = os.getenv("TIMESCALEDB_DB_USER")
pg_password = os.getenv("TIMESCALEDB_DB_PASSWORD")
def get_pgconn_string(
db_name=pg_name,
db_host=pg_host,
db_port=pg_port,
db_user=pg_user,
db_password=pg_password,
):
"""返回 PostgreSQL 连接字符串"""
return f"dbname={db_name} host={db_host} port={db_port} user={db_user} password={db_password}"
def get_pg_config():
"""返回 PostgreSQL 配置变量的字典"""
return {
"name": pg_name,
"host": pg_host,
"port": pg_port,
"user": pg_user,
}
def get_pg_password():
"""返回密码(谨慎使用)"""
return pg_password