Skip to content

Commit

Permalink
[llm]add adam-mini (PaddlePaddle#9542)
Browse files Browse the repository at this point in the history
* add adam-mini

* fix following comments
  • Loading branch information
lugimzzz authored Dec 16, 2024
1 parent f3ba5b3 commit da41c4f
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并
--optim
优化器名称,默认为adamw,(`str`, 可选,默认为 `adamw`)
The optimizer to use. (default: adamw)
可能的值为:
- `"adamw"`
- `"adamw_mini"`
--report_to
日志可视化显示,默认使用visualdl可视化展示。(可选,默认为 None,展示所有)
Expand Down
1 change: 1 addition & 0 deletions llm/docs/dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo
- `unified_checkpoint`: 是否使用统一的 checkpoint,默认为 `True`
- `autotuner_benchmark`: 是否启用 autotuner 基准测试,默认为 `False`
- `benchmark`: 是否开启基准测试,默认为 `False`
- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`
### DPO 参数(DPOArguments)
- `beta`: DPO 损失函数的 beta 参数,默认为 0.1。
- `simpo_gamma`: SimPO 损失函数的 gamma 参数,默认为 0.5。
Expand Down
1 change: 1 addition & 0 deletions llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ python merge_lora_params.py \
- `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。
- `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。
- `sharding`:是否使用 Paddle 的 Sharding 数据并行功能,用户的参数。支持 sharding `stage1`, `stage2` or `stage3`。其中`stage2``stage3`可以和`offload`组合使用。
- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`
</div>


Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,11 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:

optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_MINI:
from ..utils import AdamWMini

optimizer_cls = AdamWMini
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class OptimizerNames(ExplicitEnum):

ADAMW = "adamw"
ADAFACTOR = "adafactor"
ADAMW_MINI = "adamw_mini"


class ShardingOption(ExplicitEnum):
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ def __post_init__(self):
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")

self.optim = OptimizerNames(self.optim)
if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1:
raise ValueError("AdamW Mini currently doesn't support tensor parallelism.")

self.use_hybrid_parallel = False

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .import_utils import *
from .infohub import infohub
from .initializer import to
from .optimizer import *
from .serialization import load_torch

# hack impl for EagerParamBase to function
Expand Down
151 changes: 151 additions & 0 deletions paddlenlp/utils/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import pir
from paddle.base import core, framework
from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode
from paddle.base.libpaddle import DataType
from paddle.optimizer.adamw import AdamW
from paddle.pir import Value


class AdamWMini(AdamW):
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype = DataType.FLOAT32 if in_pir_mode() else paddle.float32

self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
# change moment2
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, shape=[1])
try:
type = core.VarDesc.VarType.DENSE_TENSOR
except:
type = core.VarDesc.VarType.LOD_TENSOR
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, (Variable, Value)) else self._beta1,
shape=[1],
type=type,
device="cpu",
)
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, (Variable, Value)) else self._beta2,
shape=[1],
type=type,
device="cpu",
)

def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, (framework.Block, pir.Block))
if isinstance(param_and_grad, dict):
param_and_grad = self._update_param_group(param_and_grad)
param = param_and_grad[0]

# Whether we should do weight decay for the parameter.
with_decay = True
if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name):
with_decay = False

moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0])
moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0])
beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0])
beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0])
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
master_weight = self._master_weights[param_and_grad[0].name] if find_master else None
lr = self._create_param_lr(param_and_grad)
# create the adamw optimize op
if in_dynamic_or_pir_mode():
lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])

_beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0)
_beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0)

found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None
self.adamw_python(
param_and_grad[0],
param_and_grad[1],
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
find_master,
)
return None
else:
raise NotImplementedError("Not implemented yet.")

def adamw_python(
self,
param,
grad,
learning_rate,
moment1,
moment2,
beta1_pow,
beta2_pow,
master_weight,
skip_update,
beta1,
beta2,
epsilon,
lr_ratio,
coeff,
with_decay,
multi_precision,
):
if skip_update:
return
if not with_decay:
coeff = 0.0
if not multi_precision:
master_weight = None
lr = learning_rate * lr_ratio
if master_weight is not None:
p = master_weight
else:
p = param
p *= 1.0 - lr * coeff
mom1 = moment1
mom2 = moment2

mom1 = beta1 * mom1 + (1.0 - beta1) * grad
mom2 = beta2 * mom2 + (1.0 - beta2) * (grad * grad).mean()
denom = mom2.sqrt() / (1.0 - beta2_pow).sqrt() + epsilon
p += (moment1 / denom) * (-(lr / (1.0 - beta1_pow)))
if master_weight is not None:
master_weight[:] = p
param[:] = p.astype(param.dtype)
else:
param[:] = p
moment1[:] = mom1
moment2[:] = mom2
beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]
# 看看怎么更新
return
35 changes: 35 additions & 0 deletions tests/fixtures/llm/adamw_mini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
finetune:
base:
dataset_name_or_path: "./data"
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
per_device_eval_batch_size: 8
eval_accumulation_steps: 16
num_train_epochs: 3
learning_rate: 3e-05
warmup_steps: 30
logging_steps: 1
evaluation_strategy: "epoch"
save_strategy: "epoch"
src_length: 1024
max_length: 2048
fp16: true
fp16_opt_level: "O2"
do_train: true
do_eval: true
use_flash_attention: true
disable_tqdm: true
load_best_model_at_end: true
eval_with_do_generation: false
metric_for_best_model: "accuracy"
recompute: true
refined_recompute: "flash_attn:-1"
save_total_limit: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
ignore_save_lr_and_optim: 1
optim: "adamw_mini"

default:
llama:
model_name_or_path: __internal_testing__/tiny-random-llama
53 changes: 53 additions & 0 deletions tests/llm/test_adamw_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import sys
import unittest

from parameterized import parameterized_class

from tests.testing_utils import argv_context_guard, load_test_config

from .testing_utils import LLMTest


@parameterized_class(
["model_dir"],
[
["llama"],
],
)
class FinetuneTest(LLMTest, unittest.TestCase):
config_path: str = "./tests/fixtures/llm/adamw_mini.yaml"
model_dir: str = None

def setUp(self) -> None:
LLMTest.setUp(self)

sys.path.insert(0, self.model_dir)

def tearDown(self) -> None:
LLMTest.tearDown(self)

def test_finetune(self):
finetune_config = load_test_config(self.config_path, "finetune", self.model_dir)

finetune_config["dataset_name_or_path"] = self.data_dir
finetune_config["output_dir"] = self.output_dir

with argv_context_guard(finetune_config):
from run_finetune import main

main()

0 comments on commit da41c4f

Please sign in to comment.