diff --git a/src/poetry/core/version/markers.py b/src/poetry/core/version/markers.py index 67e370661..a5f50cf2c 100644 --- a/src/poetry/core/version/markers.py +++ b/src/poetry/core/version/markers.py @@ -744,7 +744,7 @@ def exclude(self, marker_name: str) -> BaseMarker: if not marker.is_empty(): new_markers.append(marker) - return self.of(*new_markers) + return intersection(*new_markers) def only(self, *marker_names: str) -> BaseMarker: return self.of(*(m.only(*marker_names) for m in self._markers)) @@ -920,7 +920,7 @@ def exclude(self, marker_name: str) -> BaseMarker: # All markers were the excluded marker. return AnyMarker() - return self.of(*new_markers) + return union(*new_markers) def only(self, *marker_names: str) -> BaseMarker: return self.of(*(m.only(*marker_names) for m in self._markers)) @@ -1076,11 +1076,28 @@ def dnf(marker: BaseMarker) -> BaseMarker: def intersection(*markers: BaseMarker) -> BaseMarker: - return dnf(MultiMarker(*markers)) + # Sometimes normalization makes it more complicated instead of simple + # -> choose candidate with the least complexity + unnormalized: BaseMarker = MultiMarker(*markers) + while ( + isinstance(unnormalized, (MultiMarker, MarkerUnion)) + and len(unnormalized.markers) == 1 + ): + unnormalized = unnormalized.markers[0] + + disjunction = dnf(unnormalized) + if not isinstance(disjunction, MarkerUnion): + return disjunction + + conjunction = cnf(disjunction) + if not isinstance(conjunction, MultiMarker): + return conjunction + + return min(disjunction, conjunction, unnormalized, key=lambda x: x.complexity) def union(*markers: BaseMarker) -> BaseMarker: - # Sometimes normalization makes it more complicate instead of simple + # Sometimes normalization makes it more complicated instead of simple # -> choose candidate with the least complexity unnormalized: BaseMarker = MarkerUnion(*markers) while ( diff --git a/tests/version/test_markers.py b/tests/version/test_markers.py index d2a8fd0fb..280eecbab 100644 --- a/tests/version/test_markers.py +++ b/tests/version/test_markers.py @@ -940,9 +940,8 @@ def test_marker_union_intersect_single_marker() -> None: intersection = m.intersect(parse_marker('implementation_name == "cpython"')) assert ( - str(intersection) - == 'sys_platform == "darwin" and implementation_name == "cpython" ' - 'or python_version < "3.4" and implementation_name == "cpython"' + str(intersection) == '(sys_platform == "darwin" or python_version < "3.4")' + ' and implementation_name == "cpython"' ) @@ -968,11 +967,8 @@ def test_marker_union_intersect_marker_union() -> None: parse_marker('implementation_name == "cpython" or os_name == "Windows"') ) assert ( - str(intersection) - == 'sys_platform == "darwin" and implementation_name == "cpython" ' - 'or sys_platform == "darwin" and os_name == "Windows" or ' - 'python_version < "3.4" and implementation_name == "cpython" or ' - 'python_version < "3.4" and os_name == "Windows"' + str(intersection) == '(sys_platform == "darwin" or python_version < "3.4") and ' + '(implementation_name == "cpython" or os_name == "Windows")' ) @@ -1000,18 +996,16 @@ def test_marker_union_intersect_multi_marker() -> None: # Intersection isn't _quite_ symmetrical. expected1 = ( - 'sys_platform == "darwin" and implementation_name == "cpython" and os_name ==' - ' "Windows" or python_version < "3.4" and implementation_name == "cpython" and' - ' os_name == "Windows"' + '(sys_platform == "darwin" or python_version < "3.4")' + ' and implementation_name == "cpython" and os_name == "Windows"' ) intersection = m1.intersect(m2) assert str(intersection) == expected1 expected2 = ( - 'implementation_name == "cpython" and os_name == "Windows" and sys_platform' - ' == "darwin" or implementation_name == "cpython" and os_name == "Windows"' - ' and python_version < "3.4"' + 'implementation_name == "cpython" and os_name == "Windows"' + ' and (sys_platform == "darwin" or python_version < "3.4")' ) intersection = m2.intersect(m1) @@ -1357,6 +1351,18 @@ def test_without_extras(marker: str, expected: str) -> None: "extra", 'python_version >= "3.6"', ), + ( + ( + 'python_version >= "2.7" and (python_version < "2.8"' + ' or python_version >= "3.7") and python_version < "3.8"' + ' and extra == "foo"' + ), + "extra", + ( + 'python_version >= "2.7" and python_version < "2.8"' + ' or python_version >= "3.7" and python_version < "3.8"' + ), + ), ], ) def test_exclude(marker: str, excluded: str, expected: str) -> None: