Files
TJWaterServerBinary/timescaledb/database.py

116 lines
3.8 KiB
Python

import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import timescaledb.timescaledb_info as timescaledb_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("TimescaleDB connection pool closed.")
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")