Skip to content

Commit

Permalink
fix: refactor validate_dates function (avoid writing duplicate code)
Browse files Browse the repository at this point in the history
  • Loading branch information
M03ED committed Sep 12, 2024
1 parent d7281c2 commit aea3273
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 63 deletions.
33 changes: 18 additions & 15 deletions app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from app.db import Session, crud, get_db
from config import SUDOERS
from fastapi import Depends, HTTPException
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
from app.utils.jwt import get_subscription_payload


Expand Down Expand Up @@ -36,20 +36,23 @@ def get_dbnode(node_id: int, db: Session = Depends(get_db)):
return dbnode


def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> bool:
def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> (datetime, datetime):
"""Validate if start and end dates are correct and if end is after start."""
try:
if start:
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start)
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start).astimezone(timezone.utc)
else:
start_date = None
start_date = datetime.now(timezone.utc) - timedelta(days=30)
if end:
end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end)
end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end).astimezone(timezone.utc)
if start_date and end_date < start_date:
return False
return True
raise HTTPException(status_code=400, detail="Start date must be before end date")
else:
end_date = datetime.now(timezone.utc)

return start_date, end_date
except ValueError:
return False
raise HTTPException(status_code=400, detail="Invalid date range or format")


def get_user_template(template_id: int, db: Session = Depends(get_db)):
Expand All @@ -61,8 +64,8 @@ def get_user_template(template_id: int, db: Session = Depends(get_db)):


def get_validated_sub(
token: str,
db: Session = Depends(get_db)
token: str,
db: Session = Depends(get_db)
) -> UserResponse:
sub = get_subscription_payload(token)
if not sub:
Expand All @@ -79,9 +82,9 @@ def get_validated_sub(


def get_validated_user(
username: str,
admin: Admin = Depends(Admin.get_current),
db: Session = Depends(get_db)
username: str,
admin: Admin = Depends(Admin.get_current),
db: Session = Depends(get_db)
) -> UserResponse:
dbuser = crud.get_user(db, username)
if not dbuser:
Expand All @@ -93,8 +96,8 @@ def get_validated_user(
return dbuser


def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None, expired_before: Optional[datetime] = None):

def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None,
expired_before: Optional[datetime] = None):
expired_before = expired_before or datetime.now(timezone.utc)
expired_after = expired_after or datetime.min.replace(tzinfo=timezone.utc)

Expand Down
12 changes: 1 addition & 11 deletions app/routers/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,7 @@ def get_usage(
_: Admin = Depends(Admin.check_sudo_admin)
):
"""Retrieve usage statistics for nodes within a specified date range."""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)
start, end = validate_dates(start, end)

usages = crud.get_nodes_usage(db, start, end)

Expand Down
12 changes: 1 addition & 11 deletions app/routers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,7 @@ def user_get_usage(
db: Session = Depends(get_db)
):
"""Fetches the usage statistics for the user within a specified date range."""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)
start, end = validate_dates(start, end)

usages = crud.get_user_usages(db, dbuser, start, end)

Expand Down
30 changes: 4 additions & 26 deletions app/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,7 @@ def get_user_usage(
db: Session = Depends(get_db)
):
"""Get users usage"""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)
start, end = validate_dates(start, end)

usages = crud.get_user_usages(db, dbuser, start, end)

Expand All @@ -291,17 +281,7 @@ def get_users_usage(
admin: Admin = Depends(Admin.get_current)
):
"""Get all users usage"""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

if not start:
start = datetime.now(timezone.utc) - timedelta(days=30)
else:
start = datetime.fromisoformat(start).astimezone(timezone.utc)
if not end:
end = datetime.now(timezone.utc)
else:
end = datetime.fromisoformat(end).astimezone(timezone.utc)
start, end = validate_dates(start, end)

usages = crud.get_all_users_usages(
db=db,
Expand Down Expand Up @@ -350,8 +330,7 @@ def get_expired_users(
- If both are omitted, returns all expired users
"""

if not validate_dates(expired_after, expired_before):
raise HTTPException(status_code=400, detail="Invalid date range or format")
expired_after, expired_before = validate_dates(expired_after, expired_before)

expired_users = get_expired_users_list(db, admin, expired_after, expired_before)
return [u.username for u in expired_users]
Expand All @@ -372,8 +351,7 @@ def delete_expired_users(
- **expired_before** UTC datetime (optional)
- At least one of expired_after or expired_before must be provided
"""
if not validate_dates(expired_after, expired_before, allow_both_none=False):
raise HTTPException(status_code=400, detail="Invalid date range or format")
expired_after, expired_before = validate_dates(expired_after, expired_before)

expired_users = get_expired_users_list(db, admin, expired_after, expired_before)
removed_users = [u.username for u in expired_users]
Expand Down

0 comments on commit aea3273

Please sign in to comment.