Skip to content

Commit

Permalink
fix: ensures the system prompt is set when mixing datasets from SDG
Browse files Browse the repository at this point in the history
Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
  • Loading branch information
RobotSail committed Feb 12, 2025
1 parent 7176ffe commit 79d8261
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions docs/examples/mix_datasets/example_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
output_dir = Path(__file__).parent.joinpath("output")
output_dir.mkdir(exist_ok=True)

system_prompt = "You are a helpful assistant."

concatenate_recipe_yaml = Path(__file__).parent.joinpath("concatenate_recipe.yaml")
concatenated_output_jsonl = output_dir.joinpath("concatenated.jsonl")
mix_datasets(concatenate_recipe_yaml, concatenated_output_jsonl)
mix_datasets(concatenate_recipe_yaml, concatenated_output_jsonl, system_prompt)

weighted_recipe_yaml = Path(__file__).parent.joinpath("weighted_recipe.yaml")
weighted_output_jsonl = output_dir.joinpath("weighted.jsonl")
mix_datasets(weighted_recipe_yaml, weighted_output_jsonl)
mix_datasets(weighted_recipe_yaml, weighted_output_jsonl, system_prompt)
5 changes: 4 additions & 1 deletion src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,9 @@ def mix_datasets(
recipe_file: str,
output_file: str,
num_proc: Optional[int] = 8,
system_prompt: Optional[str] = None,
):
recipe = Recipe(recipe_file)
recipe = Recipe(recipe_file, system_prompt)
if recipe.datasets:
recipe.save_mixed_dataset(output_file, num_proc)
else:
Expand Down Expand Up @@ -719,10 +720,12 @@ def generate_data(
mix_datasets(
recipe_file=f"{output_dir}/skills_recipe_{date_suffix}.yaml",
output_file=f"{output_dir}/skills_train_msgs_{date_suffix}.jsonl",
system_prompt=system_prompt,
)
mix_datasets(
recipe_file=f"{output_dir}/knowledge_recipe_{date_suffix}.yaml",
output_file=f"{output_dir}/knowledge_train_msgs_{date_suffix}.jsonl",
system_prompt=system_prompt,
)

generate_duration = time.time() - generate_start
Expand Down

0 comments on commit 79d8261

Please sign in to comment.