Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Aug 29, 2024
1 parent 0e9a550 commit 9cb999e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
10 changes: 10 additions & 0 deletions tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
MultiPack,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
Expand Down Expand Up @@ -78,6 +79,15 @@ def test_dataclass_parse_successfully():
)
assert isinstance(cfg.padding_free, PaddingFree)

# 4. Specifing "--multipack" will parse a MultiPack class
parser = transformers.HfArgumentParser(
dataclass_types=AttentionAndDistributedPackingConfig
)
(cfg,) = parser.parse_args_into_dataclasses(
["--multipack", "16"],
)
assert isinstance(cfg.multipack, MultiPack)


def test_two_dataclasses_parse_successfully_together():
"""Ensure that the two dataclasses can parse arguments successfully
Expand Down
29 changes: 24 additions & 5 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tuning.config.acceleration_configs.attention_and_distributed_packing import (
AttentionAndDistributedPackingConfig,
PaddingFree,
MultiPack,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
Expand Down Expand Up @@ -72,9 +73,9 @@
# Third Party
from fms_acceleration_foak import FastQuantizedPeftAccelerationPlugin

if is_fms_accelerate_available(plugins="ilab"):
if is_fms_accelerate_available(plugins="aadp"):
# Third Party
from fms_acceleration_ilab import PaddingFreeAccelerationPlugin
from fms_acceleration_aadp import PaddingFreeAccelerationPlugin


# There are more extensive unit tests in the
Expand Down Expand Up @@ -471,7 +472,7 @@ def test_framework_intialized_properly_foak():
reason="Only runs if fms-accelerate is installed along with \
attention_and_distributed_packing plugin",
)
def test_framework_initialize_and_trains_with_ilab():
def test_framework_initialize_and_trains_with_aadp():
"""
Ensure that a properly configured ilab dataclass is
correctly activated in train.
Expand Down Expand Up @@ -580,8 +581,8 @@ def test_error_raised_with_paddingfree_and_flash_attn_disabled():
"""Ensure error raised when padding-free is not used with flash attention"""
with pytest.raises(
ValueError,
match="`--padding_free` argument was called without enabling flash attention, \
ensure `use_flash_attn = True` to use padding-free flash attention",
match="`--padding_free` argument was called without enabling \
flash attention, ensure `use_flash_attn = True` to use padding-free flash attention",
):
attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
Expand All @@ -594,3 +595,21 @@ def test_error_raised_with_paddingfree_and_flash_attn_disabled():
TRAIN_ARGS,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
)

def test_error_raised_with_multipack_and_paddingfree_disabled():
"""Ensure error raised when padding-free is not used with flash attention"""
with pytest.raises(
AssertionError,
match="`--multipack` is currently only supported with `--padding_free`",
):
attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig(
multipack=MultiPack(num_processes=16),
padding_free=None,
)
model_args = copy.deepcopy(MODEL_ARGS)
sft_trainer.train(
model_args,
DATA_ARGS,
TRAIN_ARGS,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
)
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ class AccelerationFrameworkConfig:
),
] = None

def _verify_configured_dataclasses(self):
if self.multipack is not None:
# ensure if multipack is set, padding free is also turned on as well
# this also ensures that the attention implementation for multipack
# will be flash attention as sfttrainer will enforce flash attn to be
# set for padding free
assert self.padding_free is not None, "`--multipack` is currently only supported with `--padding_free`"

@staticmethod
def from_dataclasses(*dataclasses: Type):
"Convert one or many FMS config dataclasses to a monolithic AccelerationConfig"
Expand Down Expand Up @@ -179,6 +187,8 @@ def from_dataclasses(*dataclasses: Type):
setattr(config, fi.name, dc)
del rem_fields[fi.name] # remove the field

# perform some checks on dataclasse
config._verify_configured_dataclasses()
return config

def get_framework(self):
Expand Down
3 changes: 2 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def train(
raise ValueError("gradient_accumulation_steps has to be an integer >= 1")

if (
attention_and_distributed_packing_config.padding_free is not None
attention_and_distributed_packing_config is not None
and attention_and_distributed_packing_config.padding_free is not None
and model_args.use_flash_attn is False
):
raise ValueError(
Expand Down

0 comments on commit 9cb999e

Please sign in to comment.