From dbb51fa0cc116713d19455ea6ce361dd199eecbd Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:40:38 -0700 Subject: [PATCH] [brief] Updates the function interface for safe_torch_load. [detailed] - Makes it mimic the one from torch.load itself. This should make auto-completion of arguments work correctly. --- src/helios/core/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/helios/core/utils.py b/src/helios/core/utils.py index 245a116..ab45327 100644 --- a/src/helios/core/utils.py +++ b/src/helios/core/utils.py @@ -487,7 +487,10 @@ def add_safe_torch_serialization_globals(safe_globals: list[typing.Any]) -> None torch.serialization.add_safe_globals(safe_globals) # type: ignore[attr-defined] -def safe_torch_load(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: +def safe_torch_load( + f: str | os.PathLike | typing.BinaryIO | typing.IO[bytes], + **kwargs: typing.Any, +) -> typing.Any: """ Wrap :code:`torch.load` to handle safe loading. @@ -500,12 +503,14 @@ def safe_torch_load(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: value yourself when using this function. Args: - *args: positional arguments to pass to :code:`torch.load`. + f: a file-like object (has to implement ``read()``, ``readline()``, ``tell()``, + and ``seek()``), or a string or a ``os.PathLike`` object containing a file + name. **kwargs: keyword arguments to pass to :code:`torch.load`. Returns: The result of calling :code:`torch.load`. """ if enable_safe_torch_loading(): - return torch.load(*args, **kwargs, weights_only=True) - return torch.load(*args, **kwargs) + return torch.load(f, **kwargs, weights_only=True) + return torch.load(f, **kwargs)