diff --git a/tests/unit/test_find_urls.py b/tests/unit/test_find_urls.py index 9ee46eb..952d761 100644 --- a/tests/unit/test_find_urls.py +++ b/tests/unit/test_find_urls.py @@ -153,3 +153,46 @@ def test_find_urls_schema_only(urlextract, text, expected): :param list(str) expected: list of URLs that has to be found in text """ assert urlextract.find_urls(text, with_schema_only=True) == expected + + +@pytest.mark.parametrize( + "text, expected", + [ + ("multiple protocols, job:https://example.co", ["https://example.co"]), + ( + "more multiple protocols, link:job:https://example.com/r", + ["https://example.com/r"], + ), + ("svn+ssh://example.com", ["svn+ssh://example.com"]), + ], +) +def test_find_urls_multiple_protocol(urlextract, text, expected): + """ + Testing find_urls returning all URLs + + :param fixture urlextract: fixture holding URLExtract object + :param str text: text in which we should find links + :param list(str) expected: list of URLs that has to be found in text + """ + assert urlextract.find_urls(text) == expected + + +@pytest.mark.parametrize( + "text, expected", + [ + ("svn+ssh://example.com", ["ssh://example.com"]), + ("multiple protocols, job:https://example.co", ["https://example.co"]), + ("test link:job:https://example.com/r", ["https://example.com/r"]), + ], +) +def test_find_urls_multiple_protocol_custom(urlextract, text, expected): + """ + Testing find_urls returning all URLs + + :param fixture urlextract: fixture holding URLExtract object + :param str text: text in which we should find links + :param list(str) expected: list of URLs that has to be found in text + """ + stop_chars = urlextract.get_stop_chars_left_from_scheme() | {"+"} + urlextract.set_stop_chars_left_from_scheme(stop_chars) + assert urlextract.find_urls(text) == expected diff --git a/urlextract/urlextract_core.py b/urlextract/urlextract_core.py index 8283024..24bceb0 100644 --- a/urlextract/urlextract_core.py +++ b/urlextract/urlextract_core.py @@ -111,6 +111,9 @@ def __init__( self._stop_chars_left = set(string.whitespace) self._stop_chars_left |= general_stop_chars | {"|", "=", "]", ")", "}"} + # default stop characters on left side from schema + self._stop_chars_left_from_schema = self._stop_chars_left.copy() | {":"} + # defining default stop chars left self._stop_chars_right = set(string.whitespace) self._stop_chars_right |= general_stop_chars @@ -334,6 +337,31 @@ def set_stop_chars_left(self, stop_chars: Set[str]): self._stop_chars_left = stop_chars + def get_stop_chars_left_from_scheme(self) -> Set[str]: + """ + Returns set of stop chars for text on left from TLD. + + :return: set of stop chars + :rtype: set + """ + return self._stop_chars_left_from_schema + + def set_stop_chars_left_from_scheme(self, stop_chars: Set[str]): + """ + Set stop characters for text on left from scheme. + Stop characters are used when determining end of URL. + + :param set stop_chars: set of characters + :raises: TypeError + """ + if not isinstance(stop_chars, set): + raise TypeError( + "stop_chars should be type set " + "but {} was given".format(type(stop_chars)) + ) + + self._stop_chars_left_from_schema = stop_chars + def get_stop_chars_right(self) -> Set[str]: """ Returns set of stop chars for text on right from TLD. @@ -420,12 +448,18 @@ def _complete_url( max_len = len(text) - 1 end_pos = tld_pos start_pos = tld_pos + in_scheme = False while left_ok or right_ok: if left_ok: if start_pos <= 0: left_ok = False else: - if text[start_pos - 1] not in self._stop_chars_left: + if ( + in_scheme + and text[start_pos - 1] in self._stop_chars_left_from_schema + ): + left_ok = False + if left_ok and text[start_pos - 1] not in self._stop_chars_left: start_pos -= 1 else: left_ok = False @@ -438,6 +472,9 @@ def _complete_url( else: right_ok = False + if text[start_pos : start_pos + 3] == "://": + in_scheme = True + complete_url = text[start_pos : end_pos + 1].lstrip("/") # remove last character from url # when it is allowed character right after TLD (e.g. dot, comma)