Changes in the login so that the token works with my solution

This commit is contained in:
tanshu 2020-05-11 21:53:38 +05:30
parent 9dff72aaed
commit 07b7248b4e
3 changed files with 38 additions and 32 deletions

@ -8,6 +8,7 @@ from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from jose import jwt from jose import jwt
from jose.exceptions import ExpiredSignatureError
from brewman.models.auth import User as UserModel from brewman.models.auth import User as UserModel
from ..db.session import SessionLocal from ..db.session import SessionLocal
@ -71,7 +72,9 @@ def get_user(username: str, id_: str, locked_out: bool, scopes: List[str]) -> Us
) )
def authenticate_user(username: str, password: str, db: Session) -> Union[UserModel, bool]: def authenticate_user(
username: str, password: str, db: Session
) -> Union[UserModel, bool]:
found, user = UserModel.auth(username, password, db) found, user = UserModel.auth(username, password, db)
if not found: if not found:
return False return False
@ -79,8 +82,7 @@ def authenticate_user(username: str, password: str, db: Session) -> Union[UserMo
async def get_current_user( async def get_current_user(
security_scopes: SecurityScopes, security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme),
token: str = Depends(oauth2_scheme),
) -> UserToken: ) -> UserToken:
if security_scopes.scopes: if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
@ -98,9 +100,14 @@ async def get_current_user(
raise credentials_exception raise credentials_exception
token_scopes = payload.get("scopes", []) token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username) token_data = TokenData(scopes=token_scopes, username=username)
except (PyJWTError, ValidationError): except (PyJWTError, ValidationError, ExpiredSignatureError):
raise credentials_exception raise credentials_exception
user = get_user(username=token_data.username, id_=payload.get("user_id", None), locked_out=payload.get("locked_out", True), scopes=token_scopes) user = get_user(
username=token_data.username,
id_=payload.get("userId", None),
locked_out=payload.get("lockedOut", True),
scopes=token_scopes,
)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
for scope in security_scopes.scopes: for scope in security_scopes.scopes:

@ -1,6 +1,6 @@
import traceback import traceback
import uuid import uuid
from typing import List from typing import List, Optional
from datetime import datetime from datetime import datetime
from fastapi import APIRouter, HTTPException, status, Depends, Security from fastapi import APIRouter, HTTPException, status, Depends, Security
@ -19,7 +19,7 @@ router = APIRouter()
# Dependency # Dependency
def get_db(): def get_db() -> Session:
try: try:
db = SessionLocal() db = SessionLocal()
yield db yield db
@ -43,7 +43,7 @@ def save(
cost_centre_id=data.cost_centre.id_, cost_centre_id=data.cost_centre.id_,
).create(db) ).create(db)
db.commit() db.commit()
return account_info(item.id, db) return account_info(item)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(
@ -81,7 +81,7 @@ def update(
item.is_starred = data.is_starred item.is_starred = data.is_starred
item.cost_centre_id = data.cost_centre.id_ item.cost_centre_id = data.cost_centre.id_
db.commit() db.commit()
return account_info(item.id, db) return account_info(item)
except SQLAlchemyError as e: except SQLAlchemyError as e:
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(
@ -106,9 +106,9 @@ def delete(
if can_delete: if can_delete:
delete_with_data(account, db) delete_with_data(account, db)
db.commit() db.commit()
return account_info(None, db) return account_info(None)
else: else:
db.abort() db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Cannot delete account because {reason}", detail=f"Cannot delete account because {reason}",
@ -119,7 +119,7 @@ def delete(
def show_blank( def show_blank(
db: Session = Depends(get_db), user: UserToken = Security(get_user, scopes=["accounts"]) db: Session = Depends(get_db), user: UserToken = Security(get_user, scopes=["accounts"])
): ):
return account_info(None, db) return account_info(None)
@router.get("/list") @router.get("/list")
@ -156,7 +156,7 @@ async def show_term(
list_.append({"id": item.id, "name": item.name}) list_.append({"id": item.id, "name": item.name})
if count is not None and index == count - 1: if count is not None and index == count - 1:
break break
return {"user": current_user.name, "list": list_} return list_
@router.get("/{id_}/balance") @router.get("/{id_}/balance")
@ -176,7 +176,8 @@ def show_id(
db: Session = Depends(get_db), db: Session = Depends(get_db),
user: UserToken = Security(get_user, scopes=["accounts"]), user: UserToken = Security(get_user, scopes=["accounts"]),
): ):
return account_info(id_, db) item: Account = db.query(Account).filter(Account.id == id_).first()
return account_info(item)
def balance(id_: uuid.UUID, date, db: Session): def balance(id_: uuid.UUID, date, db: Session):
@ -196,9 +197,9 @@ def balance(id_: uuid.UUID, date, db: Session):
return 0 if bal is None else bal return 0 if bal is None else bal
def account_info(id_: uuid.UUID, db: Session): def account_info(item: Optional[Account]):
if id_ is None: if item is None:
account = { return {
"code": "(Auto)", "code": "(Auto)",
"type": AccountType.by_name("Creditors").id, "type": AccountType.by_name("Creditors").id,
"isActive": True, "isActive": True,
@ -207,22 +208,20 @@ def account_info(id_: uuid.UUID, db: Session):
"costCentre": CostCentre.overall(), "costCentre": CostCentre.overall(),
} }
else: else:
account = db.query(Account).filter(Account.id == id_).first() return {
account = { "id": item.id,
"id": account.id, "code": item.code,
"code": account.code, "name": item.name,
"name": account.name, "type": item.type,
"type": account.type, "isActive": item.is_active,
"isActive": account.is_active, "isReconcilable": item.is_reconcilable,
"isReconcilable": account.is_reconcilable, "isStarred": item.is_starred,
"isStarred": account.is_starred, "isFixture": item.is_fixture,
"isFixture": account.is_fixture,
"costCentre": { "costCentre": {
"id": account.cost_centre_id, "id": item.cost_centre_id,
"name": account.cost_centre.name, "name": item.cost_centre.name,
}, },
} }
return account
def delete_with_data(account: Account, db: Session): def delete_with_data(account: Account, db: Session):

@ -48,8 +48,8 @@ async def login_for_access_token(
] ]
) )
), ),
"user_id": str(user.id), "userId": str(user.id),
"locked_out": user.locked_out, "lockedOut": user.locked_out,
}, },
expires_delta=access_token_expires, expires_delta=access_token_expires,
) )