-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathDockerfile.t5x
84 lines (71 loc) · 3.38 KB
/
Dockerfile.t5x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:upstream-t5x
ARG GIT_USER_EMAIL=jax@nvidia.com
ARG GIT_USER_NAME=NVIDIA
# If set to "true", then will pull new local patches, the manifest.yaml and create-distribution.sh (in case it was updated).
# This is useful for development if you run `./bump.sh -i manifest.yaml` manually and do not want to trigger a full rebuild all
# the way up to the jax build.
ARG UPDATE_PATCHES=false
# It is common for TE developers to test a different TE against the LLM application. This is a knob to override what's in the manifest
# Accepts git-ref's from NVIDIA/TransformerEngine or pull requests (pull/$number/head)
ARG UPDATED_TE_REF=""
# Rosetta and optionally patches are pulled from this
FROM scratch AS jax-toolbox
###############################################################################
### Download source and add auxiliary scripts
################################################################################
FROM ${BASE_IMAGE} AS mealkit
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME
ARG UPDATE_PATCHES
ARG UPDATED_TE_REF
ENV ENABLE_TE=1
RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu
MANIFEST_DIR=$(dirname ${MANIFEST_FILE})
if [[ "${UPDATE_PATCHES}" != "true" && "${UPDATE_PATCHES}" != "false" ]]; then
echo "UPDATE_PATCHES can only be true or false"
exit 1
fi
if [[ "${UPDATE_PATCHES}" == "true" ]]; then
cp -r /mnt/jax-toolbox/.github/container/patches ${MANIFEST_DIR}/
cp /mnt/jax-toolbox/.github/container/manifest.yaml ${MANIFEST_DIR}/manifest.yaml
cp /mnt/jax-toolbox/.github/container/create-distribution.sh ${MANIFEST_DIR}/create-distribution.sh
# TODO: remove
cp /mnt/jax-toolbox/.github/container/pip-finalize.sh /usr/local/bin/
fi
cp -r /mnt/jax-toolbox/rosetta /opt/rosetta
if [[ -n "${UPDATED_TE_REF}" ]]; then
TE_INSTALL_DIR=/opt/transformer-engine
yq e ".transformer-engine.latest_verified_commit = \"${UPDATED_TE_REF}\"" -i $MANIFEST_FILE
# Install from source instead of pre-built wheel
sed -i -E 's@( file:///opt/transformer-engine)/dist/[^ ]*@\1@' /opt/pip-tools.d/requirements-te.in
git -C $TE_INSTALL_DIR fetch -a
if [[ "${UPDATED_TE_REF}" =~ ^pull/ ]]; then
PR_ID=$(cut -d/ -f2 <<<"${UPDATED_TE_REF}")
git -C $TE_INSTALL_DIR fetch origin ${UPDATED_TE_REF}:PR-${PR_ID}
git -C $TE_INSTALL_DIR checkout PR-${PR_ID}
else
git -C $TE_INSTALL_DIR checkout ${UPDATED_TE_REF}
fi
fi
# Setting the username/email is required to author commits from patches
git config --global user.email "${GIT_USER_EMAIL}"
git config --global user.name "${GIT_USER_NAME}"
bash ${MANIFEST_DIR}/create-distribution.sh \
--manifest ${MANIFEST_FILE} \
--package t5x
bash ${MANIFEST_DIR}/create-distribution.sh \
--manifest ${MANIFEST_FILE} \
--package flax
# Remove .gitconfig to avoid end-user authoring commits as the "build user"
rm -f ~/.gitconfig
echo "--extra-index-url https://developer.download.nvidia.com/compute/redist" >> /opt/pip-tools.d/requirements-t5x.in
echo "-e file:///opt/rosetta" >> /opt/pip-tools.d/requirements-t5x.in
EOF
WORKDIR /opt/rosetta
COPY --from=jax-toolbox rosetta/tests/test-vit.sh /usr/local/bin
###############################################################################
### Install accumulated packages from the base image and the previous stage
################################################################################
FROM mealkit as final
RUN pip-finalize.sh