-
Notifications
You must be signed in to change notification settings - Fork 0
/
extractor.py
148 lines (113 loc) · 5.37 KB
/
extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gspread
import tiktoken
import logging
from omegaconf import DictConfig
from typing import List, Dict
class GSpreadClient:
def __init__(self):
pass
def Authenticate(self, config_path: str):
'''
Method to authenticate with Google Sheets
Returns:
The authenticated Google Sheets client
'''
try:
gc = gspread.service_account(filename=config_path)
return gc
except FileNotFoundError as e:
raise RuntimeError(f"Failed to authenticate with Google Sheets due to missing credentials file: {e}")
except Exception as e:
raise RuntimeError(f"Failed to authenticate with Google Sheets due to unexpected error: {e}")
class DataHandler:
def __init__(self, cfg: DictConfig, gspread_client: gspread.client, google_sheet_url: str):
self.openai_model = cfg.openAI.model
self.token_limit = cfg.openAI.token_limit
self.fine_tuning_cost_per_1000_tokens = cfg.openAI.fine_tuning_cost_per_1000_tokens
self.gspread_client = gspread_client
self.google_sheet_url = google_sheet_url
def extract_data(self, sheet_index: int) -> List[Dict[str, str]]:
'''
Method to extract data from Google Sheets
Args:
sheet_index: The index of the sheet to extract data from
Returns:
The extracted data
'''
logger = logging.getLogger("DataHandler")
try:
sh = self.gspread_client.open_by_url(self.google_sheet_url)
worksheet = sh.get_worksheet(sheet_index)
rows = worksheet.get_all_values()
processed_data = []
total_word_counts = 0
total_token_counts = 0
total_estimated_costs = 0
# Start from the second row, skip the headers
for row in rows[1:]:
prompt, completion = row[0], row[1]
# TEMP - Check if both prompt and completion are not empty
if prompt and completion:
processed_data_item, total_word_count, total_token_count, estimated_cost = self.filter_data(prompt, completion)
# Update the totals
total_word_counts += total_word_count
total_token_counts += total_token_count
total_estimated_costs += estimated_cost
# Append the processed data
processed_data.extend(processed_data_item)
# Print the output
logger.info(f"Total word count: {total_word_counts}")
logger.info(f"Estimated tokens: {total_token_counts}")
logger.info(f"Estimated cost of fine-tuning: ${total_estimated_costs}\n")
# Return the processed data
return processed_data
except (AttributeError, KeyError) as e:
print(f"An AttributeError or KeyError occurred: {e}")
except Exception as e:
raise RuntimeError(f"Failed to extract data from Google Sheets: {e}")
def filter_data(self, prompt: str, completion: str) -> List[Dict[str, str]]:
filtered_data = []
# Remove leading & trailing whitepace in the left & the right
prompt = prompt.lstrip().rstrip()
completion = completion.lstrip().rstrip()
# Check and add "\n\n###\n\n" at the end of prompt
if not prompt.endswith("\n\n###\n\n"):
prompt += " \n\n###\n\n"
# Check and add "." at the end of completion
if not completion.endswith("."):
completion += "."
# Check and add "END" at the end of completion
if not completion.endswith("END"):
completion += " END"
## Check and add " " at the front of completion
if not completion.startswith(" "):
completion = " " + completion
# Check tokens validity
total_word_count, total_token_count, estimated_cost = self.check_tokens(prompt, completion)
# Append the processed data
filtered_data.append({"prompt": prompt, "completion": completion})
# Return the processed data
return filtered_data, total_word_count, total_token_count, estimated_cost
def check_tokens(self, prompt: str, completion: str):
'''
Method to check the number of tokens
Args:
prompt: The prompt text
completion: The completion text
Returns:
The total number of words, the total number of tokens, and the estimated cost
'''
# Concatenate prompt and completion
combined_text = prompt + " " + completion
# Create a GPT-3 encoder instance
encoder = tiktoken.get_encoding(self.openai_model)
# Calculate the number of tokens
total_token_count = len(encoder.encode(combined_text))
# Check if the total number of tokens has reached the limit
if total_token_count > self.token_limit:
raise ValueError("Error: The total number of tokens is greater than 2048.")
# Calculate the estimated cost for fine-tuning
estimated_cost = total_token_count * self.fine_tuning_cost_per_1000_tokens/1000
total_word_count = len(combined_text.split())
# Return Estimated Cost
return total_word_count, total_token_count, estimated_cost