Skip to content

Commit

Permalink
Merge pull request #5222 from RasaHQ/fix-urlsplit-python3.7.6
Browse files Browse the repository at this point in the history
Fix urlsplit python3.7.6
  • Loading branch information
wochinge authored Feb 12, 2020
2 parents 22f0a99 + 28ec6c5 commit 79a5ffd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
13 changes: 9 additions & 4 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

logger = logging.getLogger(__name__)

SQLITE_SCHEME = "sqlite"


class TrackerStore:
"""Class to hold all of the TrackerStore classes"""
Expand Down Expand Up @@ -639,17 +641,20 @@ def get_db_url(
URL ready to be used with an SQLAlchemy `Engine` object.
"""
from urllib.parse import urlsplit
from urllib import parse
from sqlalchemy.engine.url import URL

# Users might specify a url in the host
parsed = urlsplit(host or "")
if parsed.scheme:
parsed = parse.urlsplit(host or "")
# We have to check `scheme` and `hostname` because Python 3.7.6 parses strings
# like `localhost:1234` as a URL with scheme `localhost`. However, `sqlite:///`
# a special case because it doesn't require a hostname.
if parsed.scheme and (parsed.hostname or parsed.scheme == SQLITE_SCHEME):
return host

if host:
# add fake scheme to properly parse components
parsed = urlsplit("schema://" + host)
parsed = parse.urlsplit(f"scheme://{host}")

# users might include the port in the url
port = parsed.port or port
Expand Down
6 changes: 6 additions & 0 deletions tests/core/test_tracker_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def pickle_serialise_tracker(_tracker):
"postgresql://localhost",
"postgresql://localhost:5432",
"postgresql://user:secret@localhost",
"sqlite:///",
],
)
def test_get_db_url_with_fully_specified_url(full_url: Text):
Expand All @@ -279,6 +280,11 @@ def test_get_db_url_with_port_in_host():
)


def test_db_get_url_with_sqlite():
expected = "sqlite:///rasa.db"
assert str(SQLTrackerStore.get_db_url(dialect="sqlite", db="rasa.db")) == expected


def test_get_db_url_with_correct_host():
expected = "postgresql://localhost:5005/mydb"

Expand Down

0 comments on commit 79a5ffd

Please sign in to comment.