Skip to content

Commit

Permalink
add sharding tensor fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Sep 22, 2022
1 parent 1e07bdf commit 40c55a8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 3 additions & 1 deletion ppfleetx/models/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def process_optim_configs(config):

nranks = dist.get_world_size()
dp_degree = config['Distributed']['dp_degree']
sharding_degree = config['Distributed']['sharding']['sharding_degree']
if config['Optimizer']['tensor_fusion']:
assert nranks == dp_degree, "tensor_fusion only support single card train or data parallel train"
assert nranks == dp_degree * sharding_degree, \
"tensor_fusion only support single card train or data/sharding parallel train"


def process_data_configs(config):
Expand Down
7 changes: 6 additions & 1 deletion ppfleetx/optims/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sys
import paddle
import paddle.distributed.fleet as fleet

from ppfleetx.utils.tensor_fusion_helper import fused_parameters
from paddle.optimizer import Adam, AdamW, Momentum
Expand All @@ -30,9 +31,13 @@ class FusedAdamW(paddle.optimizer.AdamW):
def __init__(self, learning_rate, parameters, grad_clip, **config):
tensor_fusion = config.pop("tensor_fusion", False)

if paddle.distributed.get_world_size() > 1:
hcg = fleet.get_hybrid_communicate_group()
sharding_size = hcg.get_sharding_parallel_world_size()

if tensor_fusion:
self.decay_fused_tensors, self.all_fused_tensors = fused_parameters(
parameters)
parameters, sharding_size > 1)
decay_params = [p.name for p in self.decay_fused_tensors]
else:
decay_params = [
Expand Down

0 comments on commit 40c55a8

Please sign in to comment.