Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce QLoRA-like technique #19356

Merged
merged 12 commits into from
Mar 25, 2024
Merged

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Mar 22, 2024

Highlights

This PR enables training with frozen int8 weights for Dense and EinsumDense.

Overall, we will have a similar training speed and lower memory usage (about 68~50%) compared to floating-point LoRA.

Notes

Similar to QLoRA, but this PR lacks the following:

  • NF4 is not utilized because there is no backend supports it.
  • Double quantization is not used.
  • Paged optimizer is not included because it is more like a low-level hardware optimization

The training speed with torch backend is slower due to the lack of hardware-accelerated matmul/einsum.

Results

  • MNIST classification
  • backend: tensorflow (to measure GPU memory usage)
layer quantized compute_dtype acc. (LoRA unmerged/merged) inference time (LoRA unmerged/merged) peak gpu memory (fitting with LoRA)
Dense float32 0.95790 / 0.95790 0.00395s / 0.00338s 0.528GB
Dense bfloat16 0.96030 / 0.96060 0.00270s / 0.00246s 0.452GB
Dense int8 float32 0.95860†/ 0.95920 0.00282s / 0.00208s 0.264GB
Dense int8 bfloat16 0.95790†/ 0.95830 0.00254s / 0.00207s 0.263GB
Einsum float32 0.96030 / 0.96030 0.00385s / 0.00331s 0.526GB
Einsum bfloat16 0.96300 / 0.96310 0.00258s / 0.00237s 0.451GB
Einsum int8 float32 0.95940†/ 0.95860 0.00499s*/ 0.00191s 0.360GB
Einsum int8 bfloat16 0.96400†/ 0.96370 0.00364s*/ 0.00200s 0.285GB

*: The performance of the quantized Einsum with lora_enabled=True is suboptimal due to the current implementation of LoRA calculation.
†: Merging LoRA weights into int8 kernels results in lossy compression, leading to slightly differences in the final outputs.

Standalone benchmark script:

# Usage
# Train a float model
python3 benchmark.py --type train [--use-einsum] [--dtype-policy mixed_bfloat16]
# Finetune with quantized weights
python3 benchmark.py --type finetune --path [model_int8.keras|model_fp32.keras] [--use-einsum] [--dtype-policy mixed_bfloat16]
benchmark.py
import argparse
import os
import time

import numpy as np
import tensorflow as tf

import keras
from keras import backend
from keras import dtype_policies
from keras import layers
from keras import models
from keras import ops
from keras import saving
from keras.utils.traceback_utils import disable_traceback_filtering

disable_traceback_filtering()


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--type",
        default="train",
        choices=["train", "finetune"],
    )
    parser.add_argument("--path")
    parser.add_argument(
        "--dtype-policy",
        default="float32",
        choices=["float32", "mixed_bfloat16"],
    )
    parser.add_argument("--use-einsum", action="store_true")
    return parser.parse_args()


def build_model(num_layers=32, units=1024, use_einsum=False):
    inputs = layers.Input([28, 28])
    x = layers.Flatten()(inputs)
    for _ in range(num_layers):
        if use_einsum:
            x = layers.EinsumDense("ab,bc->ac", output_shape=[units])(x)
        else:
            x = layers.Dense(units)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
    outputs = layers.Dense(10, use_bias=True, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model


def benchmark(model, batch_size=1024, input_shape=(28, 28), iterations=200):
    def fn(x):
        return model(x, training=False)

    if backend.backend() == "tensorflow":
        import tensorflow as tf

        jit_fn = tf.function(fn, jit_compile=True)
    elif backend.backend() == "jax":
        import jax

        jit_fn = jax.jit(fn)
    elif backend.backend() == "torch":
        jit_fn = fn
    else:
        jit_fn = fn

    # warmup
    x = ops.ones([batch_size, *input_shape])
    for _ in range(10):
        _ = ops.convert_to_numpy(jit_fn(x))

    times = []
    for _ in range(iterations):
        t0 = time.time()
        _ = ops.convert_to_numpy(jit_fn(x))
        t1 = time.time()
        times.append(t1 - t0)
    avg_time = sum(times) / len(times)
    return avg_time


class GPUMemoryCallback(keras.callbacks.Callback):
    def __init__(self, target_batches, **kwargs):
        super().__init__(**kwargs)
        self.target_batches = target_batches
        self.memory_usage = []

    def _compute_memory_usage(self):
        try:
            memory_stats = tf.config.experimental.get_memory_info("GPU:0")
        except ValueError:
            memory_stats = {"peak": 0}
        # Convert bytes to GB and store in list.
        peak_usage = round(memory_stats["peak"] / (2**30), 3)
        self.memory_usage.append(peak_usage)

    def on_epoch_begin(self, epoch, logs=None):
        self._compute_memory_usage()

    def on_train_batch_begin(self, batch, logs=None):
        if batch in self.target_batches:
            self._compute_memory_usage()

    def on_epoch_end(self, epoch, logs=None):
        self._compute_memory_usage()


def train(args, dtype):
    # Model / data parameters
    use_einsum = args.use_einsum
    num_classes = 10
    input_shape = (28, 28, 1)
    epochs = 1

    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train.astype("float32") / 255
    x_test = x_test.astype("float32") / 255
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    model = build_model(num_layers=32, units=1024, use_einsum=use_einsum)
    model.compile(
        loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
    )

    """Train float model"""
    print("=====Start training float model=====")
    model.fit(
        x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1
    )
    print(f"Performance of {dtype}:")
    score = model.evaluate(x_test, y_test, verbose=0)
    print(f"  Test accuracy: {score[1]:.5f}")
    avg_time = benchmark(model, input_shape=input_shape)
    print(f"  Avg. inference time (batch_size=1024): {avg_time:.5f}s")

    """Save trained model"""
    model.save("model_fp32.keras")
    model.quantize("int8")
    model.save("model_int8.keras")
    print("Size of saved model:")
    print(f"  fp32: {os.path.getsize('model_fp32.keras') >> 20}MB")
    print(f"  int8: {os.path.getsize('model_int8.keras') >> 20}MB")


def finetune(args, model):
    use_einsum = args.use_einsum
    """Enable LoRA"""
    for layer in model.layers:
        if hasattr(layer, "enable_lora"):
            layer.enable_lora(2)

    # Model / data parameters
    num_classes = 10
    input_shape = (28, 28, 1)
    epochs = 1

    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train.astype("float32") / 255
    x_test = x_test.astype("float32") / 255
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    gpu_memory_callback = GPUMemoryCallback(
        target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
    )
    model.compile(
        loss="categorical_crossentropy",
        optimizer="adam",
        metrics=["accuracy"],
    )
    model.fit(
        x_train,
        y_train,
        batch_size=128,
        epochs=epochs,
        callbacks=[gpu_memory_callback],
        validation_split=0.1,
    )
    lora_memory_usage = gpu_memory_callback.memory_usage
    print("Performance of fine-tuned lora weights:")
    score = model.evaluate(x_test, y_test, verbose=0)
    print(f"  Test accuracy: {score[1]:.5f}")
    avg_time = benchmark(model, input_shape=input_shape)
    print(f"  Avg. inference time (batch_size=1024): {avg_time:.5f}s")
    print(f"  GPU Memory Usage (in GB): {max(lora_memory_usage)}")

    """Saving & loading"""
    model_path = "finetune.keras"
    weights_path = "finetune.weights.h5"
    model.save(model_path)
    model.save_weights(weights_path)
    reloaded_model = saving.load_model(model_path)
    reloaded_score = reloaded_model.evaluate(x_test, y_test, verbose=0)
    print(f"Reloaded model test accuracy: {reloaded_score[1]:.5f}")
    # Load the file into a fresh, non-lora model
    new_model = build_model(num_layers=32, units=1024, use_einsum=use_einsum)
    new_model.build(input_shape)
    if isinstance(
        model.layers[2].dtype_policy, dtype_policies.QuantizedDTypePolicy
    ):
        new_model.quantize("int8")
    new_model.load_weights(weights_path)
    new_model.compile(
        loss="categorical_crossentropy",
        optimizer="adam",
        metrics=["accuracy"],
    )
    reloaded_score = new_model.evaluate(x_test, y_test, verbose=0)
    print("Non-lora model:")
    print(f"  Test accuracy: {reloaded_score[1]:.5f}")
    avg_time = benchmark(new_model, input_shape=input_shape)
    print(f"  Avg. inference time (batch_size=1024): {avg_time:.5f}s")
    # Try loading a normal checkpoint into a lora model
    new_model.save_weights(weights_path)
    model.load_weights(weights_path)
    reloaded_score = model.evaluate(x_test, y_test, verbose=0)
    print(f"Lora model test accuracy: {reloaded_score[1]:.5f}")


def main():
    args = get_args()

    # Set dtype policy
    dtype = args.dtype_policy
    dtype_policies.dtype_policy.set_dtype_policy(dtype)
    print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}")

    if args.type == "train":
        train(args, dtype)
    elif args.type == "finetune":
        model = saving.load_model(args.path)
        finetune(args, model)


if __name__ == "__main__":
    main()

GPT2 & Gemma Finetuning

Try this PR with GPT2/Gemma and LoRA

model quantized compute_dtype training speed peak memory usage
GPT2 float32 120ms 4.728GB
GPT2 int8 float32 112ms 3.994GB
GPT2 bfloat16 86ms 3.423GB
GPT2 int8 bfloat16 88ms 2.745GB
# Usage
python3 qlora.py [--model gpt2_base_en|gpt2_medium_en|gemma_2b_en] [--bfloat16] [--qlora]
qlora.py
import argparse
import json

import kagglehub
import keras_nlp
import tensorflow as tf

import keras
import keras.saving


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        default="gpt2_base_en",
        choices=["gpt2_base_en", "gpt2_medium_en", "gemma_2b_en"],
        help="Which model to demonstrate",
    )
    parser.add_argument(
        "--qlora",
        action="store_true",
        help="Whether to use QLoRA-like technique",
    )
    parser.add_argument(
        "--bfloat16",
        action="store_true",
        help="Whether to use mixed bfloat16",
    )
    args = parser.parse_args()
    return args


class GPUMemoryCallback(keras.callbacks.Callback):
    def __init__(self, target_batches, **kwargs):
        super().__init__(**kwargs)
        self.target_batches = target_batches
        self.memory_usage = []

    def _compute_memory_usage(self):
        memory_stats = tf.config.experimental.get_memory_info("GPU:0")
        # Convert bytes to GB and store in list.
        peak_usage = round(memory_stats["peak"] / (2**30), 3)
        self.memory_usage.append(peak_usage)

    def on_epoch_begin(self, epoch, logs=None):
        self._compute_memory_usage()

    def on_train_batch_begin(self, batch, logs=None):
        if batch in self.target_batches:
            self._compute_memory_usage()

    def on_epoch_end(self, epoch, logs=None):
        self._compute_memory_usage()


def get_optimizer_and_loss():
    optimizer = keras.optimizers.AdamW(
        learning_rate=5e-5,
        weight_decay=0.01,
        epsilon=1e-6,
        global_clipnorm=1.0,  # Gradient clipping.
    )
    # Exclude layernorm and bias terms from weight decay.
    optimizer.exclude_from_weight_decay(var_names=["bias"])
    optimizer.exclude_from_weight_decay(var_names=["gamma"])
    optimizer.exclude_from_weight_decay(var_names=["beta"])

    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return optimizer, loss


if __name__ == "__main__":
    EPOCHS = 1
    RANK = 4
    args = get_args()
    if args.model == "gemma_2b_en":
        kagglehub.login()
    if args.bfloat16:
        keras.mixed_precision.set_global_policy("mixed_bfloat16")

    # Setup dataset
    data = []
    with open("databricks-dolly-15k.jsonl") as file:
        for line in file:
            features = json.loads(line)
            # Filter out examples with context, to keep it simple.
            if features["context"]:
                continue
            # Format the entire example as a single string.
            template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
            data.append(template.format(**features))

    # Only use 16000 training examples, to keep it fast.
    data = data[:16000]

    if args.model == "gemma_2b_en":
        preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset(
            args.model, sequence_length=128
        )
        lora_model = keras_nlp.models.GemmaCausalLM.from_preset(
            args.model, preprocessor=preprocessor
        )
    elif "gpt2" in args.model:
        preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
            args.model, sequence_length=128
        )
        lora_model = keras_nlp.models.GPT2CausalLM.from_preset(
            args.model, preprocessor=preprocessor
        )
    if args.qlora:
        lora_model.quantize("int8")
    lora_model.backbone.enable_lora(RANK)
    lora_model.summary()
    optimizer, loss = get_optimizer_and_loss()
    lora_model.compile(
        optimizer=optimizer, loss=loss, weighted_metrics=["accuracy"]
    )
    gpu_memory_callback = GPUMemoryCallback(
        target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
    )
    lora_model.fit(data, epochs=EPOCHS, callbacks=[gpu_memory_callback])
    lora_model_memory_usage = gpu_memory_callback.memory_usage
    print(f"GPU Memory Usage (in GB): {max(lora_model_memory_usage)}")

@codecov-commenter
Copy link

codecov-commenter commented Mar 22, 2024

Codecov Report

Attention: Patch coverage is 80.50847% with 23 lines in your changes are missing coverage. Please review.

Project coverage is 75.91%. Comparing base (576aa8d) to head (8891da4).

Files Patch % Lines
keras/layers/core/einsum_dense.py 70.58% 10 Missing and 10 partials ⚠️
keras/layers/core/dense.py 94.00% 1 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19356      +/-   ##
==========================================
+ Coverage   75.86%   75.91%   +0.04%     
==========================================
  Files         366      366              
  Lines       40479    40532      +53     
  Branches     7869     7884      +15     
==========================================
+ Hits        30711    30768      +57     
+ Misses       8068     8066       -2     
+ Partials     1700     1698       -2     
Flag Coverage Δ
keras 75.76% <80.50%> (+0.04%) ⬆️
keras-jax 60.05% <72.03%> (+0.03%) ⬆️
keras-numpy 54.01% <14.40%> (-0.28%) ⬇️
keras-tensorflow 61.38% <80.50%> (+0.06%) ⬆️
keras-torch 60.34% <72.03%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@james77777778 james77777778 marked this pull request as draft March 22, 2024 09:34
@james77777778
Copy link
Contributor Author

We should wait for #19302

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The custom gradient PR has been merged. Please fix merge conflicts.

keras/layers/layer.py Outdated Show resolved Hide resolved
keras/layers/core/einsum_dense_test.py Show resolved Hide resolved
@mattdangerw
Copy link
Member

mattdangerw commented Mar 22, 2024

Is there a way we could try this out with an LLM? E.g. adapt this guide -> https://ai.google.dev/gemma/docs/lora_tuning

IIUC, this should probably allow us to fine tune a gemma 7b model checkpoint on less than 16gb GPU RAM, because lora will essentially zero out the size of optimizer variables relative to model weights, and quantizing our weights to int8 should bring us to a little over 8gb of space.

An end to end test with a massive model might validate a lot.

@mattdangerw
Copy link
Member

How do we want to handle embeddings and quantization? Embeddings are usually the biggest individual memory hogs for models that might want quantization. We might want to add some quant support to our layer (though does not need to be this PR!).

https://github.com/google/gemma_pytorch/blob/cf8658c186255379194ba5b62612321eacde1b6b/gemma/model.py#L132-L154
https://pytorch.org/docs/stable/generated/torch.ao.nn.quantized.Embedding.html

@fchollet
Copy link
Member

Since Embedding is just a lookup table, presumably just storing the weights in int8 should do it? There's no input scaling either (inputs are integer indices).

@mattdangerw
Copy link
Member

Since Embedding is just a lookup table, presumably just storing the weights in int8 should do it? There's no input scaling either (inputs are integer indices).

Yeah I think it should be pretty simple.

Then there's the question of lora + quantization + an embedding layer. I don't think practically doing lora + quantization will be that important, as most people probably use lora with any embeddings frozen, but it might be worth adding for consistency (since we have enable_lora on the layer).

ops.cast(kernel, dtype=self.compute_dtype),
kernel_scale,
)
# From https://stackoverflow.com/a/47609896
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow. Brilliant! Never thought of it this way!! A very beautiful way to exploit Einsteinian Tensor Summation.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 23, 2024
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Mar 23, 2024
@james77777778
Copy link
Contributor Author

james77777778 commented Mar 23, 2024

Is there a way we could try this out with an LLM? E.g. adapt this guide -> https://ai.google.dev/gemma/docs/lora_tuning

I have encountered an issue to try this PR with KerasNLP:
When loading a quantized model using keras.saving.load_model(...), I noticed that KerasNLP omits the dtype and always use the global dtype policy during loading.

This issue will cause the following to fail:

preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en", sequence_length=128
)
lora_model = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)
lora_model.quantize("int8")
lora_model.save("model_int8.keras")
reloaded_model = keras.saving.load_model("model_int8.keras")  # <- this line

The above is neccesary to accurately record the peak GPU memory usage because tensorflow doesn't release GPU memory after using quantize("int8") on float model.

How do we want to handle embeddings and quantization?

I can add quantize to Embedding layer in another PR. It should be feasible.

@james77777778 james77777778 marked this pull request as ready for review March 23, 2024 18:14
@james77777778
Copy link
Contributor Author

I just realized that I need to manually call gc.collect() to release GPU memory after the quantization.
I have added the result of finetuning GPT2. Please refer to #19356 (comment)

@fchollet
Copy link
Member

Thanks for the update and for the GPT-2 numbers!

I just realized that I need to manually call gc.collect() to release GPU memory after the quantization.

Is this something we should do in the framework code, then?

Have you been able to try Gemma 2B?

@github-actions github-actions bot added the Gemma Gemma model specific issues label Mar 25, 2024
@james77777778
Copy link
Contributor Author

james77777778 commented Mar 25, 2024

Is this something we should do in the framework code, then?

This should be a harmless addition. I have updated the code and verified its effectiveness.

Have you been able to try Gemma 2B?

Unfortunately, I failed to fit Gemma 2B with my rig. (12GB 4070...)
I have added an option [--model gpt2_base_en|gpt2_medium_en|gemma_2b_en] to the script. Perhaps someone can try it out to obtain the result.

BTW, the training speed is improved after #19368

model quantized compute_dtype training speed peak memory usage
GPT2 float32 120ms 4.728GB
GPT2 int8 float32 112ms 3.994GB
GPT2 bfloat16 86ms 3.423GB
GPT2 int8 bfloat16 88ms 2.745GB

Finetuning a quantized model might be faster in float32 and is competitive in bfloat16

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thank you for the great contribution. Let's merge!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 25, 2024
@fchollet fchollet merged commit f40e443 into keras-team:master Mar 25, 2024
8 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Mar 25, 2024
@james77777778 james77777778 deleted the qlora-like branch March 25, 2024 04:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues size:L
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants