元数据库目录结构变更
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .database import get_metadata_session, close_metadata_engine
|
||||
|
||||
__all__ = ["get_metadata_session", "close_metadata_engine"]
|
||||
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.METADATA_DATABASE_URI,
|
||||
pool_size=settings.METADATA_DB_POOL_SIZE,
|
||||
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def close_metadata_engine() -> None:
|
||||
await engine.dispose()
|
||||
logger.info("Metadata database engine disposed.")
|
||||
@@ -0,0 +1,117 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
keycloak_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Project(Base):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class ProjectDatabase(Base):
|
||||
__tablename__ = "project_databases"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
db_role: Mapped[str] = mapped_column(String(20))
|
||||
db_type: Mapped[str] = mapped_column(String(20))
|
||||
dsn_encrypted: Mapped[str] = mapped_column(Text)
|
||||
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
|
||||
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
|
||||
|
||||
|
||||
class ProjectGeoServerConfig(Base):
|
||||
__tablename__ = "project_geoserver_configs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
|
||||
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
srid: Mapped[int] = mapped_column(Integer, default=4326)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class UserProjectMembership(Base):
|
||||
__tablename__ = "user_project_membership"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
project_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50))
|
||||
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
@@ -0,0 +1,112 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.db.metadb import models
|
||||
|
||||
|
||||
class AuditRepository:
|
||||
"""审计日志数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
action: str,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
) -> AuditLogResponse:
|
||||
log = models.AuditLog(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return AuditLogResponse.model_validate(log)
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> List[AuditLogResponse]:
|
||||
conditions = []
|
||||
if user_id is not None:
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = (
|
||||
select(models.AuditLog)
|
||||
.where(*conditions)
|
||||
.order_by(models.AuditLog.timestamp.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
AuditLogResponse.model_validate(log)
|
||||
for log in result.scalars().all()
|
||||
]
|
||||
|
||||
async def get_log_count(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> int:
|
||||
conditions = []
|
||||
if user_id is not None:
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.scalar() or 0)
|
||||
@@ -0,0 +1,203 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.encryption import (
|
||||
get_database_encryptor,
|
||||
get_encryptor,
|
||||
is_database_encryption_configured,
|
||||
is_encryption_configured,
|
||||
)
|
||||
from app.infra.db.metadb import models
|
||||
|
||||
|
||||
def _normalize_postgres_dsn(dsn: str) -> str:
|
||||
if not dsn or "://" not in dsn:
|
||||
return dsn
|
||||
scheme, rest = dsn.split("://", 1)
|
||||
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
|
||||
return dsn
|
||||
if "@" not in rest:
|
||||
return dsn
|
||||
userinfo, hostinfo = rest.rsplit("@", 1)
|
||||
if ":" not in userinfo:
|
||||
return dsn
|
||||
username, password = userinfo.split(":", 1)
|
||||
if "@" not in password:
|
||||
return dsn
|
||||
password = password.replace("@", "%40")
|
||||
return f"{scheme}://{username}:{password}@{hostinfo}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectDbRouting:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
db_type: str
|
||||
dsn: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectGeoServerInfo:
|
||||
project_id: UUID
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_admin_password: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectSummary:
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
|
||||
|
||||
class MetadataRepository:
|
||||
"""元数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
|
||||
result = await self.session.execute(
|
||||
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[models.User]:
|
||||
result = await self.session.execute(
|
||||
select(models.User).where(models.User.username == username)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).where(models.Project.id == project_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_membership_role(
|
||||
self, project_id: UUID, user_id: UUID
|
||||
) -> Optional[str]:
|
||||
result = await self.session.execute(
|
||||
select(models.UserProjectMembership.project_role).where(
|
||||
models.UserProjectMembership.project_id == project_id,
|
||||
models.UserProjectMembership.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_db_routing(
|
||||
self, project_id: UUID, db_role: str
|
||||
) -> Optional[ProjectDbRouting]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectDatabase).where(
|
||||
models.ProjectDatabase.project_id == project_id,
|
||||
models.ProjectDatabase.db_role == db_role,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if not is_database_encryption_configured():
|
||||
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
|
||||
encryptor = get_database_encryptor()
|
||||
try:
|
||||
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||
except InvalidToken:
|
||||
raise ValueError(
|
||||
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
|
||||
"or invalid dsn_encrypted value"
|
||||
)
|
||||
dsn = _normalize_postgres_dsn(dsn)
|
||||
return ProjectDbRouting(
|
||||
project_id=record.project_id,
|
||||
db_role=record.db_role,
|
||||
db_type=record.db_type,
|
||||
dsn=dsn,
|
||||
pool_min_size=record.pool_min_size,
|
||||
pool_max_size=record.pool_max_size,
|
||||
)
|
||||
|
||||
async def get_geoserver_config(
|
||||
self, project_id: UUID
|
||||
) -> Optional[ProjectGeoServerInfo]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectGeoServerConfig).where(
|
||||
models.ProjectGeoServerConfig.project_id == project_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if record.gs_admin_password_encrypted:
|
||||
if is_encryption_configured():
|
||||
encryptor = get_encryptor()
|
||||
password = encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||
else:
|
||||
password = record.gs_admin_password_encrypted
|
||||
else:
|
||||
password = None
|
||||
return ProjectGeoServerInfo(
|
||||
project_id=record.project_id,
|
||||
gs_base_url=record.gs_base_url,
|
||||
gs_admin_user=record.gs_admin_user,
|
||||
gs_admin_password=password,
|
||||
gs_datastore_name=record.gs_datastore_name,
|
||||
default_extent=record.default_extent,
|
||||
srid=record.srid,
|
||||
)
|
||||
|
||||
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||
stmt = (
|
||||
select(models.Project, models.UserProjectMembership.project_role)
|
||||
.join(
|
||||
models.UserProjectMembership,
|
||||
models.UserProjectMembership.project_id == models.Project.id,
|
||||
)
|
||||
.where(models.UserProjectMembership.user_id == user_id)
|
||||
.order_by(models.Project.name)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=role,
|
||||
)
|
||||
for project, role in result.all()
|
||||
]
|
||||
|
||||
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).order_by(models.Project.name)
|
||||
)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role="owner",
|
||||
)
|
||||
for project in result.scalars().all()
|
||||
]
|
||||
@@ -0,0 +1,235 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.user import UserCreate, UserUpdate, UserInDB
|
||||
from app.domain.models.role import UserRole
|
||||
from app.core.security import get_password_hash
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserRepository:
|
||||
"""用户数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
async def create_user(self, user: UserCreate) -> Optional[UserInDB]:
|
||||
"""
|
||||
创建新用户
|
||||
|
||||
Args:
|
||||
user: 用户创建数据
|
||||
|
||||
Returns:
|
||||
创建的用户对象
|
||||
"""
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
query = """
|
||||
INSERT INTO users (username, email, hashed_password, role, is_active, is_superuser)
|
||||
VALUES (%(username)s, %(email)s, %(hashed_password)s, %(role)s, TRUE, FALSE)
|
||||
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'username': user.username,
|
||||
'email': user.email,
|
||||
'hashed_password': hashed_password,
|
||||
'role': user.role.value
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> Optional[UserInDB]:
|
||||
"""根据ID获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = %(user_id)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'user_id': user_id})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
||||
"""根据用户名获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = %(username)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'username': username})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[UserInDB]:
|
||||
"""根据邮箱获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE email = %(email)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'email': email})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_all_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
||||
"""获取所有用户(分页)"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'skip': skip, 'limit': limit})
|
||||
rows = await cur.fetchall()
|
||||
return [UserInDB(**row) for row in rows]
|
||||
|
||||
async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[UserInDB]:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
user_update: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的用户对象
|
||||
"""
|
||||
# 构建动态更新语句
|
||||
update_fields = []
|
||||
params = {'user_id': user_id}
|
||||
|
||||
if user_update.email is not None:
|
||||
update_fields.append("email = %(email)s")
|
||||
params['email'] = user_update.email
|
||||
|
||||
if user_update.password is not None:
|
||||
update_fields.append("hashed_password = %(hashed_password)s")
|
||||
params['hashed_password'] = get_password_hash(user_update.password)
|
||||
|
||||
if user_update.role is not None:
|
||||
update_fields.append("role = %(role)s")
|
||||
params['role'] = user_update.role.value
|
||||
|
||||
if user_update.is_active is not None:
|
||||
update_fields.append("is_active = %(is_active)s")
|
||||
params['is_active'] = user_update.is_active
|
||||
|
||||
if not update_fields:
|
||||
return await self.get_user_by_id(user_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE users
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %(user_id)s
|
||||
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def delete_user(self, user_id: int) -> bool:
|
||||
"""
|
||||
删除用户
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
query = "DELETE FROM users WHERE id = %(user_id)s"
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'user_id': user_id})
|
||||
return cur.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def user_exists(self, username: str = None, email: str = None) -> bool:
|
||||
"""
|
||||
检查用户是否存在
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
email: 邮箱
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
if email:
|
||||
conditions.append("email = %(email)s")
|
||||
params['email'] = email
|
||||
|
||||
if not conditions:
|
||||
return False
|
||||
|
||||
query = f"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM users WHERE {' OR '.join(conditions)}
|
||||
)
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['exists'] if result else False
|
||||
Reference in New Issue
Block a user