Skip to content

Commit

Permalink
Merge pull request #39 from amalek215/connect_defaults
Browse files Browse the repository at this point in the history
Connect defaults, will try to use PGHOST if env var exists and host not otherwise set
  • Loading branch information
amalek215 authored Mar 2, 2023
2 parents 7878ba3 + 58ed7f2 commit d2c438b
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 125 deletions.
133 changes: 67 additions & 66 deletions wrds/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SchemaNotFoundError(FileNotFoundError):


class Connection(object):
def __init__(self, autoconnect=True, **kwargs):
def __init__(self, autoconnect=True, verbose=False, **kwargs):
"""
Set up the connection to the WRDS database.
By default, also establish the connection to the database.
Expand All @@ -40,13 +40,16 @@ def __init__(self, autoconnect=True, **kwargs):
*wrds_port*: database connection port number
*wrds_dbname*: WRDS database name
*wrds_username*: WRDS username
*autoconnect*: If false will not immediately establish the connection
The constructor will use the .pgpass file if it exists.
The constructor will use the .pgpass file if it exists and may make use of
PostgreSQL environment variables such as PGHOST, PGUSER, etc., if cooresponding
parameters are not set.
If not, it will ask the user for a username and password.
It will also direct the user to information on setting up .pgpass.
Additionally, creating the instance will load a list of schemas
the user has permission to access.
the user has permission to access.
:return: None
Expand All @@ -55,91 +58,89 @@ def __init__(self, autoconnect=True, **kwargs):
Loading library list...
Done
"""
self._verbose = verbose
self._password = ""
# If user passed in any of these parameters, override defaults.
self._username = kwargs.get("wrds_username", None)
self._hostname = kwargs.get("wrds_hostname", WRDS_POSTGRES_HOST)
self._username = kwargs.get("wrds_username", "")
# PGHOST if set will override default for first attempt
self._hostname = kwargs.get(
"wrds_hostname", os.environ.get('PGHOST', WRDS_POSTGRES_HOST)
)
self._port = kwargs.get("wrds_port", WRDS_POSTGRES_PORT)
self._dbname = kwargs.get("wrds_dbname", WRDS_POSTGRES_DB)
self._connect_args = kwargs.get("wrds_connect_args", WRDS_CONNECT_ARGS)

# If username was passed in, the URI is different.
if self._username:
pguri = "postgresql://{usr}@{host}:{port}/{dbname}"
self.engine = sa.create_engine(
pguri.format(
usr=self._username,
host=self._hostname,
port=self._port,
dbname=self._dbname,
),
isolation_level="AUTOCOMMIT",
connect_args=self._connect_args,
)
# No username passed in, but other parameters might have been.
else:
pguri = "postgresql://{host}:{port}/{dbname}"
self.engine = sa.create_engine(
pguri.format(host=self._hostname, port=self._port, dbname=self._dbname),
isolation_level="AUTOCOMMIT",
connect_args=self._connect_args,
)
if autoconnect:
self.connect()
self.load_library_list()

def connect(self):
"""Make a connection to the WRDS database."""
def __make_sa_engine_conn(self, raise_err=False):
username = self._username
hostname = self._hostname
password = urllib.parse.quote_plus(self._password)
port = self._port
dbname = self._dbname
pguri = f"postgresql://{username}:{password}@{hostname}:{port}/{dbname}"
if self._verbose:
print(f"postgresql://{username}:@{hostname}:{port}/{dbname}")
try:
self.connection = self.engine.connect()
except Exception:
# These things should probably not be exported all over creation
self._username, self._password = self.__get_user_credentials()
pghost = "postgresql://{usr}:{pwd}@{host}:{port}/{dbname}"
self.engine = sa.create_engine(
pghost.format(
usr=self._username,
pwd=urllib.parse.quote_plus(self._password),
host=self._hostname,
port=self._port,
dbname=self._dbname,
),
pguri,
isolation_level="AUTOCOMMIT",
connect_args=self._connect_args,
)
try:
self.connection = self.engine.connect()
except Exception as e:
print("There was an error with your password.")
self._username = None
self._password = None
raise e

# Connection successful. Offer to create a .pgpass for the user.
print("WRDS recommends setting up a .pgpass file.")
do_create_pgpass = ""
while do_create_pgpass != "y" and do_create_pgpass != "n":
do_create_pgpass = input("Create .pgpass file now [y/n]?: ")

if do_create_pgpass == "y":
try:
self.create_pgpass_file()
print("Created .pgpass file successfully.")
except:
print(
"Failed to create .pgpass file. Please try manually with the "
"create_pgpass_file() function."
)
self.connection = self.engine.connect()
except Exception as err:
if self._verbose:
print(f"{err=}")
self.engine = None
if raise_err:
raise err

def connect(self):
"""Make a connection to the WRDS database."""
# first try connection using system defaults and params set in constructor
self.__make_sa_engine_conn()

if (self.engine is None and self._hostname != WRDS_POSTGRES_HOST):
# try explicit w/ default hostname
print(f"Trying '{WRDS_POSTGRES_HOST}'...")
self._hostname = WRDS_POSTGRES_HOST
self.__make_sa_engine_conn()

if (self.engine is None):
# Use explicit username and password
self._username, self._password = self.__get_user_credentials()
# Last attempt, raise error if Exception encountered
self.__make_sa_engine_conn(raise_err=True)

if (self.engine is None):
print(f"Failed to connect {self._username}@{self._hostname}")
else:
print("You can create this file yourself at any time")
print("with the create_pgpass_file() function.")
# Connection successful. Offer to create a .pgpass for the user.
print("WRDS recommends setting up a .pgpass file.")
do_create_pgpass = ""
while do_create_pgpass != "y" and do_create_pgpass != "n":
do_create_pgpass = input("Create .pgpass file now [y/n]?: ")

if do_create_pgpass == "y":
try:
self.create_pgpass_file()
print("Created .pgpass file successfully.")
except Exception:
print("Failed to create .pgpass file.")
print(
"You can create this file yourself at any time "
"with the create_pgpass_file() function."
)

def close(self):
"""
Close the connection to the database.
"""
self.connection.close()
self.engine.dispose()
self.engine = None

def __enter__(self):
self.connect()
Expand Down Expand Up @@ -215,7 +216,7 @@ def __get_user_credentials(self):

def create_pgpass_file(self):
"""
Create a .pgpass file to store WRDS connection credentials..
Create a .pgpass file to store WRDS connection credentials.
Use the existing username and password if already connected to WRDS,
or prompt for that information if not.
Expand Down
Loading

0 comments on commit d2c438b

Please sign in to comment.