Source code for aiida_restapi.routers.auth
"""Handle API authentication and authorization."""
from __future__ import annotations
import typing as t
from datetime import datetime, timedelta, timezone
import bcrypt
from aiida import orm
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from pydantic import BaseModel
from aiida_restapi import config
[docs]
class Token(BaseModel):
access_token: str
token_type: str
[docs]
class TokenData(BaseModel):
email: str
[docs]
class UserInDB(orm.User.ReadModel):
hashed_password: str
disabled: t.Optional[bool] = None
pwd_context = PasswordHasher()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f'{config.API_CONFIG["PREFIX"]}/auth/token')
read_router = APIRouter(prefix='/auth')
write_router = APIRouter(prefix='/auth')
[docs]
def verify_password(plain_password: str, hashed_password: str) -> bool:
if hashed_password.startswith('$argon2'):
try:
return pwd_context.verify(hashed_password, plain_password)
except VerifyMismatchError:
return False
if hashed_password.startswith('$2b$'):
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
return False
[docs]
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
[docs]
def get_user(db: dict, email: str) -> UserInDB | None:
if email in db:
user_dict = db[email]
return UserInDB(**user_dict)
return None
[docs]
def authenticate_user(fake_db: dict, email: str, password: str) -> UserInDB | None:
user = get_user(fake_db, email)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
[docs]
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({'exp': expire})
encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
return encoded_jwt
[docs]
async def get_current_user(token: t.Annotated[str, Depends(oauth2_scheme)]) -> orm.User.ReadModel:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'},
)
try:
payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
email = payload.get('sub')
if email is None:
raise credentials_exception
token_data = TokenData(email=email)
except JWTError:
raise credentials_exception # pylint: disable=raise-missing-from
user = get_user(config.fake_users_db, email=token_data.email)
if user is None:
raise credentials_exception
return user
[docs]
async def get_current_active_user(
current_user: t.Annotated[UserInDB, Depends(get_current_user)],
) -> UserInDB:
if current_user.disabled:
raise HTTPException(status_code=400, detail='Inactive user')
return current_user
[docs]
@write_router.post(
'/token',
response_model=Token,
)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
) -> dict[str, t.Any]:
"""Login to get access token."""
user = authenticate_user(config.fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Incorrect email or password',
headers={'WWW-Authenticate': 'Bearer'},
)
access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(data={'sub': user.email}, expires_delta=access_token_expires)
return {'access_token': access_token, 'token_type': 'bearer'}
[docs]
@read_router.get(
'/me/',
response_model=orm.User.ReadModel,
)
async def read_users_me(
current_user: t.Annotated[orm.User.ReadModel, Depends(get_current_active_user)],
) -> orm.User.ReadModel:
"""Get the current authenticated user."""
return current_user