-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling.py
295 lines (243 loc) · 15 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the core classes used by DINO.
"""
import math
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any, Union
import openai
import torch
from tqdm import tqdm
from transformers import GPT2Tokenizer, PreTrainedTokenizer, PreTrainedModel
from generation import SelfDebiasingGPT2LMHeadModel
from utils import DatasetEntry
PLACEHOLDER_STR = "<X1>"
SENT_END = set(['.', '?', '!', ']', ')'])
class DinoGenerator:
"""
This class represents a generative language model which can be used to generate datasets from instructions.
"""
def __init__(self, task_spec: Dict[str, Any], model: Union['str', 'ModelWrapper'] = None, openai_api_key: Optional[str] = None,
max_output_length: int = 40, decay_constant: float = 100, top_p: float = 0.9, top_k: int = 5,
remove_duplicates: bool = True, remove_identical_pairs: bool = False, min_num_words: int = -1, min_num_tokens: int = -1,
keep_outputs_without_eos: bool = False, ignore_eos=False, allow_newlines_in_outputs: bool = False):
"""
:param task_spec: the task specification
:param model: a wrapper around the underlying language model.
If GPT-3 is used, this should instead be the name of the GPT-3 model (e.g., "davinci")
:param openai_api_key: an optional API key for GPT-3. If given, GPT-3 is used as a language model
:param max_output_length: the maximum output length for each generated text
:param decay_constant: the decay constant for self-debiasing
:param top_p: p value for top-p sampling (set to 0 to perform no top-p sampling)
:param top_k: k value for top-k sampling (set to 0 to perform no top-k sampling)
:param remove_duplicates: whether duplicates should be removed from the generated dataset
:param remove_identical_pairs: whether text pairs with identical texts should be removed (only for text pair datasets)
:param min_num_words: the minimum number of (whitespace-separated) words for each dataset entry
:param min_num_tokens: the minimum number of tokens for each dataset entry
:param keep_outputs_without_eos: if set to true, examples where the language model does not output a quotation mark (which is
interpreted as a signal that it has completed its output) are not removed from the dataset.
:param allow_newlines_in_outputs: if set to true, model outputs that contain a newline character before the end-of-sequence token
(a quotation mark) are not removed from the dataset
"""
self.model = model
self.openai_api_key = openai_api_key
self.max_output_length = max_output_length
self.decay_constant = decay_constant
self.top_p = top_p
self.top_k = top_k
self.remove_duplicates = remove_duplicates
self.remove_identical_pairs = remove_identical_pairs
self.min_num_words = min_num_words
self.min_num_tokens = min_num_tokens
self.keep_outputs_without_eos = keep_outputs_without_eos
self.allow_newlines_in_outputs = allow_newlines_in_outputs
self.ignore_eos = ignore_eos
self.labels = list(task_spec['labels'].keys())
self.instructions = {label: task_spec['labels'][label]['instruction'] for label in self.labels}
self.counter_labels = {label: task_spec['labels'][label].get('counter_labels', []) for label in self.labels}
def generate_dataset(self, input_texts: Optional[List[str]], num_entries_per_input_and_label: Optional[int] = None,
num_entries_per_label: Optional[int] = None, batch_size: Optional[int] = None) -> List[DatasetEntry]:
"""
Generate a new dataset.
:param input_texts: an optional list of raw texts; this is required for generating text pair datasets
:param num_entries_per_input_and_label: the number of entries to generate for each pair of input text and label
:param num_entries_per_label: the number of entries to generate for each label
:param batch_size: the number of entries to generate simultaneously
:return: the generated dataset
"""
generate_with_inputs = input_texts is not None
if not generate_with_inputs:
input_texts = list(range(math.ceil(num_entries_per_label / batch_size)))
num_entries_per_input_and_label = batch_size
input_iterator = tqdm(input_texts, desc="Dataset Entries")
dataset = []
for input_text_or_id in input_iterator:
for label in self.labels:
dataset += self._generate_dataset_entries(input_text_or_id, label=label, num_entries=num_entries_per_input_and_label,
generate_with_inputs=generate_with_inputs)
dataset = self._postprocess_dataset(dataset, generate_with_inputs)
return dataset
def _generate_dataset_entries(self, input_text_or_id: Union[str, int], label: str, num_entries: int,
generate_with_inputs: bool) -> List[DatasetEntry]:
instruction = self._build_instruction(label, input_text_or_id, generate_with_inputs)
if self.openai_api_key is not None:
try:
model_responses = [openai.Completion.create(
engine=self.model, prompt=instruction, max_tokens=self.max_output_length, top_p=self.top_p, stop=['"']
) for _ in range(num_entries)]
model_outputs = [model_response["choices"][0]["text"] for model_response in model_responses]
except openai.error.RateLimitError as e:
print(e)
return []
else:
counter_instructions = [
self._build_instruction(other_label, input_text_or_id, generate_with_inputs) for other_label in self.counter_labels[label]
]
model_outputs = self.model.generate_self_debiasing(
input_text=instruction, debiasing_texts=counter_instructions, num_samples=num_entries, decay_constant=self.decay_constant,
do_sample=True, min_length=self.max_output_length, max_length=self.max_output_length, top_k=self.top_k, top_p=self.top_p
)
model_outputs = [
self._process_output(input_text=input_text_or_id, output_text=output, label=label, generate_with_inputs=generate_with_inputs)
for output in model_outputs
]
model_outputs = [output for output in model_outputs if output is not None]
return model_outputs
def _build_instruction(self, label: str, text: str, generate_with_inputs: bool) -> str:
instruction_template = self.instructions[label]
if generate_with_inputs:
assert instruction_template.count(PLACEHOLDER_STR) == 1, \
f"An input text was provided, but the instruction for label '{label}' does not contain exactly one placeholder"
return instruction_template.replace(PLACEHOLDER_STR, text)
else:
assert instruction_template.count(PLACEHOLDER_STR) == 0, \
f"No input text was provided, but the instruction for label '{label}' contains a placeholder"
return instruction_template
def _process_output(self, input_text: Union[str, int], output_text: str, label: str, generate_with_inputs: bool) \
-> Optional[DatasetEntry]:
if not self.ignore_eos:
output_text = output_text.split('"')[0] if '"' in output_text else (output_text if self.keep_outputs_without_eos else None)
else: # ignore eos
if '"' in output_text: # remove the artificial eos symbol '"'
output_text = output_text[0:output_text.find('"')] + output_text[output_text.find('"')+1:]
last_punct_index = -1
for i in range(len(output_text)-1, 0, -1):
if output_text[i] in SENT_END:
last_punct_index = i
break
if last_punct_index > 0:
output_text = output_text[:last_punct_index+1]
if output_text and ('\n' not in output_text or self.allow_newlines_in_outputs):
text_a = input_text if generate_with_inputs else output_text
text_b = output_text if generate_with_inputs else None
return DatasetEntry(text_a=text_a, text_b=text_b, label=label)
return None
def _postprocess_dataset(self, dataset: List[DatasetEntry], generate_with_inputs: bool) -> List[DatasetEntry]:
if self.remove_duplicates:
dataset = list(set(dataset))
if self.min_num_words > 0:
if generate_with_inputs:
dataset = [entry for entry in dataset if len(entry.text_b.split()) >= self.min_num_words]
else:
dataset = [entry for entry in dataset if len(entry.text_a.split()) >= self.min_num_words]
if self.min_num_tokens > 0:
if generate_with_inputs:
dataset = [entry for entry in dataset if len(self.model._tokenizer.tokenize(entry.text_b)) >= self.min_num_tokens]
else:
dataset = [entry for entry in dataset if len(self.model._tokenizer.tokenize(entry.text_a)) >= self.min_num_tokens]
if generate_with_inputs and self.remove_identical_pairs:
dataset = [entry for entry in dataset if entry.text_a != entry.text_b]
return dataset
class ModelWrapper(ABC):
"""
This class represents a wrapper for a pretrained language model that provides high-level functions for the generation of texts with
the self-debiasing method described in https://arxiv.org/abs/2103.00453.
"""
def __init__(self, use_cuda: bool = True):
"""
:param use_cuda: whether to use CUDA
"""
self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self._tokenizer = None # type: Optional[PreTrainedTokenizer]
self._model = None # type: Optional[PreTrainedModel]
def query_model(self, input_text: str) -> torch.FloatTensor:
"""For a given input text, returns the probability distribution over possible next tokens."""
return self.query_model_batch([input_text])[0]
@abstractmethod
def query_model_batch(self, input_texts: List[str]) -> torch.FloatTensor:
"""For a batch of input texts, returns the probability distribution over possible next tokens."""
pass
@abstractmethod
def generate(self, input_text: str, **kwargs) -> str:
"""Generates a continuation for a given input text."""
pass
@abstractmethod
def generate_self_debiasing(self, input_text: str, debiasing_texts: List[str], num_samples: int = 1, decay_constant: float = 100,
epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
"""
Generates continuations for the given input texts with self-debiasing.
:param input_texts: the input texts to generate continuations for
:param debiasing_prefixes: the debiasing prefixes to be used
:param decay_constant: the decay constant (lambda in the paper)
:param epsilon: the minimum factor by which each probability is multiplied
:param debug: whether to print additional debugging output
:param kwargs: further arguments are passed on to the original generate function
:return: the list of generated continuations
"""
pass
class GPT2Wrapper(ModelWrapper):
def __init__(self, model_name: str = "gpt2-xl", use_cuda: bool = True):
"""
:param model_name: the name of the pretrained GPT2 model (default: "gpt2-xl")
:param use_cuda: whether to use CUDA
"""
super().__init__(use_cuda=use_cuda)
self._tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self._model = SelfDebiasingGPT2LMHeadModel.from_pretrained(model_name) # type: SelfDebiasingGPT2LMHeadModel
if use_cuda:
self._model.parallelize()
self._tokenizer.pad_token = self._tokenizer.eos_token
self._model.config.pad_token_id = self._tokenizer.eos_token_id
def query_model_batch(self, input_texts: List[str]):
inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, max_length=512, return_tensors='pt')
inputs = {key: val.to(self._device) for key, val in inputs.items()}
output_indices = inputs['attention_mask'].sum(dim=1) - 1
output = self._model(**inputs)['logits']
return torch.stack([output[example_idx, last_word_idx, :] for example_idx, last_word_idx in enumerate(output_indices)])
def generate(self, input_text: str, **kwargs):
input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
output_ids = self._model.generate(input_ids, **kwargs)[0]
return self._tokenizer.decode(output_ids)
def generate_self_debiasing(self, input_text: str, debiasing_texts: List[str], num_samples: int = 1, decay_constant: float = 100,
epsilon: float = 0.01, debug: bool = False, min_length: int = None, max_length: int = None,
**kwargs) -> List[str]:
self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_texts), decay_constant=decay_constant, epsilon=epsilon,
debug=debug, tokenizer=self._tokenizer)
inputs = [input_text] * num_samples
for debiasing_text in debiasing_texts:
inputs += [debiasing_text] * num_samples
inputs = self._tokenizer.batch_encode_plus(inputs, padding=True, return_tensors='pt')
inputs['attention_mask'] = torch.flip(inputs['attention_mask'], dims=[1])
shifts = inputs['attention_mask'].shape[-1] - inputs['attention_mask'].sum(dim=-1)
for batch_idx in range(inputs['input_ids'].shape[0]):
inputs['input_ids'][batch_idx] = inputs['input_ids'][batch_idx].roll(shifts[batch_idx].item())
inputs = {k: v.to(self._device) for k, v in inputs.items()}
input_length = inputs['input_ids'].shape[1]
if min_length is not None:
min_length = min_length + input_length
if max_length is not None:
max_length = min(self._model.config.max_position_embeddings, max_length + input_length)
output_ids = self._model.generate(**inputs, min_length=min_length, max_length=max_length, no_repeat_ngram_size=3, **kwargs)
batch_size = output_ids.shape[0] // (1 + len(debiasing_texts))
output_ids = output_ids[:batch_size, inputs['input_ids'].shape[1]:]
return self._tokenizer.batch_decode(output_ids)