diff --git a/specfile/rpm.py b/specfile/rpm.py index 60482de..f0d5556 100644 --- a/specfile/rpm.py +++ b/specfile/rpm.py @@ -8,7 +8,6 @@ import re import sys import tempfile -import urllib.parse from enum import IntEnum from pathlib import Path from typing import Iterator, List, Optional, Tuple @@ -16,6 +15,7 @@ import rpm from specfile.exceptions import MacroRemovalException, RPMException +from specfile.utils import get_filename_from_location MAX_REMOVAL_RETRIES = 20 @@ -267,7 +267,7 @@ def make_dummy_sources(sources: List[str], sourcedir: Path) -> Iterator[List[Pat MAGIC_LENGTH = 13 dummy_sources = [] for source in sources: - filename = Path(urllib.parse.urlsplit(source).path).name + filename = get_filename_from_location(source) if not filename: continue path = sourcedir / filename diff --git a/specfile/sources.py b/specfile/sources.py index 445cb03..b861a65 100644 --- a/specfile/sources.py +++ b/specfile/sources.py @@ -5,13 +5,13 @@ import re import urllib.parse from abc import ABC, abstractmethod -from pathlib import Path from typing import Iterable, List, Optional, Tuple, Union, cast, overload from specfile.exceptions import DuplicateSourceException from specfile.rpm import Macros from specfile.sourcelist import Sourcelist, SourcelistEntry from specfile.tags import Comments, Tag, Tags +from specfile.utils import get_filename_from_location class Source(ABC): @@ -137,12 +137,12 @@ def expanded_location(self) -> str: @property def filename(self) -> str: """Literal filename of the source.""" - return Path(urllib.parse.urlsplit(self._tag.value).path).name + return get_filename_from_location(self._tag.value) @property def expanded_filename(self) -> str: """Filename of the source after expanding macros.""" - return Path(urllib.parse.urlsplit(self._tag.expanded_value).path).name + return get_filename_from_location(self._tag.expanded_value) @property def comments(self) -> Comments: @@ -193,12 +193,12 @@ def expanded_location(self) -> str: @property def filename(self) -> str: """Literal filename of the source.""" - return Path(urllib.parse.urlsplit(self._source.location).path).name + return get_filename_from_location(self._source.location) @property def expanded_filename(self) -> str: """Filename of the source after expanding macros.""" - return Path(urllib.parse.urlsplit(self._source.expanded_location).path).name + return get_filename_from_location(self._source.expanded_location) @property def comments(self) -> Comments: diff --git a/specfile/utils.py b/specfile/utils.py new file mode 100644 index 0000000..6a3b490 --- /dev/null +++ b/specfile/utils.py @@ -0,0 +1,25 @@ +# Copyright Contributors to the Packit project. +# SPDX-License-Identifier: MIT + +import urllib.parse +from pathlib import Path + + +def get_filename_from_location(location: str) -> str: + """ + Extracts filename from given source location. + + Follows RPM logic - target filename can be specified in URL fragment. + + Args: + location: Location to extract filename from. + + Returns: + Extracted filename that can be empty if there is none. + """ + url = urllib.parse.urlsplit(location) + if url.fragment: + if "/" in url.fragment: + return Path(url.fragment).name.split("=")[-1] + return Path(f"{url.path}#{url.fragment}").name + return Path(url.path).name diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..d190f35 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,31 @@ +# Copyright Contributors to the Packit project. +# SPDX-License-Identifier: MIT + +import pytest + +from specfile.utils import get_filename_from_location + + +@pytest.mark.parametrize( + "location, filename", + [ + ("", ""), + ("tarball-0.1.tar.gz", "tarball-0.1.tar.gz"), + ("https://example.com", ""), + ("https://example.com/archive/tarball-0.1.tar.gz", "tarball-0.1.tar.gz"), + ( + "https://example.com/archive/tarball-0.1.tar.gz#fragment", + "tarball-0.1.tar.gz#fragment", + ), + ( + "https://example.com/download_tarball.cgi#/tarball-0.1.tar.gz", + "tarball-0.1.tar.gz", + ), + ( + "https://example.com/tarball-latest.tar.gz#/file=tarball-0.1.tar.gz", + "tarball-0.1.tar.gz", + ), + ], +) +def test_get_filename_from_location(location, filename): + assert get_filename_from_location(location) == filename