Skip to content

Commit

Permalink
fix code generation for pipeline textcat
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Dec 17, 2024
1 parent d987e13 commit cb57cce
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/synthetic_dataset_generator/pipelines/textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def generate_pipeline_code(
temperature: float = 0.9,
) -> str:
labels = get_preprocess_labels(labels)
MODEL_ARG = "model_id" if BASE_URL else "model"
MODEL_CLASS = "InferenceEndpointsLLM" if BASE_URL else "OpenAILLM"
base_code = f"""
# Requirements: `pip install distilabel[hf-inference-endpoints]`
import os
Expand All @@ -192,15 +194,13 @@ def generate_pipeline_code(
task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
textcat_generation = GenerateTextClassificationData(
llm=InferenceEndpointsLLM(
model_id=MODEL,
llm={MODEL_CLASS}(
{MODEL_ARG}=MODEL,
base_url=BASE_URL,
api_key=os.environ["API_KEY"],
generation_kwargs={{
"temperature": {temperature},
"max_new_tokens": {MAX_NUM_TOKENS},
"do_sample": True,
"top_k": 50,
"top_p": 0.95,
}},
),
Expand Down Expand Up @@ -236,8 +236,8 @@ def generate_pipeline_code(
)
textcat_labeller = TextClassification(
llm=InferenceEndpointsLLM(
model_id=MODEL,
llm={MODEL_CLASS}(
{MODEL_ARG}=MODEL,
base_url=BASE_URL,
api_key=os.environ["API_KEY"],
generation_kwargs={{
Expand Down

0 comments on commit cb57cce

Please sign in to comment.