Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Singleton pattern from Database #175

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Backend/app/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from app.common.config_constants import APP, HOST, PORT
from app.common.PropertiesManager import PropertiesManager
from app.database.Database import Database
from app.database.DatabaseConnection import DatabaseConnection
from app.logging.logging_constants import LOGGING_MAIN
from app.logging.logging_schema import SpotifyElectronLogger
from app.middleware.cors_middleware_config import (
Expand Down Expand Up @@ -49,7 +49,7 @@ async def lifespan_handler(app: FastAPI):
"""
main_logger.info("Spotify Electron Backend Started")

Database()
DatabaseConnection()

app.include_router(playlist_controller.router)
app.include_router(song_controller.router)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sys
from enum import StrEnum
from functools import wraps
from typing import Any

from gridfs import GridFS
Expand All @@ -18,7 +19,6 @@
from app.exceptions.base_exceptions_schema import SpotifyElectronException
from app.logging.logging_constants import LOGGING_DATABASE
from app.logging.logging_schema import SpotifyElectronLogger
from app.patterns.Singleton import Singleton

database_logger = SpotifyElectronLogger(LOGGING_DATABASE).getLogger()

Expand All @@ -34,31 +34,44 @@ class DatabaseCollection(StrEnum):
SONG_BLOB_DATA = "songs"


class Database(metaclass=Singleton):
"""Singleton instance of the MongoDb connection"""
def __is_connection__init__(func):
@wraps(func)
def wrapper(*args, **kwargs):
if DatabaseConnection.connection is not None:
return func(*args, **kwargs)

return wrapper


class DatabaseConnection:
"""MongoDB connection Instance"""

TESTING_COLLECTION_NAME_PREFIX = "test."
DATABASE_NAME = "SpotifyElectron"
connection = None
collection_name_prefix = None

def __init__(self):
if not hasattr(self, "connection"):
try:
uri = getattr(PropertiesManager, MONGO_URI_ENV_NAME)
self.collection_name_prefix = self._get_collection_name_prefix()
client = self._get_mongo_client_class()
self.connection = client(uri, server_api=ServerApi("1"))[self.DATABASE_NAME]
self._ping_database_connection()
except (
DatabasePingFailed,
UnexpectedDatabasePingFailed,
Exception,
) as exception:
self._handle_database_connection_error(exception)

try:
uri = getattr(PropertiesManager, MONGO_URI_ENV_NAME)
DatabaseConnection.collection_name_prefix = self._get_collection_name_prefix()
client = self._get_mongo_client_class()
DatabaseConnection.connection = client(uri, server_api=ServerApi("1"))[
self.DATABASE_NAME
]
self._ping_database_connection()
except (
DatabasePingFailed,
UnexpectedDatabasePingFailed,
Exception,
) as exception:
self._handle_database_connection_error(exception)

@__is_connection__init__
def _ping_database_connection(self):
"""Pings database connection"""
try:
ping_result = self.connection.command("ping")
ping_result = DatabaseConnection.connection.command("ping") # type: ignore
self._check_ping_result(ping_result)
except ConnectionFailure as exception:
raise DatabasePingFailed from exception
Expand Down Expand Up @@ -112,12 +125,9 @@ def _get_collection_name_prefix(self) -> str:
else ""
)

@__is_connection__init__
@staticmethod
def get_instance():
"""Method to retrieve the singleton instance"""
return Database()

def get_collection_connection(self, collection_name: DatabaseCollection) -> Collection:
def get_collection_connection(collection_name: DatabaseCollection) -> Collection:
"""Returns the connection with a collection

Args:
Expand All @@ -126,9 +136,13 @@ def get_collection_connection(self, collection_name: DatabaseCollection) -> Coll
Returns:
Any: the connection to the collection
"""
return Database().connection[self.collection_name_prefix + collection_name] # type: ignore
return DatabaseConnection.connection[ # type: ignore
DatabaseConnection.collection_name_prefix + collection_name # type: ignore
]

def get_gridfs_collection_connection(self, collection_name: DatabaseCollection) -> Any:
@__is_connection__init__
@staticmethod
def get_gridfs_collection_connection(collection_name: DatabaseCollection) -> Any:
"""Returns the connection with gridfs collection

Args:
Expand All @@ -138,8 +152,8 @@ def get_gridfs_collection_connection(self, collection_name: DatabaseCollection)
Any: the gridfs collection connection
"""
return GridFS(
Database.get_instance().connection, # type: ignore
collection=self.collection_name_prefix + collection_name,
DatabaseConnection.connection, # type: ignore
collection=DatabaseConnection.collection_name_prefix + collection_name, # type: ignore
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pymongo.collection import Collection

from app.database.Database import Database, DatabaseCollection
from app.database.DatabaseConnection import DatabaseCollection, DatabaseConnection


def get_playlist_collection() -> Collection:
Expand All @@ -13,4 +13,4 @@ def get_playlist_collection() -> Collection:
Returns:
Collection: the playlist collection
"""
return Database().get_collection_connection(DatabaseCollection.PLAYLIST)
return DatabaseConnection.get_collection_connection(DatabaseCollection.PLAYLIST)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ARCH_STREAMING_SERVERLESS_FUNCTION,
ARCHITECTURE_ENV_NAME,
)
from app.database.Database import Database, DatabaseCollection
from app.database.DatabaseConnection import DatabaseCollection, DatabaseConnection


def get_song_collection() -> Collection:
Expand All @@ -22,8 +22,10 @@ def get_song_collection() -> Collection:
Collection: the song collection depending on architecture
"""
repository_map = {
ARCH_BLOB: Database().get_collection_connection(DatabaseCollection.SONG_BLOB_FILE),
ARCH_STREAMING_SERVERLESS_FUNCTION: Database().get_collection_connection(
ARCH_BLOB: DatabaseConnection.get_collection_connection(
DatabaseCollection.SONG_BLOB_FILE
),
ARCH_STREAMING_SERVERLESS_FUNCTION: DatabaseConnection.get_collection_connection(
DatabaseCollection.SONG_STREAMING
),
}
Expand All @@ -35,4 +37,6 @@ def get_gridfs_song_collection() -> GridFS:

:return GridFS: the gridfs song collection
"""
return Database().get_gridfs_collection_connection(DatabaseCollection.SONG_BLOB_DATA)
return DatabaseConnection.get_gridfs_collection_connection(
DatabaseCollection.SONG_BLOB_DATA
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pymongo.collection import Collection

import app.spotify_electron.user.base_user_service as base_user_service
from app.database.Database import Database, DatabaseCollection
from app.database.DatabaseConnection import DatabaseCollection, DatabaseConnection
from app.logging.logging_constants import LOGGING_USER_COLLECTION_PROVIDER
from app.logging.logging_schema import SpotifyElectronLogger
from app.spotify_electron.user.user.user_schema import UserType
Expand All @@ -23,8 +23,10 @@ def get_user_associated_collection(user_name: str) -> Collection:
Collection: the user collection
"""
collection_map = {
UserType.USER: Database().get_collection_connection(DatabaseCollection.USER),
UserType.ARTIST: Database().get_collection_connection(DatabaseCollection.ARTIST),
UserType.USER: DatabaseConnection.get_collection_connection(DatabaseCollection.USER),
UserType.ARTIST: DatabaseConnection.get_collection_connection(
DatabaseCollection.ARTIST
),
}

user_type = base_user_service.get_user_type(user_name)
Expand All @@ -43,7 +45,7 @@ def get_artist_collection() -> Collection:
Returns:
Collection: the artist collection
"""
return Database().get_collection_connection(DatabaseCollection.ARTIST)
return DatabaseConnection.get_collection_connection(DatabaseCollection.ARTIST)


def get_user_collection() -> Collection:
Expand All @@ -52,7 +54,7 @@ def get_user_collection() -> Collection:
Returns:
Collection: the artist collection
"""
return Database().get_collection_connection(DatabaseCollection.USER)
return DatabaseConnection.get_collection_connection(DatabaseCollection.USER)


def get_all_collections() -> list[Collection]:
Expand All @@ -62,7 +64,9 @@ def get_all_collections() -> list[Collection]:
list[Collection]: all the users collections
"""
collection_map = {
UserType.USER: Database().get_collection_connection(DatabaseCollection.USER),
UserType.ARTIST: Database().get_collection_connection(DatabaseCollection.ARTIST),
UserType.USER: DatabaseConnection.get_collection_connection(DatabaseCollection.USER),
UserType.ARTIST: DatabaseConnection.get_collection_connection(
DatabaseCollection.ARTIST
),
}
return list(collection_map.values())
25 changes: 12 additions & 13 deletions Backend/tests/test__database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from pytest import raises

from app.common.set_up_constants import MONGO_URI_ENV_NAME
from app.database.Database import (
Database,
from app.database.DatabaseConnection import (
DatabaseConnection,
DatabasePingFailed,
Singleton,
UnexpectedDatabasePingFailed,
)


@patch("sys.exit")
@patch("app.database.Database.Database._get_mongo_client_class")
@patch("app.database.DatabaseConnection.DatabaseConnection._get_mongo_client_class")
def test_raise_exception_connection_failure(
get_mongo_client_class_mock, sys_exit_mock, clean_modified_environments
):
Expand All @@ -24,20 +23,20 @@ def raise_exception_connection_failure(*arg):
client_mock = Mock()
database_connection_mock = Mock()
database_connection_mock.command.side_effect = raise_exception_connection_failure
client_mock.return_value = {Database.DATABASE_NAME: database_connection_mock}
client_mock.return_value = {DatabaseConnection.DATABASE_NAME: database_connection_mock}
get_mongo_client_class_mock.return_value = client_mock

os.environ[MONGO_URI_ENV_NAME] = "mongo_uri"
Singleton._instances.clear()
DatabaseConnection.connection = None
with raises(DatabasePingFailed):
Database()._ping_database_connection()
DatabaseConnection()._ping_database_connection()

assert sys_exit_mock.call_count == 1
Singleton._instances.clear()
DatabaseConnection.connection = None


@patch("sys.exit")
@patch("app.database.Database.Database._get_mongo_client_class")
@patch("app.database.DatabaseConnection.DatabaseConnection._get_mongo_client_class")
def test_raise_exception_unexpected_connection_failure(
get_mongo_client_class_mock, sys_exit_mock, clean_modified_environments
):
Expand All @@ -47,13 +46,13 @@ def raise_exception_connection_failure(*arg):
client_mock = Mock()
database_connection_mock = Mock()
database_connection_mock.command.side_effect = raise_exception_connection_failure
client_mock.return_value = {Database.DATABASE_NAME: database_connection_mock}
client_mock.return_value = {DatabaseConnection.DATABASE_NAME: database_connection_mock}
get_mongo_client_class_mock.return_value = client_mock

os.environ[MONGO_URI_ENV_NAME] = "mongo_uri"
Singleton._instances.clear()
DatabaseConnection.connection = None
with raises(UnexpectedDatabasePingFailed):
Database()._ping_database_connection()
DatabaseConnection()._ping_database_connection()

assert sys_exit_mock.call_count == 1
Singleton._instances.clear()
DatabaseConnection.connection = None