Skip to content

Commit

Permalink
[brief] Updates the function interface for safe_torch_load.
Browse files Browse the repository at this point in the history
[detailed]
- Makes it mimic the one from torch.load itself. This should make
  auto-completion of arguments work correctly.
  • Loading branch information
marovira committed Aug 28, 2024
1 parent 8c3ae57 commit dbb51fa
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/helios/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,10 @@ def add_safe_torch_serialization_globals(safe_globals: list[typing.Any]) -> None
torch.serialization.add_safe_globals(safe_globals) # type: ignore[attr-defined]


def safe_torch_load(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def safe_torch_load(
f: str | os.PathLike | typing.BinaryIO | typing.IO[bytes],
**kwargs: typing.Any,
) -> typing.Any:
"""
Wrap :code:`torch.load` to handle safe loading.
Expand All @@ -500,12 +503,14 @@ def safe_torch_load(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
value yourself when using this function.
Args:
*args: positional arguments to pass to :code:`torch.load`.
f: a file-like object (has to implement ``read()``, ``readline()``, ``tell()``,
and ``seek()``), or a string or a ``os.PathLike`` object containing a file
name.
**kwargs: keyword arguments to pass to :code:`torch.load`.
Returns:
The result of calling :code:`torch.load`.
"""
if enable_safe_torch_loading():
return torch.load(*args, **kwargs, weights_only=True)
return torch.load(*args, **kwargs)
return torch.load(f, **kwargs, weights_only=True)
return torch.load(f, **kwargs)

0 comments on commit dbb51fa

Please sign in to comment.