From d04d29cfb33e4fd1cfb90e8ab965accc5cf7860e Mon Sep 17 00:00:00 2001 From: JIANG Date: Thu, 4 Dec 2025 12:01:07 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20timescaledb=20=E7=9A=84=20?= =?UTF-8?q?CRUD=20=E6=96=B9=E6=B3=95=E3=80=81fastapi=20=E8=B7=AF=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 4 ++ timescaledb/__init__.py | 0 timescaledb/database.py | 55 +++++++++++++++ timescaledb/router.py | 105 +++++++++++++++++++++++++++++ timescaledb/schemas/realtime.py | 114 ++++++++++++++++++++++++++++++++ timescaledb/schemas/scada.py | 39 +++++++++++ timescaledb/schemas/scheme.py | 78 ++++++++++++++++++++++ timescaledb/timescaledb_info.py | 36 ++++++++++ 8 files changed, 431 insertions(+) create mode 100644 timescaledb/__init__.py create mode 100644 timescaledb/database.py create mode 100644 timescaledb/router.py create mode 100644 timescaledb/schemas/realtime.py create mode 100644 timescaledb/schemas/scada.py create mode 100644 timescaledb/schemas/scheme.py create mode 100644 timescaledb/timescaledb_info.py diff --git a/main.py b/main.py index f576146..fa7809b 100644 --- a/main.py +++ b/main.py @@ -39,6 +39,7 @@ from datetime import datetime, timedelta, timezone from dateutil import parser import influxdb_info import influxdb_api +import timescaledb import py_linq import time_api import simulation @@ -115,6 +116,8 @@ async def verify_token(authorization: Annotated[str, Header()] = None): # app = FastAPI(dependencies=[Depends(global_auth)]) app = FastAPI() +app.include_router(timescaledb.router) + access_tokens = [] @@ -3444,6 +3447,7 @@ async def fastapi_run_simulation_manually_by_date( item["name"], region_result ) ) + ( globals.source_outflow_region_patterns, globals.realtime_region_pipe_flow_and_demand_patterns, diff --git a/timescaledb/__init__.py b/timescaledb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/timescaledb/database.py b/timescaledb/database.py new file mode 100644 index 0000000..8778536 --- /dev/null +++ b/timescaledb/database.py @@ -0,0 +1,55 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator +import psycopg_pool +from psycopg.rows import dict_row +import timescaledb.timescaledb_info as timescaledb_info + +# Configure logging +logger = logging.getLogger(__name__) + +class Database: + def __init__(self): + self.pool = None + + def init_pool(self): + """Initialize the connection pool.""" + conn_string = timescaledb_info.get_pgconn_string() + try: + self.pool = psycopg_pool.AsyncConnectionPool( + conninfo=conn_string, + min_size=1, + max_size=20, + open=False, # Don't open immediately, wait for startup + kwargs={"row_factory": dict_row} # Return rows as dictionaries + ) + logger.info("TimescaleDB connection pool initialized.") + except Exception as e: + logger.error(f"Failed to initialize TimescaleDB connection pool: {e}") + raise + + async def open(self): + if self.pool: + await self.pool.open() + + async def close(self): + """Close the connection pool.""" + if self.pool: + await self.pool.close() + logger.info("TimescaleDB connection pool closed.") + + @asynccontextmanager + async def get_connection(self) -> AsyncGenerator: + """Get a connection from the pool.""" + if not self.pool: + raise Exception("Database pool is not initialized.") + + async with self.pool.connection() as conn: + yield conn + +db = Database() + +async def get_db_connection(): + """Dependency for FastAPI to get a database connection.""" + async with db.get_connection() as conn: + yield conn diff --git a/timescaledb/router.py b/timescaledb/router.py new file mode 100644 index 0000000..5c28861 --- /dev/null +++ b/timescaledb/router.py @@ -0,0 +1,105 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from typing import List, Any, Dict +from datetime import datetime +from psycopg import AsyncConnection + +from .database import get_db_connection +from .schemas.realtime import RealtimeRepository +from .schemas.scheme import SchemeRepository +from .schemas.scada import ScadaRepository + +router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"]) + +# --- Realtime Endpoints --- + +@router.post("/realtime/links/batch", status_code=201) +async def insert_realtime_links( + data: List[dict], + conn: AsyncConnection = Depends(get_db_connection) +): + await RealtimeRepository.insert_links_batch(conn, data) + return {"message": f"Inserted {len(data)} records"} + +@router.get("/realtime/links") +async def get_realtime_links( + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_db_connection) +): + return await RealtimeRepository.get_links_by_time(conn, start_time, end_time) + +@router.patch("/realtime/links/{link_id}/field") +async def update_realtime_link_field( + link_id: str, + time: datetime, + field: str, + value: float, # Assuming float for now, could be Any but FastAPI needs type + conn: AsyncConnection = Depends(get_db_connection) +): + try: + await RealtimeRepository.update_link_field(conn, time, link_id, field, value) + return {"message": "Updated successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.post("/realtime/nodes/batch", status_code=201) +async def insert_realtime_nodes( + data: List[dict], + conn: AsyncConnection = Depends(get_db_connection) +): + await RealtimeRepository.insert_nodes_batch(conn, data) + return {"message": f"Inserted {len(data)} records"} + +@router.get("/realtime/nodes") +async def get_realtime_nodes( + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_db_connection) +): + return await RealtimeRepository.get_nodes_by_time(conn, start_time, end_time) + +# --- Scheme Endpoints --- + +@router.post("/scheme/links/batch", status_code=201) +async def insert_scheme_links( + data: List[dict], + conn: AsyncConnection = Depends(get_db_connection) +): + await SchemeRepository.insert_links_batch(conn, data) + return {"message": f"Inserted {len(data)} records"} + +@router.get("/scheme/links") +async def get_scheme_links( + scheme: str, + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_db_connection) +): + return await SchemeRepository.get_links_by_scheme_and_time(conn, scheme, start_time, end_time) + +@router.post("/scheme/nodes/batch", status_code=201) +async def insert_scheme_nodes( + data: List[dict], + conn: AsyncConnection = Depends(get_db_connection) +): + await SchemeRepository.insert_nodes_batch(conn, data) + return {"message": f"Inserted {len(data)} records"} + +# --- SCADA Endpoints --- + +@router.post("/scada/batch", status_code=201) +async def insert_scada_data( + data: List[dict], + conn: AsyncConnection = Depends(get_db_connection) +): + await ScadaRepository.insert_batch(conn, data) + return {"message": f"Inserted {len(data)} records"} + +@router.get("/scada") +async def get_scada_data( + device_id: str, + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_db_connection) +): + return await ScadaRepository.get_data_by_time(conn, device_id, start_time, end_time) diff --git a/timescaledb/schemas/realtime.py b/timescaledb/schemas/realtime.py new file mode 100644 index 0000000..9fda2ab --- /dev/null +++ b/timescaledb/schemas/realtime.py @@ -0,0 +1,114 @@ +from typing import List, Any, Optional +from datetime import datetime +from psycopg import AsyncConnection, sql + +class RealtimeRepository: + + # --- Link Simulation --- + + @staticmethod + async def insert_links_batch(conn: AsyncConnection, data: List[dict]): + """Batch insert for realtime.link_simulation using COPY for performance.""" + if not data: + return + + async with conn.cursor() as cur: + async with cur.copy( + "COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" + ) as copy: + for item in data: + await copy.write_row(( + item['time'], item['id'], item.get('flow'), item.get('friction'), item.get('headloss'), + item.get('quality'), item.get('reaction'), item.get('setting'), item.get('status'), item.get('velocity') + )) + + @staticmethod + async def get_links_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s", + (start_time, end_time) + ) + return await cur.fetchall() + + @staticmethod + async def get_link_field(conn: AsyncConnection, time: datetime, link_id: str, field: str) -> Any: + # Validate field name to prevent SQL injection + valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("SELECT {} FROM realtime.link_simulation WHERE time = %s AND id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (time, link_id)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_link_field(conn: AsyncConnection, time: datetime, link_id: str, field: str, value: Any): + valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (value, time, link_id)) + + @staticmethod + async def delete_links_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime): + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s", + (start_time, end_time) + ) + + # --- Node Simulation --- + + @staticmethod + async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]): + if not data: + return + + async with conn.cursor() as cur: + async with cur.copy( + "COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN" + ) as copy: + for item in data: + await copy.write_row(( + item['time'], item['id'], item.get('actual_demand'), item.get('total_head'), item.get('pressure'), item.get('quality') + )) + + @staticmethod + async def get_nodes_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s", + (start_time, end_time) + ) + return await cur.fetchall() + + @staticmethod + async def get_node_field(conn: AsyncConnection, time: datetime, node_id: str, field: str) -> Any: + valid_fields = {"actual_demand", "total_head", "pressure", "quality"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("SELECT {} FROM realtime.node_simulation WHERE time = %s AND id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (time, node_id)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_node_field(conn: AsyncConnection, time: datetime, node_id: str, field: str, value: Any): + valid_fields = {"actual_demand", "total_head", "pressure", "quality"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (value, time, node_id)) diff --git a/timescaledb/schemas/scada.py b/timescaledb/schemas/scada.py new file mode 100644 index 0000000..a4c4202 --- /dev/null +++ b/timescaledb/schemas/scada.py @@ -0,0 +1,39 @@ +from typing import List, Any +from datetime import datetime +from psycopg import AsyncConnection, sql + +class ScadaRepository: + + @staticmethod + async def insert_batch(conn: AsyncConnection, data: List[dict]): + if not data: + return + + async with conn.cursor() as cur: + async with cur.copy( + "COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN" + ) as copy: + for item in data: + await copy.write_row(( + item['time'], item['device_id'], item.get('monitored_value'), item.get('cleaned_value') + )) + + @staticmethod + async def get_data_by_time(conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s", + (device_id, start_time, end_time) + ) + return await cur.fetchall() + + @staticmethod + async def update_field(conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any): + valid_fields = {"monitored_value", "cleaned_value"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (value, time, device_id)) diff --git a/timescaledb/schemas/scheme.py b/timescaledb/schemas/scheme.py new file mode 100644 index 0000000..ef699e8 --- /dev/null +++ b/timescaledb/schemas/scheme.py @@ -0,0 +1,78 @@ +from typing import List, Any +from datetime import datetime +from psycopg import AsyncConnection, sql + +class SchemeRepository: + + # --- Link Simulation --- + + @staticmethod + async def insert_links_batch(conn: AsyncConnection, data: List[dict]): + if not data: + return + + async with conn.cursor() as cur: + async with cur.copy( + "COPY scheme.link_simulation (time, scheme, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" + ) as copy: + for item in data: + await copy.write_row(( + item['time'], item['scheme'], item['id'], item.get('flow'), item.get('friction'), item.get('headloss'), + item.get('quality'), item.get('reaction'), item.get('setting'), item.get('status'), item.get('velocity') + )) + + @staticmethod + async def get_links_by_scheme_and_time(conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s", + (scheme, start_time, end_time) + ) + return await cur.fetchall() + + @staticmethod + async def update_link_field(conn: AsyncConnection, time: datetime, scheme: str, link_id: str, field: str, value: Any): + valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (value, time, scheme, link_id)) + + # --- Node Simulation --- + + @staticmethod + async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]): + if not data: + return + + async with conn.cursor() as cur: + async with cur.copy( + "COPY scheme.node_simulation (time, scheme, id, actual_demand, total_head, pressure, quality) FROM STDIN" + ) as copy: + for item in data: + await copy.write_row(( + item['time'], item['scheme'], item['id'], item.get('actual_demand'), item.get('total_head'), item.get('pressure'), item.get('quality') + )) + + @staticmethod + async def get_nodes_by_scheme_and_time(conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s", + (scheme, start_time, end_time) + ) + return await cur.fetchall() + + @staticmethod + async def update_node_field(conn: AsyncConnection, time: datetime, scheme: str, node_id: str, field: str, value: Any): + valid_fields = {"actual_demand", "total_head", "pressure", "quality"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL("UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s").format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (value, time, scheme, node_id)) diff --git a/timescaledb/timescaledb_info.py b/timescaledb/timescaledb_info.py new file mode 100644 index 0000000..3a1fbd8 --- /dev/null +++ b/timescaledb/timescaledb_info.py @@ -0,0 +1,36 @@ +from dotenv import load_dotenv +import os + +load_dotenv() + +pg_name = os.getenv("TIMESCALEDB_DB_NAME") +pg_host = os.getenv("TIMESCALEDB_DB_HOST") +pg_port = os.getenv("TIMESCALEDB_DB_PORT") +pg_user = os.getenv("TIMESCALEDB_DB_USER") +pg_password = os.getenv("TIMESCALEDB_DB_PASSWORD") + + +def get_pgconn_string( + db_name=pg_name, + db_host=pg_host, + db_port=pg_port, + db_user=pg_user, + db_password=pg_password, +): + """返回 PostgreSQL 连接字符串""" + return f"dbname={db_name} host={db_host} port={db_port} user={db_user} password={db_password}" + + +def get_pg_config(): + """返回 PostgreSQL 配置变量的字典""" + return { + "name": pg_name, + "host": pg_host, + "port": pg_port, + "user": pg_user, + } + + +def get_pg_password(): + """返回密码(谨慎使用)""" + return pg_password