212 lines
6.9 KiB
Python
212 lines
6.9 KiB
Python
import asyncio
|
|
import logging
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
from typing import Dict
|
|
from uuid import UUID
|
|
|
|
from psycopg_pool import AsyncConnectionPool
|
|
from psycopg.rows import dict_row
|
|
from sqlalchemy.engine.url import make_url
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PgEngineEntry:
|
|
engine: AsyncEngine
|
|
sessionmaker: async_sessionmaker[AsyncSession]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CacheKey:
|
|
project_id: UUID
|
|
db_role: str
|
|
|
|
|
|
class ProjectConnectionManager:
|
|
def __init__(self) -> None:
|
|
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
|
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
|
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
|
self._pg_lock = asyncio.Lock()
|
|
self._ts_lock = asyncio.Lock()
|
|
self._pg_raw_lock = asyncio.Lock()
|
|
|
|
def _normalize_pg_url(self, url: str) -> str:
|
|
parsed = make_url(url)
|
|
if parsed.drivername == "postgresql":
|
|
parsed = parsed.set(drivername="postgresql+psycopg")
|
|
return str(parsed)
|
|
|
|
async def get_pg_sessionmaker(
|
|
self,
|
|
project_id: UUID,
|
|
db_role: str,
|
|
connection_url: str,
|
|
pool_min_size: int,
|
|
pool_max_size: int,
|
|
) -> async_sessionmaker[AsyncSession]:
|
|
async with self._pg_lock:
|
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
|
entry = self._pg_cache.get(key)
|
|
if entry:
|
|
self._pg_cache.move_to_end(key)
|
|
return entry.sessionmaker
|
|
|
|
normalized_url = self._normalize_pg_url(connection_url)
|
|
pool_min_size = max(1, pool_min_size)
|
|
pool_max_size = max(pool_min_size, pool_max_size)
|
|
engine = create_async_engine(
|
|
normalized_url,
|
|
pool_size=pool_min_size,
|
|
max_overflow=max(0, pool_max_size - pool_min_size),
|
|
pool_pre_ping=True,
|
|
)
|
|
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
|
|
self._pg_cache[key] = PgEngineEntry(
|
|
engine=engine,
|
|
sessionmaker=sessionmaker,
|
|
)
|
|
await self._evict_pg_if_needed()
|
|
logger.info(
|
|
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
|
|
)
|
|
return sessionmaker
|
|
|
|
async def get_timescale_pool(
|
|
self,
|
|
project_id: UUID,
|
|
db_role: str,
|
|
connection_url: str,
|
|
pool_min_size: int,
|
|
pool_max_size: int,
|
|
) -> AsyncConnectionPool:
|
|
async with self._ts_lock:
|
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
|
pool = self._ts_cache.get(key)
|
|
if pool:
|
|
self._ts_cache.move_to_end(key)
|
|
return pool
|
|
|
|
pool_min_size = max(1, pool_min_size)
|
|
pool_max_size = max(pool_min_size, pool_max_size)
|
|
pool = AsyncConnectionPool(
|
|
conninfo=connection_url,
|
|
min_size=pool_min_size,
|
|
max_size=pool_max_size,
|
|
open=False,
|
|
kwargs={"row_factory": dict_row},
|
|
)
|
|
await pool.open()
|
|
self._ts_cache[key] = pool
|
|
await self._evict_ts_if_needed()
|
|
logger.info(
|
|
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
|
)
|
|
return pool
|
|
|
|
async def get_pg_pool(
|
|
self,
|
|
project_id: UUID,
|
|
db_role: str,
|
|
connection_url: str,
|
|
pool_min_size: int,
|
|
pool_max_size: int,
|
|
) -> AsyncConnectionPool:
|
|
async with self._pg_raw_lock:
|
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
|
pool = self._pg_raw_cache.get(key)
|
|
if pool:
|
|
self._pg_raw_cache.move_to_end(key)
|
|
return pool
|
|
|
|
pool_min_size = max(1, pool_min_size)
|
|
pool_max_size = max(pool_min_size, pool_max_size)
|
|
pool = AsyncConnectionPool(
|
|
conninfo=connection_url,
|
|
min_size=pool_min_size,
|
|
max_size=pool_max_size,
|
|
open=False,
|
|
kwargs={"row_factory": dict_row},
|
|
)
|
|
await pool.open()
|
|
self._pg_raw_cache[key] = pool
|
|
await self._evict_pg_raw_if_needed()
|
|
logger.info(
|
|
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
|
)
|
|
return pool
|
|
|
|
async def _evict_pg_if_needed(self) -> None:
|
|
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
|
key, entry = self._pg_cache.popitem(last=False)
|
|
await entry.engine.dispose()
|
|
logger.info(
|
|
"Evicted PostgreSQL engine for project %s (%s)",
|
|
key.project_id,
|
|
key.db_role,
|
|
)
|
|
|
|
async def _evict_ts_if_needed(self) -> None:
|
|
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
|
key, pool = self._ts_cache.popitem(last=False)
|
|
await pool.close()
|
|
logger.info(
|
|
"Evicted TimescaleDB pool for project %s (%s)",
|
|
key.project_id,
|
|
key.db_role,
|
|
)
|
|
|
|
async def _evict_pg_raw_if_needed(self) -> None:
|
|
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
|
key, pool = self._pg_raw_cache.popitem(last=False)
|
|
await pool.close()
|
|
logger.info(
|
|
"Evicted PostgreSQL pool for project %s (%s)",
|
|
key.project_id,
|
|
key.db_role,
|
|
)
|
|
|
|
async def close_all(self) -> None:
|
|
async with self._pg_lock:
|
|
for key, entry in list(self._pg_cache.items()):
|
|
await entry.engine.dispose()
|
|
logger.info(
|
|
"Closed PostgreSQL engine for project %s (%s)",
|
|
key.project_id,
|
|
key.db_role,
|
|
)
|
|
self._pg_cache.clear()
|
|
|
|
async with self._ts_lock:
|
|
for key, pool in list(self._ts_cache.items()):
|
|
await pool.close()
|
|
logger.info(
|
|
"Closed TimescaleDB pool for project %s (%s)",
|
|
key.project_id,
|
|
key.db_role,
|
|
)
|
|
self._ts_cache.clear()
|
|
|
|
async with self._pg_raw_lock:
|
|
for key, pool in list(self._pg_raw_cache.items()):
|
|
await pool.close()
|
|
logger.info(
|
|
"Closed PostgreSQL pool for project %s (%s)",
|
|
key.project_id,
|
|
key.db_role,
|
|
)
|
|
self._pg_raw_cache.clear()
|
|
|
|
|
|
project_connection_manager = ProjectConnectionManager()
|