Skip to content

Commit

Permalink
shard_optimizer and ShardOptimizer API
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Nov 29, 2023
1 parent a4a66c8 commit 2c2ac66
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 8 deletions.
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
dtensor_from_fn,
reshard,
shard_layer,
ShardOptimizer,
)

from .fleet import BoxPSDataset # noqa: F401
Expand Down Expand Up @@ -157,4 +158,5 @@
"Shard",
"Replicate",
"Partial",
"ShardOptimizer",
]
152 changes: 151 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Callable

import paddle
Expand Down Expand Up @@ -406,3 +406,153 @@ def replicate_layer_params_and_buffers(
"`paddle.distributed.shard_layer` only supports dynamic graph mode "
"now. It will be supported for static graph mode later."
)


class ShardOptimizer:
"""
Warp the global view optimizer to distributed view.
The `shard_fn` should have the following signature:
def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator
Args:
optimizer (paddle.optimizer.Optimizer): The optimizer to be sharded.
shard_fn (optional, Callable): The function to shard accumulators. If not specified,
we simply pass down the dist attr of the params.
Method:
step(): same with optimzier.step()
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.ShardOptimizer(opt)
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

def __init__(self, optimizer, shard_fn=None):
assert (
paddle.in_dynamic_mode()
), "`paddle.distributed.ShardOptimizer` only supports dynamic mode for now."
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
assert isinstance(
optimizer, paddle.optimizer.AdamW
), "`paddle.distributed.ShardOptimizer` only supports AdamW optimizer for now."

self.target_block = (
paddle.base.framework.default_main_program().global_block()
)
optimizer.helper = paddle.base.layer_helper.LayerHelper(
optimizer.__class__.__name__
)
self._inner_opt = optimizer
self._shard_fn = shard_fn

def _shard_accumulator(self, param):
# create the accumulators
self._inner_opt._create_accumulators(self.target_block, [param])

target_name = param.name
if param.name in self._inner_opt._master_weights.keys():
target_name = self._inner_opt._master_weights[param.name].name

# shard the accumulators
for key in self._inner_opt._accumulators.keys():
accumulator = self._inner_opt._accumulators[key][target_name]
if accumulator.is_dist():
continue
if self._shard_fn is not None:
self._inner_opt._accumulators[key][
target_name
] = self._shard_fn(key, param, accumulator)
else:
if param.is_dist():
if 'beta' not in key:
# If param is a dist tensor should keep the shard info
# for accumulators except beta.
placements = param.placements
else:
# The beta should be replicated cross param's mesh
placements = [
dist.Replicate()
for _ in range(len(param.process_mesh.shape))
]
self._inner_opt._accumulators[key][
target_name
] = shard_tensor(
accumulator,
mesh=param.process_mesh,
placements=placements,
)

def _shard_optimizer(self, parameter_list):
if parameter_list is not None:
for p in parameter_list:
self.shard_accumulator(p)
else:
if not isinstance(self._inner_opt._parameter_list[0], dict):
for p in self._inner_opt._parameter_list:
self.shard_accumulator(p)
else:
for param_group in self._inner_opt._param_groups:
params = param_group['params']
for p in params:
self.shard_accumulator(p)

def step(self):
if not isinstance(self._inner_opt._parameter_list[0], dict):
params_grads = []
parameter_list = []
for param in self._inner_opt._parameter_list:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
parameter_list.append(param)
params_grads.append((param, grad_var))
self._inner_opt = self._shard_optimizer(
parameter_list=parameter_list
)
self._inner_opt._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
else:
for param_group in self._inner_opt._param_groups:
params_grads = defaultdict(lambda: [])
parameter_list = []
for param in param_group['params']:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
parameter_list.append(param)
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v for k, v in param_group.items() if k != 'params'}
)
self._inner_opt = self._shard_optimizer(
parameter_list=parameter_list,
)
self._inner_opt._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)

def __getattr__(self, item):
return getattr(self._inner_opt, item)
12 changes: 5 additions & 7 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ def test_adamw_mp(self):
opt.clear_grad()
for key in opt._accumulators.keys():
for k, v in opt._accumulators[key].items():
if 'momentum' in key:
if 'moment' in key:
assert opt._accumulators[key][k].is_dist()
if 'w' in k:
assert opt._accumulators[key][k].shape == [10, 10]
assert opt._accumulators[key][k]._local_shape == [10, 5]
else:
assert opt._accumulators[key][k].shape == [10]
assert opt._accumulators[key][k]._local_shape == [5]
assert (
opt._accumulators[key][k].shape[-1]
== opt._accumulators[key][k]._local_shape[-1] * 2
)
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

Expand Down
176 changes: 176 additions & 0 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist


class TestSemiAutoParallelShardOptimizerAPI:
def __init__(self):
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def check_tensor_eq(self, a, b, rtol=1e-05, atol=0, verbose=True):
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, verbose=verbose)

def get_single_card_rst(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.weight = linear.weight.numpy()
self.bias = linear.bias.numpy()

def shard_layer_fn(self, layer_name, layer, process_mesh):
layer.weight = dist.shard_tensor(
layer.weight, process_mesh, [dist.Shard(1)]
)
layer.bias = dist.shard_tensor(
layer.bias, process_mesh, [dist.Shard(0)]
)

def test_opt(self, opt):
for key in opt._accumulators.keys():
for k, v in opt._accumulators[key].items():
assert opt._accumulators[key][k].is_dist()
if 'moment' in key:
assert (
opt._accumulators[key][k].shape[-1]
== opt._accumulators[key][k]._local_shape[-1] * 2
)
else:
assert opt._accumulators[key][k].shape == [1]
assert opt._accumulators[key][k]._local_shape == [1]

def test_shard_optimizer_mp(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.ShardOptimizer(opt)
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.test_opt(opt)
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

def test_shard_optimizer_from_non_shard_layer(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.ShardOptimizer(opt)
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

def shard_opt_fn(self, accumulator_name, param, accumulator):
if param.is_dist():
if 'beta' not in accumulator_name:
placements = param.placements
else:
placements = [
dist.Replicate()
for _ in range(len(param.process_mesh.shape))
]
return dist.shard_tensor(
accumulator, param.process_mesh, placements
)
return accumulator

def test_shard_optimizer_shard_fn(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.ShardOptimizer(opt, self.shard_opt_fn)
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.test_opt(opt)

def test_shard_optimizer_master_params(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
linear = paddle.amp.decorate(linear, level="O2", dtype="float16")
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
opt = paddle.optimizer.AdamW(
parameters=linear.parameters(), multi_precision=True
)
opt = dist.ShardOptimizer(opt)
loss = linear(batch)
loss.backward()
opt.step()
self.test_opt(opt)
for k, v in opt._master_weights.items():
assert v.dtype == paddle.float32
assert v.is_dist()
assert v.shape[-1] == v._local_shape[-1] * 2

def test_shard_optimizer_params_group(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
linear.weight.optimize_attr = {'lr': 1}
linear.bias.optimize_attr = {'lr': 1}
params_group = [{'params': linear.weight}, {'params': linear.bias}]
opt = paddle.optimizer.AdamW(parameters=params_group)
opt = dist.ShardOptimizer(opt)
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.test_opt(opt)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")

self.get_single_card_rst()
self.test_shard_optimizer_params_group()
self.test_shard_optimizer_master_params()
self.test_shard_optimizer_shard_fn()
if self._backend == "gpu":
self.test_shard_optimizer_mp()
self.test_shard_optimizer_from_non_shard_layer()


if __name__ == '__main__':
TestSemiAutoParallelShardOptimizerAPI().run_test_case()
10 changes: 10 additions & 0 deletions test/auto_parallel/test_semi_auto_parallel_single_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def test_shard_optimizer(self):
user_defined_envs=envs,
)

def test_shard_optimizer_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_shard_optimizer_api.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2c2ac66

Please sign in to comment.