44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import logging
|
||
|
||
from django.core.cache import cache
|
||
from rest_framework.exceptions import APIException
|
||
from rest_framework_simplejwt.authentication import JWTAuthentication
|
||
from rest_framework_simplejwt.exceptions import InvalidToken
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_REVOKED = {'code': 'AUTH_TOKEN_REVOKED', 'message': 'token已被吊销,请重新登录'}
|
||
_INVALIDATED = {'code': 'AUTH_TOKEN_INVALIDATED', 'message': 'token已失效,请重新登录'}
|
||
_REDIS_DOWN = {'code': 'SYS_DEPENDENCY_DOWN', 'message': 'Redis服务不可用'}
|
||
|
||
|
||
class RedisBlacklistJWTAuthentication(JWTAuthentication):
|
||
"""在 simplejwt 校验通过后追加 Redis 黑名单检查。"""
|
||
|
||
def get_validated_token(self, raw_token):
|
||
token = super().get_validated_token(raw_token)
|
||
try:
|
||
jti = token.payload.get('jti')
|
||
uid = token.payload.get('user_id')
|
||
iat = token.payload.get('iat')
|
||
|
||
if jti and cache.get(f'jwt_blacklist:{jti}'):
|
||
raise InvalidToken(_REVOKED)
|
||
|
||
if uid and iat is not None:
|
||
invalid_before = cache.get(f'jwt_user_invalid_before:{uid}')
|
||
if invalid_before is not None and iat < int(invalid_before):
|
||
raise InvalidToken(_INVALIDATED)
|
||
|
||
except InvalidToken:
|
||
raise
|
||
except Exception as e:
|
||
# fail-closed:Redis 不可用时拒绝请求(安全 > 可用)
|
||
logger.error('Redis error during JWT blacklist check: %s', e)
|
||
ex = APIException()
|
||
ex.status_code = 503
|
||
ex.detail = _REDIS_DOWN
|
||
raise ex
|
||
|
||
return token
|