187 lines
6.0 KiB
Python
187 lines
6.0 KiB
Python
from typing import Annotated
|
||
from datetime import timedelta
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from fastapi.security import OAuth2PasswordRequestForm
|
||
from app.core.config import settings
|
||
from app.core.security import create_access_token, create_refresh_token, verify_password
|
||
from app.domain.schemas.user import UserCreate, UserResponse, UserLogin, Token
|
||
from app.infra.repositories.user_repository import UserRepository
|
||
from app.auth.dependencies import get_user_repository, get_current_active_user
|
||
from app.domain.schemas.user import UserInDB
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||
async def register(
|
||
user_data: UserCreate,
|
||
user_repo: UserRepository = Depends(get_user_repository)
|
||
) -> UserResponse:
|
||
"""
|
||
用户注册
|
||
|
||
创建新用户账号
|
||
"""
|
||
# 检查用户名和邮箱是否已存在
|
||
if await user_repo.user_exists(username=user_data.username):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Username already registered"
|
||
)
|
||
|
||
if await user_repo.user_exists(email=user_data.email):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Email already registered"
|
||
)
|
||
|
||
# 创建用户
|
||
try:
|
||
user = await user_repo.create_user(user_data)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Failed to create user"
|
||
)
|
||
return UserResponse.model_validate(user)
|
||
except Exception as e:
|
||
logger.error(f"Error during user registration: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="Registration failed"
|
||
)
|
||
|
||
@router.post("/login", response_model=Token)
|
||
async def login(
|
||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||
user_repo: UserRepository = Depends(get_user_repository)
|
||
) -> Token:
|
||
"""
|
||
用户登录(OAuth2 标准格式)
|
||
|
||
返回 JWT Access Token 和 Refresh Token
|
||
"""
|
||
# 验证用户(支持用户名或邮箱登录)
|
||
user = await user_repo.get_user_by_username(form_data.username)
|
||
if not user:
|
||
# 尝试用邮箱登录
|
||
user = await user_repo.get_user_by_email(form_data.username)
|
||
|
||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Incorrect username or password",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
if not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="Inactive user account"
|
||
)
|
||
|
||
# 生成 Token
|
||
access_token = create_access_token(subject=user.username)
|
||
refresh_token = create_refresh_token(subject=user.username)
|
||
|
||
return Token(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="bearer",
|
||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||
)
|
||
|
||
@router.post("/login/simple", response_model=Token)
|
||
async def login_simple(
|
||
username: str,
|
||
password: str,
|
||
user_repo: UserRepository = Depends(get_user_repository)
|
||
) -> Token:
|
||
"""
|
||
简化版登录接口(保持向后兼容)
|
||
|
||
直接使用 username 和 password 参数
|
||
"""
|
||
# 验证用户
|
||
user = await user_repo.get_user_by_username(username)
|
||
if not user:
|
||
user = await user_repo.get_user_by_email(username)
|
||
|
||
if not user or not verify_password(password, user.hashed_password):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Incorrect username or password"
|
||
)
|
||
|
||
if not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="Inactive user account"
|
||
)
|
||
|
||
# 生成 Token
|
||
access_token = create_access_token(subject=user.username)
|
||
refresh_token = create_refresh_token(subject=user.username)
|
||
|
||
return Token(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="bearer",
|
||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||
)
|
||
|
||
@router.get("/me", response_model=UserResponse)
|
||
async def get_current_user_info(
|
||
current_user: UserInDB = Depends(get_current_active_user)
|
||
) -> UserResponse:
|
||
"""
|
||
获取当前登录用户信息
|
||
"""
|
||
return UserResponse.model_validate(current_user)
|
||
|
||
@router.post("/refresh", response_model=Token)
|
||
async def refresh_token(
|
||
refresh_token: str,
|
||
user_repo: UserRepository = Depends(get_user_repository)
|
||
) -> Token:
|
||
"""
|
||
刷新 Access Token
|
||
|
||
使用 Refresh Token 获取新的 Access Token
|
||
"""
|
||
from jose import jwt, JWTError
|
||
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Could not validate refresh token",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
try:
|
||
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
username: str = payload.get("sub")
|
||
token_type: str = payload.get("type")
|
||
|
||
if username is None or token_type != "refresh":
|
||
raise credentials_exception
|
||
|
||
except JWTError:
|
||
raise credentials_exception
|
||
|
||
# 验证用户仍然存在且激活
|
||
user = await user_repo.get_user_by_username(username)
|
||
if not user or not user.is_active:
|
||
raise credentials_exception
|
||
|
||
# 生成新的 Access Token
|
||
new_access_token = create_access_token(subject=user.username)
|
||
|
||
return Token(
|
||
access_token=new_access_token,
|
||
refresh_token=refresh_token, # 保持原 refresh token
|
||
token_type="bearer",
|
||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||
)
|