236 lines
7.8 KiB
Python
236 lines
7.8 KiB
Python
from typing import Optional, List
|
|
from datetime import datetime
|
|
from app.infra.db.postgresql.database import Database
|
|
from app.domain.schemas.user import UserCreate, UserUpdate, UserInDB
|
|
from app.domain.models.role import UserRole
|
|
from app.core.security import get_password_hash
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class UserRepository:
|
|
"""用户数据访问层"""
|
|
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
|
|
async def create_user(self, user: UserCreate) -> Optional[UserInDB]:
|
|
"""
|
|
创建新用户
|
|
|
|
Args:
|
|
user: 用户创建数据
|
|
|
|
Returns:
|
|
创建的用户对象
|
|
"""
|
|
hashed_password = get_password_hash(user.password)
|
|
|
|
query = """
|
|
INSERT INTO users (username, email, hashed_password, role, is_active, is_superuser)
|
|
VALUES (%(username)s, %(email)s, %(hashed_password)s, %(role)s, TRUE, FALSE)
|
|
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
|
created_at, updated_at
|
|
"""
|
|
|
|
try:
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {
|
|
'username': user.username,
|
|
'email': user.email,
|
|
'hashed_password': hashed_password,
|
|
'role': user.role.value
|
|
})
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return UserInDB(**row)
|
|
except Exception as e:
|
|
logger.error(f"Error creating user: {e}")
|
|
raise
|
|
|
|
return None
|
|
|
|
async def get_user_by_id(self, user_id: int) -> Optional[UserInDB]:
|
|
"""根据ID获取用户"""
|
|
query = """
|
|
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
|
created_at, updated_at
|
|
FROM users
|
|
WHERE id = %(user_id)s
|
|
"""
|
|
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {'user_id': user_id})
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return UserInDB(**row)
|
|
|
|
return None
|
|
|
|
async def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
|
"""根据用户名获取用户"""
|
|
query = """
|
|
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
|
created_at, updated_at
|
|
FROM users
|
|
WHERE username = %(username)s
|
|
"""
|
|
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {'username': username})
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return UserInDB(**row)
|
|
|
|
return None
|
|
|
|
async def get_user_by_email(self, email: str) -> Optional[UserInDB]:
|
|
"""根据邮箱获取用户"""
|
|
query = """
|
|
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
|
created_at, updated_at
|
|
FROM users
|
|
WHERE email = %(email)s
|
|
"""
|
|
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {'email': email})
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return UserInDB(**row)
|
|
|
|
return None
|
|
|
|
async def get_all_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
|
"""获取所有用户(分页)"""
|
|
query = """
|
|
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
|
created_at, updated_at
|
|
FROM users
|
|
ORDER BY created_at DESC
|
|
LIMIT %(limit)s OFFSET %(skip)s
|
|
"""
|
|
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {'skip': skip, 'limit': limit})
|
|
rows = await cur.fetchall()
|
|
return [UserInDB(**row) for row in rows]
|
|
|
|
async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[UserInDB]:
|
|
"""
|
|
更新用户信息
|
|
|
|
Args:
|
|
user_id: 用户ID
|
|
user_update: 更新数据
|
|
|
|
Returns:
|
|
更新后的用户对象
|
|
"""
|
|
# 构建动态更新语句
|
|
update_fields = []
|
|
params = {'user_id': user_id}
|
|
|
|
if user_update.email is not None:
|
|
update_fields.append("email = %(email)s")
|
|
params['email'] = user_update.email
|
|
|
|
if user_update.password is not None:
|
|
update_fields.append("hashed_password = %(hashed_password)s")
|
|
params['hashed_password'] = get_password_hash(user_update.password)
|
|
|
|
if user_update.role is not None:
|
|
update_fields.append("role = %(role)s")
|
|
params['role'] = user_update.role.value
|
|
|
|
if user_update.is_active is not None:
|
|
update_fields.append("is_active = %(is_active)s")
|
|
params['is_active'] = user_update.is_active
|
|
|
|
if not update_fields:
|
|
return await self.get_user_by_id(user_id)
|
|
|
|
query = f"""
|
|
UPDATE users
|
|
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = %(user_id)s
|
|
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
|
created_at, updated_at
|
|
"""
|
|
|
|
try:
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, params)
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return UserInDB(**row)
|
|
except Exception as e:
|
|
logger.error(f"Error updating user {user_id}: {e}")
|
|
raise
|
|
|
|
return None
|
|
|
|
async def delete_user(self, user_id: int) -> bool:
|
|
"""
|
|
删除用户
|
|
|
|
Args:
|
|
user_id: 用户ID
|
|
|
|
Returns:
|
|
是否成功删除
|
|
"""
|
|
query = "DELETE FROM users WHERE id = %(user_id)s"
|
|
|
|
try:
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {'user_id': user_id})
|
|
return cur.rowcount > 0
|
|
except Exception as e:
|
|
logger.error(f"Error deleting user {user_id}: {e}")
|
|
return False
|
|
|
|
async def user_exists(self, username: str = None, email: str = None) -> bool:
|
|
"""
|
|
检查用户是否存在
|
|
|
|
Args:
|
|
username: 用户名
|
|
email: 邮箱
|
|
|
|
Returns:
|
|
是否存在
|
|
"""
|
|
conditions = []
|
|
params = {}
|
|
|
|
if username:
|
|
conditions.append("username = %(username)s")
|
|
params['username'] = username
|
|
|
|
if email:
|
|
conditions.append("email = %(email)s")
|
|
params['email'] = email
|
|
|
|
if not conditions:
|
|
return False
|
|
|
|
query = f"""
|
|
SELECT EXISTS(
|
|
SELECT 1 FROM users WHERE {' OR '.join(conditions)}
|
|
)
|
|
"""
|
|
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, params)
|
|
result = await cur.fetchone()
|
|
return result['exists'] if result else False
|