Skip to content

Commit

Permalink
Update JetStream instructions (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeandy authored Aug 23, 2024
1 parent 59538fc commit 647ab24
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 48 deletions.
179 changes: 139 additions & 40 deletions docs/online-inference-with-maxtext-engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ Follow the steps in [Manage TPU resources | Google Cloud](https://cloud.google.c
## Step 1: Download JetStream and the MaxText github repository

```bash
git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
git clone -b v0.2.2 https://github.com/google/JetStream.git
git clone https://github.com/google/maxtext.git
git clone https://github.com/google/JetStream.git
```

## Step 2: Setup MaxText
## Step 2: Setup MaxText and JetStream

```bash
# Create a python virtual environment for the demo.
Expand All @@ -36,6 +36,12 @@ source .env/bin/activate
# Setup MaxText.
cd maxtext/
bash setup.sh

# Setup JetStream
cd JetStream
pip install -e .
cd benchmarks
pip install -r requirements.in
```

## Step 3: Convert Model Checkpoints
Expand All @@ -45,16 +51,16 @@ You can run the JetStream MaxText Server with Gemma and Llama2 models. This sect
### Use a Gemma model checkpoint

* You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b).
* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`.
* After downloading orbax Gemma checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively.
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
* Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.

```bash
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}

# For gemma-7b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}
```

Note: For more information about the Gemma model and checkpoints, see [About Gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma/Run_Gemma.md).
Expand All @@ -63,25 +69,25 @@ Note: For more information about the Gemma model and checkpoints, see [About Gem
### Use a Llama2 model checkpoint

* You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/).
* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`.
* After downloading PyTorch checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively.
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
* Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.

```bash
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}

# For llama2-7b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}

# For llama2-13b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED}
```

Note: For more information about the Llama2 model and checkpoints, see [About Llama2](https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md).


## Step4: Run the JetStream MaxText server
## Step 4: Run the JetStream MaxText server


### Create model config environment variables for server flags
Expand All @@ -104,8 +110,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Expand All @@ -122,17 +128,15 @@ export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
```

#### Create Llama2-13b environment variables for server flags



* Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server

```bash
Expand All @@ -142,8 +146,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=1
export ICI_TENSOR_PARALLELISM=-1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
Expand Down Expand Up @@ -187,7 +191,8 @@ python MaxText/maxengine_server.py \
Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml)


## Step 5: Send test request to JetStream MaxText server
## Step 5: Send a test request to JetStream MaxText server
In a new tab in your terminal, run the following command

```bash
cd ~
Expand All @@ -207,34 +212,125 @@ Response: to be a fan

## Step 6: Run benchmarks with JetStream MaxText server

Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following:
Note: The JetStream MaxText Server commands from Step 4 are not running with any quantization optimizations. To get the best benchmark results, we need to enable quantization for weights and KV cache. To do this, first generate AQT trained or fine-tuned checkpoints. Then, add the quantization flags and restart the server.

### Generating a quantized checkpoint

First, define the path to which the quantized checkpoint
```bash
export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-7b-chat
```

There are several different quantization configurations to choose from:

#### int8 DRQ quantized checkpoint
```bash
# Enable int8 quantization for both weights and KV cache
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

#### Weights-only int8 quantized checkpoint
```bash
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

#### Mixed precision weight-only quantized checkpoint
First, update the mixed precision config file (`MaxText/configs/quantization/mp_scale.json`) in MaxText repo to the mixed-precision-config defined below.
```
{
".*/query": {"bits": 4, "scale": 0.8},
".*/key": {"bits": 4, "scale": 0.9},
".*/value": {"bits": 8},
".*/out": {"bits": 4},
".*/wi_0": {"bits": 4},
".*/wo": {"bits": 8}
}
```
Then run the following command:
```bash
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMETERS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp
quant_cfg_path=configs/quantization/mp_scale.json save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

### Restart the server with quantization flags

#### Set flags

Setting base quantization flags
```bash
# To load an int8 DRQcheckpoint
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH}
export CHECKPOINT_IS_QUANTIZED=True

# To load an int8 weight-only checkpoint
export QUANTIZATION=int8w
export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH}
export CHECKPOINT_IS_QUANTIZED=True

# To load a Mixed-Precision quantized checkpoint
# If using Mixed-Precision mode, make sure to update the mixed precision config file to the same file as used for quantizing the checkpoint (MaxText/configs/quantization/mp_scale.json)
export QUANTIZATION=intmp
export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH}
export CHECKPOINT_IS_QUANTIZED=True
export QUANT_CFG_PATH=configs/quantization/mp_scale.json
```

The KV-cache is quantized to int8 by using the following config params
```bash
export QUANTIZE_KVCACHE=True
```
If you don't want to quantize the KV-cache, set
```bash
export QUANTIZE_KVCACHE=False
```


#### Restart server
```bash
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE} \
checkpoint_is_quantized=${CHECKPOINT_IS_QUANTIZED}
```

For the mixed precision quantized model
```bash
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE} \
checkpoint_is_quantized=${CHECKPOINT_IS_QUANTIZED} \
quant_cfg_path=${QUANT_CFG_PATH}
```


### Benchmarking Gemma-7b

Instructions
Expand All @@ -261,11 +357,12 @@ python JetStream/benchmarks/benchmark_serving.py \
--request-rate 5 \
--warmup-mode sampled
```
For details, please see https://github.com/google/JetStream/blob/main/benchmarks/README.md

### Benchmarking Llama2-\*b
### Benchmarking Llama2

```bash
# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2).
# The command is the same as that for the Gemma-7b, except for the tokenizer. Since we need to use a tokenizer that matches the model, it should now be tokenizer.llama2.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
Expand All @@ -276,17 +373,19 @@ python JetStream/benchmarks/benchmark_serving.py \
--request-rate 5 \
--warmup-mode sampled
```
For details, please see https://github.com/google/JetStream/blob/main/benchmarks/README.md

## Clean Up

```bash
# Clean up gcs buckets.
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up python virtual environment
rm -rf .env
```
12 changes: 4 additions & 8 deletions jetstream/tools/maxtext/model_ckpt_conversion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,21 @@ export MODEL=$1
export MODEL_VARIATION=$2
export MODEL_NAME=${MODEL}-${MODEL_VARIATION}

# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET
# Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Point these variables to a GCS bucket that you created.
# An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION}
export CHKPT_BUCKET=$3
export MODEL_BUCKET=gs://${USER}-maxtext
export MODEL_BUCKET=$4

# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run.
export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs

# Point `DATASET_PATH` to the GCS bucket where you have your training data.
export DATASET_PATH=gs://${USER}-maxtext-dataset
# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint.
export BASE_OUTPUT_DIRECTORY=$5

export BUCKET_LOCATION=US

# Create three GCS buckets for the demo.
gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true
gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true
gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true

# Convert model checkpoints to MaxText compatible checkpoints.
if [ "$MODEL" == "gemma" ]; then
Expand Down

0 comments on commit 647ab24

Please sign in to comment.