Skip to content

Commit

Permalink
Fix #97 -- Properly parse domain names in markdown links (#98)
Browse files Browse the repository at this point in the history
The problem is described in #97.

I had to introduce a library -
[tldextract](https://github.com/john-kurkowski/tldextract/) to properly
parse all varieties of domain names.

---------

Co-authored-by: Johannes Maron <johannes@maron.family>
  • Loading branch information
amureki and codingjoe authored Aug 26, 2024
1 parent 7aa754b commit 44231ac
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 6 deletions.
3 changes: 2 additions & 1 deletion emark/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def update_url_params(self, url, **params):
return redirect_url
site_url = self.get_site_url()
# external links should not be tracked
top_level_domain = ".".join(site_url.split(".")[-2:])
top_level_domain = utils.extract_domain(site_url)

if not redirect_url_parts.netloc.endswith(top_level_domain):
return redirect_url
tracking_url = reverse("emark:email-click", kwargs={"pk": self.uuid})
Expand Down
26 changes: 26 additions & 0 deletions emark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

__all__ = ["HTML2TextParser"]

from urllib import parse

import tldextract


@dataclasses.dataclass
class Node:
Expand Down Expand Up @@ -122,3 +126,25 @@ def __str__(self) -> str:
# sanitize all wide vertical or horizontal spaces
text = self.DOUBLE_NEWLINE.sub("\n\n", text.strip())
return self.DOUBLE_SPACE.sub(" ", text)


def extract_domain(url: str) -> str:
"""Extracts the registered domain from a given URL.
If the domain is "localhost", it includes the port number in the returned string.
Args:
url (str): The URL from which to extract the domain.
Returns:
str: The registered domain or "localhost" with port if applicable.
"""
extractor = tldextract.TLDExtract(suffix_list_urls=())
extracted = extractor(url)
if extracted.domain == "localhost":
registered_domain = "localhost"
else:
registered_domain = extracted.registered_domain
if port := parse.urlparse(url).port:
return f"{registered_domain}:{port}"
return registered_domain
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
"Framework :: Django :: 5.0",
]
requires-python = ">=3.10"
dependencies = ["django", "markdown", "premailer"]
dependencies = ["django", "markdown", "premailer", "tldextract"]

[project.optional-dependencies]
test = [
Expand Down
30 changes: 26 additions & 4 deletions tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,13 @@ def test_open_in_browser__html(self, email_message):
in email_message.html
)

def test_get_site_domain__setting(self, email_message):
assert email_message.get_site_url() == "http://www.example.com"
@pytest.mark.parametrize("domain", ["www.example.com", "example.com"])
def test_get_site_url__setting(self, email_message, settings, domain):
settings.EMARK = {"DOMAIN": domain}
assert email_message.get_site_url() == f"http://{domain}"

@pytest.mark.django_db
def test_test_get_site_domain__sites_framework(self, email_message, settings):
def test_get_site_url__sites_framework(self, email_message, settings):
settings.EMARK = {"DOMAIN": None}
settings.SITE_ID = 1
assert email_message.get_site_url() == "http://example.com"
Expand Down Expand Up @@ -330,7 +332,27 @@ def test_update_url_params__tracking_uuid(self, email_message):
"click?url=https%3A%2F%2Fwww.example.com%2F%3Futm_medium%3Dbaz%26utm_source%3Dfoo"
)

def test_update_url_params__subdomain(self, email_message):
@pytest.mark.parametrize(
"domain",
["www.example.com", "example.com", "test.example.com", "localhost:8000"],
)
def test_update_url_params__domains(self, settings, email_message, domain):
settings.EMARK["DOMAIN"] = domain
email_message.uuid = "12341234-1234-1234-1234-123412341234"
encoded_domain = domain.replace(":", "%3A")
expected_url = (
f"http://{domain}/emark/12341234-1234-1234-1234-123412341234/"
f"click?url=https%3A%2F%2F{encoded_domain}%2F%3Futm_medium%3Dbaz%26utm_source%3Dfoo"
)
assert (
email_message.update_url_params(
f"https://{domain}/?utm_source=foo", utm_medium="baz"
)
== expected_url
)

def test_update_url_params__subdomain(self, settings, email_message):
settings.EMARK["DOMAIN"] = "www.example.com"
email_message.uuid = "12341234-1234-1234-1234-123412341234"
assert (
email_message.update_url_params(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,12 @@ def test_parse_html_email(self):
"--------------------------------------------------\n"
"some footer"
)


def test_extract_domain():
assert utils.extract_domain("https://example.com") == "example.com"
assert utils.extract_domain("https://www.example.com") == "example.com"
assert utils.extract_domain("https://www.example.co.uk") == "example.co.uk"
assert utils.extract_domain("https://www.example.com:1337") == "example.com:1337"
assert utils.extract_domain("https://localhost") == "localhost"
assert utils.extract_domain("https://localhost:8000") == "localhost:8000"

0 comments on commit 44231ac

Please sign in to comment.