-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset_utils.py
203 lines (174 loc) · 8.88 KB
/
dataset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from datasets import load_dataset, Dataset
import json
import csv
import openai
import anthropic
import requests
import os
import logging
from tqdm import tqdm
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def prepare_dataset(dataset_source, dataset_path, tokenizer, hf_token=None):
"""
Prepare a dataset for fine-tuning, either from Hugging Face or a local file.
Args:
dataset_source (str): 'huggingface' or 'local'
dataset_path (str): Path or identifier of the dataset
tokenizer: The tokenizer associated with the model
hf_token (str, optional): Hugging Face token for accessing datasets
Returns:
Dataset: Prepared dataset ready for fine-tuning
"""
if dataset_source == 'huggingface':
try:
dataset = load_dataset(dataset_path, split="train", use_auth_token=hf_token)
except ValueError:
# If use_auth_token is not supported, try without it
dataset = load_dataset(dataset_path, split="train")
elif dataset_source == 'local':
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"File not found: {dataset_path}")
if dataset_path.endswith('.json'):
with open(dataset_path, 'r') as f:
data = json.load(f)
if isinstance(data, list):
dataset = Dataset.from_list(data)
elif isinstance(data, dict):
dataset = Dataset.from_dict(data)
else:
raise ValueError("JSON file must contain either a list or a dictionary.")
elif dataset_path.endswith('.csv'):
with open(dataset_path, 'r') as f:
reader = csv.DictReader(f)
data = list(reader)
dataset = Dataset.from_list(data)
else:
raise ValueError("Unsupported file format. Please use JSON or CSV.")
else:
raise ValueError("Invalid dataset source. Use 'huggingface' or 'local'.")
# Check if 'conversations' column exists, if not, try to create it
if 'conversations' not in dataset.column_names:
if 'text' in dataset.column_names:
dataset = dataset.map(lambda example: {'conversations': [{'human': example['text'], 'assistant': ''}]})
else:
raise ValueError("Dataset does not contain 'conversations' or 'text' column. Please check your dataset structure.")
# Only apply standardize_sharegpt if 'conversations' column exists
if 'conversations' in dataset.column_names:
dataset = standardize_sharegpt(dataset)
def formatting_prompts_func(examples):
if tokenizer is None:
raise ValueError("Tokenizer is not properly initialized. Please load the model and tokenizer before preparing the dataset.")
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return {"text": texts}
dataset = dataset.map(formatting_prompts_func, batched=True)
if 'text' not in dataset.column_names:
def format_conversation(example):
formatted_text = ""
for turn in example['conversations']:
formatted_text += f"{turn['role']}: {turn['content']}\n"
return {"text": formatted_text.strip()}
dataset = dataset.map(format_conversation)
return dataset
def standardize_sharegpt(dataset):
# This is a simplified version. You might need to adjust it based on your specific needs.
def process_conversation(conversation):
standardized = []
for turn in conversation:
if 'human' in turn:
standardized.append({'role': 'user', 'content': turn['human']})
if 'assistant' in turn:
standardized.append({'role': 'assistant', 'content': turn['assistant']})
return standardized
return dataset.map(lambda x: {'conversations': process_conversation(x['conversations'])})
def create_synthetic_dataset(examples, expected_structure, num_samples, ai_provider, api_key, model_name=None):
"""
Create a synthetic dataset based on example conversations and expected structure.
Args:
examples (str): Example conversations to base the synthetic data on.
expected_structure (str): Description of the expected dataset structure.
num_samples (int): Number of synthetic samples to generate.
ai_provider (str): AI provider to use for generation ('OpenAI', 'Anthropic', or 'Ollama').
api_key (str): API key for the chosen AI provider.
model_name (str, optional): Model name for Ollama (if applicable).
Returns:
Dataset: Synthetic dataset ready for fine-tuning.
"""
synthetic_data = []
prompt = f"""
You are an AI assistant creating training dataset for finetuning a model.
You are provided an one-shot or few-shot output example of output that application expects from the AI model. You are also provided the
expected structure that the to-be trained AI model expects during training process.
Examples:
{examples}
Expected structure:
{expected_structure}
Please help Generate a new dataset in the provided same style and expected structure. Do not produce any extra output except the dataset in the training needed structure:
"""
if ai_provider == "OpenAI":
client = openai.OpenAI(api_key=api_key)
for _ in tqdm(range(num_samples), desc="Generating samples"):
try:
response = client.chat.completions.create(
model="gpt-4-0125-preview",
messages=[{"role": "user", "content": prompt}],
timeout=30 # 30 seconds timeout
)
conversation = response.choices[0].message.content
synthetic_data.append({"conversations": json.loads(conversation)})
except json.JSONDecodeError:
logger.warning(f"Failed to decode response as JSON: {response.choices[0].message.content}")
except openai.APITimeoutError:
logger.warning("OpenAI API request timed out")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
time.sleep(1) # Rate limiting
elif ai_provider == "Anthropic":
client = anthropic.Anthropic(api_key=api_key)
for _ in tqdm(range(num_samples), desc="Generating samples"):
try:
response = client.completions.create(
model="claude-3-opus-20240229",
prompt=f"Human: {prompt}\n\nAssistant:",
max_tokens_to_sample=1000,
timeout=30 # 30 seconds timeout
)
synthetic_data.append({"conversations": json.loads(response.completion)})
except json.JSONDecodeError:
logger.warning(f"Failed to decode response as JSON: {response.completion}")
except anthropic.APITimeoutError:
logger.warning("Anthropic API request timed out")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
time.sleep(1) # Rate limiting
elif ai_provider == "Ollama":
for _ in tqdm(range(num_samples), desc="Generating samples"):
try:
response = requests.post('http://localhost:11434/api/generate',
json={
"model": model_name,
"prompt": prompt,
"stream": False
},
timeout=30) # 30 seconds timeout
response.raise_for_status()
synthetic_data.append({"conversations": json.loads(response.json()["response"])})
except json.JSONDecodeError:
logger.warning(f"Failed to decode response as JSON: {response.json()['response']}")
except requests.Timeout:
logger.warning("Ollama API request timed out")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
time.sleep(1) # Rate limiting
dataset = Dataset.from_list(synthetic_data)
dataset = standardize_sharegpt(dataset)
if 'text' not in dataset.column_names:
def format_conversation(example):
formatted_text = ""
for turn in example['conversations']:
formatted_text += f"{turn['role']}: {turn['content']}\n"
return {"text": formatted_text.strip()}
dataset = dataset.map(format_conversation)
return dataset