diff --git a/wrds/sql.py b/wrds/sql.py index 6d51b55..7cd2e53 100644 --- a/wrds/sql.py +++ b/wrds/sql.py @@ -30,7 +30,7 @@ class SchemaNotFoundError(FileNotFoundError): class Connection(object): - def __init__(self, autoconnect=True, **kwargs): + def __init__(self, autoconnect=True, verbose=False, **kwargs): """ Set up the connection to the WRDS database. By default, also establish the connection to the database. @@ -40,13 +40,16 @@ def __init__(self, autoconnect=True, **kwargs): *wrds_port*: database connection port number *wrds_dbname*: WRDS database name *wrds_username*: WRDS username + *autoconnect*: If false will not immediately establish the connection - The constructor will use the .pgpass file if it exists. + The constructor will use the .pgpass file if it exists and may make use of + PostgreSQL environment variables such as PGHOST, PGUSER, etc., if cooresponding + parameters are not set. If not, it will ask the user for a username and password. It will also direct the user to information on setting up .pgpass. Additionally, creating the instance will load a list of schemas - the user has permission to access. + the user has permission to access. :return: None @@ -55,84 +58,81 @@ def __init__(self, autoconnect=True, **kwargs): Loading library list... Done """ + self._verbose = verbose self._password = "" # If user passed in any of these parameters, override defaults. - self._username = kwargs.get("wrds_username", None) - self._hostname = kwargs.get("wrds_hostname", WRDS_POSTGRES_HOST) + self._username = kwargs.get("wrds_username", "") + # PGHOST if set will override default for first attempt + self._hostname = kwargs.get( + "wrds_hostname", os.environ.get('PGHOST', WRDS_POSTGRES_HOST) + ) self._port = kwargs.get("wrds_port", WRDS_POSTGRES_PORT) self._dbname = kwargs.get("wrds_dbname", WRDS_POSTGRES_DB) self._connect_args = kwargs.get("wrds_connect_args", WRDS_CONNECT_ARGS) - # If username was passed in, the URI is different. - if self._username: - pguri = "postgresql://{usr}@{host}:{port}/{dbname}" - self.engine = sa.create_engine( - pguri.format( - usr=self._username, - host=self._hostname, - port=self._port, - dbname=self._dbname, - ), - isolation_level="AUTOCOMMIT", - connect_args=self._connect_args, - ) - # No username passed in, but other parameters might have been. - else: - pguri = "postgresql://{host}:{port}/{dbname}" - self.engine = sa.create_engine( - pguri.format(host=self._hostname, port=self._port, dbname=self._dbname), - isolation_level="AUTOCOMMIT", - connect_args=self._connect_args, - ) if autoconnect: self.connect() self.load_library_list() - def connect(self): - """Make a connection to the WRDS database.""" + def __make_sa_engine_conn(self, raise_err=False): + username = self._username + hostname = self._hostname + password = urllib.parse.quote_plus(self._password) + port = self._port + dbname = self._dbname + pguri = f"postgresql://{username}:{password}@{hostname}:{port}/{dbname}" + if self._verbose: + print(f"postgresql://{username}:@{hostname}:{port}/{dbname}") try: - self.connection = self.engine.connect() - except Exception: - # These things should probably not be exported all over creation - self._username, self._password = self.__get_user_credentials() - pghost = "postgresql://{usr}:{pwd}@{host}:{port}/{dbname}" self.engine = sa.create_engine( - pghost.format( - usr=self._username, - pwd=urllib.parse.quote_plus(self._password), - host=self._hostname, - port=self._port, - dbname=self._dbname, - ), + pguri, isolation_level="AUTOCOMMIT", connect_args=self._connect_args, ) - try: - self.connection = self.engine.connect() - except Exception as e: - print("There was an error with your password.") - self._username = None - self._password = None - raise e - - # Connection successful. Offer to create a .pgpass for the user. - print("WRDS recommends setting up a .pgpass file.") - do_create_pgpass = "" - while do_create_pgpass != "y" and do_create_pgpass != "n": - do_create_pgpass = input("Create .pgpass file now [y/n]?: ") - - if do_create_pgpass == "y": - try: - self.create_pgpass_file() - print("Created .pgpass file successfully.") - except: - print( - "Failed to create .pgpass file. Please try manually with the " - "create_pgpass_file() function." - ) + self.connection = self.engine.connect() + except Exception as err: + if self._verbose: + print(f"{err=}") + self.engine = None + if raise_err: + raise err + + def connect(self): + """Make a connection to the WRDS database.""" + # first try connection using system defaults and params set in constructor + self.__make_sa_engine_conn() + + if (self.engine is None and self._hostname != WRDS_POSTGRES_HOST): + # try explicit w/ default hostname + print(f"Trying '{WRDS_POSTGRES_HOST}'...") + self._hostname = WRDS_POSTGRES_HOST + self.__make_sa_engine_conn() + + if (self.engine is None): + # Use explicit username and password + self._username, self._password = self.__get_user_credentials() + # Last attempt, raise error if Exception encountered + self.__make_sa_engine_conn(raise_err=True) + + if (self.engine is None): + print(f"Failed to connect {self._username}@{self._hostname}") else: - print("You can create this file yourself at any time") - print("with the create_pgpass_file() function.") + # Connection successful. Offer to create a .pgpass for the user. + print("WRDS recommends setting up a .pgpass file.") + do_create_pgpass = "" + while do_create_pgpass != "y" and do_create_pgpass != "n": + do_create_pgpass = input("Create .pgpass file now [y/n]?: ") + + if do_create_pgpass == "y": + try: + self.create_pgpass_file() + print("Created .pgpass file successfully.") + except Exception: + print("Failed to create .pgpass file.") + print( + "You can create this file yourself at any time " + "with the create_pgpass_file() function." + ) def close(self): """ @@ -140,6 +140,7 @@ def close(self): """ self.connection.close() self.engine.dispose() + self.engine = None def __enter__(self): self.connect() @@ -215,7 +216,7 @@ def __get_user_credentials(self): def create_pgpass_file(self): """ - Create a .pgpass file to store WRDS connection credentials.. + Create a .pgpass file to store WRDS connection credentials. Use the existing username and password if already connected to WRDS, or prompt for that information if not. diff --git a/wrds/test.py b/wrds/test.py index c0fe0bc..9af353e 100644 --- a/wrds/test.py +++ b/wrds/test.py @@ -2,6 +2,7 @@ import wrds import unittest + try: import unittest.mock as mock except ImportError: @@ -10,134 +11,147 @@ class TestInitMethod(unittest.TestCase): - """ Test the wrds.Connection.__init__() method, - with both default and custom parameters. + """Test the wrds.Connection.__init__() method, + with both default and custom parameters. """ - @mock.patch('wrds.sql.sa') + + @mock.patch("wrds.sql.sa") def test_init_calls_sqlalchemy_create_engine_defaults(self, mock_sa): wrds.Connection() - connstring = 'postgresql://{host}:{port}/{dbname}' + connstring = "postgresql://{host}:{port}/{dbname}" connstring = connstring.format( host=wrds.sql.WRDS_POSTGRES_HOST, port=wrds.sql.WRDS_POSTGRES_PORT, - dbname=wrds.sql.WRDS_POSTGRES_DB) + dbname=wrds.sql.WRDS_POSTGRES_DB, + ) mock_sa.create_engine.assert_called_with( connstring, - connect_args={'sslmode': 'require', - 'application_name': wrds.sql.appname}, - isolation_level='AUTOCOMMIT') + connect_args={"sslmode": "require", "application_name": wrds.sql.appname}, + isolation_level="AUTOCOMMIT", + ) - @mock.patch('wrds.sql.sa') + @mock.patch("wrds.sql.sa") def test_init_calls_sqlalchemy_create_engine_custom(self, mock_sa): - username = 'faketestusername' - connstring = 'postgresql://{usr}@{host}:{port}/{dbname}' + username = "faketestusername" + connstring = "postgresql://{usr}@{host}:{port}/{dbname}" connstring = connstring.format( usr=username, host=wrds.sql.WRDS_POSTGRES_HOST, port=wrds.sql.WRDS_POSTGRES_PORT, - dbname=wrds.sql.WRDS_POSTGRES_DB) + dbname=wrds.sql.WRDS_POSTGRES_DB, + ) wrds.Connection(wrds_username=username) mock_sa.create_engine.assert_called_with( connstring, - connect_args={'sslmode': 'require', - 'application_name': wrds.sql.appname}, - isolation_level='AUTOCOMMIT') + connect_args={"sslmode": "require", "application_name": wrds.sql.appname}, + isolation_level="AUTOCOMMIT", + ) - @mock.patch('wrds.sql.Connection.load_library_list') - @mock.patch('wrds.sql.Connection.connect') + @mock.patch("wrds.sql.Connection.load_library_list") + @mock.patch("wrds.sql.Connection.connect") def test_init_default_connect(self, mock_connect, mock_lll): wrds.Connection() mock_connect.assert_called_once() - @mock.patch('wrds.sql.Connection.connect') + @mock.patch("wrds.sql.Connection.connect") def test_init_autoconnect_false_no_connect(self, mock_connect): wrds.Connection(autoconnect=False) mock_connect.assert_not_called() - @mock.patch('wrds.sql.Connection.connect') - @mock.patch('wrds.sql.Connection.load_library_list') + @mock.patch("wrds.sql.Connection.connect") + @mock.patch("wrds.sql.Connection.load_library_list") def test_init_default_load_library_list(self, mock_lll, mock_connect): wrds.Connection() mock_lll.assert_called_once() - @mock.patch('wrds.sql.Connection.connect') - @mock.patch('wrds.sql.Connection.load_library_list') - def test_init_autoconnect_false_no_connect_second_function(self, mock_lll, mock_connect): + @mock.patch("wrds.sql.Connection.connect") + @mock.patch("wrds.sql.Connection.load_library_list") + def test_init_autoconnect_false_no_connect_second_function( + self, mock_lll, mock_connect + ): wrds.Connection(autoconnect=False) mock_lll.assert_not_called() class TestConnectMethod(unittest.TestCase): - """ Test the wrds.Connection.connect method. + """ + Test the wrds.Connection.connect method. - Since all exceptions are caught immediately, - I'm just not smart enough to simulate bad passwords with - the code as written. + Since all exceptions are caught immediately, + I'm just not smart enough to simulate bad passwords with + the code as written. """ def setUp(self): self.t = wrds.Connection(autoconnect=False) - self.t._hostname = 'wrds.test.private' + self.t._hostname = "wrds.test.private" self.t._port = 12345 - self.t._username = 'faketestusername' - self.t._password = 'faketestuserpass' - self.t._dbname = 'testdbname' + self.t._username = "faketestusername" + self.t._password = "faketestuserpass" + self.t._dbname = "testdbname" self.t._Connection__get_user_credentials = mock.Mock() - self.t._Connection__get_user_credentials.return_value = (self.t._username, self.t._password) + self.t._Connection__get_user_credentials.return_value = ( + self.t._username, + self.t._password, + ) def test_connect_calls_sqlalchemy_engine_connect(self): self.t.engine = mock.Mock() self.t.connect() self.t.engine.connect.assert_called_once() - @mock.patch('wrds.sql.sa') + @mock.patch("wrds.sql.sa") def test_connect_calls_get_user_credentials_on_exception(self, mock_sa): self.t.engine = mock.Mock() - self.t.engine.connect.side_effect = Exception('Fake exception for testing') + self.t.engine.connect.side_effect = Exception("Fake exception for testing") self.t.connect() self.t._Connection__get_user_credentials.assert_called_once() - @mock.patch('wrds.sql.sa') + @mock.patch("wrds.sql.sa") def test_connect_calls_sqlalchemy_create_engine_on_exception(self, mock_sa): self.t.engine = mock.Mock() - self.t.engine.connect.side_effect = Exception('Fake exception for testing') - connstring = 'postgresql://{usr}:{pwd}@{host}:{port}/{dbname}' + self.t.engine.connect.side_effect = Exception("Fake exception for testing") + connstring = "postgresql://{usr}:{pwd}@{host}:{port}/{dbname}" connstring = connstring.format( usr=self.t._username, pwd=self.t._password, host=self.t._hostname, port=self.t._port, - dbname=self.t._dbname) + dbname=self.t._dbname, + ) self.t.connect() mock_sa.create_engine.assert_called_with( connstring, - connect_args={'sslmode': 'require', - 'application_name': wrds.sql.appname}, - isolation_level='AUTOCOMMIT') + connect_args={"sslmode": "require", "application_name": wrds.sql.appname}, + isolation_level="AUTOCOMMIT", + ) class TestRawSqlMethod(unittest.TestCase): - """ Test the wrds.Connection.raw_sql method. + """Test the wrds.Connection.raw_sql method. - wrds.Connection.raw_sql() should be able to take - 'normal' and parameterized SQL, - and throw an error if not all parameters are supplied. + wrds.Connection.raw_sql() should be able to take + 'normal' and parameterized SQL, + and throw an error if not all parameters are supplied. """ def setUp(self): self.t = wrds.Connection(autoconnect=False) - self.t._hostname = 'wrds.test.private' + self.t._hostname = "wrds.test.private" self.t._port = 12345 - self.t._username = 'faketestusername' - self.t._password = 'faketestuserpass' - self.t._dbname = 'testdbname' + self.t._username = "faketestusername" + self.t._password = "faketestuserpass" + self.t._dbname = "testdbname" self.t._Connection__get_user_credentials = mock.Mock() - self.t._Connection__get_user_credentials.return_value = (self.t._username, self.t._password) + self.t._Connection__get_user_credentials.return_value = ( + self.t._username, + self.t._password, + ) self.t.connection = mock.Mock() self.t.engine = mock.Mock() - @mock.patch('wrds.sql.sa') - @mock.patch('wrds.sql.pd') + @mock.patch("wrds.sql.sa") + @mock.patch("wrds.sql.pd") def test_rawsql_takes_unparameterized_sql(self, mock_pd, mock_sa): sql = "SELECT * FROM information_schema.tables LIMIT 1" self.t.raw_sql(sql) @@ -150,10 +164,13 @@ def test_rawsql_takes_unparameterized_sql(self, mock_pd, mock_sa): params=None, ) - @mock.patch('wrds.sql.sa') - @mock.patch('wrds.sql.pd') + @mock.patch("wrds.sql.sa") + @mock.patch("wrds.sql.pd") def test_rawsql_takes_parameterized_sql(self, mock_pd, mock_sa): - sql = "SELECT * FROM information_schema.tables where table_name = %(tablename)s LIMIT 1" + sql = ( + "SELECT * FROM information_schema.tables " + "WHERE table_name = %(tablename)s LIMIT 1" + ) tablename = "pg_stat_activity" self.t.engine = mock.Mock() self.t.raw_sql(sql, params=tablename) @@ -171,7 +188,10 @@ class TestCreatePgpassFile(unittest.TestCase): def setUp(self): self.t = wrds.Connection(autoconnect=False) self.t._Connection__get_user_credentials = mock.Mock() - self.t._Connection__get_user_credentials.return_value = ('faketestusername', 'faketestpassword') + self.t._Connection__get_user_credentials.return_value = ( + "faketestusername", + "faketestpassword", + ) self.t._Connection__create_pgpass_file_win32 = mock.Mock() self.t._Connection__create_pgpass_file_unix = mock.Mock() @@ -185,16 +205,16 @@ def test_create_pgpass_calls_get_user_credentials_if_not_password(self): self.t.create_pgpass_file() self.t._Connection__get_user_credentials.assert_called_once() - @unittest.skipIf(sys.platform != 'win32', 'Windows-only test') + @unittest.skipIf(sys.platform != "win32", "Windows-only test") def test_create_pgpass_calls_win32_version_if_windows(self): self.t.create_pgpass_file() self.t._Connection__create_pgpass_file_win32.assert_called_once() - @unittest.skipIf(sys.platform == 'win32', 'Unix-only test') + @unittest.skipIf(sys.platform == "win32", "Unix-only test") def test_create_pgpass_calls_unix_version_if_unix(self): self.t.create_pgpass_file() self.t._Connection__create_pgpass_file_unix.assert_called_once() -if (__name__ == '__main__'): +if __name__ == "__main__": unittest.main()