Changes in the login so that the token works with my solution

This commit is contained in:
tanshu 2020-05-11 21:53:38 +05:30
parent 9dff72aaed
commit 07b7248b4e
3 changed files with 38 additions and 32 deletions

View File

@ -8,6 +8,7 @@ from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from pydantic import BaseModel, ValidationError
from sqlalchemy.orm import Session
from jose import jwt
from jose.exceptions import ExpiredSignatureError
from brewman.models.auth import User as UserModel
from ..db.session import SessionLocal
@ -71,7 +72,9 @@ def get_user(username: str, id_: str, locked_out: bool, scopes: List[str]) -> Us
)
def authenticate_user(username: str, password: str, db: Session) -> Union[UserModel, bool]:
def authenticate_user(
username: str, password: str, db: Session
) -> Union[UserModel, bool]:
found, user = UserModel.auth(username, password, db)
if not found:
return False
@ -79,8 +82,7 @@ def authenticate_user(username: str, password: str, db: Session) -> Union[UserMo
async def get_current_user(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme),
security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme),
) -> UserToken:
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
@ -98,9 +100,14 @@ async def get_current_user(
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (PyJWTError, ValidationError):
except (PyJWTError, ValidationError, ExpiredSignatureError):
raise credentials_exception
user = get_user(username=token_data.username, id_=payload.get("user_id", None), locked_out=payload.get("locked_out", True), scopes=token_scopes)
user = get_user(
username=token_data.username,
id_=payload.get("userId", None),
locked_out=payload.get("lockedOut", True),
scopes=token_scopes,
)
if user is None:
raise credentials_exception
for scope in security_scopes.scopes:

View File

@ -1,6 +1,6 @@
import traceback
import uuid
from typing import List
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, HTTPException, status, Depends, Security
@ -19,7 +19,7 @@ router = APIRouter()
# Dependency
def get_db():
def get_db() -> Session:
try:
db = SessionLocal()
yield db
@ -43,7 +43,7 @@ def save(
cost_centre_id=data.cost_centre.id_,
).create(db)
db.commit()
return account_info(item.id, db)
return account_info(item)
except SQLAlchemyError as e:
db.rollback()
raise HTTPException(
@ -81,7 +81,7 @@ def update(
item.is_starred = data.is_starred
item.cost_centre_id = data.cost_centre.id_
db.commit()
return account_info(item.id, db)
return account_info(item)
except SQLAlchemyError as e:
db.rollback()
raise HTTPException(
@ -106,9 +106,9 @@ def delete(
if can_delete:
delete_with_data(account, db)
db.commit()
return account_info(None, db)
return account_info(None)
else:
db.abort()
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Cannot delete account because {reason}",
@ -119,7 +119,7 @@ def delete(
def show_blank(
db: Session = Depends(get_db), user: UserToken = Security(get_user, scopes=["accounts"])
):
return account_info(None, db)
return account_info(None)
@router.get("/list")
@ -156,7 +156,7 @@ async def show_term(
list_.append({"id": item.id, "name": item.name})
if count is not None and index == count - 1:
break
return {"user": current_user.name, "list": list_}
return list_
@router.get("/{id_}/balance")
@ -176,7 +176,8 @@ def show_id(
db: Session = Depends(get_db),
user: UserToken = Security(get_user, scopes=["accounts"]),
):
return account_info(id_, db)
item: Account = db.query(Account).filter(Account.id == id_).first()
return account_info(item)
def balance(id_: uuid.UUID, date, db: Session):
@ -196,9 +197,9 @@ def balance(id_: uuid.UUID, date, db: Session):
return 0 if bal is None else bal
def account_info(id_: uuid.UUID, db: Session):
if id_ is None:
account = {
def account_info(item: Optional[Account]):
if item is None:
return {
"code": "(Auto)",
"type": AccountType.by_name("Creditors").id,
"isActive": True,
@ -207,22 +208,20 @@ def account_info(id_: uuid.UUID, db: Session):
"costCentre": CostCentre.overall(),
}
else:
account = db.query(Account).filter(Account.id == id_).first()
account = {
"id": account.id,
"code": account.code,
"name": account.name,
"type": account.type,
"isActive": account.is_active,
"isReconcilable": account.is_reconcilable,
"isStarred": account.is_starred,
"isFixture": account.is_fixture,
return {
"id": item.id,
"code": item.code,
"name": item.name,
"type": item.type,
"isActive": item.is_active,
"isReconcilable": item.is_reconcilable,
"isStarred": item.is_starred,
"isFixture": item.is_fixture,
"costCentre": {
"id": account.cost_centre_id,
"name": account.cost_centre.name,
"id": item.cost_centre_id,
"name": item.cost_centre.name,
},
}
return account
def delete_with_data(account: Account, db: Session):

View File

@ -48,8 +48,8 @@ async def login_for_access_token(
]
)
),
"user_id": str(user.id),
"locked_out": user.locked_out,
"userId": str(user.id),
"lockedOut": user.locked_out,
},
expires_delta=access_token_expires,
)