重构现代化 FastAPI 后端项目框架
This commit is contained in:
1
app/infra/db/postgresql/__init__.py
Normal file
1
app/infra/db/postgresql/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .router import router
|
||||
108
app/infra/db/postgresql/database.py
Normal file
108
app/infra/db/postgresql/database.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
import psycopg_pool
|
||||
from psycopg.rows import dict_row
|
||||
import app.native.api.postgresql_info as postgresql_info
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name=None):
|
||||
self.pool = None
|
||||
self.db_name = db_name
|
||||
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
min_size=5,
|
||||
max_size=20,
|
||||
open=False, # Don't open immediately, wait for startup
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: default")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
||||
raise
|
||||
|
||||
async def open(self):
|
||||
if self.pool:
|
||||
await self.pool.open()
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection pool."""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
logger.info("PostgreSQL connection pool closed.")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator:
|
||||
"""Get a connection from the pool."""
|
||||
if not self.pool:
|
||||
raise Exception("Database pool is not initialized.")
|
||||
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# 默认数据库实例
|
||||
db = Database()
|
||||
|
||||
# 缓存不同数据库的实例 - 避免重复创建连接池
|
||||
_database_instances: Dict[str, Database] = {}
|
||||
|
||||
|
||||
def create_database_instance(db_name):
|
||||
"""Create a new Database instance for a specific database."""
|
||||
return Database(db_name=db_name)
|
||||
|
||||
|
||||
async def get_database_instance(db_name: Optional[str] = None) -> Database:
|
||||
"""Get or create a database instance for the specified database name."""
|
||||
if not db_name:
|
||||
return db # 返回默认数据库实例
|
||||
|
||||
if db_name not in _database_instances:
|
||||
# 创建新的数据库实例
|
||||
instance = create_database_instance(db_name)
|
||||
instance.init_pool()
|
||||
await instance.open()
|
||||
_database_instances[db_name] = instance
|
||||
logger.info(f"Created new database instance for: {db_name}")
|
||||
|
||||
return _database_instances[db_name]
|
||||
|
||||
|
||||
async def get_db_connection():
|
||||
"""Dependency for FastAPI to get a database connection."""
|
||||
async with db.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_database_connection(db_name: Optional[str] = None):
|
||||
"""
|
||||
FastAPI dependency to get database connection with optional database name.
|
||||
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
|
||||
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
|
||||
"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def cleanup_database_instances():
|
||||
"""Clean up all database instances (call this on application shutdown)."""
|
||||
for db_name, instance in _database_instances.items():
|
||||
await instance.close()
|
||||
logger.info(f"Closed database instance for: {db_name}")
|
||||
_database_instances.clear()
|
||||
|
||||
# 关闭默认数据库
|
||||
await db.close()
|
||||
logger.info("All database instances cleaned up.")
|
||||
83
app/infra/db/postgresql/internal_queries.py
Normal file
83
app/infra/db/postgresql/internal_queries.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.logger import logger
|
||||
import postgresql_info
|
||||
import psycopg
|
||||
|
||||
|
||||
class InternalQueries:
|
||||
@staticmethod
|
||||
def get_links_by_property(
|
||||
fields: Optional[List[str]] = None,
|
||||
property_conditions: Optional[dict] = None,
|
||||
db_name: str = None,
|
||||
max_retries: int = 3,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中,pipes 的指定字段记录或根据属性筛选
|
||||
:param fields: 要查询的字段列表,如 ["id", "diameter", "status"],默认查询所有字段
|
||||
:param property: 可选的筛选条件字典,如 {"status": "Open"} 或 {"diameter": 300}
|
||||
:param db_name: 数据库名称
|
||||
:param max_retries: 最大重试次数
|
||||
:return: 包含所有记录的列表,每条记录为一个字典
|
||||
"""
|
||||
# 如果未指定字段,查询所有字段
|
||||
if not fields:
|
||||
fields = [
|
||||
"id",
|
||||
"node1",
|
||||
"node2",
|
||||
"length",
|
||||
"diameter",
|
||||
"roughness",
|
||||
"minor_loss",
|
||||
"status",
|
||||
]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn_string = (
|
||||
postgresql_info.get_pgconn_string(db_name=db_name)
|
||||
if db_name
|
||||
else postgresql_info.get_pgconn_string()
|
||||
)
|
||||
with psycopg.Connection.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 构建SELECT子句
|
||||
select_fields = ", ".join(fields)
|
||||
base_query = f"""
|
||||
SELECT {select_fields}
|
||||
FROM public.pipes
|
||||
"""
|
||||
|
||||
# 如果提供了筛选条件,构建WHERE子句
|
||||
if property_conditions:
|
||||
conditions = []
|
||||
params = []
|
||||
for key, value in property_conditions.items():
|
||||
conditions.append(f"{key} = %s")
|
||||
params.append(value)
|
||||
|
||||
query = base_query + " WHERE " + " AND ".join(conditions)
|
||||
cur.execute(query, params)
|
||||
else:
|
||||
cur.execute(base_query)
|
||||
|
||||
records = cur.fetchall()
|
||||
# 将查询结果转换为字典列表
|
||||
pipes = []
|
||||
for record in records:
|
||||
pipe_dict = {}
|
||||
for idx, field in enumerate(fields):
|
||||
pipe_dict[field] = record[idx]
|
||||
pipes.append(pipe_dict)
|
||||
|
||||
return pipes
|
||||
break # 成功
|
||||
except Exception as e:
|
||||
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise
|
||||
90
app/infra/db/postgresql/router.py
Normal file
90
app/infra/db/postgresql/router.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .scada_info import ScadaRepository
|
||||
from .scheme import SchemeRepository
|
||||
|
||||
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
@router.get("/scada-info")
|
||||
async def get_scada_info_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有SCADA信息
|
||||
"""
|
||||
try:
|
||||
# 使用ScadaRepository查询SCADA信息
|
||||
scada_data = await ScadaRepository.get_scadas(conn)
|
||||
return {"success": True, "data": scada_data, "count": len(scada_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scheme-list")
|
||||
async def get_scheme_list_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有方案信息
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询方案信息
|
||||
scheme_data = await SchemeRepository.get_schemes(conn)
|
||||
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/burst-locate-result")
|
||||
async def get_burst_locate_result_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有爆管定位结果
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询爆管定位结果
|
||||
burst_data = await SchemeRepository.get_burst_locate_results(conn)
|
||||
return {"success": True, "data": burst_data, "count": len(burst_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"查询爆管定位结果时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/burst-locate-result/{burst_incident}")
|
||||
async def get_burst_locate_result_by_incident(
|
||||
burst_incident: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
根据 burst_incident 查询爆管定位结果
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询爆管定位结果
|
||||
return await SchemeRepository.get_burst_locate_result_by_incident(
|
||||
conn, burst_incident
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}",
|
||||
)
|
||||
36
app/infra/db/postgresql/scada_info.py
Normal file
36
app/infra/db/postgresql/scada_info.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import List, Optional, Any
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
|
||||
class ScadaRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_scadas(conn: AsyncConnection) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中,scada_info 的所有记录
|
||||
:param conn: 异步数据库连接
|
||||
:return: 包含所有记录的列表,每条记录为一个字典
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, type, associated_element_id, transmission_mode, transmission_frequency, reliability
|
||||
FROM public.scada_info
|
||||
"""
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
# 将查询结果转换为字典列表(假设 record 是字典)
|
||||
scada_infos = []
|
||||
for record in records:
|
||||
scada_infos.append(
|
||||
{
|
||||
"id": record["id"], # 使用字典键
|
||||
"type": record["type"],
|
||||
"associated_element_id": record["associated_element_id"],
|
||||
"transmission_mode": record["transmission_mode"],
|
||||
"transmission_frequency": record["transmission_frequency"],
|
||||
"reliability": record["reliability"],
|
||||
}
|
||||
)
|
||||
|
||||
return scada_infos
|
||||
104
app/infra/db/postgresql/scheme.py
Normal file
104
app/infra/db/postgresql/scheme.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import List, Optional, Any
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
|
||||
class SchemeRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_schemes(conn: AsyncConnection) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中, scheme_list 的所有记录
|
||||
:param conn: 异步数据库连接
|
||||
:return: 包含所有记录的列表, 每条记录为一个字典
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail
|
||||
FROM public.scheme_list
|
||||
"""
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
|
||||
scheme_list = []
|
||||
for record in records:
|
||||
scheme_list.append(
|
||||
{
|
||||
"scheme_id": record["scheme_id"],
|
||||
"scheme_name": record["scheme_name"],
|
||||
"scheme_type": record["scheme_type"],
|
||||
"username": record["username"],
|
||||
"create_time": record["create_time"],
|
||||
"scheme_start_time": record["scheme_start_time"],
|
||||
"scheme_detail": record["scheme_detail"],
|
||||
}
|
||||
)
|
||||
|
||||
return scheme_list
|
||||
|
||||
@staticmethod
|
||||
async def get_burst_locate_results(conn: AsyncConnection) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中, burst_locate_result 的所有记录
|
||||
:param conn: 异步数据库连接
|
||||
:return: 包含所有记录的列表, 每条记录为一个字典
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, type, burst_incident, leakage, detect_time, locate_result
|
||||
FROM public.burst_locate_result
|
||||
"""
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
|
||||
results = []
|
||||
for record in records:
|
||||
results.append(
|
||||
{
|
||||
"id": record["id"],
|
||||
"type": record["type"],
|
||||
"burst_incident": record["burst_incident"],
|
||||
"leakage": record["leakage"],
|
||||
"detect_time": record["detect_time"],
|
||||
"locate_result": record["locate_result"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
async def get_burst_locate_result_by_incident(
|
||||
conn: AsyncConnection, burst_incident: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
根据 burst_incident 查询爆管定位结果
|
||||
:param conn: 异步数据库连接
|
||||
:param burst_incident: 爆管事件标识
|
||||
:return: 包含匹配记录的列表
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, type, burst_incident, leakage, detect_time, locate_result
|
||||
FROM public.burst_locate_result
|
||||
WHERE burst_incident = %s
|
||||
""",
|
||||
(burst_incident,),
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
|
||||
results = []
|
||||
for record in records:
|
||||
results.append(
|
||||
{
|
||||
"id": record["id"],
|
||||
"type": record["type"],
|
||||
"burst_incident": record["burst_incident"],
|
||||
"leakage": record["leakage"],
|
||||
"detect_time": record["detect_time"],
|
||||
"locate_result": record["locate_result"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user