Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom prompt engineering for various models #2130

Merged
merged 24 commits into from
Dec 12, 2023
Merged

custom prompt engineering for various models #2130

merged 24 commits into from
Dec 12, 2023

Conversation

percyliang
Copy link
Contributor

  • Created different RunExpanders for Google, OpenAI, and Anthropic models that are instruction following and hence need more explicit instructions to follow the format and not ramble on and on.
  • Also changed the default metric for GSM8K to test whether the final number in the response matches rather than the final word.
  • For HumanEval, which is a completion task (rather than in-context learning), we need a different prompt, which only sort of works (there are some annoyances like GPT-4 working better without instructions, but GPT-4 Turbo working better with instructions).

@@ -22,12 +22,14 @@
CHATML_MODEL_TAG: str = "CHATML_MODEL_TAG"

# OpenAI Chat format
OPENAI_CHATGPT_MODEL_TAG: str = "openai_chatgpt"
OPENAI_CHATGPT_MODEL_TAG: str = "OPENAI_CHATGPT_MODEL_TAG"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug fix

metric_specs=get_metric_specs(big_bench_task["metrics"]),
groups=["BIG-bench"],
groups=[f"big_bench_" + task],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive by fix when I was trying out some BIG-bench scenarios

@@ -1317,7 +1317,12 @@ metrics:
description: Fraction of model outputs that are mathematically equivalent to the correct reference when using chain-of-thought prompting.
lower_is_better: false
- name: exact_match_indicator
display_name: Exact match (up to specified indicator)
display_name: Exact match (final)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"indicator" is confusing

@@ -884,14 +889,14 @@ run_groups:
- efficiency
- general_information
environment:
main_name: exact_match_indicator
main_name: final_number_exact_match
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we change the main metric for GSM8K

Comment on lines 216 to 218
# Note: case that's not handled is "2,300" is parsed as "300"
x = re.sub(",", "", x) # To handle numbers like "2,300"
x = re.sub(r"[^0-9\.]", " ", x) # Replace non-digit, non-'.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: the . character will have to change support math scenarios in non-English European languages.

- Returns 1
"""

def get_final_number(x: str) -> str:
Copy link
Collaborator

@yifanmai yifanmai Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a regex is more straightforward here.

def get_final_number(x: str) -> str:
    match = re.search(r"(-?[\d,]+(?:.\d+)?)\D*$", x)
    if match is None:
        return ""
    return match.group(1).replace(",", "")

This doesn't work for fractions (the original code does not either), but I checked the GSM8K test set and it looks like the answers are non-negative integers only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation of what this does: https://regexr.com/7ou5g

Copy link
Collaborator

@yifanmai yifanmai Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up going with a slightly different and simpler regex.

Also noting that my regex has slightly different behavior from this code when the last numeric digits are part of a malformed number.

assert final_number_exact_match("33", "33") == 1
assert final_number_exact_match("33", "33 eggs.") == 1
assert final_number_exact_match("The answer is 33", "\\boxed{33}") == 1
assert final_number_exact_match("The answer is 33", "\\boxed{33} and 34") == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: add a test for negative numbers.

@@ -291,6 +307,65 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]:
]


class OpenAIRunExpander(RunExpander):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just one InContextLearningInstructionsRunExpander since the OpenAI and Google run expanders are identical? That class name would also be more descriptive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had OpenAI and Google run expanders because I had different prompts at some point...how about having a InContextLearningInstructionsRunExpander class and then just having OpenAI and Google inherit from it? In the future, we might tweak things.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a todo to deal with this later.

- Returns 1
"""

def get_final_number(x: str) -> str:
Copy link
Collaborator

@yifanmai yifanmai Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up going with a slightly different and simpler regex.

Also noting that my regex has slightly different behavior from this code when the last numeric digits are part of a malformed number.

@@ -291,6 +307,65 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]:
]


class OpenAIRunExpander(RunExpander):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a todo to deal with this later.

@yifanmai yifanmai merged commit 6dd3a5b into main Dec 12, 2023
6 checks passed
@yifanmai yifanmai deleted the pliang-prompt-eng branch December 12, 2023 05:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants