Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generation using Jetstream Pytorch #94

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch

on:
push:
branches: [ main ]
paths:
- "text-generation-inference/**"
pull_request:
branches: [ main ]
paths:
- "text-generation-inference/**"

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
do-the-job:
name: Run TGI tests - Jetstream Pytorch
runs-on: optimum-tpu
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged
env:
PJRT_DEVICE: TPU
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Build and test TGI server
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test_jetstream
11 changes: 0 additions & 11 deletions .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
name: Run TGI tests
runs-on: optimum-tpu
container:
# Use a nightly image that works with TPU (release was not working)
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged
env:
Expand All @@ -31,13 +30,3 @@ jobs:
- name: Build and test TGI server
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test

# Use a different step to test the Jetstream Pytorch version, to avoid conflicts with torch-xla[tpu]
- name: Install and test TGI server (Jetstream Pytorch)
run: |
pip install -U .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests -k jetstream
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,6 @@ dmypy.json
*.pt

.vscode
.idea/
.idea/

jetstream-pt-deps
14 changes: 13 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)
python -m build

clean:
rm -rf dist
rm -rf dist deps
make -C text-generation-inference/server/ clean

tpu-tgi:
Expand Down Expand Up @@ -87,6 +87,18 @@ tgi_server:
make -C text-generation-inference/server clean
VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server

jetstream_requirements:
bash install-jetstream-pt.sh
python -m pip install .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html

tgi_test_jetstream: test_installs jetstream_requirements tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
JETSTREAM_PT=1 python -m pytest -sv text-generation-inference/tests -k jetstream

tgi_test: test_installs tgi_server
find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \
-exec python -m pip install --force-reinstall {} \;
Expand Down
13 changes: 13 additions & 0 deletions install-jetstream-pt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
deps_dir=deps
rm -rf $deps_dir
mkdir -p $deps_dir
cd $deps_dir
pwd
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch
git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921
git submodule update --init --recursive
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes,
# because it will install its dependendencies from that directory.
pip install -e .
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ build-backend = "setuptools.build_meta"
[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Jetstream/Pytorch support is experimental for now, it needs to be installed manually.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921",
"jetstream-pt",
"torch-xla[pallas] == 2.4.0"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def instantiate_model_from_repo_id(
env.device = "meta"
model = create_model(model_dir, env)
weights = fetch_models._load_weights(model_dir)
updated_keys = model.get_hf_names_to_real_name()
for name, updated in updated_keys.items():
if name in weights:
val = weights.pop(name)
weights[updated] = val
weights = model.convert_hf_weights(weights)

model.load_state_dict(weights, assign=True, strict=False)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
import os
import time
from enum import Enum
from typing import List, Optional, Tuple
Expand All @@ -9,7 +10,7 @@
import numpy as np
import torch
import torch_xla2
from jetstream.engine.token_utils import pad_tokens, take_nearest_length, DEFAULT_PREFILL_BUCKETS
from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS, pad_tokens, take_nearest_length
from jetstream_pt.engine import PyTorchEngine
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
Expand Down Expand Up @@ -330,6 +331,9 @@ def warmup(self, batch: Batch) -> int:
# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
if os.environ.get("SKIP_WARMUP", "0") == "1":
logger.debug("Skipping warmup")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this is used mostly for debug only? Or can it be turned on for other reasons? In the later case I would use logger.warning if not, logger.debug is fine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say it's only for debugging. Warmup will check if model can fit in memory and prepare inference so prefill and decode is rather fast afterwards, but it can last around 4-5 minutes, so it can be annoying for debugging the container.

return batch_size * seq_len
bucket_seq_len = take_nearest_length(DEFAULT_PREFILL_BUCKETS, seq_len)
decode_done = False
for l in reversed(DEFAULT_PREFILL_BUCKETS):
Expand Down
Loading