Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[online-dpo] allow parse-args as list of floats #2108

Merged
merged 16 commits into from
Sep 24, 2024
Merged

[online-dpo] allow parse-args as list of floats #2108

merged 16 commits into from
Sep 24, 2024

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Sep 24, 2024

What does this PR do?

fixes #2106

@kashif kashif requested a review from qgallouedec September 24, 2024 12:32
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif changed the title use a seperate argument for list of floats [online-dpo] use a seperate argument for list of floats Sep 24, 2024
@qgallouedec
Copy link
Member

qgallouedec commented Sep 24, 2024

What about this instead?

from transformers import HfArgumentParser, TrainingArguments
from dataclasses import dataclass, field
from typing import List

@dataclass
class MyArgs(TrainingArguments):
    beta: List[float] = field(default_factory=lambda: [0.1])

    def __post_init__(self):
        super().__post_init__()
        if len(self.beta) == 1:
            self.beta = self.beta[0]

parser = HfArgumentParser(MyArgs)
args = parser.parse_args_into_dataclasses()[0]
print(args.beta)
$ python 2106.py --output_dir tmp --beta 0.2
0.2
$ python 2106.py --output_dir tmp --beta 0.1 0.2
[0.1, 0.2]
$ python 2106.py --output_dir tmp
0.1

@kashif
Copy link
Collaborator Author

kashif commented Sep 24, 2024

ah yes good idea! fixing

@qgallouedec
Copy link
Member

To allow using args without the parser, we should check if the object has is sizeable as well:

from transformers import TrainingArguments
from dataclasses import dataclass, field

from typing import List


@dataclass
class MyArgs(TrainingArguments):
    beta: List[float] = field(default_factory=lambda: [0.1])

    def __post_init__(self):
        super().__post_init__()
        if hasattr(self.beta, "__len__") and len(self.beta) == 1:
            self.beta = self.beta[0]


args = MyArgs(output_dir="tmp", beta=0.2)
print(args.beta) # 0.2
args = MyArgs(output_dir="tmp", beta=[0.1, 0.2])
print(args.beta) # [0.1, 0.2]

kashif and others added 2 commits September 24, 2024 15:23
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@qgallouedec
Copy link
Member

I'll just add a few tests

@qgallouedec
Copy link
Member

LGTM thanks, just make sure to update the title

@kashif kashif changed the title [online-dpo] use a seperate argument for list of floats [online-dpo] allow parse-args as list of floats Sep 24, 2024
@kashif kashif merged commit 80038a5 into main Sep 24, 2024
3 of 10 checks passed
@kashif kashif deleted the issue-2106 branch September 24, 2024 14:56
qgallouedec added a commit that referenced this pull request Sep 24, 2024
* use a seperate argument for list of floats

* do super first

* fix docstrings

* typos

* use list of floats only

* check if it has len

* fix docstring

* fix suggestion

* fix default

* Update trl/trainer/online_dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/xpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

* additional tests

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

xpo can not work
3 participants