diff --git a/src/synthetic_dataset_generator/pipelines/textcat.py b/src/synthetic_dataset_generator/pipelines/textcat.py index deefe78..4f5f11d 100644 --- a/src/synthetic_dataset_generator/pipelines/textcat.py +++ b/src/synthetic_dataset_generator/pipelines/textcat.py @@ -171,6 +171,8 @@ def generate_pipeline_code( temperature: float = 0.9, ) -> str: labels = get_preprocess_labels(labels) + MODEL_ARG = "model_id" if BASE_URL else "model" + MODEL_CLASS = "InferenceEndpointsLLM" if BASE_URL else "OpenAILLM" base_code = f""" # Requirements: `pip install distilabel[hf-inference-endpoints]` import os @@ -192,15 +194,13 @@ def generate_pipeline_code( task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}]) textcat_generation = GenerateTextClassificationData( - llm=InferenceEndpointsLLM( - model_id=MODEL, + llm={MODEL_CLASS}( + {MODEL_ARG}=MODEL, base_url=BASE_URL, api_key=os.environ["API_KEY"], generation_kwargs={{ "temperature": {temperature}, "max_new_tokens": {MAX_NUM_TOKENS}, - "do_sample": True, - "top_k": 50, "top_p": 0.95, }}, ), @@ -236,8 +236,8 @@ def generate_pipeline_code( ) textcat_labeller = TextClassification( - llm=InferenceEndpointsLLM( - model_id=MODEL, + llm={MODEL_CLASS}( + {MODEL_ARG}=MODEL, base_url=BASE_URL, api_key=os.environ["API_KEY"], generation_kwargs={{