Skip to content

Commit

Permalink
Refactor to move common logic to the base class
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Nov 5, 2024
1 parent 7d01620 commit 0420d2d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 22 deletions.
50 changes: 47 additions & 3 deletions neon_users_service/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from abc import ABC, abstractmethod

from neon_users_service.exceptions import UserNotFoundError
from neon_users_service.exceptions import UserNotFoundError, UserExistsError
from neon_data_models.models.user.database import User


class UserDatabase(ABC):
@abstractmethod
def create_user(self, user: User) -> User:
"""
Add a new user to the database. Raises a `UserExistsError` if the input
`user` already exists in the database (by `username` or `user_id`).
@param user: `User` object to insert to the database
@return: `User` object inserted into the database
"""
if self._check_user_exists(user):
raise UserExistsError(user)
return self._db_create_user(user)

@abstractmethod
def _db_create_user(self, user: User) -> User:
"""
Add a new user to the database. The `user` object has already been
validated as unique, so this just needs to perform the database
transaction.
@param user: `User` object to insert to the database
@return: `User` object inserted into the database
"""

@abstractmethod
def read_user_by_id(self, user_id: str) -> User:
Expand Down Expand Up @@ -45,14 +57,33 @@ def read_user(self, user_spec: str) -> User:
except UserNotFoundError:
return self.read_user_by_username(user_spec)

@abstractmethod
def update_user(self, user: User) -> User:
"""
Update a user entry in the database. Raises a `UserNotFoundError` if
the input user's `user_id` is not found in the database.
@param user: `User` object to update in the database
@return: Updated `User` object read from the database
"""
# Lookup user to ensure they exist in the database
existing_id = self.read_user_by_id(user.user_id)
try:
if self.read_user_by_username(user.username) != existing_id:
raise UserExistsError(f"Another user with username "
f"'{user.username}' already exists")
except UserNotFoundError:
pass
return self._db_update_user(user)

@abstractmethod
def _db_update_user(self, user: User) -> User:
"""
Update a user entry in the database. The `user` object has already been
validated as existing and changes valid, so this just needs to perform
the database transaction.
@param user: `User` object to update in the database
@return: Updated `User` object read from the database
"""

def delete_user(self, user_id: str) -> User:
"""
Remove a user from the database if it exists. Raises a
Expand All @@ -61,6 +92,19 @@ def delete_user(self, user_id: str) -> User:
@param user_id: `user_id` to remove
@return: User object removed from the database
"""
# Lookup user to ensure they exist in the database
user_to_delete = self.read_user_by_id(user_id)
return self._db_delete_user(user_to_delete)

@abstractmethod
def _db_delete_user(self, user: User) -> User:
"""
Remove a user from the database if it exists. The `user` object has
already been validated as existing, so this just needs to perform the
database transaction.
@param user: User object to remove
@return: User object removed from the database
"""

def _check_user_exists(self, user: User) -> bool:
"""
Expand Down
27 changes: 8 additions & 19 deletions neon_users_service/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from os import makedirs
from os.path import expanduser, dirname
from sqlite3 import connect, Cursor
from sqlite3 import connect
from threading import Lock
from typing import Optional, List

from neon_users_service.databases import UserDatabase
from neon_users_service.exceptions import UserNotFoundError, UserExistsError, DatabaseError
from neon_users_service.exceptions import UserNotFoundError, DatabaseError
from neon_data_models.models.user.database import User


Expand All @@ -26,9 +26,7 @@ def __init__(self, db_path: Optional[str] = None):
)
self.connection.commit()

def create_user(self, user: User) -> User:
if self._check_user_exists(user):
raise UserExistsError(user)
def _db_create_user(self, user: User) -> User:
with self._db_lock:
self.connection.execute(
f'''INSERT INTO users VALUES
Expand Down Expand Up @@ -72,15 +70,7 @@ def read_user_by_username(self, username: str) -> User:
cursor.close()
return User(**json.loads(self._parse_lookup_results(username, rows)))

def update_user(self, user: User) -> User:
# Lookup user to ensure they exist in the database
existing_id = self.read_user_by_id(user.user_id)
try:
if self.read_user_by_username(user.username) != existing_id:
raise UserExistsError(f"Another user with username "
f"'{user.username}' already exists")
except UserNotFoundError:
pass
def _db_update_user(self, user: User) -> User:
with self._db_lock:
self.connection.execute(
f'''UPDATE users SET username = '{user.username}',
Expand All @@ -91,13 +81,12 @@ def update_user(self, user: User) -> User:
self.connection.commit()
return self.read_user_by_id(user.user_id)

def delete_user(self, user_id: str) -> User:
# Lookup user to ensure they exist in the database
user_to_delete = self.read_user_by_id(user_id)
def _db_delete_user(self, user: User) -> User:
with self._db_lock:
self.connection.execute(f"DELETE FROM users WHERE user_id = '{user_id}'")
self.connection.execute(
f"DELETE FROM users WHERE user_id = '{user.user_id}'")
self.connection.commit()
return user_to_delete
return user

def shutdown(self):
self.connection.close()

0 comments on commit 0420d2d

Please sign in to comment.