Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove shim for bitsandbytes versions that don't support the device parameter #228

Merged
merged 4 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions curated_transformers/_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

try:
import transformers

Expand All @@ -6,19 +8,43 @@
transformers = None # type: ignore
has_hf_transformers = False

try:
import bitsandbytes

has_bitsandbytes = True
def _check_scipy_presence() -> bool:
try:
from scipy.stats import norm

return True
except ImportError:
return False


# Check if we have the correct version (that exposes the `device` parameter).
def _check_bnb_presence() -> bool:
try:
_ = bitsandbytes.nn.Linear8bitLt(1, 1, device=None)
_ = bitsandbytes.nn.Linear4bit(1, 1, device=None)
has_bitsandbytes_linear_device = True
except TypeError:
has_bitsandbytes_linear_device = False
except ImportError:
import bitsandbytes
except ModuleNotFoundError:
return False
except ImportError:
pass
return True


# As of v0.40.0, `bitsandbytes` doesn't correctly specify `scipy` as an installation
# dependency. This can lead to situations where the former is installed but
# the latter isn't and the ImportError gets masked. So, we additionally check
# for the presence of `scipy`.
if _check_scipy_presence():
try:
import bitsandbytes

has_bitsandbytes = True
except ImportError:
bitsandbytes = None # type: ignore
has_bitsandbytes = False
else:
if _check_bnb_presence():
warnings.warn(
"The `bitsandbytes` library is installed but its dependency "
"`scipy` isn't. Please install `scipy` to correctly load `bitsandbytes`."
)
bitsandbytes = None # type: ignore
has_bitsandbytes = False
has_bitsandbytes_linear_device = False
50 changes: 18 additions & 32 deletions curated_transformers/quantization/bnb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module, Parameter

from ..._compat import bitsandbytes as bnb
from ..._compat import has_bitsandbytes, has_bitsandbytes_linear_device
from ..._compat import has_bitsandbytes
from ...util.pytorch import ModuleIterator, apply_to_module
from ...util.serde import TensorToParameterConverterT
from .config import BitsAndBytesConfig, _4BitConfig, _8BitConfig
Expand Down Expand Up @@ -134,36 +134,30 @@ def _init_8bit_linear(
source: Module, config: Union[_8BitConfig, _4BitConfig], device: torch.device
) -> "bnb.nn.Linear8bitLt":
assert isinstance(config, _8BitConfig)
kwargs: Dict[str, Any] = {
"input_features": source.in_features,
"output_features": source.out_features,
"bias": source.bias is not None,
"has_fp16_weights": config.finetunable,
"threshold": config.outlier_threshold,
}
if has_bitsandbytes_linear_device:
kwargs["device"] = device

quantized_module = bnb.nn.Linear8bitLt(**kwargs)
quantized_module = bnb.nn.Linear8bitLt(
input_features=source.in_features,
output_features=source.out_features,
bias=source.bias is not None,
has_fp16_weights=config.finetunable,
threshold=config.outlier_threshold,
device=device,
)
return quantized_module


def _init_4bit_linear(
source: Module, config: Union[_8BitConfig, _4BitConfig], device: torch.device
) -> "bnb.nn.Linear4bit":
assert isinstance(config, _4BitConfig)
kwargs: Dict[str, Any] = {
"input_features": source.in_features,
"output_features": source.out_features,
"bias": source.bias is not None,
"compute_dtype": config.compute_dtype,
"compress_statistics": config.double_quantization,
"quant_type": config.quantization_dtype.value,
}
if has_bitsandbytes_linear_device:
kwargs["device"] = device

quantized_module = bnb.nn.Linear4bit(**kwargs)
quantized_module = bnb.nn.Linear4bit(
input_features=source.in_features,
output_features=source.out_features,
bias=source.bias is not None,
compute_dtype=config.compute_dtype,
compress_statistics=config.double_quantization,
quant_type=config.quantization_dtype.value,
device=device,
)
return quantized_module


Expand All @@ -172,11 +166,3 @@ def _assert_bitsandbytes_installed():
raise ValueError(
"The `bitsandbytes` Python library is required for quantization support"
)
elif not has_bitsandbytes_linear_device:
msg = (
"The installed version of the `bitsandbytes` library does not support passing "
"device parameters to its modules. This can result in increased memory usage "
"during model initialization. Please refer to the README for more information."
)
warnings.filterwarnings("once", msg)
warnings.warn(msg)