Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation release v1 #1012

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fbc0385
add optional dependency for preview to environment.yml
Titus-von-Koeller Feb 1, 2024
84b5fc0
Add additional sections, first optimizers, MacOS WIP
Titus-von-Koeller Feb 1, 2024
725d29a
drafting + refactoring new docs
Titus-von-Koeller Feb 1, 2024
58566e2
some changes
younesbelkada Feb 2, 2024
47cc3e9
run pre-commit hooks
Titus-von-Koeller Feb 2, 2024
c26645b
add mention of pre-commit to contributing
Titus-von-Koeller Feb 2, 2024
ab42c5f
fix
younesbelkada Feb 2, 2024
a71efa8
test autodoc
younesbelkada Feb 2, 2024
c1ec5f8
new additions
younesbelkada Feb 2, 2024
544114d
add subtilte
younesbelkada Feb 2, 2024
f735b35
add some content
younesbelkada Feb 2, 2024
daff94c
add more methods
younesbelkada Feb 2, 2024
301ee80
fix
younesbelkada Feb 2, 2024
683a72b
further docs updates
Titus-von-Koeller Feb 2, 2024
60a7699
Update _toctree.yml
younesbelkada Feb 2, 2024
543a7b1
fix link
Titus-von-Koeller Feb 3, 2024
2d73f4d
run pre-commit hooks
Titus-von-Koeller Feb 3, 2024
8f0fd8a
refactor + further docs
Titus-von-Koeller Feb 4, 2024
a3c45d3
Update README.md with new docs link
Titus-von-Koeller Feb 4, 2024
b370cee
list of blog posts
Titus-von-Koeller Feb 4, 2024
fd64f21
list of blog posts
Titus-von-Koeller Feb 4, 2024
38d323a
accept change suggestion
Titus-von-Koeller Feb 4, 2024
82485d0
accept suggestion
Titus-von-Koeller Feb 4, 2024
75cfb1c
accept suggestion
Titus-von-Koeller Feb 4, 2024
7a71390
Update docs/source/integrations.mdx
Titus-von-Koeller Feb 4, 2024
a84afcf
index instead of intro
Titus-von-Koeller Feb 4, 2024
d3709f4
fixup README, add docs link
Titus-von-Koeller Feb 4, 2024
e00cbc9
add instructions for creating docstrings
Titus-von-Koeller Feb 4, 2024
8a67759
final polish (except integrations)
Titus-von-Koeller Feb 4, 2024
d632531
fill out integrations section
Titus-von-Koeller Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ The bitsandbytes is a lightweight wrapper around CUDA custom functions, in parti



Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
Resources:
- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/)

- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/)

## TL;DR
**Requirements**
Expand Down
85 changes: 85 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@


class StableEmbedding(torch.nn.Embedding):
"""
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
TODO: @titus fill this with some info
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
"""
def __init__(
self,
num_embeddings: int,
Expand Down Expand Up @@ -222,9 +225,50 @@ def to(self, *args, **kwargs):


class Linear4bit(nn.Linear):
"""
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various
compute datatypes such as FP4 and NF4.

In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights.

Example:

```python
import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear4bit

fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)

quantized_model = nn.Sequential(
Linear4bit(64, 64),
Linear4bit(64, 64)
)

quantized_model.load_state_dict(fp16_model.state_dict())
quantized_model = quantized_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, device)
"""
Initialize Linear4bit class.

Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
Expand Down Expand Up @@ -397,9 +441,50 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k


class Linear8bitLt(nn.Linear):
"""
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
To read more about it, have a look at the paper.

In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights.

Example:

```python
import torch
import torch.nn as nn

import bitsandbytes as bnb
from bnb.nn import Linear8bitLt

fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)

int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)

int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
```
"""
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
super().__init__(input_features, output_features, bias, device)
"""
Initialize Linear8bitLt class.

Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
"""
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
Expand Down
30 changes: 26 additions & 4 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
- sections:
- local: index
title: Bits & Bytes
- title: Get started
sections:
- local: introduction
title: Introduction
- local: quickstart
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
title: Quickstart
- local: installation
title: Installation
title: Get started
- local: moduletree
title: Module Tree
- title: Features & Integrations
sections:
- local: quantization
title: Quantization
- local: optimizers
title: Optimizers
- local: integrations
title: Integrations
- title: Support & Learning
sections:
- local: resources
title: Papers, related resources & how to cite
- local: faqs
title: FAQs (Frequently Asked Questions)
- title: Contributors Guidelines
sections:
- local: contributing
title: Contributing
# - local: code_of_conduct
# title: Code of Conduct
13 changes: 13 additions & 0 deletions docs/source/contributing.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Contributors guidelines
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved
... stil under construction ... (feel free to propose materials, `bitsandbytes` is a community project)

# Setup pre-commit hooks
- Install pre-commit hooks with `pip install pre-commit`.
- Run `pre-commit autoupdate` once to configure the hooks.
- Re-run `pre-commit autoupdate` every time a new hook got added.

Now all the pre-commit hooks will be automatically run when you try to commit and if they introduce some changes, you need to re-add the changed files before being able to commit and push.

## Documentation
- [guideline for documentation syntax](https://github.com/huggingface/doc-builder#readme)
- images shall be uploaded via PR in the `bitsandbytes/` directory [here](https://huggingface.co/datasets/huggingface/documentation-images)
7 changes: 7 additions & 0 deletions docs/source/faqs.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# FAQs
Titus-von-Koeller marked this conversation as resolved.
Show resolved Hide resolved

Please submit your questions in [this Github Discussion thread](https://github.com/TimDettmers/bitsandbytes/discussions/1013) if you feel that they will likely affect a lot of other users and that they haven't been sufficiently covered in the documentation.

We'll pick the most generally applicable ones and post the QAs here or integrate them into the general documentation (also feel free to submit doc PRs, please).

# ... under construction ...
10 changes: 9 additions & 1 deletion docs/source/installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CUDA_VERSION=XXX make cuda12x
python setup.py install
```

with `XXX` being your CUDA version, for <12.0 call `make cuda 11x`
with `XXX` being your CUDA version, for <12.0 call `make cuda 11x`. Note support for non-CUDA GPUs (e.g. AMD, Intel), is also coming soon.

</hfoption>
<hfoption id="Windows">
Expand All @@ -40,4 +40,12 @@ python -m build --wheel
Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows.

</hfoption>
<hfoption id="MacOS">

## MacOS

Mac support is still a work in progress. Please make sure to check out the latest bitsandbytes issues to get notified about the progress with respect to MacOS integration.

</hfoption>

</hfoptions>
11 changes: 11 additions & 0 deletions docs/source/integrations.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Transformers

... TODO: to be filled out ...
Copy link
Collaborator

@younesbelkada younesbelkada Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we point out to relevant doc sections in transformers / peft / Trainer + very briefly explain how these are integrated:
e.g. for transformers state that you can load any model in 8-bit / 4-bit precision, for PEFT, you can use QLoRA out of the box with LoraConfig + 4-bit base model, for Trainer: all bnb optimizers are supported by passing the correct string in TrainingArguments : https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/training_args.py#L134

Few references:


# PEFT

... TODO: to be filled out ...

# Trainer for the optimizers

... TODO: to be filled out ...
77 changes: 7 additions & 70 deletions docs/source/index.mdx → docs/source/introduction.mdx
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should stay as index.mdx and be referred as the first page / index page that briefly presents what bitsandbytes is, in one paragraph & lists the available documentation sections (quickstart, quantization, optimizers, integrations, resources)

Original file line number Diff line number Diff line change
@@ -1,52 +1,14 @@
# bitsandbytes
TODO: Many parts of this doc will still be redistributed among the new doc structure.

The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions.
# `bitsandbytes`

The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions.

There are ongoing efforts to support further hardware backends, i.e. Intel CPU + GPU, AMD GPU, Apple Silicon. Windows support is on its way as well.
The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8bit optimizers through `bitsandbytes.optim` module.

Resources:
- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/)
**Using 8-bit optimizers**:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this and below should be completely removed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, deleted everything from here on down

Copy link
Collaborator Author

@Titus-von-Koeller Titus-von-Koeller Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even though this part here below is still relevant, maybe? Wdyt?

Features

  • 8-bit Matrix multiplication with mixed precision decomposition (this is also a standalone feature, I guess)
  • LLM.int8() inference (this is covered by the HF docs, right?)
  • 8-bit quantization: Quantile, Linear, and Dynamic quantization (hmm, should we explain the differences?)
  • Fast quantile estimation: Up to 100x faster than other algorithms (this is again such standalone algorithm thing)

I think the standalone algs are worth documenting, but maybe we make a note of that for later and leave it out completely for now. Or do you think it's better to leave a placeholder and put a comment encouraging community contributions? I think it makes sense to leverage the community as much as possible for docs and doc-strings

- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/)

## TL;DR
**Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.

(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0)

**Installation**:

``pip install bitsandbytes``

In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.

Compilation quickstart:
```bash
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes

# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=117 make cuda11x
python setup.py install
```

**Using Int8 inference with HuggingFace Transformers**

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf',
device_map='auto',
load_in_8bit=True,
max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
```

A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py).

**Using 8-bit optimizer**:
1. Comment out optimizer: ``#torch.optim.Adam(....)``
2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same)
3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)``
Expand All @@ -68,6 +30,7 @@ out = linear(x.to(torch.float16))


## Features

- 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
Expand All @@ -89,9 +52,6 @@ The bitsandbytes library is currently only supported on Linux distributions. Win

The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.

To install run:

``pip install bitsandbytes``

## Using bitsandbytes

Expand Down Expand Up @@ -166,26 +126,3 @@ For more detailed instruction, please follow the [compile_from_source.md](compil
The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license.

We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization.

## How to cite us
If you found this library and found LLM.int8() useful, please consider citing our work:

```bibtex
@article{dettmers2022llmint8,
title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale},
author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke},
journal={arXiv preprint arXiv:2208.07339},
year={2022}
}
```

For 8-bit optimizers or quantization routines, please consider citing the following work:

```bibtex
@article{dettmers2022optimizers,
title={8-bit Optimizers via Block-wise Quantization},
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},
journal={9th International Conference on Learning Representations, ICLR},
year={2022}
}
```
5 changes: 5 additions & 0 deletions docs/source/moduletree.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Module tree overview
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file really needed?

Copy link
Collaborator Author

@Titus-von-Koeller Titus-von-Koeller Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I took this from some docs that I found from before. But you're right, it's a bit obvious. I thought it could eventually be formed into some walkthrough of the library. But I agree that it doesn't add any value the way it is now and I agree that we should remove it. I'll delete.


- **bitsandbytes.functional**: Contains quantization functions (4-bit & 8-bit) and stateless 8-bit optimizer update functions.
- **bitsandbytes.nn.modules**: Contains stable embedding layer with automatic 32-bit optimizer overrides (important for NLP stability)
- **bitsandbytes.optim**: Contains 8-bit optimizers.
Loading
Loading