-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT: Add CLIs in TRL ! #1419
FEAT: Add CLIs in TRL ! #1419
Conversation
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the lazy loading have any downsides? looks like a pretty dramatic change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In any case we should test this extensively
trl/commands/sft.py
Outdated
@@ -0,0 +1,148 @@ | |||
# flake8: noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So i guess this means we can't use trl/examples/scripts/sft.py
directly for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah for now but I think we can move the example folders there, let me think a bit
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, mainly left a comment on where we want to handle the config parsing :)
@@ -0,0 +1,20 @@ | |||
# This is an example configuration file of TRL CLI, you can use it for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think about adding it in the cli or examples folder folder?
setup.cfg
Outdated
[options.packages.find] | ||
include = examples/scripts/*.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can also be removed, no?
trl/commands/cli.py
Outdated
parser = HfArgumentParser((SftScriptArguments, TrainingArguments, ModelConfig)) | ||
|
||
(args, training_args, model_config, _) = parser.parse_args_into_dataclasses(return_remaining_strings=True) | ||
|
||
if command_name not in SUPPORTED_COMMANDS: | ||
raise ValueError( | ||
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}" | ||
) | ||
|
||
# Get the required args | ||
config = args.config | ||
|
||
# if the configuration is None, create a new `output_dir` variable | ||
config_parser = YamlConfigParser(config, [args, training_args, model_config]) | ||
trl_examples_dir = os.path.dirname(__file__) | ||
|
||
model_name = model_config.model_name_or_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be cleaner to just pass all the args as they are to the downstream script rather than parsing them here and then passing them as a string. We could add the logic to update the config with passed args inside the dpo.py and sft.py so they would also immediately profit from being able to be called with a config. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed the approach you suggested is much cleaner!
trl/commands/cli_utils.py
Outdated
class TrlParser(HfArgumentParser): | ||
def __init__(self, args, training_args, model_config): | ||
super().__init__((args, training_args, model_config)) | ||
|
||
def parse_args_and_config(self): | ||
parsed_args, parsed_training_args, parsed_model_config, _ = self.parse_args_into_dataclasses( | ||
return_remaining_strings=True | ||
) | ||
|
||
self.config_parser = YamlConfigParser(parsed_args.config) | ||
args, training_args, model_config = self.config_parser.merge_dataclasses( | ||
((parsed_args, parsed_training_args, parsed_model_config)) | ||
) | ||
|
||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=args.gradient_checkpointing_use_reentrant) | ||
return args, training_args, model_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make it a bit more agnostic such that we don't hardcode the order and the kind of args? e.g. like the HfArgumentParser
does it. the chat interface for example will only have one dataclass to pass and maybe a future method might require 3 or 4.
trl/commands/config_parser.py
Outdated
import yaml | ||
|
||
|
||
class YamlConfigParser: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this into cli_utils.py
trl/commands/cli_utils.py
Outdated
for parser_dataclass in dataclasses: | ||
if hasattr(parser_dataclass, "config"): | ||
self.config_parser = YamlConfigParser(parser_dataclass.config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should check if we already parsed a config once and throw a warning/error if there are more than one dataclass with a config - otherwise there will weird behaviour (e.g. only the last config is applicable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes totally sense - fixed it !
* CLI V1 * v1 CLI * add rich enhancmeents * revert unindented change * some comments * cleaner CLI * fix * fix * remove print callback * move to cli instead of trl_cli * revert unneeded changes * fix test * Update trl/commands/sft.py Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * remove redundant strings * fix import issue * fix other issues * add packing * add config parser * some refactor * cleaner * add example config yaml file * small refactor * change a bit the logic * fix issues here and there * add CLI in docs * move to examples/sft * remove redundant licenses * make it work on dpo * set to None * switch to accelerate and fix many things * add docs * more docs * added tests * doc clarification * more docs * fix CI for windows and python 3.8 * fix * attempt to fix CI * fix? * test * fix * tweak? * fix * test * another test * fix * test * fix * fix * fix * skip tests for windows * test @lvwerra approach * make dev * revert unneeded changes * fix sft dpo * optimize a bit * address final comments * update docs * final comment --------- Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
What does the PR do?
This PR introduces a new feature in TRL - CLIs for DPO and SFTTrainer!
All arguments that are supported in
ModelConfig
should be supported by the CLI together with arguments fromTrainingArguments
ortransformers
.Users will need to first call
accelerate config
before running the CLI to use custom accelerate configs