调整代码,支持项目切换,打开不同数据库的连接

This commit is contained in:
2026-02-11 10:42:40 +08:00
parent a472639b8a
commit eb45e4aaa5
5 changed files with 60 additions and 6 deletions

View File

@@ -17,7 +17,9 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = postgresql_info.get_pgconn_string()
target_db_name = db_name or self.db_name
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -26,7 +28,7 @@ class Database:
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: default")
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from typing import Optional
from psycopg import AsyncConnection
from pydantic import BaseModel
from .database import get_database_instance
from .scada_info import ScadaRepository
@@ -9,6 +10,10 @@ from .scheme import SchemeRepository
router = APIRouter()
class DatabaseConfig(BaseModel):
db_name: str
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
@@ -21,6 +26,26 @@ async def get_database_connection(
yield conn
@router.post("/postgres/open-database")
async def open_database(config: DatabaseConfig):
"""
尝试连接指定数据库,如果成功则返回成功消息
"""
try:
instance = await get_database_instance(config.db_name)
# 尝试获取连接并执行简单查询以验证连接
async with instance.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
await cur.fetchone()
return {"success": True, "message": f"Successfully connected to database: {config.db_name}"}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}"
)
@router.get("/scada-info")
async def get_scada_info_with_connection(
conn: AsyncConnection = Depends(get_database_connection),

View File

@@ -17,7 +17,8 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
target_db_name = db_name or self.db_name
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -27,7 +28,7 @@ class Database:
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")

View File

@@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from datetime import datetime
from psycopg import AsyncConnection
from pydantic import BaseModel
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
@@ -15,6 +16,10 @@ from app.infra.db.postgresql.database import (
router = APIRouter()
class DatabaseConfig(BaseModel):
db_name: str
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
@@ -39,6 +44,26 @@ async def get_postgres_connection(
yield conn
@router.post("/timescaledb/open-database")
async def open_database(config: DatabaseConfig):
"""
尝试连接指定数据库,如果成功则返回成功消息
"""
try:
instance = await get_database_instance(config.db_name)
# 尝试获取连接并执行简单查询以验证连接
async with instance.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
await cur.fetchone()
return {"success": True, "message": f"Successfully connected to database: {config.db_name}"}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}"
)
# --- Realtime Endpoints ---

View File

@@ -1235,7 +1235,7 @@ def run_simulation(
starttime = time.time()
if simulation_type.upper() == "REALTIME":
TimescaleInternalStorage.store_realtime_simulation(
node_result, link_result, modify_pattern_start_time
node_result, link_result, modify_pattern_start_time, db_name=name
)
elif simulation_type.upper() == "EXTENDED":
TimescaleInternalStorage.store_scheme_simulation(
@@ -1245,6 +1245,7 @@ def run_simulation(
link_result,
modify_pattern_start_time,
num_periods_result,
db_name=name,
)
endtime = time.time()
logging.info("store time: %f", endtime - starttime)