diff --git a/barker/barker/routers/product.py b/barker/barker/routers/product.py index 5fbae44..7d04923 100644 --- a/barker/barker/routers/product.py +++ b/barker/barker/routers/product.py @@ -6,13 +6,16 @@ from typing import List import barker.schemas.product as schemas from fastapi import APIRouter, Depends, HTTPException, Security, status -from sqlalchemy import and_, or_, select, update +from sqlalchemy import and_, insert, or_, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session, contains_eager, joinedload +from sqlalchemy.sql.functions import count from ..core.security import get_current_active_user as get_user from ..db.session import SessionFuture from ..models.menu_category import MenuCategory +from ..models.modifier_categories_products import modifier_categories_products +from ..models.modifier_category import ModifierCategory from ..models.product import Product from ..models.product_version import ProductVersion from ..models.sale_category import SaleCategory @@ -88,6 +91,7 @@ def save( valid_till=None, ) db.add(product_version) + add_modifiers(item.id, product_version.menu_category_id, date_, db) db.commit() return product_info(product_version) except SQLAlchemyError as e: @@ -97,6 +101,48 @@ def save( ) +def add_modifiers(product_id: uuid.UUID, menu_category_id: uuid.UUID, date_: date, db: Session): + products_in_category = db.execute( + select(count(ProductVersion.id)).where( + and_( + ProductVersion.menu_category_id == menu_category_id, + ProductVersion.product_id != product_id, + or_( + ProductVersion.valid_from == None, # noqa: E711 + ProductVersion.valid_from <= date_, + ), + or_( + ProductVersion.valid_till == None, # noqa: E711 + ProductVersion.valid_till >= date_, + ), + ) + ) + ).scalar() + categories = db.execute( + select(count(ProductVersion.id), ModifierCategory.id) + .join(Product.versions) + .join(modifier_categories_products, ProductVersion.id == modifier_categories_products.c.product_id) + .join(ModifierCategory, ModifierCategory.id == modifier_categories_products.c.modifier_category_id) + .where( + and_( + ProductVersion.menu_category_id == menu_category_id, + or_( + ProductVersion.valid_from == None, # noqa: E711 + ProductVersion.valid_from <= date_, + ), + or_( + ProductVersion.valid_till == None, # noqa: E711 + ProductVersion.valid_till >= date_, + ), + ) + ) + .group_by(ModifierCategory.id) + ).all() + for c, mc in categories: + if c == products_in_category: + db.execute(insert(modifier_categories_products).values(product_id=product_id, modifier_category_id=mc)) + + @router.put("/{id_}", response_model=schemas.Product) def update_route( id_: uuid.UUID,