Skip to content

Commit

Permalink
Move non-NF4 tensor to device prior to quantization on copy (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Aug 23, 2024
1 parent 68e4643 commit 0ed3090
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def copy_(func, *args, **kwargs):
# Convert Non NF4Tensor into NF4 for copy in
if not isinstance(copy_in, NF4Tensor):
copy_in_nf4 = NF4Tensor.from_tensor(
copy_in, original.block_size, original.scaler_block_size
copy_in.to(original.device), original.block_size, original.scaler_block_size
)
return original.copy_(copy_in_nf4)

Expand Down

0 comments on commit 0ed3090

Please sign in to comment.