55 lines
2.2 KiB
Python
55 lines
2.2 KiB
Python
from rest_framework.permissions import AllowAny
|
||
from rest_framework_simplejwt.views import TokenRefreshView
|
||
from rest_framework_simplejwt.tokens import RefreshToken
|
||
from rest_framework_simplejwt.exceptions import TokenError
|
||
from drf_spectacular.utils import extend_schema
|
||
|
||
from config.exceptions import AppError
|
||
from apps.user.utils.jwt_redis import revoke_token, is_token_revoked, get_user_invalid_before
|
||
|
||
|
||
@extend_schema(tags=['认证'])
|
||
class CustomTokenRefreshView(TokenRefreshView):
|
||
"""U8 刷新 Token — 在 simplejwt 旋转前后加入 Redis 黑名单检查 + 旧 token 吊销"""
|
||
permission_classes = [AllowAny]
|
||
authentication_classes = ()
|
||
|
||
def post(self, request, *args, **kwargs):
|
||
refresh_raw = request.data.get('refresh', '')
|
||
if not refresh_raw:
|
||
raise AppError('AUTH_TOKEN_INVALID', '请提供 refresh token', status_code=401)
|
||
|
||
# ── 解析旧 token(必须在 super().post() 之前,因为 simplejwt 会 mutate) ──
|
||
try:
|
||
old_token = RefreshToken(refresh_raw)
|
||
except TokenError:
|
||
raise AppError('AUTH_TOKEN_INVALID', 'refresh token 无效或已过期', status_code=401)
|
||
|
||
old_jti = old_token.payload.get('jti')
|
||
old_exp = old_token.payload.get('exp')
|
||
uid = old_token.payload.get('user_id')
|
||
iat = old_token.payload.get('iat')
|
||
|
||
# ── Redis 黑名单检查 ──
|
||
if old_jti and is_token_revoked(old_jti):
|
||
raise AppError('AUTH_TOKEN_INVALID', 'refresh token 已被吊销', status_code=401)
|
||
|
||
# ── 用户级失效截止检查 ──
|
||
if uid and iat is not None:
|
||
invalid_before = get_user_invalid_before(uid)
|
||
if invalid_before is not None and iat < invalid_before:
|
||
raise AppError('AUTH_TOKEN_INVALID', 'token 已失效,请重新登录', status_code=401)
|
||
|
||
# ── 交给 simplejwt 处理旋转 ──
|
||
response = super().post(request, *args, **kwargs)
|
||
|
||
# ── 吊销旧 refresh token ──
|
||
if old_jti and old_exp:
|
||
revoke_token(old_jti, old_exp)
|
||
|
||
return response
|
||
|
||
|
||
# 函数式引用,供 urls.py 保持一致风格
|
||
refresh_token = CustomTokenRefreshView.as_view()
|