Skip to content

Commit

Permalink
🧹 Cleanup and fixes for TGI (#96)
Browse files Browse the repository at this point in the history
* feat(Jetstream Pt): set torch default dtype to bfloat16

This is what should be used on TPUs.

* fix(warmup): make warmup work for smallest prefill size

Also, add timing checks in warmup test.

* chore(pytest): ignore common warnings

* chore(docker): add jetstream installation step to TGI image

* feat(tgi): update TGI version

- TGI version updated from 2.0.3 to
  0ff6ff60ada291840beed63d8bf458d6f9606f7f, that is essentially 2.3.0 +
  few fixes to get the v2 proto interface working again.
- This update was done because otherwise debug logs were not working.
  This can be complicated if we need to debug something in TGI, and so
  far the only solution was a hack forcing to re-add the debug in the
  server. This is now fixed and with the Jetstream Pt generator logs are
  fine now.
- Obviously there was a drawback 😖 Logs on threads spawned by the
  Pytorch/XLA generator were now all weird and always appearing even
  when debug was off. This has been fixed, but the workaround is not
  very nice (I set an env var). I think the multithread generator is
  going to go away soon anyway, so this should not be a big deal.
- The new TGI version is built using Python 3.11, while so far with
  optimum-tpu we have worked on Python 3.10, because that is what they
  say it should be used on Pytorch/XLA front page. So the image has been
  updated with the python3.11 support and the required transformers
  installation for now, because it's easier since they run on separate
  processes.
- TGI build process has changed a bit, so dockerfile has been changed
  accordingly. The text-generation-router-v2 is renamed into
  text-generation-router because that is what the launcher expects.

* chore(install): install torch for cpu

This extra step will make images leaner as it will avoid having the
gpu dependencies.
  • Loading branch information
tengomucho authored Sep 27, 2024
1 parent 01d3a42 commit f5ad698
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tpu-tgi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
TGI_VERSION=v2.0.3
TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f
- name: Generate artifact attestation for TGI
Expand All @@ -95,7 +95,7 @@ jobs:
labels: ${{ steps.meta-ie.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
TGI_VERSION=v2.0.3
TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f
target: inference-endpoint


Expand Down
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))

.PHONY: build_dist style style_check clean

TGI_VERSION ?= v2.0.3
# Ths is essentially v2.3.0 plus a fix to support v2 proto interface
TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f

rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*))))

Expand Down Expand Up @@ -51,7 +52,7 @@ tpu-tgi:

tpu-tgi-ie:
docker build --rm -f text-generation-inference/docker/Dockerfile \
--target inference-endpoints \
--target inference-endpoint \
--build-arg VERSION=$(VERSION) \
--build-arg TGI_VERSION=$(TGI_VERSION) \
-t huggingface/optimum-tpu:$(VERSION)-tgi .
Expand All @@ -76,6 +77,7 @@ pypi_upload: ${PACKAGE_DIST} ${PACKAGE_WHEEL}

# Tests
test_installs:
python -m pip install -r requirements.txt
python -m pip install .[tests] -f https://storage.googleapis.com/libtpu-releases/index.html

tests: test_installs
Expand Down
7 changes: 6 additions & 1 deletion install-jetstream-pt.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#!/bin/bash
THIS_DIR=$(dirname "$0")

deps_dir=deps
rm -rf $deps_dir
mkdir -p $deps_dir


# install torch cpu to avoid GPU requirements
pip install -r $THIS_DIR/requirements.txt
cd $deps_dir
pwd
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch
git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,10 @@ known-first-party = ["optimum.tpu"]
markers = [
"is_staging_test",
]
filterwarnings = [
"ignore:Some donated",
"ignore:The given NumPy array is not writable",
"ignore:`do_sample` is set",
"ignore:Device capability of jax",
"ignore:`tensorflow` can conflict",
]
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# This is not a complete list of dependencies, but it allows to install torch without CUDA support
--index-url https://download.pytorch.org/whl/cpu
torch==2.4.0
26 changes: 19 additions & 7 deletions text-generation-inference/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1

# Build cargo components (adapted from TGI original Dockerfile)
# Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04)
FROM lukemathwalker/cargo-chef:latest-rust-1.77-bookworm AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
Expand All @@ -20,28 +20,32 @@ COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml
COPY --from=tgi /tgi/proto proto
COPY --from=tgi /tgi/benchmark benchmark
COPY --from=tgi /tgi/router router
COPY --from=tgi /tgi/backends backends
COPY --from=tgi /tgi/launcher launcher
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json

COPY --from=tgi /tgi/Cargo.toml Cargo.toml
COPY --from=tgi /tgi/Cargo.lock Cargo.lock
COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml
COPY --from=tgi /tgi/proto proto
COPY --from=tgi /tgi/benchmark benchmark
COPY --from=tgi /tgi/router router
COPY --from=tgi /tgi/backends backends
COPY --from=tgi /tgi/launcher launcher
RUN cargo build --release --workspace --exclude benchmark
RUN cargo build --profile release-opt

# Python base image
FROM ubuntu:22.04 AS base
Expand Down Expand Up @@ -85,6 +89,8 @@ ARG VERSION=${VERSION}
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
libpython3.10 \
libpython3.11 \
python3.11 \
git \
gnupg2 \
wget \
Expand All @@ -107,17 +113,23 @@ ENV HUGGINGFACE_HUB_CACHE=/data \

COPY . /opt/optimum-tpu

# Install requirements for TGI, that uses python3.11
RUN python3.11 -m pip install transformers==${TRANSFORMERS_VERSION}

# Install requirements for optimum-tpu, then for TGI then optimum-tpu
RUN python3 -m pip install hf_transfer safetensors==${SAFETENSORS_VERSION} && \
python3 -m pip install -e /opt/optimum-tpu[jetstream-pt] \
RUN python3 -m pip install hf_transfer safetensors==${SAFETENSORS_VERSION}
RUN bash /opt/optimum-tpu/install-jetstream-pt.sh
RUN python3 -m pip install -e /opt/optimum-tpu[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
-f https://storage.googleapis.com/libtpu-releases/index.html

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release-opt/text-generation-router-v2 /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Install python server
COPY --from=pyserver /pyserver/build/dist dist
RUN pip install dist/text_generation_server*.tar.gz
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pkg_name := text_generation_server
BUILDDIR ?= $(CURDIR)/build
VERSION ?= 0.0.1
TGI_VERSION ?= v2.0.3
TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
pkg_dir := $(BUILDDIR)/$(pkg_name)
Expand Down
15 changes: 15 additions & 0 deletions text-generation-inference/server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def serve(
uds_path: str = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
"""This is the main entry-point for the server CLI.
Expand All @@ -37,6 +39,10 @@ def serve(
The server logger level. Defaults to *INFO*.
json_output (`bool`):
Use JSON format for log serialization.
otlp_service_name (`str`):
The name of the OTLP service. For now it is ignored.
max_input_tokens (`Optional[int]`):
The maximum number of tokens allowed in the input. For now it is ignored.
"""
if sharded:
raise ValueError("Sharding is not supported.")
Expand All @@ -55,6 +61,15 @@ def serve(
if trust_remote_code is not None:
logger.warning("'trust_remote_code' argument is not supported and will be ignored.")

# TODO: these two parameters are used when the server is started, but they are not used yet, so just inform the
# user about that.
logger.info("'otlp_service_name' argument is not supported and will be ignored.")
logger.info("'max_input_tokens' argument is not supported and will be ignored.")

# This is a workaround to pass the logger level to other threads, it's only used in
# Pytorch/XLA generator.
os.environ["LOGGER_LEVEL_GENERATOR"] = logger_level

# Import here after the logger is added to log potential import exceptions
from optimum.tpu.model import fetch_model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,14 +814,16 @@ def _mp_fn(
mailbox = AgentMailbox(root_mailbox)

# re-init logger for each child process
logger_level = os.environ.get("LOGGER_LEVEL_GENERATOR", "DEBUG")
logger.logger.remove()
logger.logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level="DEBUG",
level=logger_level,
backtrace=True,
diagnose=False,
)
logger.info(f'😈 here! {logger_level}')

logger.debug(
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ def warmup(self, batch: Batch) -> int:
# Skip all the unsupported lengths
if l > bucket_seq_len:
continue
# create a dummy request with the current sequence length
dummy_request = self._create_dummy_request(l)
# create a dummy request with the current sequence length -1 (so it gets padded up to l)
dummy_request = self._create_dummy_request(l - 1)
# We define few max_new_tokens to request at least one (by prefill) and another by decode.
MAX_NEW_TOKENS = 10
dummy_request.stopping_parameters.max_new_tokens = MAX_NEW_TOKENS
Expand Down Expand Up @@ -671,6 +671,7 @@ def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, ma
logger.warning("Revision is not supported for JetStream/Pytorch engine, ignoring.")
logger.info("Loading model engine (this can take a few minutes).")
start = time.time()
torch.set_default_dtype(torch.bfloat16)
engine = create_engine(
model_path,
max_batch_size,
Expand Down
23 changes: 20 additions & 3 deletions text-generation-inference/tests/test_warmup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@


from time import time

import pytest
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
Expand All @@ -14,17 +16,32 @@ def test_warmup_jetstream_pytorch():
model_id = "Maykeye/TinyLLama-v0"

# The maximum sequence length of the model is set to 1000, but warmup will round that up to the next power of two
# in prefill (1024).
sequence_length = 1000
# in prefill (256).
sequence_length = 250

model_path = prepare_model(model_id, sequence_length)
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = 20

generator = AutoGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length
model_path, revision="", max_batch_size=2, max_sequence_length=sequence_length
)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length)
generator.warmup(batch)

# Prepare a new request with different settings. Warmup should have triggered compilation so this can be run
# quickly.
input_text = "What is Deep Learning?"
max_new_tokens = 3
max_tokens = 13
request1 = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=1, requests=[request1], size=1, max_tokens=max_tokens)

start = time()
_generations, new_batch = generator.prefill(batch)
_generations, new_batch = generator.decode([new_batch])
end = time()

# Prefill and decode time should be less than 1 second (rather fast)
assert end - start < 1.0

0 comments on commit f5ad698

Please sign in to comment.