Use sorted set in redis (for JWT tokens)
This commit is contained in:
parent
cd27603712
commit
f640157f91
bsition
|
@ -2,10 +2,12 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from bsition.api.models.user import User
|
||||
from bsition.api.utils.jwt import write_token
|
||||
from bsition.api.utils.password import verify_password
|
||||
from bsition.api.utils.security import get_current_user, oauth2_scheme
|
||||
from bsition.backend.postgres.users import get_user_by_username
|
||||
from bsition.backend.redis.tokens import add_token
|
||||
from bsition.backend.redis.tokens import add_token, remove_token, clean_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -19,8 +21,9 @@ def login(form: OAuth2PasswordRequestForm = Depends()):
|
|||
detail="User not found.",
|
||||
)
|
||||
|
||||
token = write_token({"sub": form.username})
|
||||
add_token(token, form.username)
|
||||
token, expire = write_token({"sub": form.username})
|
||||
add_token(token, form.username, expire)
|
||||
clean_tokens(form.username)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"access_token": token,
|
||||
|
@ -28,3 +31,10 @@ def login(form: OAuth2PasswordRequestForm = Depends()):
|
|||
},
|
||||
status_code=202,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/token")
|
||||
def logout(token: str = Depends(oauth2_scheme), user: User = Depends(get_current_user)):
|
||||
remove_token(user[1], token)
|
||||
return JSONResponse(content={"detail": "Token deleted."}, status_code=201)
|
||||
|
||||
|
|
|
@ -8,13 +8,11 @@ def expire_date(days: int):
|
|||
return datetime.now() + timedelta(days)
|
||||
|
||||
|
||||
# TODO: migrar a librería 'jose'
|
||||
|
||||
|
||||
def write_token(data: dict):
|
||||
expire = expire_date(1)
|
||||
return encode(
|
||||
payload={**data, "exp": expire_date(1)}, key=getenv("SECRET"), algorithm="HS256"
|
||||
)
|
||||
payload={**data, "exp": expire}, key=getenv("SECRET"), algorithm="HS256"
|
||||
), int(expire.timestamp())
|
||||
|
||||
|
||||
def validate_token(token):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from passlib.context import CryptContext
|
||||
from passlib.exc import UnknownHashError
|
||||
|
||||
password_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
@ -8,4 +9,7 @@ def get_hashed_password(password: str):
|
|||
|
||||
|
||||
def verify_password(password: str, hashed_pass: str):
|
||||
return password_context.verify(password, hashed_pass)
|
||||
try:
|
||||
return password_context.verify(password, hashed_pass)
|
||||
except UnknownHashError:
|
||||
return False
|
||||
|
|
|
@ -1,21 +1,14 @@
|
|||
from redis.exceptions import ResponseError
|
||||
import time
|
||||
from math import inf
|
||||
|
||||
from bsition.backend.redis.utils import get_client
|
||||
|
||||
max_tokens = 10
|
||||
|
||||
|
||||
def add_token(token, username):
|
||||
def add_token(token, username, expire):
|
||||
client = get_client()
|
||||
try:
|
||||
client.bf().reserve(username, 0.01, max_tokens, noScale=True)
|
||||
except ResponseError:
|
||||
pass
|
||||
|
||||
if client.bf().info(username).insertedNum == max_tokens:
|
||||
remove_tokens(username)
|
||||
|
||||
client.bf().add(username, token)
|
||||
client.zadd(username, {token: expire})
|
||||
|
||||
|
||||
def remove_tokens(username):
|
||||
|
@ -23,6 +16,21 @@ def remove_tokens(username):
|
|||
client.unlink(username)
|
||||
|
||||
|
||||
def remove_token(username, token):
|
||||
client = get_client()
|
||||
tokens = client.zrangebyscore(username, -inf, inf)
|
||||
for aux in tokens:
|
||||
if aux.decode("utf-8") == token:
|
||||
client.zrem(username, token)
|
||||
|
||||
|
||||
# Puede correr en un cron o, por ejemplo, cada vez que el usuario hace login (o logout)
|
||||
|
||||
def clean_tokens(username):
|
||||
client = get_client()
|
||||
client.zremrangebyscore(username, -inf, int(time.time()))
|
||||
|
||||
|
||||
def valid_token(token, username):
|
||||
client = get_client()
|
||||
return client.bf().exists(username, token) == 1
|
||||
return client.zscore(username, token) is not None
|
||||
|
|
Loading…
Reference in New Issue