diff --git a/upath/_protocol.py b/upath/_protocol.py index a3827bdd..d333dd6a 100644 --- a/upath/_protocol.py +++ b/upath/_protocol.py @@ -3,11 +3,16 @@ import os import re from pathlib import PurePath +from typing import TYPE_CHECKING from typing import Any +if TYPE_CHECKING: + from upath.core import UPath + __all__ = [ "get_upath_protocol", "normalize_empty_netloc", + "compatible_protocol", ] # Regular expression to match fsspec style protocols. @@ -59,3 +64,15 @@ def normalize_empty_netloc(pth: str) -> str: path = m.group("path") pth = f"{protocol}:///{path}" return pth + + +def compatible_protocol(protocol: str, *args: str | os.PathLike[str] | UPath) -> bool: + """check if UPath protocols are compatible""" + for arg in args: + other_protocol = get_upath_protocol(arg) + # consider protocols equivalent if they match up to the first "+" + other_protocol = other_protocol.partition("+")[0] + # protocols: only identical (or empty "") protocols can combine + if other_protocol and other_protocol != protocol: + return False + return True diff --git a/upath/core.py b/upath/core.py index 4c86e28d..714bfa3d 100644 --- a/upath/core.py +++ b/upath/core.py @@ -35,6 +35,7 @@ from upath._flavour import LazyFlavourDescriptor from upath._flavour import upath_get_kwargs_from_url from upath._flavour import upath_urijoin +from upath._protocol import compatible_protocol from upath._protocol import get_upath_protocol from upath._stat import UPathStatResult from upath.registry import get_upath_class @@ -251,23 +252,12 @@ def __init__( self._storage_options = storage_options.copy() # check that UPath subclasses in args are compatible - # --> ensures items in _raw_paths are compatible - for arg in args: - if not isinstance(arg, UPath): - continue - # protocols: only identical (or empty "") protocols can combine - if arg.protocol and arg.protocol != self._protocol: - raise TypeError("can't combine different UPath protocols as parts") - # storage_options: args may not define other storage_options - if any( - self._storage_options.get(key) != value - for key, value in arg.storage_options.items() - ): - # TODO: - # Future versions of UPath could verify that storage_options - # can be combined between UPath instances. Not sure if this - # is really necessary though. A warning might be enough... - pass + # TODO: + # Future versions of UPath could verify that storage_options + # can be combined between UPath instances. Not sure if this + # is really necessary though. A warning might be enough... + if not compatible_protocol(self._protocol, *args): + raise ValueError("can't combine incompatible UPath protocols") # fill ._raw_paths if hasattr(self, "_raw_paths"): diff --git a/upath/implementations/cloud.py b/upath/implementations/cloud.py index e2f4cb98..455fca6b 100644 --- a/upath/implementations/cloud.py +++ b/upath/implementations/cloud.py @@ -22,6 +22,13 @@ class CloudPath(UPath): __slots__ = () + def __init__( + self, *args, protocol: str | None = None, **storage_options: Any + ) -> None: + super().__init__(*args, protocol=protocol, **storage_options) + if not self.drive and len(self.parts) > 1: + raise ValueError("non key-like path provided (bucket/container missing)") + @classmethod def _transform_init_args( cls, diff --git a/upath/implementations/local.py b/upath/implementations/local.py index 4552585f..a0961cea 100644 --- a/upath/implementations/local.py +++ b/upath/implementations/local.py @@ -12,6 +12,7 @@ from typing import MutableMapping from urllib.parse import SplitResult +from upath._protocol import compatible_protocol from upath.core import UPath __all__ = [ @@ -141,6 +142,8 @@ def __new__( raise NotImplementedError( f"cannot instantiate {cls.__name__} on your system" ) + if not compatible_protocol("", *args): + raise ValueError("can't combine incompatible UPath protocols") obj = super().__new__(cls, *args) obj._protocol = "" return obj # type: ignore[return-value] @@ -152,6 +155,11 @@ def __init__( self._drv, self._root, self._parts = type(self)._parse_args(args) _upath_init(self) + def _make_child(self, args): + if not compatible_protocol(self._protocol, *args): + raise ValueError("can't combine incompatible UPath protocols") + return super()._make_child(args) + @classmethod def _from_parts(cls, *args, **kwargs): obj = super(Path, cls)._from_parts(*args, **kwargs) @@ -205,6 +213,8 @@ def __new__( raise NotImplementedError( f"cannot instantiate {cls.__name__} on your system" ) + if not compatible_protocol("", *args): + raise ValueError("can't combine incompatible UPath protocols") obj = super().__new__(cls, *args) obj._protocol = "" return obj # type: ignore[return-value] @@ -216,6 +226,11 @@ def __init__( self._drv, self._root, self._parts = self._parse_args(args) _upath_init(self) + def _make_child(self, args): + if not compatible_protocol(self._protocol, *args): + raise ValueError("can't combine incompatible UPath protocols") + return super()._make_child(args) + @classmethod def _from_parts(cls, *args, **kwargs): obj = super(Path, cls)._from_parts(*args, **kwargs) diff --git a/upath/tests/test_core.py b/upath/tests/test_core.py index f52e6b52..92d608df 100644 --- a/upath/tests/test_core.py +++ b/upath/tests/test_core.py @@ -410,3 +410,32 @@ def test_query_string(uri, query_str): p = UPath(uri) assert str(p).endswith(query_str) assert p.path.endswith(query_str) + + +@pytest.mark.parametrize( + "base,join", + [ + ("/a", "s3://bucket/b"), + ("s3://bucket/a", "gs://b/c"), + ("gs://bucket/a", "memory://b/c"), + ("memory://bucket/a", "s3://b/c"), + ], +) +def test_joinpath_on_protocol_mismatch(base, join): + with pytest.raises(ValueError): + UPath(base).joinpath(UPath(join)) + with pytest.raises(ValueError): + UPath(base) / UPath(join) + + +@pytest.mark.parametrize( + "base,join", + [ + ("/a", "s3://bucket/b"), + ("s3://bucket/a", "gs://b/c"), + ("gs://bucket/a", "memory://b/c"), + ("memory://bucket/a", "s3://b/c"), + ], +) +def test_joinuri_on_protocol_mismatch(base, join): + assert UPath(base).joinuri(UPath(join)) == UPath(join)