brewman/brewman/brewman/core/security.py

153 lines
4.3 KiB
Python

import uuid
from datetime import datetime, timedelta
from typing import List, Optional
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from jose import jwt
from jose.exceptions import ExpiredSignatureError
from jwt import PyJWTError
from pydantic import BaseModel, ValidationError
from sqlalchemy.orm import Session
from ..core.config import settings
from ..db.session import SessionLocal
from ..models.auth import Client
from ..models.auth import User as UserModel
# to get a string like this run:
from ..schemas.auth import UserToken
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token", scopes={})
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str = None
scopes: List[str] = []
def f7(seq):
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]
# Dependency
def get_db():
try:
db = SessionLocal()
yield db
finally:
db.close()
def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def get_user(username: str, id_: str, locked_out: bool, scopes: List[str]) -> UserToken:
return UserToken(
id_=uuid.UUID(id_),
name=username,
locked_out=locked_out,
password="",
permissions=scopes,
)
def authenticate_user(
username: str, password: str, client_id: Optional[int], otp: int, db: Session
) -> Optional[UserModel]:
user = UserModel.auth(username, password, db)
return user
def client_allowed(
user: UserModel, client_id: int, otp: Optional[int] = None, db: Session = None
) -> (bool, int):
client = (
db.query(Client).filter(Client.code == client_id).first() if client_id else None
)
allowed = "clients" in set(
[p.name.replace(" ", "-").lower() for r in user.roles for p in r.permissions]
)
if allowed:
return True, 0
elif client is None:
client = Client.create(db)
return False, client.code
elif client.enabled:
return True, client.code
elif client.otp == otp:
client.otp = None
client.enabled = True
return True, client.code
else:
return False, client.code
async def get_current_user(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme),
) -> UserToken:
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_value = f"Bearer"
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": authenticate_value},
)
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (PyJWTError, ValidationError, ExpiredSignatureError):
raise credentials_exception
user = get_user(
username=token_data.username,
id_=payload.get("userId", None),
locked_out=payload.get("lockedOut", True),
scopes=token_scopes,
)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return user
async def get_current_active_user(
current_user: UserToken = Security(get_current_user, scopes=["authenticated"])
) -> UserToken:
if current_user.locked_out:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user