diff --git a/brewman/core/security.py b/brewman/core/security.py index a6a2565d..a276a072 100644 --- a/brewman/core/security.py +++ b/brewman/core/security.py @@ -8,6 +8,7 @@ from fastapi.security import OAuth2PasswordBearer, SecurityScopes from pydantic import BaseModel, ValidationError from sqlalchemy.orm import Session from jose import jwt +from jose.exceptions import ExpiredSignatureError from brewman.models.auth import User as UserModel from ..db.session import SessionLocal @@ -71,7 +72,9 @@ def get_user(username: str, id_: str, locked_out: bool, scopes: List[str]) -> Us ) -def authenticate_user(username: str, password: str, db: Session) -> Union[UserModel, bool]: +def authenticate_user( + username: str, password: str, db: Session +) -> Union[UserModel, bool]: found, user = UserModel.auth(username, password, db) if not found: return False @@ -79,8 +82,7 @@ def authenticate_user(username: str, password: str, db: Session) -> Union[UserMo async def get_current_user( - security_scopes: SecurityScopes, - token: str = Depends(oauth2_scheme), + security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme), ) -> UserToken: if security_scopes.scopes: authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' @@ -98,9 +100,14 @@ async def get_current_user( raise credentials_exception token_scopes = payload.get("scopes", []) token_data = TokenData(scopes=token_scopes, username=username) - except (PyJWTError, ValidationError): + except (PyJWTError, ValidationError, ExpiredSignatureError): raise credentials_exception - user = get_user(username=token_data.username, id_=payload.get("user_id", None), locked_out=payload.get("locked_out", True), scopes=token_scopes) + user = get_user( + username=token_data.username, + id_=payload.get("userId", None), + locked_out=payload.get("lockedOut", True), + scopes=token_scopes, + ) if user is None: raise credentials_exception for scope in security_scopes.scopes: diff --git a/brewman/routers/account.py b/brewman/routers/account.py index e9b245a8..51a732ed 100644 --- a/brewman/routers/account.py +++ b/brewman/routers/account.py @@ -1,6 +1,6 @@ import traceback import uuid -from typing import List +from typing import List, Optional from datetime import datetime from fastapi import APIRouter, HTTPException, status, Depends, Security @@ -19,7 +19,7 @@ router = APIRouter() # Dependency -def get_db(): +def get_db() -> Session: try: db = SessionLocal() yield db @@ -43,7 +43,7 @@ def save( cost_centre_id=data.cost_centre.id_, ).create(db) db.commit() - return account_info(item.id, db) + return account_info(item) except SQLAlchemyError as e: db.rollback() raise HTTPException( @@ -81,7 +81,7 @@ def update( item.is_starred = data.is_starred item.cost_centre_id = data.cost_centre.id_ db.commit() - return account_info(item.id, db) + return account_info(item) except SQLAlchemyError as e: db.rollback() raise HTTPException( @@ -106,9 +106,9 @@ def delete( if can_delete: delete_with_data(account, db) db.commit() - return account_info(None, db) + return account_info(None) else: - db.abort() + db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Cannot delete account because {reason}", @@ -119,7 +119,7 @@ def delete( def show_blank( db: Session = Depends(get_db), user: UserToken = Security(get_user, scopes=["accounts"]) ): - return account_info(None, db) + return account_info(None) @router.get("/list") @@ -156,7 +156,7 @@ async def show_term( list_.append({"id": item.id, "name": item.name}) if count is not None and index == count - 1: break - return {"user": current_user.name, "list": list_} + return list_ @router.get("/{id_}/balance") @@ -176,7 +176,8 @@ def show_id( db: Session = Depends(get_db), user: UserToken = Security(get_user, scopes=["accounts"]), ): - return account_info(id_, db) + item: Account = db.query(Account).filter(Account.id == id_).first() + return account_info(item) def balance(id_: uuid.UUID, date, db: Session): @@ -196,9 +197,9 @@ def balance(id_: uuid.UUID, date, db: Session): return 0 if bal is None else bal -def account_info(id_: uuid.UUID, db: Session): - if id_ is None: - account = { +def account_info(item: Optional[Account]): + if item is None: + return { "code": "(Auto)", "type": AccountType.by_name("Creditors").id, "isActive": True, @@ -207,22 +208,20 @@ def account_info(id_: uuid.UUID, db: Session): "costCentre": CostCentre.overall(), } else: - account = db.query(Account).filter(Account.id == id_).first() - account = { - "id": account.id, - "code": account.code, - "name": account.name, - "type": account.type, - "isActive": account.is_active, - "isReconcilable": account.is_reconcilable, - "isStarred": account.is_starred, - "isFixture": account.is_fixture, + return { + "id": item.id, + "code": item.code, + "name": item.name, + "type": item.type, + "isActive": item.is_active, + "isReconcilable": item.is_reconcilable, + "isStarred": item.is_starred, + "isFixture": item.is_fixture, "costCentre": { - "id": account.cost_centre_id, - "name": account.cost_centre.name, + "id": item.cost_centre_id, + "name": item.cost_centre.name, }, } - return account def delete_with_data(account: Account, db: Session): diff --git a/brewman/routers/login.py b/brewman/routers/login.py index 88c5c337..94d7d3a6 100644 --- a/brewman/routers/login.py +++ b/brewman/routers/login.py @@ -48,8 +48,8 @@ async def login_for_access_token( ] ) ), - "user_id": str(user.id), - "locked_out": user.locked_out, + "userId": str(user.id), + "lockedOut": user.locked_out, }, expires_delta=access_token_expires, )