diff --git a/specfile/sources.py b/specfile/sources.py index 061d38e..469d3c7 100644 --- a/specfile/sources.py +++ b/specfile/sources.py @@ -62,23 +62,25 @@ def comments(self) -> Comments: class TagSource(Source): """Class that represents a source backed by a spec file tag.""" - def __init__(self, tag: Tag) -> None: + def __init__(self, tag: Tag, number: Optional[int] = None) -> None: """ Constructs a `TagSource` object. Args: tag: Tag that this source represents. + number: Source number (in the case of implicit numbering). Returns: Constructed instance of `TagSource` class. """ self._tag = tag + self._number = number def __repr__(self) -> str: tag = repr(self._tag) - return f"TagSource({tag})" + return f"TagSource({tag}, {self._number})" - def _get_number(self) -> Optional[str]: + def _extract_number(self) -> Optional[str]: """ Extracts source number from tag name. @@ -93,12 +95,14 @@ def _get_number(self) -> Optional[str]: @property def number(self) -> int: """Source number.""" - return int(self._get_number() or 0) + return self._number or int(self._extract_number() or 0) @property def number_digits(self) -> int: """Number of digits in the source number.""" - return len(self._get_number() or "") + if self._number: + return 0 + return len(self._extract_number() or "") @property def location(self) -> str: @@ -196,6 +200,7 @@ def __init__( tags: Tags, sourcelists: List[Sourcelist], allow_duplicates: bool = False, + default_to_implicit_numbering: bool = False, default_source_number_digits: int = 1, ) -> None: """ @@ -205,6 +210,7 @@ def __init__( tags: All spec file tags. sourcelists: List of all %sourcelist sections. allow_duplicates: Whether to allow duplicate entries when adding new sources. + default_to_implicit_numbering: Use implicit numbering (no source numbers) by default. default_source_number_digits: Default number of digits in a source number. Returns: @@ -213,17 +219,19 @@ def __init__( self._tags = tags self._sourcelists = sourcelists self._allow_duplicates = allow_duplicates + self._default_to_implicit_numbering = default_to_implicit_numbering self._default_source_number_digits = default_source_number_digits def __repr__(self) -> str: tags = repr(self._tags) sourcelists = repr(self._sourcelists) allow_duplicates = repr(self._allow_duplicates) + default_to_implicit_numbering = repr(self._default_to_implicit_numbering) # determine class name dynamically so that inherited classes # don't have to reimplement __repr__() return ( f"{self.__class__.__name__}({tags}, {sourcelists}, {allow_duplicates}, " - f"{self._default_source_number_digits})" + f"{default_to_implicit_numbering}, {self._default_source_number_digits})" ) def __contains__(self, location: object) -> bool: @@ -285,11 +293,19 @@ def _get_tags(self) -> List[Tuple[TagSource, Tags, int]]: container is the container the tag is part of and index is its index within that container. """ - return [ - (TagSource(t), self._tags, i) - for i, t in enumerate(self._tags) - if t.name.capitalize().startswith(self.PREFIX.capitalize()) - ] + result = [] + last_number = -1 + for i, tag in enumerate(self._tags): + if tag.name.capitalize() == self.PREFIX.capitalize(): + last_number += 1 + ts = TagSource(tag, last_number) + elif tag.name.capitalize().startswith(self.PREFIX.capitalize()): + ts = TagSource(tag) + last_number = ts.number + else: + continue + result.append((ts, self._tags, i)) + return result def _get_items(self) -> List[Tuple[Source, Union[Tags, Sourcelist], int]]: """ @@ -312,6 +328,21 @@ def _get_items(self) -> List[Tuple[Source, Union[Tags, Sourcelist], int]]: ) return result + def _detect_implicit_numbering(self) -> bool: + """ + Tries to detect if implicit numbering is being used, i.e. Source/Patch + tags don't have numbers. + + Returns: + True if implicit numbering is being/should be used, False otherwise. + """ + tags = self._get_tags() + if any(t._number is None for t, _, _ in tags): + return False + if len(tags) <= 1: + return self._default_to_implicit_numbering + return True + def _get_tag_format(self, reference: TagSource, number: int) -> Tuple[str, str]: """ Determines name and separator of a new source tag based on @@ -328,7 +359,11 @@ def _get_tag_format(self, reference: TagSource, number: int) -> Tuple[str, str]: Tuple in the form of (name, separator). """ prefix = self.PREFIX.capitalize() - name = f"{prefix}{number:0{reference.number_digits}}" + if self._detect_implicit_numbering(): + suffix = "" + else: + suffix = f"{number:0{reference.number_digits}}" + name = f"{prefix}{suffix}" diff = len(reference._tag.name) - len(name) if diff >= 0: return name, reference._tag._separator + " " * diff @@ -347,7 +382,10 @@ def _get_initial_tag_setup(self, number: int = 0) -> Tuple[int, str, str]: Tuple in the form of (index, name, separator). """ prefix = self.PREFIX.capitalize() - suffix = f"{number:0{self._default_source_number_digits}}" + if self._default_to_implicit_numbering: + suffix = "" + else: + suffix = f"{number:0{self._default_source_number_digits}}" return len(self._tags) if self._tags else 0, f"{prefix}{suffix}", ": " def _deduplicate_tag_names(self) -> None: diff --git a/specfile/specfile.py b/specfile/specfile.py index a450f19..d1e69d7 100644 --- a/specfile/specfile.py +++ b/specfile/specfile.py @@ -157,13 +157,17 @@ def prep(self) -> Iterator[Optional[Prep]]: @contextlib.contextmanager def sources( - self, allow_duplicates: bool = False, default_source_number_digits: int = 1 + self, + allow_duplicates: bool = False, + default_to_implicit_numbering: bool = False, + default_source_number_digits: int = 1, ) -> Iterator[Sources]: """ Context manager for accessing sources. Args: allow_duplicates: Whether to allow duplicate entries when adding new sources. + default_to_implicit_numbering: Use implicit numbering (no source numbers) by default. default_source_number_digits: Default number of digits in a source number. Yields: @@ -178,6 +182,7 @@ def sources( tags, list(zip(*sourcelists))[1] if sourcelists else [], allow_duplicates, + default_to_implicit_numbering, default_source_number_digits, ) finally: @@ -186,13 +191,17 @@ def sources( @contextlib.contextmanager def patches( - self, allow_duplicates: bool = False, default_source_number_digits: int = 1 + self, + allow_duplicates: bool = False, + default_to_implicit_numbering: bool = False, + default_source_number_digits: int = 1, ) -> Iterator[Patches]: """ Context manager for accessing patches. Args: allow_duplicates: Whether to allow duplicate entries when adding new patches. + default_to_implicit_numbering: Use implicit numbering (no source numbers) by default. default_source_number_digits: Default number of digits in a source number. Yields: @@ -207,6 +216,7 @@ def patches( tags, list(zip(*patchlists))[1] if patchlists else [], allow_duplicates, + default_to_implicit_numbering, default_source_number_digits, ) finally: diff --git a/tests/unit/test_sources.py b/tests/unit/test_sources.py index 0e01af7..923b9a6 100644 --- a/tests/unit/test_sources.py +++ b/tests/unit/test_sources.py @@ -20,9 +20,63 @@ ("Patch99999", "99999"), ], ) -def test_tag_source_get_number(tag_name, number): +def test_tag_source_extract_number(tag_name, number): ts = TagSource(Tag(tag_name, "", "", "", Comments())) - assert ts._get_number() == number + assert ts._extract_number() == number + + +@pytest.mark.parametrize( + "tags, default, result", + [ + ( + [ + ("Name", "test"), + ("Version", "0.1"), + ("Source0", "source0"), + ("Source1", "source1"), + ("Source2", "source2"), + ], + True, + False, + ), + ( + [ + ("Name", "test"), + ("Version", "0.1"), + ("Source", "source0"), + ("Source", "source1"), + ("Source", "source2"), + ], + False, + True, + ), + ( + [ + ("Name", "test"), + ("Version", "0.1"), + ("Source", "source0"), + ], + False, + False, + ), + ( + [ + ("Name", "test"), + ("Version", "0.1"), + ("Source", "source0"), + ], + True, + True, + ), + ], +) +def test_sources_detect_implicit_numbering(tags, default, result): + sources = Sources( + Tags([Tag(t, v, v, ": ", Comments()) for t, v in tags]), + [], + default_to_implicit_numbering=default, + ) + assert sources._detect_implicit_numbering() == result @pytest.mark.parametrize( @@ -33,7 +87,7 @@ def test_tag_source_get_number(tag_name, number): ], ) def test_sources_get_tag_format(ref_name, ref_separator, number, name, separator): - sources = Sources(None, []) + sources = Sources(Tags(), []) reference = TagSource(Tag(ref_name, "", "", ref_separator, Comments())) assert sources._get_tag_format(reference, number) == (name, separator) @@ -265,7 +319,7 @@ def test_sources_insert_numbered(tags, number, location, index): ], ) def test_patches_get_tag_format(ref_name, ref_separator, number, name, separator): - patches = Patches(None, []) + patches = Patches(Tags(), []) reference = TagSource(Tag(ref_name, "", "", ref_separator, Comments())) assert patches._get_tag_format(reference, number) == (name, separator) diff --git a/tests/unit/test_tags.py b/tests/unit/test_tags.py index 0816689..ad61b0f 100644 --- a/tests/unit/test_tags.py +++ b/tests/unit/test_tags.py @@ -12,9 +12,9 @@ def test_find(): tags = Tags( [ - Tag("Name", "test", "test", ": ", []), - Tag("Version", "0.1", "0.1", ": ", []), - Tag("Release", "1%{?dist}", "1.fc35", ": ", []), + Tag("Name", "test", "test", ": ", Comments()), + Tag("Version", "0.1", "0.1", ": ", Comments()), + Tag("Release", "1%{?dist}", "1.fc35", ": ", Comments()), ] ) assert tags.find("version") == 1 @@ -122,7 +122,7 @@ def test_get_raw_section_data(): def test_copy_tags(): tags = Tags( [ - Tag("Name", "test", "test", ": ", []), + Tag("Name", "test", "test", ": ", Comments()), ] ) tags_copy = copy.deepcopy(tags)