import asyncio import logging from collections import OrderedDict from dataclasses import dataclass from typing import Dict from uuid import UUID from psycopg_pool import AsyncConnectionPool from psycopg.rows import dict_row from sqlalchemy.engine.url import make_url from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from app.core.config import settings logger = logging.getLogger(__name__) @dataclass(frozen=True) class PgEngineEntry: engine: AsyncEngine sessionmaker: async_sessionmaker[AsyncSession] @dataclass(frozen=True) class CacheKey: project_id: UUID db_role: str class ProjectConnectionManager: def __init__(self) -> None: self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict() self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict() self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict() self._pg_lock = asyncio.Lock() self._ts_lock = asyncio.Lock() self._pg_raw_lock = asyncio.Lock() def _normalize_pg_url(self, url: str) -> str: parsed = make_url(url) if parsed.drivername == "postgresql": parsed = parsed.set(drivername="postgresql+psycopg") return str(parsed) async def get_pg_sessionmaker( self, project_id: UUID, db_role: str, connection_url: str, pool_min_size: int, pool_max_size: int, ) -> async_sessionmaker[AsyncSession]: async with self._pg_lock: key = CacheKey(project_id=project_id, db_role=db_role) entry = self._pg_cache.get(key) if entry: self._pg_cache.move_to_end(key) return entry.sessionmaker normalized_url = self._normalize_pg_url(connection_url) pool_min_size = max(1, pool_min_size) pool_max_size = max(pool_min_size, pool_max_size) engine = create_async_engine( normalized_url, pool_size=pool_min_size, max_overflow=max(0, pool_max_size - pool_min_size), pool_pre_ping=True, ) sessionmaker = async_sessionmaker(engine, expire_on_commit=False) self._pg_cache[key] = PgEngineEntry( engine=engine, sessionmaker=sessionmaker, ) await self._evict_pg_if_needed() logger.info( "Created PostgreSQL engine for project %s (%s)", project_id, db_role ) return sessionmaker async def get_timescale_pool( self, project_id: UUID, db_role: str, connection_url: str, pool_min_size: int, pool_max_size: int, ) -> AsyncConnectionPool: async with self._ts_lock: key = CacheKey(project_id=project_id, db_role=db_role) pool = self._ts_cache.get(key) if pool: self._ts_cache.move_to_end(key) return pool pool_min_size = max(1, pool_min_size) pool_max_size = max(pool_min_size, pool_max_size) pool = AsyncConnectionPool( conninfo=connection_url, min_size=pool_min_size, max_size=pool_max_size, open=False, kwargs={"row_factory": dict_row}, ) await pool.open() self._ts_cache[key] = pool await self._evict_ts_if_needed() logger.info( "Created TimescaleDB pool for project %s (%s)", project_id, db_role ) return pool async def get_pg_pool( self, project_id: UUID, db_role: str, connection_url: str, pool_min_size: int, pool_max_size: int, ) -> AsyncConnectionPool: async with self._pg_raw_lock: key = CacheKey(project_id=project_id, db_role=db_role) pool = self._pg_raw_cache.get(key) if pool: self._pg_raw_cache.move_to_end(key) return pool pool_min_size = max(1, pool_min_size) pool_max_size = max(pool_min_size, pool_max_size) pool = AsyncConnectionPool( conninfo=connection_url, min_size=pool_min_size, max_size=pool_max_size, open=False, kwargs={"row_factory": dict_row}, ) await pool.open() self._pg_raw_cache[key] = pool await self._evict_pg_raw_if_needed() logger.info( "Created PostgreSQL pool for project %s (%s)", project_id, db_role ) return pool async def _evict_pg_if_needed(self) -> None: while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE: key, entry = self._pg_cache.popitem(last=False) await entry.engine.dispose() logger.info( "Evicted PostgreSQL engine for project %s (%s)", key.project_id, key.db_role, ) async def _evict_ts_if_needed(self) -> None: while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE: key, pool = self._ts_cache.popitem(last=False) await pool.close() logger.info( "Evicted TimescaleDB pool for project %s (%s)", key.project_id, key.db_role, ) async def _evict_pg_raw_if_needed(self) -> None: while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE: key, pool = self._pg_raw_cache.popitem(last=False) await pool.close() logger.info( "Evicted PostgreSQL pool for project %s (%s)", key.project_id, key.db_role, ) async def close_all(self) -> None: async with self._pg_lock: for key, entry in list(self._pg_cache.items()): await entry.engine.dispose() logger.info( "Closed PostgreSQL engine for project %s (%s)", key.project_id, key.db_role, ) self._pg_cache.clear() async with self._ts_lock: for key, pool in list(self._ts_cache.items()): await pool.close() logger.info( "Closed TimescaleDB pool for project %s (%s)", key.project_id, key.db_role, ) self._ts_cache.clear() async with self._pg_raw_lock: for key, pool in list(self._pg_raw_cache.items()): await pool.close() logger.info( "Closed PostgreSQL pool for project %s (%s)", key.project_id, key.db_role, ) self._pg_raw_cache.clear() project_connection_manager = ProjectConnectionManager()