From d6f1c61fb6c05b8870f410ddc5ab380df8f8c205 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 30 Sep 2022 00:48:49 +0800 Subject: [PATCH] fix gpt N4C32 dp script bug (#3392) --- tests/test_tipc/benchmark_train.sh | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/test_tipc/benchmark_train.sh b/tests/test_tipc/benchmark_train.sh index 54422ba5f653..0eb23f72b685 100644 --- a/tests/test_tipc/benchmark_train.sh +++ b/tests/test_tipc/benchmark_train.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + source test_tipc/common_func.sh # run benchmark sh @@ -67,6 +82,12 @@ function set_gpu_id(){ echo $seq } +function get_world_size(){ + IFS="C" + arr=($1) + echo ${arr[1]} +} + function get_repo_name(){ IFS=";" cur_dir=$(pwd) @@ -202,10 +223,10 @@ for batch_size in ${batch_size_list[*]}; do # NOTE: Only for GPT for now. if [[ ${model_name} =~ gpt* ]]; then - num_gpu_devices=$[(${#gpu_id}+1)/2] + num_gpu_devices=`get_world_size $device_num` sed_norm_train=$norm_train - global_batch_size=$[$batch_size*$num_gpu_devices] + global_batch_size=$(($batch_size*$num_gpu_devices)) extra_params="--global_batch_size=$global_batch_size --dp_degree=$num_gpu_devices" sed_norm_train="$sed_norm_train $extra_params"