Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.load without weights_only parameter is unsafe #1

Closed
kit1980 opened this issue Feb 21, 2024 · 2 comments
Closed

torch.load without weights_only parameter is unsafe #1

kit1980 opened this issue Feb 21, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@kit1980
Copy link

kit1980 commented Feb 21, 2024

This is found via https://github.com/pytorch-labs/torchfix/

torch.load without weights_only parameter is unsafe. Explicitly set weights_only to False only if you trust the data you load and full pickle functionality is needed, otherwise set weights_only=True.

gemma/model.py:562:13

--- /home/sdym/repos/google/gemma_pytorch/gemma/model.py
+++ /home/sdym/repos/google/gemma_pytorch/gemma/model.py
@@ -557,9 +557,9 @@
         # If a string was provided as input, return a string as output.
         return results[0] if is_str_prompt else results
 
     def load_weights(self, model_path: str):
         self.load_state_dict(
-            torch.load(model_path, mmap=True)['model_state_dict'],
+            torch.load(model_path, mmap=True, weights_only=True)['model_state_dict'],
             strict=False,
         )

gemma/model_xla.py:517:22

--- /home/sdym/repos/google/gemma_pytorch/gemma/model_xla.py
+++ /home/sdym/repos/google/gemma_pytorch/gemma/model_xla.py
@@ -512,11 +512,11 @@
             top_ks=top_ks,
         )
         return next_tokens
 
     def load_weights(self, model_path: str):
-        checkpoint = torch.load(model_path)
+        checkpoint = torch.load(model_path, weights_only=True)
         model_state_dict = checkpoint['model_state_dict']
 
         num_attn_heads = self.config.num_attention_heads
         num_kv_heads = self.config.num_key_value_heads
         head_dim = self.config.head_dim

@pengchongjin
Copy link
Collaborator

Thanks for reporting it. We are working on it.

@michaelmoynihan
Copy link
Collaborator

Hi Sergii, Thanks for the suggestion! We have updated accordingly. I tested it out locally and it works and I have created and merged a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants