diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index f0bab9d6..d6d6b089 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -29,6 +29,7 @@ jobs: - "framework" - "accelerated-peft" - "fused-ops-and-kernels" + - "instruct-lab" steps: - uses: actions/checkout@v4 diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py index 48ad8201..3cdef252 100644 --- a/plugins/framework/src/fms_acceleration/constants.py +++ b/plugins/framework/src/fms_acceleration/constants.py @@ -21,4 +21,4 @@ # and activated. # - hence the plugins that have model loaders should be on top of this list -PLUGINS = ["peft", "foak"] +PLUGINS = ["peft", "foak", "ilab"] diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py index c8de939c..1e6ecb44 100644 --- a/plugins/framework/src/fms_acceleration/framework.py +++ b/plugins/framework/src/fms_acceleration/framework.py @@ -193,7 +193,9 @@ def augmentation( train_args: TrainingArguments, modifiable_args: Tuple, ): - model_archs = set(model.config.architectures) # get the config + # get the config + archs = model.config.architectures + model_archs = set(archs if archs is not None else []) # NOTE: this assumes that augmentation order does not matter for plugin_name, plugin in self.active_plugins: diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index 0db569c6..0f24597e 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -68,7 +68,8 @@ def get_relevant_configuration_sections(configuration: Dict) -> Dict: _cfg = relevant_config while n > 1: p = path.pop(0) - _cfg[p] = {} + if p not in _cfg: + _cfg[p] = {} _cfg = _cfg[p] n -= 1 diff --git a/plugins/instruct-lab/.isort.cfg b/plugins/instruct-lab/.isort.cfg new file mode 100644 index 00000000..4aa62fac --- /dev/null +++ b/plugins/instruct-lab/.isort.cfg @@ -0,0 +1,13 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty= +known_localfolder=tuning + +# skip code imported from unsloth +skip_glob=**/unsloth*/** diff --git a/plugins/instruct-lab/.pylintrc b/plugins/instruct-lab/.pylintrc new file mode 100644 index 00000000..31cb902c --- /dev/null +++ b/plugins/instruct-lab/.pylintrc @@ -0,0 +1,650 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +# NOTE: do not lint code imported from unsloth +ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth* + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/plugins/instruct-lab/README.md b/plugins/instruct-lab/README.md new file mode 100644 index 00000000..ca1ea246 --- /dev/null +++ b/plugins/instruct-lab/README.md @@ -0,0 +1,34 @@ +# FMS Acceleration for Instruct Lab + +This library contains plugins to accelerate finetuning with the following optimizations: + +1. Padding-Free Flash Attention Computation + + +## Plugins + +Plugin | Description | Depends | Loading | Augmentation | Callbacks +--|--|--|--|--|-- +[padding_free](./src/fms_acceleration_ilab/framework_plugin_padding_free.py) | Padding-Free Flash Attention Computation | flash_attn | ✅ | ✅ + + +## Native Transformers Support from V4.44.0 +Transformers natively supports padding-free from v4.44.0. The padding-free plugin will use the transformers library if compatible, +otherwise if `transformers < V4.44.0` the plugin will use an internal implementation instead. + +## Known Issues + +### Currently Only Supports Pre-Tokenized Dataset + +The padding-free plugin currently only works with pre-tokenized datasets, this is because it is currently designed to replace +the data collator from `SFTTrainer` with a custom data collator to manipulate the input to the modified flash attention forward. + +There are some cases, the data collator for SFTTrainer will handle the formatting and tokenization from raw text datasets. The plugin +is currently unable to both handle the original data collation and apply its custom data collator over it as the same time. This issue +will be addressed in a future commit to support this case. + +In the meantime, the plugin expects the user to provide a pretokenized dataset that +- is formatted with a template for instruct-tuning cases +- is tokenized +- has template labels that are masked to exclude from loss computation +- has eos token appended diff --git a/plugins/instruct-lab/configs/instruct_lab.yaml b/plugins/instruct-lab/configs/instruct_lab.yaml new file mode 100644 index 00000000..0ae81ea0 --- /dev/null +++ b/plugins/instruct-lab/configs/instruct_lab.yaml @@ -0,0 +1,10 @@ +# Configurations to accelerate data packing/padding in training +training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: "huggingface" diff --git a/plugins/instruct-lab/pyproject.toml b/plugins/instruct-lab/pyproject.toml new file mode 100644 index 00000000..e6e4adb1 --- /dev/null +++ b/plugins/instruct-lab/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fms-acceleration-ilab" +version = '0.0.1' +description = "FMS Acceleration Plugin for Functionalities Used in Instruct Lab Training" +authors = [ + {name = "Fabian Lim", email = "flim@sg.ibm.com"}, + {name = "Aaron Chew", email = "aaron.chew1@ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.9" +keywords = ['fms-hf-tuning', 'acceleration', 'padding-free'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] + +[tool.hatch.build.targets.wheel] +only-include = ["src/fms_acceleration_ilab"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/__init__.py b/plugins/instruct-lab/src/fms_acceleration_ilab/__init__.py new file mode 100644 index 00000000..12e86c4a --- /dev/null +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/__init__.py @@ -0,0 +1,16 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .framework_plugin_padding_free import PaddingFreeAccelerationPlugin diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py new file mode 100644 index 00000000..ce471510 --- /dev/null +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py @@ -0,0 +1,108 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.utils import is_flash_attn_2_available +from types import MethodType + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func # pylint: disable=import-error + +def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length): + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seq_lens = torch.cat(( + indices_q[position_ids==0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32) + )) + max_length = position_ids.max()+1 + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + +def build_fa_forward( + attention: torch.nn.Module, causal: bool = True, dropout: float = None +): + # assert not hasattr(self, '_position_ids'), "cannot patch fa attention" + + position_ids: torch.Tensor = None + old_forward = attention.forward + if dropout is not None: + attention.dropout = torch.nn.Dropout(p=dropout) + + def forward(self, *args, **kwargs): + nonlocal position_ids + position_ids = kwargs['position_ids'] + out, *others = old_forward(*args, **kwargs) + if dropout is not None: + out = self.dropout(out) + return out, *others + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + **kwargs, + ): + # if not self._flash_attn_uses_top_left_mask: + # causal = self.is_causal + # else: + # # TODO: Remove the `query_length != 1` + # # check once Flash Attention for RoCm is bumped to 2.1. + # # For details, please see the comment in LlamaFlashAttention2 __init__. + # causal = self.is_causal and query_length != 1 + + assert attention_mask is None, "should not be using attention mask" + assert position_ids is not None, "should be expecting position ids" + batch_size = query_states.size(0) + ( + query_states, + key_states, + value_states, + _, + cu_seq_lens, + max_seq_lens, + ) = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + # do this replace + attention._flash_attention_forward = MethodType(_flash_attention_forward, attention) + + # return the forward + return forward diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py new file mode 100644 index 00000000..6486791d --- /dev/null +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py @@ -0,0 +1,157 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict, Tuple +from packaging import version +import warnings + +# Third Party +from fms_acceleration import AccelerationPlugin +from peft import LoraConfig +from transformers import ( + TrainingArguments, + __version__ as transformers_version, + DataCollatorForSeq2Seq, +) +from accelerate import Accelerator +import torch +from types import MethodType +from torch.utils.data import DataLoader + +# This is the version where padding-free was merged into transformers +TRANSFORMERS_VERSION = "4.44" + +class PaddingFreeAccelerationPlugin(AccelerationPlugin): + + require_packages = ["flash_attn"] + + def __init__(self, configurations: Dict[str, Dict]): + super().__init__(configurations) + + # the fast attention requires knowledge about the + # data collator. + # - currently we do not have a data collator specific plugin + # - so it requires knowledge about the dataloader + self._method = self._check_config_and_maybe_check_values( + key="training.attention.padding_free.method", + values=["huggingface"], + ) + + @property + def requires_agumentation(self): + return True + + def augmentation( + self, + model, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + + # This check is done here to only patch the attention forward + # if below a specific transformer version (4.43.3) that already + # addresses padding free + # https://github.com/huggingface/transformers/pull/31629 + # Subsequently, when additional features are added to the patch + # such as attention dropout, the version check should be shifted + # into `build_fa_forward` to manage the forward replacement inside + # the function. + if version.parse(transformers_version) < version.parse(TRANSFORMERS_VERSION): + # guarded + from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + ) + from .flash_attn import build_fa_forward # pylint: disable=import-outside-toplevel + from functools import partial # pylint: disable=import-outside-toplevel + + # TODO: have a generic version of this rule + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + model_type = model.config.model_type + def is_flash_attn_2(module): + if ( + module.__class__.__name__.endswith("FlashAttention2") + ): + return True + return False + + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{model_type}-pad-free", + trigger=ModelPatcherTrigger(check=is_flash_attn_2), + forward_builder=partial( + build_fa_forward, + causal=True, + ), + ), + ) + else: + warnings.warn(f"transformers version is equal or later \ + than {TRANSFORMERS_VERSION}, attention forward will not be replaced.") + + return model, modifiable_args + + def get_callbacks_and_ready_for_train( + self, model: torch.nn.Module = None, accelerator: Accelerator = None + ): + # patch the dataloader on the accelerator + self._patch_dataloader(accelerator) + return [] + + def _patch_dataloader( + self, + accelerator: Accelerator, + ): + """ + Hijacks the accelorator prepare inside `Trainer.train` + - If it is a single argument. it is assumed to be the prepare call on the dataloader + - we replace the collate function in the dataloader to flatten the batch into a long + sequence with special tokens to define the attention computation boundaries + """ + # Check if transformers already supports a collator that flattens the batch + # Otherwise, use the locally implemented DataCollatorWithFlattening + if version.parse(transformers_version) < version.parse(TRANSFORMERS_VERSION): + from .ilab_utils import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel + else: + from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module + + # hijack the dataloader in accelerator.prepare to replace the collate_fn + _old_prepare = accelerator.prepare + def prepare(self, *args, device_placement=None): + if len(args) > 1 or not isinstance(args[0], DataLoader): + return _old_prepare(*args, device_placement=device_placement) + dataloader = args[0] + + if not isinstance(dataloader.collate_fn, DataCollatorForSeq2Seq): + raise TypeError("The padding-free plugin currently only works with a \ + `DataCollatorForSeq2Seq` collate_fn, \ + otherwise the collation can be unreliable") + + # Replace the collate_fn in dataloader + dataloader.collate_fn = DataCollatorWithFlattening() + + return dataloader + + accelerator.prepare = MethodType(prepare, accelerator) + +# register +AccelerationPlugin.register_plugin( + PaddingFreeAccelerationPlugin, + configuration_and_paths=[ + "training.attention.padding_free", + ], +) diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py b/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py new file mode 100644 index 00000000..c8529669 --- /dev/null +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py @@ -0,0 +1,54 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +import warnings +from transformers import DefaultDataCollator, default_data_collator + +@dataclass +class DataCollatorWithFlattening(DefaultDataCollator): + """ + Data collator used for padding free approach. Does the following: + - concatate the entire mini batch into single long sequence [1, total_tokens] + - no padding will be added, returns `input_ids`, `labels` and `position_ids` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warnings.warn( + "Using `DataCollatorWithFlattening` will flatten the entire mini batch " + "into single long sequence." + "Make sure your attention computation is able to handle it!" + ) + + def __call__(self, features, return_tensors=None): + """ + This implementation assumes that only 3 arguments, input_ids, position_ids and labels + are needed by the model, anything else is dropped by the collator + """ + if return_tensors is None: + return_tensors = self.return_tensors + + # Preserve the the original collate behaviour to cater to all use cases + is_labels_provided = "labels" in features[0] + ret = {"input_ids": [], "labels": [], "position_ids": []} + for feature in features: + ret["input_ids"] += feature["input_ids"] + ret["position_ids"] += list(range(len(feature["input_ids"]))) + if is_labels_provided: + ret["labels"] += [-100] + feature["labels"][1:] + else: + ret["labels"] += [-100] + feature["input_ids"][1:] + return default_data_collator([ret], return_tensors) + \ No newline at end of file diff --git a/plugins/instruct-lab/tests/__init__.py b/plugins/instruct-lab/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/instruct-lab/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/instruct-lab/tests/test_ilab_plugin.py b/plugins/instruct-lab/tests/test_ilab_plugin.py new file mode 100644 index 00000000..c3185d83 --- /dev/null +++ b/plugins/instruct-lab/tests/test_ilab_plugin.py @@ -0,0 +1,31 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from fms_acceleration.utils import ( + instantiate_framework, + read_configuration, +) +from fms_acceleration_ilab import PaddingFreeAccelerationPlugin + +# configuration +DIRNAME = os.path.dirname(__file__) +CONFIG_PATH_ILAB = os.path.join(DIRNAME, "../configs/instruct_lab.yaml") + +def test_framework_installs_ilab_padding_free_plugin(): + with instantiate_framework( + read_configuration(CONFIG_PATH_ILAB), require_packages_check=False + ) as framework: + for plugin in framework.active_plugins: + assert isinstance(plugin[1], PaddingFreeAccelerationPlugin) diff --git a/plugins/instruct-lab/tox.ini b/plugins/instruct-lab/tox.ini new file mode 100644 index 00000000..7dfd370e --- /dev/null +++ b/plugins/instruct-lab/tox.ini @@ -0,0 +1,48 @@ +[tox] +envlist = py, lint + +[testenv] +deps = + pytest>=7 + -e {toxinidir} +skip_install = true +commands = + + # install the dependencies here to ensure + # the order + pip install -e {toxinidir}/../framework + pytest {posargs:tests} + +[testenv:lint] +description = run linters +skip_install = false +deps = + -e {toxinidir}/../framework + pylint>=2.16.2,<=3.1.0 +commands = + pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format +skip_install = true +deps = + black>=22.12 + isort>=5.11 +commands = + black {posargs:.} + isort {posargs:.} + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build -w +skip_install = True + +[testenv:twinecheck] +description = check wheel +deps = + twine +commands = twine check dist/* +skip_install = True diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index 75f7279b..f5dc6819 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -31,4 +31,9 @@ framework_configs: plugins: - accelerated-peft - fused-ops-and-kernels - filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml \ No newline at end of file + filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml + + - shortname: ilab-padding-free + plugins: + - instruct-lab + filename: ilab-padding-free-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/ilab-padding-free-sample-configuration.yaml b/sample-configurations/ilab-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..6df59a59 --- /dev/null +++ b/sample-configurations/ilab-padding-free-sample-configuration.yaml @@ -0,0 +1,15 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index b3485e3c..c147df6a 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -144,6 +144,7 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4" KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak" KEY_BNB_NF4_FOAK = "bnb-nf4-foak" +KEY_ILAB_PADDING_FREE = "ilab-padding-free" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -166,6 +167,7 @@ def read_configuration(path: str) -> Dict: "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")], ), + KEY_ILAB_PADDING_FREE: "plugins/instruct-lab/configs/instruct_lab.yaml", } # list of (tag, combi) tuples @@ -179,6 +181,7 @@ def read_configuration(path: str) -> Dict: ("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)), ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), + ("ilab-padding-free", (KEY_ILAB_PADDING_FREE,)), ]