Skip to content

Commit

Permalink
Wrap torch.ops.quantized_decomposed to improve import errors (pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Jun 5, 2024
1 parent 3d609fc commit c7cd729
Show file tree
Hide file tree
Showing 6 changed files with 630 additions and 573 deletions.
71 changes: 71 additions & 0 deletions torchao/_executorch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch


def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs):
"""
Wrapper around torch.ops.quantized_decomposed.quantize_per_channel_group to mitigate
availability issue until it can be supplanted by new quantize_affine function.
torch.ops.quantized_decomposed.quantize_per_channel_group is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.")


def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs):
"""
Wrapper around torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric to mitigate
availability issue until it can be supplanted by new choose_qparams_affine function.
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.")


def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs):
"""
Wrapper around torch.ops.quantized_decomposed.dequantize_per_channel_group to mitigate
availability issue until it can be supplanted by new choose_qparams_affine function.
torch.ops.quantized_decomposed.dequantize_per_channel_group is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.")


def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs):
"""
Wrapper around torch.ops.quantized_decomposed.quantize_per_token to mitigate
availability issue until it can be supplanted by new choose_qparams_affine function.
torch.ops.quantized_decomposed.quantize_per_token is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.")


def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs):
"""
Wrapper around torch.ops.quantized_decomposed.dequantize_per_token to mitigate
availability issue until it can be supplanted by new choose_qparams_affine function.
torch.ops.quantized_decomposed.dequantize_per_token is only available
in PyTorch 2.3+ and recently changed signatures.
"""
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
if TORCH_VERSION_AFTER_2_3:
return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs)
raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.")
Loading

0 comments on commit c7cd729

Please sign in to comment.