From aea3273cd7c327d9870c435cb6a6c3d0741737d0 Mon Sep 17 00:00:00 2001 From: Random Guy <50927468+M03ED@users.noreply.github.com> Date: Thu, 12 Sep 2024 13:18:51 +0330 Subject: [PATCH] fix: refactor validate_dates function (avoid writing duplicate code) --- app/dependencies.py | 33 ++++++++++++++++++--------------- app/routers/node.py | 12 +----------- app/routers/subscription.py | 12 +----------- app/routers/user.py | 30 ++++-------------------------- 4 files changed, 24 insertions(+), 63 deletions(-) diff --git a/app/dependencies.py b/app/dependencies.py index 45c4faa7..0d9ce058 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -4,7 +4,7 @@ from app.db import Session, crud, get_db from config import SUDOERS from fastapi import Depends, HTTPException -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from app.utils.jwt import get_subscription_payload @@ -36,20 +36,23 @@ def get_dbnode(node_id: int, db: Session = Depends(get_db)): return dbnode -def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> bool: +def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> (datetime, datetime): """Validate if start and end dates are correct and if end is after start.""" try: if start: - start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start) + start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start).astimezone(timezone.utc) else: - start_date = None + start_date = datetime.now(timezone.utc) - timedelta(days=30) if end: - end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end) + end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end).astimezone(timezone.utc) if start_date and end_date < start_date: - return False - return True + raise HTTPException(status_code=400, detail="Start date must be before end date") + else: + end_date = datetime.now(timezone.utc) + + return start_date, end_date except ValueError: - return False + raise HTTPException(status_code=400, detail="Invalid date range or format") def get_user_template(template_id: int, db: Session = Depends(get_db)): @@ -61,8 +64,8 @@ def get_user_template(template_id: int, db: Session = Depends(get_db)): def get_validated_sub( - token: str, - db: Session = Depends(get_db) + token: str, + db: Session = Depends(get_db) ) -> UserResponse: sub = get_subscription_payload(token) if not sub: @@ -79,9 +82,9 @@ def get_validated_sub( def get_validated_user( - username: str, - admin: Admin = Depends(Admin.get_current), - db: Session = Depends(get_db) + username: str, + admin: Admin = Depends(Admin.get_current), + db: Session = Depends(get_db) ) -> UserResponse: dbuser = crud.get_user(db, username) if not dbuser: @@ -93,8 +96,8 @@ def get_validated_user( return dbuser -def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None, expired_before: Optional[datetime] = None): - +def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None, + expired_before: Optional[datetime] = None): expired_before = expired_before or datetime.now(timezone.utc) expired_after = expired_after or datetime.min.replace(tzinfo=timezone.utc) diff --git a/app/routers/node.py b/app/routers/node.py index 45979544..e32b5088 100644 --- a/app/routers/node.py +++ b/app/routers/node.py @@ -198,17 +198,7 @@ def get_usage( _: Admin = Depends(Admin.check_sudo_admin) ): """Retrieve usage statistics for nodes within a specified date range.""" - if not validate_dates(start, end): - raise HTTPException(status_code=400, detail="Invalid date range or format") - - if not start: - start = datetime.now(timezone.utc) - timedelta(days=30) - else: - start = datetime.fromisoformat(start).astimezone(timezone.utc) - if not end: - end = datetime.now(timezone.utc) - else: - end = datetime.fromisoformat(end).astimezone(timezone.utc) + start, end = validate_dates(start, end) usages = crud.get_nodes_usage(db, start, end) diff --git a/app/routers/subscription.py b/app/routers/subscription.py index a601a7f6..9117df19 100644 --- a/app/routers/subscription.py +++ b/app/routers/subscription.py @@ -144,17 +144,7 @@ def user_get_usage( db: Session = Depends(get_db) ): """Fetches the usage statistics for the user within a specified date range.""" - if not validate_dates(start, end): - raise HTTPException(status_code=400, detail="Invalid date range or format") - - if not start: - start = datetime.now(timezone.utc) - timedelta(days=30) - else: - start = datetime.fromisoformat(start).astimezone(timezone.utc) - if not end: - end = datetime.now(timezone.utc) - else: - end = datetime.fromisoformat(end).astimezone(timezone.utc) + start, end = validate_dates(start, end) usages = crud.get_user_usages(db, dbuser, start, end) diff --git a/app/routers/user.py b/app/routers/user.py index 282e1081..4446fc6b 100644 --- a/app/routers/user.py +++ b/app/routers/user.py @@ -265,17 +265,7 @@ def get_user_usage( db: Session = Depends(get_db) ): """Get users usage""" - if not validate_dates(start, end): - raise HTTPException(status_code=400, detail="Invalid date range or format") - - if not start: - start = datetime.now(timezone.utc) - timedelta(days=30) - else: - start = datetime.fromisoformat(start).astimezone(timezone.utc) - if not end: - end = datetime.now(timezone.utc) - else: - end = datetime.fromisoformat(end).astimezone(timezone.utc) + start, end = validate_dates(start, end) usages = crud.get_user_usages(db, dbuser, start, end) @@ -291,17 +281,7 @@ def get_users_usage( admin: Admin = Depends(Admin.get_current) ): """Get all users usage""" - if not validate_dates(start, end): - raise HTTPException(status_code=400, detail="Invalid date range or format") - - if not start: - start = datetime.now(timezone.utc) - timedelta(days=30) - else: - start = datetime.fromisoformat(start).astimezone(timezone.utc) - if not end: - end = datetime.now(timezone.utc) - else: - end = datetime.fromisoformat(end).astimezone(timezone.utc) + start, end = validate_dates(start, end) usages = crud.get_all_users_usages( db=db, @@ -350,8 +330,7 @@ def get_expired_users( - If both are omitted, returns all expired users """ - if not validate_dates(expired_after, expired_before): - raise HTTPException(status_code=400, detail="Invalid date range or format") + expired_after, expired_before = validate_dates(expired_after, expired_before) expired_users = get_expired_users_list(db, admin, expired_after, expired_before) return [u.username for u in expired_users] @@ -372,8 +351,7 @@ def delete_expired_users( - **expired_before** UTC datetime (optional) - At least one of expired_after or expired_before must be provided """ - if not validate_dates(expired_after, expired_before, allow_both_none=False): - raise HTTPException(status_code=400, detail="Invalid date range or format") + expired_after, expired_before = validate_dates(expired_after, expired_before) expired_users = get_expired_users_list(db, admin, expired_after, expired_before) removed_users = [u.username for u in expired_users]