Skip to content

Commit

Permalink
Merge pull request #1 from FlipperPA/connect_defaults
Browse files Browse the repository at this point in the history
Linting changes.
  • Loading branch information
amalek215 authored Feb 23, 2023
2 parents f293f07 + 0fcff8a commit 58ed7f2
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 87 deletions.
56 changes: 28 additions & 28 deletions wrds/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,31 @@ class SchemaNotFoundError(FileNotFoundError):
class Connection(object):
def __init__(self, autoconnect=True, verbose=False, **kwargs):
"""
Set up the connection to the WRDS database.
By default, also establish the connection to the database.
Optionally, the user may specify connection parameters:
*wrds_hostname*: WRDS database hostname
*wrds_port*: database connection port number
*wrds_dbname*: WRDS database name
*wrds_username*: WRDS username
*autoconnect*: If false will not immediately establish the connection
The constructor will use the .pgpass file if it exists and may make use of
PostgreSQL environment variables such as PGHOST, PGUSER, etc., if cooresponding
parameters are not set.
If not, it will ask the user for a username and password.
It will also direct the user to information on setting up .pgpass.
Additionally, creating the instance will load a list of schemas
the user has permission to access.
:return: None
Usage::
>>> db = wrds.Connection()
Loading library list...
Done
Set up the connection to the WRDS database.
By default, also establish the connection to the database.
Optionally, the user may specify connection parameters:
*wrds_hostname*: WRDS database hostname
*wrds_port*: database connection port number
*wrds_dbname*: WRDS database name
*wrds_username*: WRDS username
*autoconnect*: If false will not immediately establish the connection
The constructor will use the .pgpass file if it exists and may make use of
PostgreSQL environment variables such as PGHOST, PGUSER, etc., if cooresponding
parameters are not set.
If not, it will ask the user for a username and password.
It will also direct the user to information on setting up .pgpass.
Additionally, creating the instance will load a list of schemas
the user has permission to access.
:return: None
Usage::
>>> db = wrds.Connection()
Loading library list...
Done
"""
self._verbose = verbose
self._password = ""
Expand All @@ -74,7 +74,7 @@ def __init__(self, autoconnect=True, verbose=False, **kwargs):
self.connect()
self.load_library_list()

def __make_sa_engine_conn(self, raise_err = False):
def __make_sa_engine_conn(self, raise_err=False):
username = self._username
hostname = self._hostname
password = urllib.parse.quote_plus(self._password)
Expand Down Expand Up @@ -127,7 +127,7 @@ def connect(self):
try:
self.create_pgpass_file()
print("Created .pgpass file successfully.")
except:
except Exception:
print("Failed to create .pgpass file.")
print(
"You can create this file yourself at any time "
Expand Down Expand Up @@ -216,7 +216,7 @@ def __get_user_credentials(self):

def create_pgpass_file(self):
"""
Create a .pgpass file to store WRDS connection credentials..
Create a .pgpass file to store WRDS connection credentials.
Use the existing username and password if already connected to WRDS,
or prompt for that information if not.
Expand Down
138 changes: 79 additions & 59 deletions wrds/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import wrds
import unittest

try:
import unittest.mock as mock
except ImportError:
Expand All @@ -10,134 +11,147 @@


class TestInitMethod(unittest.TestCase):
""" Test the wrds.Connection.__init__() method,
with both default and custom parameters.
"""Test the wrds.Connection.__init__() method,
with both default and custom parameters.
"""
@mock.patch('wrds.sql.sa')

@mock.patch("wrds.sql.sa")
def test_init_calls_sqlalchemy_create_engine_defaults(self, mock_sa):
wrds.Connection()
connstring = 'postgresql://{host}:{port}/{dbname}'
connstring = "postgresql://{host}:{port}/{dbname}"
connstring = connstring.format(
host=wrds.sql.WRDS_POSTGRES_HOST,
port=wrds.sql.WRDS_POSTGRES_PORT,
dbname=wrds.sql.WRDS_POSTGRES_DB)
dbname=wrds.sql.WRDS_POSTGRES_DB,
)
mock_sa.create_engine.assert_called_with(
connstring,
connect_args={'sslmode': 'require',
'application_name': wrds.sql.appname},
isolation_level='AUTOCOMMIT')
connect_args={"sslmode": "require", "application_name": wrds.sql.appname},
isolation_level="AUTOCOMMIT",
)

@mock.patch('wrds.sql.sa')
@mock.patch("wrds.sql.sa")
def test_init_calls_sqlalchemy_create_engine_custom(self, mock_sa):
username = 'faketestusername'
connstring = 'postgresql://{usr}@{host}:{port}/{dbname}'
username = "faketestusername"
connstring = "postgresql://{usr}@{host}:{port}/{dbname}"
connstring = connstring.format(
usr=username,
host=wrds.sql.WRDS_POSTGRES_HOST,
port=wrds.sql.WRDS_POSTGRES_PORT,
dbname=wrds.sql.WRDS_POSTGRES_DB)
dbname=wrds.sql.WRDS_POSTGRES_DB,
)
wrds.Connection(wrds_username=username)
mock_sa.create_engine.assert_called_with(
connstring,
connect_args={'sslmode': 'require',
'application_name': wrds.sql.appname},
isolation_level='AUTOCOMMIT')
connect_args={"sslmode": "require", "application_name": wrds.sql.appname},
isolation_level="AUTOCOMMIT",
)

@mock.patch('wrds.sql.Connection.load_library_list')
@mock.patch('wrds.sql.Connection.connect')
@mock.patch("wrds.sql.Connection.load_library_list")
@mock.patch("wrds.sql.Connection.connect")
def test_init_default_connect(self, mock_connect, mock_lll):
wrds.Connection()
mock_connect.assert_called_once()

@mock.patch('wrds.sql.Connection.connect')
@mock.patch("wrds.sql.Connection.connect")
def test_init_autoconnect_false_no_connect(self, mock_connect):
wrds.Connection(autoconnect=False)
mock_connect.assert_not_called()

@mock.patch('wrds.sql.Connection.connect')
@mock.patch('wrds.sql.Connection.load_library_list')
@mock.patch("wrds.sql.Connection.connect")
@mock.patch("wrds.sql.Connection.load_library_list")
def test_init_default_load_library_list(self, mock_lll, mock_connect):
wrds.Connection()
mock_lll.assert_called_once()

@mock.patch('wrds.sql.Connection.connect')
@mock.patch('wrds.sql.Connection.load_library_list')
def test_init_autoconnect_false_no_connect_second_function(self, mock_lll, mock_connect):
@mock.patch("wrds.sql.Connection.connect")
@mock.patch("wrds.sql.Connection.load_library_list")
def test_init_autoconnect_false_no_connect_second_function(
self, mock_lll, mock_connect
):
wrds.Connection(autoconnect=False)
mock_lll.assert_not_called()


class TestConnectMethod(unittest.TestCase):
""" Test the wrds.Connection.connect method.
"""
Test the wrds.Connection.connect method.
Since all exceptions are caught immediately,
I'm just not smart enough to simulate bad passwords with
the code as written.
Since all exceptions are caught immediately,
I'm just not smart enough to simulate bad passwords with
the code as written.
"""

def setUp(self):
self.t = wrds.Connection(autoconnect=False)
self.t._hostname = 'wrds.test.private'
self.t._hostname = "wrds.test.private"
self.t._port = 12345
self.t._username = 'faketestusername'
self.t._password = 'faketestuserpass'
self.t._dbname = 'testdbname'
self.t._username = "faketestusername"
self.t._password = "faketestuserpass"
self.t._dbname = "testdbname"
self.t._Connection__get_user_credentials = mock.Mock()
self.t._Connection__get_user_credentials.return_value = (self.t._username, self.t._password)
self.t._Connection__get_user_credentials.return_value = (
self.t._username,
self.t._password,
)

def test_connect_calls_sqlalchemy_engine_connect(self):
self.t.engine = mock.Mock()
self.t.connect()
self.t.engine.connect.assert_called_once()

@mock.patch('wrds.sql.sa')
@mock.patch("wrds.sql.sa")
def test_connect_calls_get_user_credentials_on_exception(self, mock_sa):
self.t.engine = mock.Mock()
self.t.engine.connect.side_effect = Exception('Fake exception for testing')
self.t.engine.connect.side_effect = Exception("Fake exception for testing")
self.t.connect()
self.t._Connection__get_user_credentials.assert_called_once()

@mock.patch('wrds.sql.sa')
@mock.patch("wrds.sql.sa")
def test_connect_calls_sqlalchemy_create_engine_on_exception(self, mock_sa):
self.t.engine = mock.Mock()
self.t.engine.connect.side_effect = Exception('Fake exception for testing')
connstring = 'postgresql://{usr}:{pwd}@{host}:{port}/{dbname}'
self.t.engine.connect.side_effect = Exception("Fake exception for testing")
connstring = "postgresql://{usr}:{pwd}@{host}:{port}/{dbname}"
connstring = connstring.format(
usr=self.t._username,
pwd=self.t._password,
host=self.t._hostname,
port=self.t._port,
dbname=self.t._dbname)
dbname=self.t._dbname,
)
self.t.connect()
mock_sa.create_engine.assert_called_with(
connstring,
connect_args={'sslmode': 'require',
'application_name': wrds.sql.appname},
isolation_level='AUTOCOMMIT')
connect_args={"sslmode": "require", "application_name": wrds.sql.appname},
isolation_level="AUTOCOMMIT",
)


class TestRawSqlMethod(unittest.TestCase):
""" Test the wrds.Connection.raw_sql method.
"""Test the wrds.Connection.raw_sql method.
wrds.Connection.raw_sql() should be able to take
'normal' and parameterized SQL,
and throw an error if not all parameters are supplied.
wrds.Connection.raw_sql() should be able to take
'normal' and parameterized SQL,
and throw an error if not all parameters are supplied.
"""

def setUp(self):
self.t = wrds.Connection(autoconnect=False)
self.t._hostname = 'wrds.test.private'
self.t._hostname = "wrds.test.private"
self.t._port = 12345
self.t._username = 'faketestusername'
self.t._password = 'faketestuserpass'
self.t._dbname = 'testdbname'
self.t._username = "faketestusername"
self.t._password = "faketestuserpass"
self.t._dbname = "testdbname"
self.t._Connection__get_user_credentials = mock.Mock()
self.t._Connection__get_user_credentials.return_value = (self.t._username, self.t._password)
self.t._Connection__get_user_credentials.return_value = (
self.t._username,
self.t._password,
)
self.t.connection = mock.Mock()
self.t.engine = mock.Mock()

@mock.patch('wrds.sql.sa')
@mock.patch('wrds.sql.pd')
@mock.patch("wrds.sql.sa")
@mock.patch("wrds.sql.pd")
def test_rawsql_takes_unparameterized_sql(self, mock_pd, mock_sa):
sql = "SELECT * FROM information_schema.tables LIMIT 1"
self.t.raw_sql(sql)
Expand All @@ -150,10 +164,13 @@ def test_rawsql_takes_unparameterized_sql(self, mock_pd, mock_sa):
params=None,
)

@mock.patch('wrds.sql.sa')
@mock.patch('wrds.sql.pd')
@mock.patch("wrds.sql.sa")
@mock.patch("wrds.sql.pd")
def test_rawsql_takes_parameterized_sql(self, mock_pd, mock_sa):
sql = "SELECT * FROM information_schema.tables where table_name = %(tablename)s LIMIT 1"
sql = (
"SELECT * FROM information_schema.tables "
"WHERE table_name = %(tablename)s LIMIT 1"
)
tablename = "pg_stat_activity"
self.t.engine = mock.Mock()
self.t.raw_sql(sql, params=tablename)
Expand All @@ -171,7 +188,10 @@ class TestCreatePgpassFile(unittest.TestCase):
def setUp(self):
self.t = wrds.Connection(autoconnect=False)
self.t._Connection__get_user_credentials = mock.Mock()
self.t._Connection__get_user_credentials.return_value = ('faketestusername', 'faketestpassword')
self.t._Connection__get_user_credentials.return_value = (
"faketestusername",
"faketestpassword",
)
self.t._Connection__create_pgpass_file_win32 = mock.Mock()
self.t._Connection__create_pgpass_file_unix = mock.Mock()

Expand All @@ -185,16 +205,16 @@ def test_create_pgpass_calls_get_user_credentials_if_not_password(self):
self.t.create_pgpass_file()
self.t._Connection__get_user_credentials.assert_called_once()

@unittest.skipIf(sys.platform != 'win32', 'Windows-only test')
@unittest.skipIf(sys.platform != "win32", "Windows-only test")
def test_create_pgpass_calls_win32_version_if_windows(self):
self.t.create_pgpass_file()
self.t._Connection__create_pgpass_file_win32.assert_called_once()

@unittest.skipIf(sys.platform == 'win32', 'Unix-only test')
@unittest.skipIf(sys.platform == "win32", "Unix-only test")
def test_create_pgpass_calls_unix_version_if_unix(self):
self.t.create_pgpass_file()
self.t._Connection__create_pgpass_file_unix.assert_called_once()


if (__name__ == '__main__'):
if __name__ == "__main__":
unittest.main()

0 comments on commit 58ed7f2

Please sign in to comment.