Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move some auto_parallel args into class AutoTrainingArguments #9155

Merged
merged 10 commits into from
Sep 30, 2024
21 changes: 7 additions & 14 deletions llm/auto_parallel/gpt-3/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import paddle.distributed as dist

from paddlenlp.ops import Topology
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
get_last_checkpoint,
)
from paddlenlp.trainer.auto_trainer import AutoTrainer
from paddlenlp.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds
from paddlenlp.transformers import (
Expand Down Expand Up @@ -59,8 +63,8 @@ def docstring_decorator(fn):


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
@add_start_docstrings(AutoTrainingArguments.__doc__)
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -77,12 +81,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -124,11 +122,6 @@ def __post_init__(self):
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

logger.info(self.strategy)


Expand Down
41 changes: 7 additions & 34 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from paddle.distributed import fleet

from paddlenlp.ops import Topology
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
get_last_checkpoint,
)
from paddlenlp.trainer.auto_trainer import AutoTrainer
from paddlenlp.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds
from paddlenlp.transformers import (
Expand Down Expand Up @@ -62,8 +66,8 @@ def docstring_decorator(fn):


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
@add_start_docstrings(AutoTrainingArguments.__doc__)
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -80,22 +84,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
fuse_allreduce_split_to_reducescatter: bool = field(
default=False,
metadata={"help": "Enable fuse_allreduce_split_to_reducescatter pass."},
)
eliminate_transpose: bool = field(
default=False,
metadata={
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -137,21 +125,6 @@ def __post_init__(self):
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.fuse_allreduce_split_to_reducescatter:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fuse_allreduce_split_to_reducescatter_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("eliminate_transpose")

logger.info(self.strategy)


Expand Down
14 changes: 3 additions & 11 deletions llm/auto_parallel/llama/run_pretrain_auto_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from paddlenlp.ops import Topology
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
Trainer,
TrainingArguments,
Expand Down Expand Up @@ -88,7 +89,7 @@ def exec_mode_guard():

@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -99,12 +100,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand All @@ -127,10 +122,7 @@ class PreTrainingArguments(TrainingArguments):
def __post_init__(self):
super().__post_init__()
assert self.enable_auto_parallel
if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

logger.info(self.strategy)


Expand Down
21 changes: 7 additions & 14 deletions llm/auto_parallel/qwen/run_pretrain_3D_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import paddle.distributed as dist
from paddle.distributed import fleet

from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
from paddlenlp.trainer import (
AutoTrainingArguments,
PdArgumentParser,
get_last_checkpoint,
)
from paddlenlp.trainer.auto_trainer import AutoTrainer
from paddlenlp.trainer.trainer_utils import IntervalStrategy
from paddlenlp.transformers import (
Expand Down Expand Up @@ -60,8 +64,8 @@ def docstring_decorator(fn):


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class PreTrainingArguments(TrainingArguments):
@add_start_docstrings(AutoTrainingArguments.__doc__)
class PreTrainingArguments(AutoTrainingArguments):
min_learning_rate: float = field(
default=1e-5,
metadata={"help": "Minimum learning rate deacyed to."},
Expand All @@ -78,12 +82,6 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ."
},
)
fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -133,11 +131,6 @@ def __post_init__(self):
self.save_strategy = IntervalStrategy.NO
self.evaluation_strategy = IntervalStrategy.NO

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

logger.info(self.strategy)


Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .argparser import *
from .auto_training_args import *
from .compression_args import *
from .plugins.timer import *
from .trainer import *
Expand Down
62 changes: 62 additions & 0 deletions paddlenlp/trainer/auto_training_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from .training_args import TrainingArguments
from .utils import add_start_docstrings


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class AutoTrainingArguments(TrainingArguments):
"""
Training Arguments for auto_parallel.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主库里面的代码,建议写英文注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


fused_linear_param_grad_add: bool = field(
default=False,
metadata={
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
fuse_allreduce_split_to_reducescatter: bool = field(
default=False,
metadata={"help": "Enable fuse_allreduce_split_to_reducescatter pass."},
)
eliminate_transpose: bool = field(
default=False,
metadata={
"help": "Enable eliminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
},
)

def __post_init__(self):
super().__post_init__()
assert self.enable_auto_parallel

if self.fused_linear_param_grad_add:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.fuse_allreduce_split_to_reducescatter:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fuse_allreduce_split_to_reducescatter_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("eliminate_transpose")
Loading
Loading