import logging from contextlib import asynccontextmanager from typing import AsyncGenerator, Dict, Optional import psycopg_pool from psycopg.rows import dict_row import postgresql_info # Configure logging logger = logging.getLogger(__name__) class Database: def __init__(self, db_name=None): self.pool = None self.db_name = db_name def init_pool(self, db_name=None): """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config target_db_name = db_name or self.db_name conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name) try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, min_size=1, max_size=20, open=False, # Don't open immediately, wait for startup kwargs={"row_factory": dict_row}, # Return rows as dictionaries ) logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}") except Exception as e: logger.error(f"Failed to initialize postgresql connection pool: {e}") raise async def open(self): if self.pool: await self.pool.open() async def close(self): """Close the connection pool.""" if self.pool: await self.pool.close() logger.info("PostgreSQL connection pool closed.") @asynccontextmanager async def get_connection(self) -> AsyncGenerator: """Get a connection from the pool.""" if not self.pool: raise Exception("Database pool is not initialized.") async with self.pool.connection() as conn: yield conn # 默认数据库实例 db = Database() # 缓存不同数据库的实例 - 避免重复创建连接池 _database_instances: Dict[str, Database] = {} def create_database_instance(db_name): """Create a new Database instance for a specific database.""" return Database(db_name=db_name) async def get_database_instance(db_name: Optional[str] = None) -> Database: """Get or create a database instance for the specified database name.""" if not db_name: return db # 返回默认数据库实例 if db_name not in _database_instances: # 创建新的数据库实例 instance = create_database_instance(db_name) instance.init_pool() await instance.open() _database_instances[db_name] = instance logger.info(f"Created new database instance for: {db_name}") return _database_instances[db_name] async def get_db_connection(): """Dependency for FastAPI to get a database connection.""" async with db.get_connection() as conn: yield conn async def get_database_connection(db_name: Optional[str] = None): """ FastAPI dependency to get database connection with optional database name. 使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name")) 或在路由函数中: conn: AsyncConnection = Depends(get_database_connection) """ instance = await get_database_instance(db_name) async with instance.get_connection() as conn: yield conn async def cleanup_database_instances(): """Clean up all database instances (call this on application shutdown).""" for db_name, instance in _database_instances.items(): await instance.close() logger.info(f"Closed database instance for: {db_name}") _database_instances.clear() # 关闭默认数据库 await db.close() logger.info("All database instances cleaned up.")