Skip to content

Commit

Permalink
[XPU] Support empty_cache on XPUs (#9789)
Browse files Browse the repository at this point in the history
* [XPU] Support empty_cache on XPUs

* warn if current device doesn't support
  • Loading branch information
will-jl944 authored Feb 8, 2025
1 parent a9d8648 commit eab22f2
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 20 deletions.
9 changes: 5 additions & 4 deletions llm/alignment/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/quantization/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.nn.quant import weight_quantize

from ..utils.log import logger
from ..utils.memory_utils import empty_device_cache
from .quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
Expand Down Expand Up @@ -150,7 +151,7 @@ def convert_to_quantize_state_dict_without_check(state_dict, quantization_linear
state_dict.update(qlora_state_dict)
del target_weight
gc.collect()
paddle.device.cuda.empty_cache()
empty_device_cache()
return state_dict


Expand Down
20 changes: 10 additions & 10 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils import infohub
from paddlenlp.utils import empty_device_cache, infohub
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
MAX_QUANTIZATION_TIMES,
Expand Down Expand Up @@ -158,7 +158,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
if self.args.should_save:
save_model_config(model_to_save, save_directory)

paddle.device.cuda.empty_cache()
empty_device_cache()

if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save:
world_size = paddle.distributed.get_world_size()
Expand Down Expand Up @@ -195,7 +195,7 @@ def load_unified_checkpoint(self, model, resume_from_checkpoint: str):
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
paddle.device.cuda.empty_cache()
empty_device_cache()

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -375,7 +375,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
optim_state_dict, shard_optim_file, sharded_optim_index = results[0]
master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1]

paddle.device.cuda.empty_cache()
empty_device_cache()
save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
Expand Down Expand Up @@ -508,7 +508,7 @@ def unified_checkpoint_into_shards(
Returns:
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()
assert hasattr(model_to_save, "config")

state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True)
Expand Down Expand Up @@ -560,7 +560,7 @@ def unified_checkpoint_into_shards(
elif isinstance(model_to_save, PrefixModelForCausalLM):
sharded_index["type"] = "ptuning"

paddle.device.cuda.empty_cache()
empty_device_cache()

return state_dict, shard_file, sharded_index

Expand All @@ -578,7 +578,7 @@ def unified_optimizer_into_shards(
optimizer (Optimizer): optimizer to save.
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -645,7 +645,7 @@ def unified_optimizer_into_shards(
filter_optim_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")
Expand All @@ -655,7 +655,7 @@ def unified_optimizer_into_shards(
filter_master_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

# build index json file
index_optimizer_file, index_master_weight_file = {}, {}
Expand Down Expand Up @@ -706,7 +706,7 @@ def unified_optimizer_into_shards(
else:
sharded_optim_index["master_weights"] = False

paddle.device.cuda.empty_cache()
empty_device_cache()
if master_weights is None:
return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)]
else:
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SimpleInfclLoss,
)
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient
from paddlenlp.utils import empty_device_cache

__all__ = ["EmbeddingTrainer"]

Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, model_args, **kwargs):
def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()
empty_device_cache()

def clear_state(self):
self.accum_data.clear()
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .import_utils import *
from .infohub import infohub
from .initializer import to
from .memory_utils import empty_device_cache
from .optimizer import *
from .serialization import load_torch

Expand Down
39 changes: 39 additions & 0 deletions paddlenlp/utils/memory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# coding:utf-8
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from .log import logger
from .tools import get_env_device

__all__ = [
"empty_device_cache",
]


def empty_device_cache():
device = get_env_device()
if device == "gpu":
paddle.device.cuda.empty_cache()
elif device == "xpu":
paddle.device.xpu.empty_cache()
else:
if not getattr(empty_device_cache, "has_warned", False):
logger.warning(
"The current device ({}) does not support empty cache, calling empty_device_cache() will have no effect.".format(
device
)
)
setattr(empty_device_cache, "has_warned", True)
9 changes: 5 additions & 4 deletions slm/examples/RLHF/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down

0 comments on commit eab22f2

Please sign in to comment.