44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
|
|
import logging
|
|||
|
|
|
|||
|
|
from django.core.cache import cache
|
|||
|
|
from rest_framework.exceptions import APIException
|
|||
|
|
from rest_framework_simplejwt.authentication import JWTAuthentication
|
|||
|
|
from rest_framework_simplejwt.exceptions import InvalidToken
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
_REVOKED = {'code': 'AUTH_TOKEN_REVOKED', 'message': 'token已被吊销,请重新登录'}
|
|||
|
|
_INVALIDATED = {'code': 'AUTH_TOKEN_INVALIDATED', 'message': 'token已失效,请重新登录'}
|
|||
|
|
_REDIS_DOWN = {'code': 'SYS_DEPENDENCY_DOWN', 'message': 'Redis服务不可用'}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RedisBlacklistJWTAuthentication(JWTAuthentication):
|
|||
|
|
"""在 simplejwt 校验通过后追加 Redis 黑名单检查。"""
|
|||
|
|
|
|||
|
|
def get_validated_token(self, raw_token):
|
|||
|
|
token = super().get_validated_token(raw_token)
|
|||
|
|
try:
|
|||
|
|
jti = token.payload.get('jti')
|
|||
|
|
uid = token.payload.get('user_id')
|
|||
|
|
iat = token.payload.get('iat')
|
|||
|
|
|
|||
|
|
if jti and cache.get(f'jwt_blacklist:{jti}'):
|
|||
|
|
raise InvalidToken(_REVOKED)
|
|||
|
|
|
|||
|
|
if uid and iat is not None:
|
|||
|
|
invalid_before = cache.get(f'jwt_user_invalid_before:{uid}')
|
|||
|
|
if invalid_before is not None and iat < int(invalid_before):
|
|||
|
|
raise InvalidToken(_INVALIDATED)
|
|||
|
|
|
|||
|
|
except InvalidToken:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
# fail-closed:Redis 不可用时拒绝请求(安全 > 可用)
|
|||
|
|
logger.error('Redis error during JWT blacklist check: %s', e)
|
|||
|
|
ex = APIException()
|
|||
|
|
ex.status_code = 503
|
|||
|
|
ex.detail = _REDIS_DOWN
|
|||
|
|
raise ex
|
|||
|
|
|
|||
|
|
return token
|