from typing import Optional, List from datetime import datetime from app.infra.db.postgresql.database import Database from app.domain.schemas.user import UserCreate, UserUpdate, UserInDB from app.domain.models.role import UserRole from app.core.security import get_password_hash import logging logger = logging.getLogger(__name__) class UserRepository: """用户数据访问层""" def __init__(self, db: Database): self.db = db async def create_user(self, user: UserCreate) -> Optional[UserInDB]: """ 创建新用户 Args: user: 用户创建数据 Returns: 创建的用户对象 """ hashed_password = get_password_hash(user.password) query = """ INSERT INTO users (username, email, hashed_password, role, is_active, is_superuser) VALUES (%(username)s, %(email)s, %(hashed_password)s, %(role)s, TRUE, FALSE) RETURNING id, username, email, hashed_password, role, is_active, is_superuser, created_at, updated_at """ try: async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, { 'username': user.username, 'email': user.email, 'hashed_password': hashed_password, 'role': user.role.value }) row = await cur.fetchone() if row: return UserInDB(**row) except Exception as e: logger.error(f"Error creating user: {e}") raise return None async def get_user_by_id(self, user_id: int) -> Optional[UserInDB]: """根据ID获取用户""" query = """ SELECT id, username, email, hashed_password, role, is_active, is_superuser, created_at, updated_at FROM users WHERE id = %(user_id)s """ async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, {'user_id': user_id}) row = await cur.fetchone() if row: return UserInDB(**row) return None async def get_user_by_username(self, username: str) -> Optional[UserInDB]: """根据用户名获取用户""" query = """ SELECT id, username, email, hashed_password, role, is_active, is_superuser, created_at, updated_at FROM users WHERE username = %(username)s """ async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, {'username': username}) row = await cur.fetchone() if row: return UserInDB(**row) return None async def get_user_by_email(self, email: str) -> Optional[UserInDB]: """根据邮箱获取用户""" query = """ SELECT id, username, email, hashed_password, role, is_active, is_superuser, created_at, updated_at FROM users WHERE email = %(email)s """ async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, {'email': email}) row = await cur.fetchone() if row: return UserInDB(**row) return None async def get_all_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]: """获取所有用户(分页)""" query = """ SELECT id, username, email, hashed_password, role, is_active, is_superuser, created_at, updated_at FROM users ORDER BY created_at DESC LIMIT %(limit)s OFFSET %(skip)s """ async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, {'skip': skip, 'limit': limit}) rows = await cur.fetchall() return [UserInDB(**row) for row in rows] async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[UserInDB]: """ 更新用户信息 Args: user_id: 用户ID user_update: 更新数据 Returns: 更新后的用户对象 """ # 构建动态更新语句 update_fields = [] params = {'user_id': user_id} if user_update.email is not None: update_fields.append("email = %(email)s") params['email'] = user_update.email if user_update.password is not None: update_fields.append("hashed_password = %(hashed_password)s") params['hashed_password'] = get_password_hash(user_update.password) if user_update.role is not None: update_fields.append("role = %(role)s") params['role'] = user_update.role.value if user_update.is_active is not None: update_fields.append("is_active = %(is_active)s") params['is_active'] = user_update.is_active if not update_fields: return await self.get_user_by_id(user_id) query = f""" UPDATE users SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP WHERE id = %(user_id)s RETURNING id, username, email, hashed_password, role, is_active, is_superuser, created_at, updated_at """ try: async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, params) row = await cur.fetchone() if row: return UserInDB(**row) except Exception as e: logger.error(f"Error updating user {user_id}: {e}") raise return None async def delete_user(self, user_id: int) -> bool: """ 删除用户 Args: user_id: 用户ID Returns: 是否成功删除 """ query = "DELETE FROM users WHERE id = %(user_id)s" try: async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, {'user_id': user_id}) return cur.rowcount > 0 except Exception as e: logger.error(f"Error deleting user {user_id}: {e}") return False async def user_exists(self, username: str = None, email: str = None) -> bool: """ 检查用户是否存在 Args: username: 用户名 email: 邮箱 Returns: 是否存在 """ conditions = [] params = {} if username: conditions.append("username = %(username)s") params['username'] = username if email: conditions.append("email = %(email)s") params['email'] = email if not conditions: return False query = f""" SELECT EXISTS( SELECT 1 FROM users WHERE {' OR '.join(conditions)} ) """ async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, params) result = await cur.fetchone() return result['exists'] if result else False