Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTuner] Add GBS search, gpu memory usage #55466

Merged
merged 19 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=None):
"""
micro_batch_size = cur_cfg.get("micro_batch_size", None)
global_batch_size = tuner_cfg["model_cfg"].get("global_batch_size", None)
if global_batch_size == "auto":
global_batch_size = cur_cfg["global_batch_size"]
if global_batch_size:
local_batch_size = (
global_batch_size
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pandas as pd


class History_recorder:
class HistoryRecorder:
# NOTE increase extenable ablitity
def __init__(self) -> None:
self.history = []
Expand Down Expand Up @@ -63,7 +63,9 @@ def store_history(self, path="./history.csv"):
cols = df.columns.tolist()
cols.insert(0, cols.pop(cols.index('job_id')))
df = df.reindex(columns=cols)
df = df.drop(columns=['time'])
# check if 'time' exists
if 'time' in df.columns:
df = df.drop(columns=['time'])
# write to csv
df.to_csv(self.store_path, index=False)

Expand All @@ -79,3 +81,7 @@ def load_history(self, path="./history.csv") -> Tuple[list, bool]:
reader = csv.reader(f)
self.history = list(reader)
return (self.history, err)

def clean_history(self) -> None:
"""Clean history."""
self.history = []
23 changes: 22 additions & 1 deletion python/paddle/distributed/auto_tuner/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from abc import ABC, abstractmethod

from .prune import _PRUNE_FUNC
from .utils import search_all
from .utils import gbs_search_all, search_all


class SearchAlgo(ABC):
Expand Down Expand Up @@ -52,3 +52,24 @@ def search_once(self, history_cfgs):
else:
return None
return new_cfg


class GBSSearch(SearchAlgo):
def __init__(self, tuner_cfg):
super().__init__(tuner_cfg)
self.idx = 0
self.all_tasks = gbs_search_all(tuner_cfg)

def search_once(self, history_cfgs):
new_cfg = None
stop = False
while not stop:
if self.idx < len(self.all_tasks):
new_cfg = self.all_tasks[self.idx]
self.idx += 1
glb = new_cfg.get("global_batch_size", None)
self.tuner_cfg["model_cfg"]["global_batch_size"] = glb
stop = not self.prune(self.tuner_cfg, new_cfg, history_cfgs)
else:
return None
return new_cfg
9 changes: 7 additions & 2 deletions python/paddle/distributed/auto_tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from .utils import default_candidates
from .utils import default_candidates, gbs_default_candidates


class AutoTuner:
Expand All @@ -29,13 +29,18 @@ def __init__(self, tuner_cfg):
self.cur_task_id = 1
self.task_limit = tuner_cfg.get("task_limit", 100)

tuner_cfg["candidates"] = default_candidates(tuner_cfg)
search_algo = tuner_cfg.get("search_algo", "grid")

if search_algo == "grid":
from .search import GridSearch

tuner_cfg["candidates"] = default_candidates(tuner_cfg)
self.algo = GridSearch(tuner_cfg)
elif search_algo == "gbs":
from .search import GBSSearch

tuner_cfg["candidates"] = gbs_default_candidates(tuner_cfg)
self.algo = GBSSearch(tuner_cfg)
else:
raise NotImplementedError()

Expand Down
208 changes: 199 additions & 9 deletions python/paddle/distributed/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import csv
import itertools
import os
import re
Expand Down Expand Up @@ -320,38 +321,227 @@ def gen_new_args(raw_args, cfg, tuner_cfg):
return res_args


def read_log(
def read_metric_log(
path, file="workerlog.0", target_metric='step/s'
) -> Tuple[float, bool]:
) -> Tuple[float, int]:
"""For extracting metric from log file."""
"""
return:
metric: average metric of last 10 steps
err_code:
00: no error
01: no metric
10: out of memory
"""
err_code = 0
target_file = path + "/" + file
if not os.path.exists(target_file):
return (0.0, True)
return (0.0, 1)
with open(target_file, "r") as f:
# read file
re_metric_pattern = (
target_metric + r":* *(\d+(\.\d*)?)|(\d+(\.\d*)?) *" + target_metric
)

re_out_of_memory_pattern = r"Out of memory"
out_of_memory_flag = 0
metric_list = []
lines = f.readlines()
for line in lines:
metric = re.findall(re_metric_pattern, line)
out_of_memory = re.findall(
re_out_of_memory_pattern, line, re.IGNORECASE
)
if metric:
metric_list.append(float(metric[0][0]))
if out_of_memory:
out_of_memory_flag = 1

if out_of_memory_flag:
metric_ave = 0.0
err_code = err_code | (out_of_memory_flag << 1)
if not metric_list:
metric_ave = 0.0
flag = True
err_code = err_code | 1
elif len(metric_list) < 10:
metric_ave = metric_list[-1]
flag = False
elif len(metric_list) < 20:
metric_ave = sum(metric_list[9:]) / (len(metric_list[9:]))
flag = False
else:
metric_ave = sum(metric_list[-10:]) / 10
flag = False
# round to 5 decimal places
metric_ave = round(metric_ave, 5)
res = metric_ave, flag
res = metric_ave, err_code
return res


def read_memory_log(path, file) -> Tuple[float, bool]:
log_path = os.path.join(path, file)
if not os.path.exists(log_path):
return (0.0, True)
memory_used = []
utilization_gpu = []
indexs = []

with open(log_path, 'r') as f:
reader = csv.reader(f)
flag = False
# skip headers
while not flag:
# show the first line of reader
row = next(reader)
if len(row) == 6 and 'memory_used' in row:
flag = True
for row in reader:
# If row length is 6 then it's a utilization data row
# skip header
if len(row) == 6:
index, util_gpu, _, mem_used, _, _ = row
indexs.append(int(index))
memory_used.append(int(mem_used))
utilization_gpu.append(int(util_gpu))
return max(memory_used), False


def read_log(
path,
metric_file="workerlog.0",
target_metric='step/s',
memory_file="0.gpu.log",
) -> Tuple[float, float, int]:
"""
extract metric and max memory usage from log file
return:
metric: average metric of last 10 steps
memory: max memory used
err_code: 00: no error, 01: no metric, 10: out of memory, 100: no memory log
"""
err_code = 0
# check out of memory
for root, dirs, files in os.walk(path):
for file in files:
if not file.startswith("workerlog"):
continue
metric, metric_flag = read_metric_log(path, file, target_metric)
if metric_flag:
err_code = (metric_flag & 2) | err_code

# read metric
res_metric, metric_flag = read_metric_log(path, metric_file, target_metric)
err_code = metric_flag | err_code
# check max memory usage
try:
res_memory, memory_flag = read_memory_log(path, memory_file)
err_code = (memory_flag << 2) | err_code
except:
res_memory = 0.0
err_code = (1 << 2) | err_code
return res_metric, res_memory, err_code


def three_mul_combinations(target):
"""Return the combinations of three numbers which product is target."""
results = []
for i in range(1, target // 3 + 1):
if target % i == 0:
for j in range(i, target // 2 + 1):
if (target // i) % j == 0:
results.append((i, j, target // i // j))
return results


def gbs_dp_mp_pp_candidates(tuner_cfg, num_gpus, num_nodes):
"""Return middle candidates of dp, mp, pp"""

start = round(num_gpus ** (1 / 3))

# find factors that can be evenly distributed
for i in range(start, 0, -1):
if num_gpus % i == 0:
remaining = num_gpus // i
# find the square root as a factor for the remaining part
j = round(remaining**0.5)
while remaining % j != 0:
j -= 1
return i, j, remaining // j

raise ValueError("Cannot distribute GPUs equally")


def gbs_default_candidates(tuner_cfg):
"""Return the default candidates of every hyper param which user defined auto"""
candidates = {}
num_gpus = tuner_cfg["num_gpus"]
num_nodes = tuner_cfg["nodes"]
assert num_gpus > 0
global_batch_size = tuner_cfg.get("model_cfg", {}).get(
"global_batch_size", "auto"
)
if global_batch_size == "auto":
dp_candidate, mp_candidate, pp_candidate = gbs_dp_mp_pp_candidates(
tuner_cfg, num_gpus, num_nodes
)
sharding_dgree_candidate = dp_candidate
candidates["dp_degree"] = [1]
candidates["mp_degree"] = [mp_candidate]
candidates["pp_degree"] = [pp_candidate]
candidates["sharding_degree"] = [sharding_dgree_candidate]
candidates["sharding_stage"] = [1]
candidates["use_recompute"] = [False]
candidates["recompute_granularity"] = [None]
candidates["micro_batch_size"] = [2**i for i in range(0, 10)]
candidates["global_batch_size"] = [
pp_candidate * dp_candidate * e
for e in candidates["micro_batch_size"]
]
return candidates


def gbs_search_all(tuner_cfg):
"""Permutate the candidates of all hyper params."""
candidates = tuner_cfg["candidates"]
# Order: dp -> mp -> pp -> mbs -> sharding-> recompute
dp_degree_candidates = candidates["dp_degree"]
mp_degree_candidates = candidates["mp_degree"]
pp_degree_candidates = candidates["pp_degree"]
mbs_candidates = candidates["micro_batch_size"]
sharding_stage_candidates = candidates["sharding_stage"]
sharding_degree_candidates = candidates["sharding_degree"]
use_recompute_candidates = candidates["use_recompute"]
recompute_granularity_candidates = candidates["recompute_granularity"]
# gbs_candidates = candidates["global_batch_size"]
all_cfgs = list(
itertools.product(
dp_degree_candidates,
mp_degree_candidates,
pp_degree_candidates,
mbs_candidates,
sharding_degree_candidates,
sharding_stage_candidates,
use_recompute_candidates,
recompute_granularity_candidates,
# gbs_candidates,
)
)
mapping = {
0: "dp_degree",
1: "mp_degree",
2: "pp_degree",
3: "micro_batch_size",
5: "sharding_stage",
4: "sharding_degree",
6: "use_recompute",
7: "recompute_granularity",
# 8: "global_batch_size",
}
new_all_cfgs = []
for cfg in all_cfgs:
new_cfg = {}
for idx, val in enumerate(cfg):
new_cfg[mapping[idx]] = val
new_cfg["global_batch_size"] = (
new_cfg["pp_degree"]
* new_cfg["dp_degree"]
* new_cfg["micro_batch_size"]
)
new_all_cfgs.append(new_cfg)
return new_all_cfgs
12 changes: 11 additions & 1 deletion python/paddle/distributed/launch/controllers/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,17 @@ def register_heartbeat(self, job_id, pod_id, ttl=10):
delete_success = True
except:
time.sleep(1)
lease = self.client.lease(ttl)

if self.ctx.is_auto_tuner_mode():
lease_success = False
while not lease_success:
try:
lease = self.client.lease(ttl)
lease_success = True
except:
time.sleep(1)
else:
lease = self.client.lease(ttl)

# self.client.delete_prefix(self.job_prefix)

Expand Down
Loading