Skip to content

Commit

Permalink
verification
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZijunZhou committed Mar 30, 2024
1 parent bf74b0d commit 68ed383
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
19 changes: 11 additions & 8 deletions jetstream/tools/maxtext/model_ckpt_conversion.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,25 @@ gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || t
# Copy the downloaded checkpoints to `CHKPT_BUCKET`.
# Gemma example: gsutil -m cp -r 7b ${CHKPT_BUCKET}
# Llama2 example: gsutil -m cp -r llama-2-7b ${CHKPT_BUCKET}
gsutil -m cp -r $3 ${CHKPT_BUCKET}
sudo gsutil -m cp -r $3 ${CHKPT_BUCKET}

# Covert model checkpoints to MaxText compatible checkpoints.
if [ "$MODEL" == "gemma" ]; then
CONVERT_CKPT_SCRIPT="convert_gemma_chkpt.py"
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
--base_model_path ${CHKPT_BUCKET} \
--maxtext_model_path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
--model_size ${MODEL_VARIATION}
else
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
pip install torch --index-url https://download.pytorch.org/whl/cpu
CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py"
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
--base-model-path ${CHKPT_BUCKET} \
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
--model-size ${MODEL_VARIATION}
fi

JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
--base-model-path ${CHKPT_BUCKET} \
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
--model-size ${MODEL_VARIATION}
echo "Writen MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}"
echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}"

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
Expand All @@ -81,7 +84,7 @@ load_parameters_path=${CONVERTED_CHECKPOINT} \
run_name=${RUN_NAME} \
model_name=${MODEL_NAME} \
force_unroll=true
echo "Writen MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"
echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"

# We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections.
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items
5 changes: 4 additions & 1 deletion jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs
# Point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://${USER}-maxtext-dataset

# Prepare C4 dataset for fine tuning: https://github.com/allenai/allennlp/discussions/5056
sudo gsutil -u $3 -m cp 'gs://allennlp-tensorflow-datasets/c4/en/3.0.1/*' ${DATASET_PATH}/c4/en/3.0.1/

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items

Expand Down Expand Up @@ -73,7 +76,7 @@ load_parameters_path=${AQT_CKPT} \
run_name=${RUN_NAME} \
model_name=${MODEL_NAME} \
force_unroll=true
echo "Writen MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"
echo "Written MaxText unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints"

# We will use the unscanned checkpoints by passing `UNSCANNED_CKPT_PATH` into `LOAD_PARAMETERS_PATH` in the following sections.
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items

0 comments on commit 68ed383

Please sign in to comment.