diff --git a/machamp/data/machamp_dataset.py b/machamp/data/machamp_dataset.py index 7bf8596..7099f2a 100644 --- a/machamp/data/machamp_dataset.py +++ b/machamp/data/machamp_dataset.py @@ -87,7 +87,7 @@ def __init__(self, num_s2s = 0 for task in self.datasets[dataset]['tasks']: task_config = self.datasets[dataset]['tasks'][task] - is_clas = task_config['task_type'] in ['classification', 'probdistr', 'regression'] + is_clas = task_config['task_type'] in ['classification', 'probdistr', 'regression', 'multiclas'] read_seq = task_config['column_idx'] == -1 if 'column_idx' in task_config else None if is_clas and not read_seq: @@ -133,6 +133,28 @@ def __init__(self, is_train, max_sents, max_words, max_input_length): self.data[dataset].append(instance) + def task_to_tasktype(self, task: str): + """ + Converts a task-name (str) to its type (str) + + Parameters + ---------- + task: str + The name of the task + + Returns + ------- + task_type: str + The task type of the given task + """ + task_trimmed = task.replace('-heads', '').replace('-rels', '') + if task_trimmed in self.tasks: + index = self.tasks.index(task_trimmed) + else: + logger.error(task + ' not found in ' + str(self.tasks)) + exit(1) + return self.task_types[index] + def __len__(self): """ The length is defined as the combined number of batches diff --git a/machamp/data/machamp_vocabulary.py b/machamp/data/machamp_vocabulary.py index 633ba1d..14d1f44 100644 --- a/machamp/data/machamp_vocabulary.py +++ b/machamp/data/machamp_vocabulary.py @@ -7,10 +7,6 @@ import os -UNK_ID = 0 -UNK = '@@unkORpad@@' - - class MachampVocabulary(): def __init__(self): """ @@ -24,6 +20,8 @@ def __init__(self): self.namespaces = {} self.inverse_namespaces = {} self.hasUnk = {} + self.UNK_ID = 0 + self.UNK = '@@unkORpad@@' def load_vocab(self, vocab_path: str, name: str): """ @@ -70,7 +68,7 @@ def get_unk(self, name: str): name in the namespace. """ if self.hasUnk[name]: - return UNK + return self.UNK def get_unk_id(self, name: str): """ @@ -82,7 +80,7 @@ def get_unk_id(self, name: str): name in the namespace. """ if self.hasUnk[name]: - return UNK_ID + return self.UNK_ID def get_vocab(self, name: str): """ @@ -126,9 +124,9 @@ def token2id(self, token: str, namespace: str, add_if_not_present: bool): self.inverse_namespaces[namespace].append(token) return len(self.inverse_namespaces[namespace]) - 1 else: - return UNK_ID if self.hasUnk[namespace] else None + return self.UNK_ID if self.hasUnk[namespace] else None if self.hasUnk[namespace]: - return self.namespaces[namespace].get(token, UNK_ID) + return self.namespaces[namespace].get(token, self.UNK_ID) else: return self.namespaces[namespace].get(token, None) @@ -162,8 +160,8 @@ def create_vocab(self, name: str, has_unk: bool): Whether this vocabulary should have an unknown/padding token. """ if name not in self.namespaces: - self.namespaces[name] = {UNK: UNK_ID} if has_unk else {} - self.inverse_namespaces[name] = [UNK] if has_unk else [] + self.namespaces[name] = {self.UNK: self.UNK_ID} if has_unk else {} + self.inverse_namespaces[name] = [self.UNK] if has_unk else [] self.hasUnk[name] = has_unk def save_vocabs(self, out_dir: str): diff --git a/machamp/metrics/metric.py b/machamp/metrics/metric.py index c5da3ba..5f95659 100644 --- a/machamp/metrics/metric.py +++ b/machamp/metrics/metric.py @@ -1,5 +1,6 @@ import logging +from machamp.metrics.multi_accuracy import MultiAccuracy from machamp.metrics.accuracy import Accuracy from machamp.metrics.avg_dist import AvgDist from machamp.metrics.f1 import F1 @@ -25,6 +26,8 @@ def __init__(self, metric_name: str): self.metrics = {} if metric_name == 'accuracy': self.metrics[metric_name] = Accuracy() + elif metric_name == 'multi_acc': + self.metrics[metric_name] = MultiAccuracy() elif metric_name == 'las': self.metrics[metric_name] = LAS() elif metric_name == 'avg_dist': diff --git a/machamp/metrics/multi_accuracy.py b/machamp/metrics/multi_accuracy.py new file mode 100644 index 0000000..5467ffc --- /dev/null +++ b/machamp/metrics/multi_accuracy.py @@ -0,0 +1,32 @@ +import torch + + +class MultiAccuracy(): + def __init__(self): + self.cor = 0 + self.total = 0 + self.str = 'multi-acc.' + + def score(self, preds, golds, mask, vocabulary): + # TODO: can this be done more efficient? + if len(preds.shape) == 3: + for sent_idx in range(len(mask)): + for word_idx in range(len(mask[sent_idx])): + if mask[sent_idx][word_idx]: + if torch.all(preds[sent_idx][word_idx] == golds[sent_idx][word_idx]): + self.cor += 1 + self.total += 1 + if len(preds.shape) == 2: + for sent_idx in range(len(preds)): + if torch.all(preds[sent_idx] == golds[sent_idx]): + self.cor += 1 + self.total += 1 + + def reset(self): + self.cor = 0 + self.total = 0 + + def get_score(self): + if self.total == 0: + return self.str, 0.0 + return self.str, self.cor / self.total diff --git a/machamp/model/classification_decoder.py b/machamp/model/classification_decoder.py index 36d3360..6e894eb 100644 --- a/machamp/model/classification_decoder.py +++ b/machamp/model/classification_decoder.py @@ -7,7 +7,7 @@ class MachampClassificationDecoder(MachampDecoder, torch.nn.Module): def __init__(self, task, vocabulary, input_dim, device, loss_weight: float = 1.0, topn: int = 1, metric: str = 'accuracy', **kwargs): - super().__init__(task, vocabulary, loss_weight, metric) + super().__init__(task, vocabulary, loss_weight, metric, device) nlabels = len(self.vocabulary.get_vocab(task)) self.hidden_to_label = torch.nn.Linear(input_dim, nlabels) diff --git a/machamp/model/crf_label_decoder.py b/machamp/model/crf_label_decoder.py index 74d3efe..3fbe5ad 100644 --- a/machamp/model/crf_label_decoder.py +++ b/machamp/model/crf_label_decoder.py @@ -22,7 +22,7 @@ def __init__( topn: int = 1, **kwargs ) -> None: - super().__init__(task, vocabulary, loss_weight, metric) + super().__init__(task, vocabulary, loss_weight, metric, device) nlabels = len(self.vocabulary.get_vocab(task)) self.input_dim = input_dim # + dec_dataset_embeds_dim diff --git a/machamp/model/dependency_decoder.py b/machamp/model/dependency_decoder.py index 66c251c..df83b54 100644 --- a/machamp/model/dependency_decoder.py +++ b/machamp/model/dependency_decoder.py @@ -110,9 +110,8 @@ def __init__( arc_representation_dim: int = 768, **kwargs, ) -> None: - super().__init__(task, vocabulary, loss_weight, metric) + super().__init__(task, vocabulary, loss_weight, metric, device) - self.device = device self.input_dim = input_dim # + dec_dataset_embeds_dim arc_representation_dim = arc_representation_dim # + dec_dataset_embeds_dim diff --git a/machamp/model/encoder.py b/machamp/model/encoder.py index 280289d..5b7556a 100644 --- a/machamp/model/encoder.py +++ b/machamp/model/encoder.py @@ -13,8 +13,8 @@ class MachampEncoder(): def __init__(self, mlm: AutoModel, max_input_length: int, - padding_token_id: int, - cls_token_id: int): + end_token_id: int, + start_token_id: int): """ The main (shared) encoder of a MachampModel. This class is mainly handling the formatting of the input/output to @@ -29,16 +29,17 @@ def __init__(self, max_input_length: int The maximum input length to the encoder, most of the code in this class is actually to handle this correctly. - padding_token_id: int + end_token_id: int The token id used for padding (behind the input) - cls_token_id: int + start_token_id: int The token id used for the start-of-sentence token (also called the cls token since BERT) """ self.mlm = mlm self.max_input_length = max_input_length - self.padding_token_id = padding_token_id - self.cls_token_id = cls_token_id + self.end_token_id = end_token_id + self.start_token_id = start_token_id + self.num_extra_tokens = 2-[start_token_id, end_token_id].count(None) def get_size(self, own_size: int, max_size: int): """ @@ -63,7 +64,7 @@ def get_size(self, own_size: int, max_size: int): """ # max(1, ..) is necessary for empty inputs, we do not want # to have 0 splits! - return max(1, math.ceil((own_size - 2) / (max_size - 2))) + return max(1, math.ceil((own_size - self.num_extra_tokens) / (max_size - self.num_extra_tokens))) def run_mlm(self, input_token_ids: torch.tensor, @@ -97,6 +98,8 @@ def run_mlm(self, args = {'input_ids': input_token_ids, 'attention_mask': subword_mask, 'output_hidden_states': True} if 'token_type_ids' in argspec[0]: args['token_type_ids'] = seg_ids + if 'decoder_input_ids' in argspec[0]: + args['decoder_input_ids'] = seg_ids output = self.mlm.forward(**args) @@ -137,7 +140,9 @@ def embed(self, of memory in the transformers library, for the decoders this matters a lot less, so we can already merge here. For the descriptions of the parameter below, note that max_sent_len_wordpieces is a variable, - depending on the batch. + depending on the batch. We do not use a sliding window at the moment + for readabilities sake (still failed to make the code readable + unforunately ;( ). Parameters ---------- @@ -166,52 +171,91 @@ def embed(self, return self.run_mlm(input_token_ids, seg_ids, subword_mask) else: # input is too long, handle: if dont_split: # truncate + # Shall we add the special last token and lose one subword instead? return self.run_mlm(input_token_ids[:, :self.max_input_length], seg_ids[:, :self.max_input_length], subword_mask[:, :self.max_input_length]) else: # split, embed, merge batch_size = input_token_ids.size(0) - - lengths = [(torch.nonzero(input_token_ids[sent_idx] == self.padding_token_id)[0]).item() + 1 for + if self.end_token_id != None: + lengths = [(torch.nonzero(input_token_ids[sent_idx] == self.end_token_id)[0]).item() + 1 for sent_idx in range(batch_size)] + else: + lengths = [] + for sent_idx in range(batch_size): + if 0 in input_token_ids[sent_idx]: + lengths.append((torch.nonzero(input_token_ids[sent_idx] == 0)[0]).item() + 1) + else: + lengths.append(len(input_token_ids[sent_idx])) + amount_of_splits = [self.get_size(length, self.max_input_length) for length in lengths] new_batch_size = sum(amount_of_splits) - new_input_tokens = torch.full((new_batch_size, self.max_input_length), self.padding_token_id, - device=input_token_ids.device, dtype=torch.int64) + if self.end_token_id != None: + new_input_tokens = torch.full((new_batch_size, self.max_input_length), self.end_token_id, + device=input_token_ids.device, dtype=torch.int64) + else: + new_input_tokens = torch.full((new_batch_size, self.max_input_length), 0, + device=input_token_ids.device, dtype=torch.int64) new_seg_ids = torch.full((new_batch_size, self.max_input_length), 0, device=input_token_ids.device, dtype=torch.int64) - new_subword_mask = torch.full((new_batch_size, self.max_input_length), 0, device=input_token_ids.device, dtype=torch.int64) + if type(subword_mask) != type(None): + new_subword_mask = torch.full((new_batch_size, self.max_input_length), 0, device=input_token_ids.device, dtype=torch.int64) curBatchIdx = 0 for sentIdx in range(batch_size): + # if current sentence < max_len, just copy it if lengths[sentIdx] <= self.max_input_length: new_input_tokens[curBatchIdx][:lengths[sentIdx]] = input_token_ids[sentIdx][:lengths[sentIdx]] new_seg_ids[curBatchIdx][:lengths[sentIdx]] = seg_ids[sentIdx][:lengths[sentIdx]] new_subword_mask[curBatchIdx][:lengths[sentIdx]] = subword_mask[sentIdx][:lengths[sentIdx]] curBatchIdx += 1 else: - # remove special tokens for simplicity, then we can just take max_input_length-2 elements - # for each split (except the last) - token_ids_sent = input_token_ids[sentIdx][1:-1] - seg_ids_sent = seg_ids[sentIdx][1:-1] + # remove special tokens for simplicity, we will add them in each split manually + token_ids_sent = input_token_ids[sentIdx] + seg_ids_sent = seg_ids[sentIdx] if type(subword_mask) != type(None): - subword_mask_sent = subword_mask[sentIdx][1:-1] + subword_mask_sent = subword_mask[sentIdx] + + if self.start_token_id != None: + token_ids_sent = token_ids_sent[1:] + seg_ids_sent = seg_ids_sent[1:] + if type(subword_mask) != type(None): + subword_mask_sent = subword_mask_sent[1:] + if self.end_token_id != None: + token_ids_sent = token_ids_sent[:-1] + seg_ids_sent = seg_ids_sent[:-1] + if type(subword_mask) != type(None): + subword_mask_sent = subword_mask_sent[:-1] + for split in range(amount_of_splits[sentIdx]): - beg = (self.max_input_length - 2) * split + beg = (self.max_input_length - self.num_extra_tokens) * split if split + 1 == amount_of_splits[sentIdx]: - end = lengths[sentIdx]-2 + end = lengths[sentIdx]-self.num_extra_tokens + else: + end = (self.max_input_length - self.num_extra_tokens) * (split + 1) + if self.start_token_id != None: + new_input_tokens[curBatchIdx][1:end - beg + 1] = token_ids_sent[beg:end] + new_input_tokens[curBatchIdx][0] = self.start_token_id + new_seg_ids[curBatchIdx][1:end - beg + 1] = seg_ids_sent[beg:end] + new_seg_ids[curBatchIdx][0] = new_seg_ids[curBatchIdx][1] + new_subword_mask[curBatchIdx][0] = 1 + new_subword_mask[curBatchIdx][1:end - beg + 1] = subword_mask_sent[beg:end] + new_subword_mask[curBatchIdx][0] = 1 + new_subword_mask[curBatchIdx][1:end - beg + 1] = subword_mask_sent[beg:end] else: - end = (self.max_input_length - 2) * (split + 1) - new_input_tokens[curBatchIdx][1:end - beg + 1] = token_ids_sent[beg:end] - new_input_tokens[curBatchIdx][0] = self.cls_token_id - new_seg_ids[curBatchIdx][1:end - beg + 1] = seg_ids_sent[beg:end] - new_seg_ids[curBatchIdx][0] = new_seg_ids[curBatchIdx][1] - new_subword_mask[curBatchIdx][0] = 1 - new_subword_mask[curBatchIdx][1:end - beg + 1] = subword_mask_sent[beg:end] + new_input_tokens[curBatchIdx][:end - beg] = token_ids_sent[beg:end] + new_seg_ids[curBatchIdx][:end - beg] = seg_ids_sent[beg:end] + new_subword_mask[curBatchIdx][:end - beg] = subword_mask_sent[beg:end] + new_subword_mask[curBatchIdx][:end - beg] = subword_mask_sent[beg:end] + curBatchIdx += 1 - # would it make sense to split it first?, instead of 35*max_len, have 32*max_len and 3*max_len - # and then run the mlm twice? - # AllenNLP doesn't to do this, and its much easier without, so for now we leave it + # We make the batches longer, but this has no (or a little) + # effect on memory usage, as a maximum number of words per + # batch is used mlm_out_split, mlm_preds = self.run_mlm(new_input_tokens, new_seg_ids, new_subword_mask) - mlm_out_merged = torch.zeros(batch_size, input_token_ids.size(1), mlm_out_split.size(-1), + if self.end_token_id != None: + mlm_out_merged = torch.full((batch_size, input_token_ids.size(1), mlm_out_split.size(-1)), self.end_token_id, + device=input_token_ids.device) + else: + mlm_out_merged = torch.zeros(batch_size, input_token_ids.size(1), mlm_out_split.size(-1), device=input_token_ids.device) splitted_idx = 0 for sent_idx in range(batch_size): @@ -219,23 +263,32 @@ def embed(self, mlm_out_merged[sent_idx][0:lengths[sent_idx]] = mlm_out_split[splitted_idx][0:lengths[sent_idx]] splitted_idx += 1 else: - # first of the splits, keep the CLS - mlm_out_merged[sent_idx][0:self.max_input_length - 1] = mlm_out_split[splitted_idx][ - 0:self.max_input_length - 1] + # first of the splits, keep as is + num_subwords = self.max_input_length + if self.end_token_id != None: + num_subwords -= 1 + mlm_out_merged[sent_idx][0:num_subwords] = mlm_out_split[splitted_idx][ + 0:num_subwords] + splitted_idx += 1 - # all except first and last, only keep the body (not CLS, not SEP) + # all except first and last, has no CLS/SEP for i in range(1, amount_of_splits[sent_idx] - 1): - beg = i * ( - self.max_input_length - 2) - 1 # -1 because the first line doesnt have a SEP, -2 because we do not need CLS and SEP from each split - end = beg + self.max_input_length - 2 - mlm_out_merged[sent_idx][beg:end] = mlm_out_split[splitted_idx][1:-1] + beg = num_subwords + (i-1) * (self.max_input_length) + end = beg + self.max_input_length - self.num_extra_tokens + mlm_out_cursplit = mlm_out_split[splitted_idx] + if self.end_token_id != None: + mlm_out_cursplit = mlm_out_cursplit[:-1] + if self.start_token_id != None: + mlm_out_cursplit = mlm_out_cursplit[1:] + + mlm_out_merged[sent_idx][beg:end] = mlm_out_cursplit splitted_idx += 1 # last of the splits, keep the SEP - beg = (amount_of_splits[sent_idx] - 1) * (self.max_input_length - 2) - 1 - end = lengths[sent_idx]-1 + beg = num_subwords + (amount_of_splits[sent_idx] - 2) * (self.max_input_length - self.num_extra_tokens) + end = lengths[sent_idx] mlm_out_merged[sent_idx][beg:end] = mlm_out_split[splitted_idx][0:end - beg] splitted_idx += 1 # Note that mlm_preds is not split. This is an error/bug, but we hardcoded that for the MLM - # task, splitting shouldn't happen, so it will never occur in practice + # task splitting shouldn't happen, so it will never occur in practice return mlm_out_merged, mlm_preds diff --git a/machamp/model/machamp.py b/machamp/model/machamp.py index 3f25630..50242d9 100644 --- a/machamp/model/machamp.py +++ b/machamp/model/machamp.py @@ -14,9 +14,12 @@ from machamp.model.classification_decoder import MachampClassificationDecoder from machamp.model.regression_decoder import MachampRegressionDecoder from machamp.model.seq_label_decoder import MachampSeqDecoder +from machamp.model.multiseq_decoder import MachampMultiseqDecoder from machamp.model.crf_label_decoder import MachampCRFDecoder from machamp.model.dependency_decoder import MachampDepDecoder from machamp.model.mlm_decoder import MachampLMDecoder +from machamp.model.multiclas_decoder import MachampMulticlasDecoder + from machamp.model.encoder import MachampEncoder from machamp.metrics.avg_dist import AvgDist from machamp.metrics.perplexity import Perplexity @@ -117,7 +120,11 @@ def __init__(self, else: self.dropout = torch.nn.Dropout(dropout) - self.encoder = MachampEncoder(self.mlm, max_input_length, tokenizer.sep_token_id, tokenizer.cls_token_id) + tokenizer_out = tokenizer.prepare_for_model([])['input_ids'] + # we assume that if there is only one special token that it is the end token + self.end_token = None if len(tokenizer_out) == 0 else tokenizer_out[-1] + self.start_token = None if len(tokenizer_out) <= 1 else tokenizer_out[0] + self.encoder = MachampEncoder(self.mlm, max_input_length, self.end_token, self.start_token) self.decoders = torch.nn.ModuleDict() for task, task_type in zip(self.tasks, self.task_types): @@ -136,6 +143,10 @@ def __init__(self, decoder_type = MachampRegressionDecoder elif task_type == 'mlm': decoder_type = MachampLMDecoder + elif task_type == 'multiseq': + decoder_type = MachampMultiseqDecoder + elif task_type == 'multiclas': + decoder_type = MachampMulticlasDecoder else: logger.error('Error, task_type ' + task_type + ' not implemented') exit(1) @@ -205,9 +216,8 @@ def forward(self, cur_task_types = self.task_types is_only_mlm = sum([task_type != 'mlm' for task_type in cur_task_types]) == 0 is_only_classification = sum( - [task_type not in ['classification', 'regression'] for task_type in cur_task_types]) == 0 + [task_type not in ['classification', 'regression', 'multiclas'] for task_type in cur_task_types]) == 0 dont_split = is_only_mlm or is_only_classification - # Run transformer model on input mlm_out, mlm_preds = self.encoder.embed(input_token_ids, seg_ids, dont_split, subword_mask) @@ -216,8 +226,8 @@ def forward(self, mlm_out_tok = None - if 'classification' in self.task_types or 'regression' in self.task_types: - mlm_out_sent = mlm_out[:, :1, :].squeeze() + if 'classification' in self.task_types or 'regression' in self.task_types or 'multiclas' in self.task_types: + mlm_out_sent = mlm_out[:, :1, :].squeeze() # always take first token, even if it is not a special token if self.dropout != None: mlm_out_sent = self.dropout(mlm_out_sent) @@ -229,16 +239,21 @@ def forward(self, mlm_out_token = self.dropout(mlm_out_token) if 'tok' in self.task_types: - mlm_out_tok = self.dropout(mlm_out[:, 1:-1, :]) + mlm_out_tok = mlm_out + if self.start_token != None: + mlm_out_tok = mlm_out_tok[:,1:,:] + if self.end_token != None: + mlm_out_tok = mlm_out_tok[:,:-1,:] if self.dropout != None: mlm_out_tok = self.dropout(mlm_out_tok) # get loss from all decoders that have annotations loss = 0.0 + loss_dict = {} if golds != {}: for task, task_type in zip(self.tasks, self.task_types): if task in golds or task + '-rels' in golds: - if task_type in ['classification', 'regression']: + if task_type in ['classification', 'regression', 'multiclas']: out_dict = self.decoders[task].forward(mlm_out_sent, eval_mask, golds[task]) elif task_type == 'dependency': out_dict = self.decoders[task].forward(mlm_out_token, eval_mask, golds[task + '-heads'], @@ -253,7 +268,8 @@ def forward(self, else: out_dict = self.decoders[task].forward(mlm_out_token, eval_mask, golds[task]) loss += out_dict['loss'] - return loss, mlm_out_token, mlm_out_sent, mlm_out_tok + loss_dict[task] = out_dict['loss'].item() + return loss, mlm_out_token, mlm_out_sent, mlm_out_tok, mlm_preds, loss_dict def get_output_labels(self, input_token_ids: torch.tensor, @@ -300,7 +316,7 @@ def get_output_labels(self, (lists of) the outputs for this task. """ # Run transformer model on input - _, mlm_out_token, mlm_out_sent, mlm_out_tok = self.forward(input_token_ids, {}, seg_ids, eval_mask, offsets, + _, mlm_out_token, mlm_out_sent, mlm_out_tok, mlm_preds, _ = self.forward(input_token_ids, {}, seg_ids, eval_mask, offsets, subword_mask, True) out_dict = {} has_tok = 'tok' in self.task_types @@ -328,7 +344,7 @@ def get_output_labels(self, for task, task_type in zip(self.tasks, self.task_types): - if task_type in ['classification', 'regression']: + if task_type in ['classification', 'regression', 'multiclas']: out_dict[task] = self.decoders[task].get_output_labels(mlm_out_sent, eval_mask, golds[task]) elif self.task_types[self.tasks.index(task)] == 'dependency': if has_tok: diff --git a/machamp/model/machamp_decoder.py b/machamp/model/machamp_decoder.py index 2b2866a..4a6baf6 100644 --- a/machamp/model/machamp_decoder.py +++ b/machamp/model/machamp_decoder.py @@ -4,13 +4,14 @@ class MachampDecoder(torch.nn.Module): - def __init__(self, task, vocabulary, loss_weight: float = 1.0, metric: str = 'avg_dist'): + def __init__(self, task, vocabulary, loss_weight: float = 1.0, metric: str = 'avg_dist', device: str = 'cpu'): super().__init__() self.task = task self.vocabulary = vocabulary self.metric = Metric(metric) self.loss_weight = loss_weight + self.device = device def reset_metrics(self): self.metric.reset() diff --git a/machamp/model/mlm_decoder.py b/machamp/model/mlm_decoder.py index 36c1289..8ee54c5 100644 --- a/machamp/model/mlm_decoder.py +++ b/machamp/model/mlm_decoder.py @@ -15,7 +15,7 @@ def __init__( topn: int = 1, **kwargs ) -> None: - super().__init__(task, vocabulary, loss_weight, metric) + super().__init__(task, vocabulary, loss_weight, metric, device) self.input_dim = input_dim # + dec_dataset_embeds_dim self.loss_function = torch.nn.CrossEntropyLoss() @@ -31,5 +31,5 @@ def forward(self, mlm_preds, gold, mask=None): return {'loss': self.loss_weight * lm_loss} def get_output_labels(self, mlm_out, gold, mask=None): - forward(mlm_out, gold) + self.forward(mlm_out, gold) return {'word_labels': [], 'probs': []} diff --git a/machamp/model/multiclas_decoder.py b/machamp/model/multiclas_decoder.py new file mode 100644 index 0000000..fc3e7ee --- /dev/null +++ b/machamp/model/multiclas_decoder.py @@ -0,0 +1,49 @@ +import torch +import torch.nn.functional as F + +from machamp.model.machamp_decoder import MachampDecoder + + +class MachampMulticlasDecoder(MachampDecoder, torch.nn.Module): + def __init__(self, task, vocabulary, input_dim, device, loss_weight: float = 1.0, topn: int = 1, + metric: str = 'accuracy', threshold: float = .0, **kwargs): + super().__init__(task, vocabulary, loss_weight, metric, device) + + nlabels = len(self.vocabulary.get_vocab(task)) + self.hidden_to_label = torch.nn.Linear(input_dim, nlabels) + self.hidden_to_label.to(device) + self.loss_function = torch.nn.BCEWithLogitsLoss() + self.topn = topn + self.threshold = threshold + + def forward(self, mlm_out, mask, gold=None): + logits = self.hidden_to_label(mlm_out) + out_dict = {'logits': logits} + if type(gold) != type(None): + preds = logits > self.threshold + self.metric.score(preds[:,1:], gold.eq(torch.tensor(1.0, device=self.device))[:,1:], mask, None) + out_dict['loss'] = self.loss_weight * self.loss_function(logits[:,1:], gold.to(torch.float32)[:,1:]) + return out_dict + + def get_output_labels(self, mlm_out, mask, gold=None): + logits = self.forward(mlm_out, mask, gold)['logits'] + if self.topn == 1: + all_labels = [] + preds = logits > self.threshold + for sent_idx in range(len(preds)): + sent_labels = [] + for label_idx in range(1,len(preds[sent_idx])): + if preds[sent_idx][label_idx]: + sent_labels.append(self.vocabulary.id2token(label_idx, self.task)) + all_labels.append('|'.join(sent_labels)) + return {'sent_labels': all_labels} + + else: # TODO implement top-n + labels = [] + probs = [] + class_probs = F.softmax(logits, -1) + for sent_scores in class_probs: + topk = torch.topk(sent_scores[1:], self.topn) + labels.append([self.vocabulary.id2token(label_id + 1, self.task) for label_id in topk.indices]) + probs.append([score.item() for score in topk.values]) + return {'sent_labels': labels, 'probs': probs} diff --git a/machamp/model/multiseq_decoder.py b/machamp/model/multiseq_decoder.py new file mode 100644 index 0000000..807c7f9 --- /dev/null +++ b/machamp/model/multiseq_decoder.py @@ -0,0 +1,75 @@ +import torch +import torch.nn.functional as F + +from machamp.model.machamp_decoder import MachampDecoder + + +class MachampMultiseqDecoder(MachampDecoder, torch.nn.Module): + def __init__( + self, + task: str, + vocabulary, + input_dim: int, + device: str, + loss_weight: float = 1.0, + metric: str = 'accuracy', + topn: int = 1, + threshold: float = .0, + **kwargs + ) -> None: + super().__init__(task, vocabulary, loss_weight, metric, device) + + nlabels = len(self.vocabulary.get_vocab(task)) + self.input_dim = input_dim # + dec_dataset_embeds_dim + self.hidden_to_label = torch.nn.Linear(input_dim, nlabels) + self.hidden_to_label.to(device) + self.loss_function = torch.nn.BCEWithLogitsLoss() + self.threshold = threshold + self.topn = topn + + def forward(self, mlm_out, mask, gold=None): + logits = self.hidden_to_label(mlm_out) + out_dict = {'logits': logits} + if type(gold) != type(None): + # convert scores to binary: + preds = logits > self.threshold + self.metric.score(preds[:,:,1:], gold.eq(torch.tensor(1.0, device='cuda:0'))[:,:,1:], mask, self.vocabulary.inverse_namespaces[self.task]) + loss = self.loss_weight * self.loss_function(logits[:,:,1:], gold.to(torch.float32)[:,:,1:]) + out_dict['loss'] = loss + return out_dict + + def get_output_labels(self, mlm_out, mask, gold=None): + """ + logits = batch_size*sent_len*num_labels + argmax converts to a list of batch_size*sent_len, + we add 1 because we leave out the padding/unk + token in position 0 (thats what [:,:,1:] does) + """ + + logits = self.forward(mlm_out, mask, gold)['logits'] + if self.topn == 1: + all_labels = [] + preds = logits > self.threshold + all_labels = [] + for sent_idx in range(len(preds)): + sent_labels = [] + for word_idx in range(len(preds[sent_idx])): + word_labels = [] + for label_idx in range(1, len(preds[sent_idx][word_idx])): + if preds[sent_idx][word_idx][label_idx]: + word_labels.append(self.vocabulary.id2token(label_idx, self.task)) + sent_labels.append('|'.join(word_labels)) + all_labels.append(sent_labels) + return {'word_labels': all_labels} + else: # TODO implement topn? + tags = [] + probs = [] + class_probs = F.softmax(logits, -1) + for sent_scores in class_probs: + tags.append([]) + probs.append([]) + for word_scores in sent_scores: + topk = torch.topk(word_scores[1:], self.topn) + tags[-1].append([self.vocabulary.id2token(label_id + 1, self.task) for label_id in topk.indices]) + probs[-1].append([score.item() for score in topk.values]) + return {'word_labels': tags, 'probs': probs} diff --git a/machamp/model/regression_decoder.py b/machamp/model/regression_decoder.py index c8857e9..3d75acb 100644 --- a/machamp/model/regression_decoder.py +++ b/machamp/model/regression_decoder.py @@ -6,7 +6,7 @@ class MachampRegressionDecoder(MachampDecoder, torch.nn.Module): def __init__(self, task, vocabulary, input_dim, device, loss_weight: float = 1.0, topn: int = 1, metric: str = 'avg_dist', **kwargs): - super().__init__(task, vocabulary, loss_weight, metric) + super().__init__(task, vocabulary, loss_weight, metric, device) self.hidden_to_label = torch.nn.Linear(input_dim, 1) self.hidden_to_label.to(device) diff --git a/machamp/model/seq_label_decoder.py b/machamp/model/seq_label_decoder.py index 5cb3942..c94554e 100644 --- a/machamp/model/seq_label_decoder.py +++ b/machamp/model/seq_label_decoder.py @@ -16,7 +16,7 @@ def __init__( topn: int = 1, **kwargs ) -> None: - super().__init__(task, vocabulary, loss_weight, metric) + super().__init__(task, vocabulary, loss_weight, metric, device) nlabels = len(self.vocabulary.get_vocab(task)) self.input_dim = input_dim # + dec_dataset_embeds_dim diff --git a/machamp/model/trainer.py b/machamp/model/trainer.py index df16a92..15d6a2f 100644 --- a/machamp/model/trainer.py +++ b/machamp/model/trainer.py @@ -148,6 +148,7 @@ def train( model.train() model.reset_metrics() epoch_loss = 0.0 + total_train_losses = {} for train_batch_idx, batch in enumerate(tqdm(train_dataloader, file=sys.stdout)): optimizer.zero_grad() @@ -155,8 +156,12 @@ def train( # gpu ram, it is quite fast anyways batch = myutils.prep_batch(batch, device, train_dataset) - loss, _, _, _ = model.forward(batch['token_ids'], batch['golds'], batch['seg_ids'], batch['eval_mask'], + loss, _, _, _, _, loss_dict = model.forward(batch['token_ids'], batch['golds'], batch['seg_ids'], batch['eval_mask'], batch['offsets'], batch['subword_mask']) + for task in loss_dict: + if task not in total_train_losses: + total_train_losses[task] = 0.0 + total_train_losses[task] += loss_dict[task] loss.backward() optimizer.step() epoch_loss += loss.item() @@ -169,12 +174,17 @@ def train( dev_loss = 0.0 dev_metrics = {} dev_batch_idx = 1 + total_dev_losses = {} if len(dev_dataset) > 0: for dev_batch_idx, batch in enumerate(tqdm(dev_dataloader, file=sys.stdout)): batch = myutils.prep_batch(batch, device, train_dataset) - loss, _, _, _ = model.forward(batch['token_ids'], batch['golds'], batch['seg_ids'], batch['eval_mask'], + loss, _, _, _, _, loss_dict = model.forward(batch['token_ids'], batch['golds'], batch['seg_ids'], batch['eval_mask'], batch['offsets'], batch['subword_mask']) + for task in loss_dict: + if task not in total_dev_losses: + total_dev_losses[task] = 0.0 + total_dev_losses[task] += loss_dict[task] dev_loss += loss.item() dev_metrics = model.get_metrics() @@ -188,7 +198,7 @@ def train( if train_batch_idx == 0: train_batch_idx = 1 info_dict = myutils.report_epoch(epoch_loss / train_batch_idx, dev_loss / dev_batch_idx, epoch, train_metrics, - dev_metrics, epoch_start_time, start_training_time) + dev_metrics, epoch_start_time, start_training_time, device, total_train_losses, total_dev_losses) json.dump(info_dict, open(os.path.join(serialization_dir, 'metrics_epoch_' + str(epoch) + '.json'), 'w'), indent=4) diff --git a/machamp/modules/allennlp/chu_liu_edmonds.py b/machamp/modules/allennlp/chu_liu_edmonds.py index 4f42514..5f16914 100644 --- a/machamp/modules/allennlp/chu_liu_edmonds.py +++ b/machamp/modules/allennlp/chu_liu_edmonds.py @@ -151,6 +151,8 @@ def chu_liu_edmonds( has_cycle, cycle = _find_cycle(parents, length, current_nodes) # If there are no cycles, find all edges and return. if not has_cycle: + foundRoot = False + root = -1 final_edges[0] = -1 for node in range(1, length): if not current_nodes[node]: @@ -158,6 +160,16 @@ def chu_liu_edmonds( parent = old_input[parents[node], node] child = old_output[parents[node], node] + # Rob: Added this to fix double roots + # However, it should be fixed better, this just + # connects all roots to the first found root + if parent == 0: + if foundRoot == True: + final_edges[child] = root + continue + else: + foundRoot = True + root = child final_edges[child] = parent return diff --git a/machamp/predictor/predict.py b/machamp/predictor/predict.py index d3c807e..90493f7 100644 --- a/machamp/predictor/predict.py +++ b/machamp/predictor/predict.py @@ -81,7 +81,7 @@ def to_string(full_data: List[Any], # For word level annotation tasks, we have a different handling # so first detect whether we only have sentence level tasks task_types = [config['tasks'][task]['task_type'] for task in config['tasks']] - only_sent = sum([task_type in ['classification', 'regression'] for task_type in task_types]) == len(config['tasks']) + only_sent = sum([task_type in ['classification', 'regression', 'multiclas'] for task_type in task_types]) == len(config['tasks']) # from transformers import AutoTokenizer # tokzr = AutoTokenizer.from_pretrained('bert-base-multilingual-cased') if only_sent: @@ -92,6 +92,7 @@ def to_string(full_data: List[Any], else: full_data[task_idx] = preds[task]['sent_labels'] return '\t'.join(full_data) + else: # word level annotation has_tok = 'tok' in task_types if has_tok: diff --git a/machamp/readers/read_classification.py b/machamp/readers/read_classification.py index f58adb2..75793e5 100644 --- a/machamp/readers/read_classification.py +++ b/machamp/readers/read_classification.py @@ -92,6 +92,9 @@ def read_classification( subword_counter = 0 unk_counter = 0 test_tok = tokenizer.encode_plus('a', 'b') + has_start_token = len(tokenizer.prepare_for_model([])['input_ids']) == 2 + has_end_token = len(tokenizer.prepare_for_model([])['input_ids']) >= 1 + has_unk_token = tokenizer.unk_token != None has_seg_ids = 'token_type_ids' in test_tok and 1 in test_tok['token_type_ids'] if 'skip_first_line' not in config: config['skip_first_line'] = False @@ -101,7 +104,7 @@ def read_classification( # We use the following format # input: sent1 sent2 sent3 ... # type_ids: 0 0 .. 1 1 .. 0 0 .. 1 1 .. - full_input = [tokenizer.cls_token_id] + full_input = [] seg_ids = [0] for counter, sent_idx in enumerate(sent_idxs): if sent_idx >= len(data_instance): @@ -109,17 +112,27 @@ def read_classification( 'line ' + dataset + ':' + str(sent_idx) + ' doesnt\'t contain enough columns, column ' + str( sent_idx) + ' is missing, should contain input.') exit(1) - new_sent = tokenizer.encode(data_instance[sent_idx].strip())[1:-1] + [copy.deepcopy(tokenizer.sep_token_id)] - subword_counter += len(new_sent) - 1 - if len(new_sent) == 1: + encoding = tokenizer.encode(data_instance[sent_idx].strip())[1:-1] + if has_start_token: + encoding = encoding[1:] + if has_end_token: + encoding = encoding[:-1] + if tokenizer.sep_token_id != None: + encoding = encoding + [copy.deepcopy(tokenizer.sep_token_id)] + subword_counter += len(encoding) + if len(encoding) == 0: logger.warning("empty sentence found in line " + str( sent_counter) + ', column ' + sent_idx + ' replaced with UNK token') - new_sent.insert(0, tokenizer.unk_token_id) + if has_unk_token: + encoding.append(0, tokenizer.unk_token_id) if has_seg_ids: - seg_ids.extend([counter % 2] * len(new_sent)) - full_input.extend(new_sent) + seg_ids.extend([counter % 2] * len(encoding)) + full_input.extend(encoding) unk_counter += full_input.count(tokenizer.unk_token_id) + if has_end_token: + full_input = full_input[:-1] + full_input = tokenizer.prepare_for_model(full_input)['input_ids'] full_input = torch.tensor(full_input, dtype=torch.long) seg_ids = torch.tensor(seg_ids, dtype=torch.long) @@ -153,6 +166,8 @@ def read_classification( logger.error('Column ' + str(col_idxs[task]) + ' in ' + dataset + ':' + str( sent_idx) + " should have a float (for regression task)") exit(1) + elif task_type == 'multiclas': + gold = torch.tensor([vocabulary.token2id(label, task, is_train) for label in gold.split('|')], dtype=torch.long) else: gold = vocabulary.token2id(gold, task, is_train) col_idxs[task] = task_idx diff --git a/machamp/readers/read_mlm.py b/machamp/readers/read_mlm.py index 088e481..62087af 100644 --- a/machamp/readers/read_mlm.py +++ b/machamp/readers/read_mlm.py @@ -60,6 +60,7 @@ def read_mlm( sent_counter = 0 unk_counter = 0 subword_counter = 0 + has_unk = tokenizer.unk_token != None masker = DataCollatorForLanguageModeling(tokenizer) if len(config['tasks']) > 1: @@ -76,14 +77,15 @@ def read_mlm( # truncate too long sentences if len(token_ids) >= max_input_length: - token_ids = token_ids[list(range(127)) + [len(token_ids) - 1]] + token_ids = token_ids[list(range(max_input_length-1)) + [len(token_ids) - 1]] # skip empty lines if len(token_ids) <= 2: continue sent_counter += 1 - unk_counter += sum(token_ids == tokenizer.unk_token_id) + if has_unk: + unk_counter += sum(token_ids == tokenizer.unk_token_id) subword_counter += len(token_ids) - 2 # if index = -1, the dataset name is used, and this is handled in the superclass diff --git a/machamp/readers/read_sequence.py b/machamp/readers/read_sequence.py index 33a3a71..b5f6e26 100644 --- a/machamp/readers/read_sequence.py +++ b/machamp/readers/read_sequence.py @@ -75,7 +75,7 @@ def seqs2data(tabular_file: str, skip_first_line: bool = False): yield sent[beg_idx:], sent -def tokenize_simple(tokenizer: AutoTokenizer, sent: List[List[str]], word_col_idx: int): +def tokenize_simple(tokenizer: AutoTokenizer, sent: List[List[str]], word_col_idx: int, num_special_tokens: int, has_unk: bool): """ A tokenizer that tokenizes each token separately (over gold tokenization). We found that this is the most robust method to tokenize overall (handling @@ -88,8 +88,13 @@ def tokenize_simple(tokenizer: AutoTokenizer, sent: List[List[str]], word_col_id sent: List[List[str]]: Contains all information of the tokens (also annotation), hence a list of lists. - word_col_idx: int): + word_col_idx: int: The column index that contains the input words. + num_special_toks: int + Number of special tokens, here assumed to be 2 (start/end token) or 1 + (only end token) + has_unk: bool + Does the tokenizer have an unk token Returns ------- @@ -105,11 +110,20 @@ def tokenize_simple(tokenizer: AutoTokenizer, sent: List[List[str]], word_col_id for token_idx in range(len(sent)): # TODO remove hardcoded special start-end token, which some don't have (i.e. google/mt5-base) # we do not use return_tensors='pt' because we do not know the length beforehand - tokked = tokenizer.encode(sent[token_idx][word_col_idx])[1:-1] - if len(tokked) == 0: + if num_special_tokens == 2: + tokked = tokenizer.encode(sent[token_idx][word_col_idx])[1:-1] + elif num_special_tokens == 1: + # We assume that if there is only one special token, it is the end token + tokked = tokenizer.encode(sent[token_idx][word_col_idx])[:-1] + elif num_special_tokens == 0: + tokked = tokenizer.encode(sent[token_idx][word_col_idx]) + else: + logger.error('Number of special tokens is currently not handled: ' + str(num_special_tokens)) + exit(1) + if len(tokked) == 0 and has_unk: tokked = [tokenizer.unk_token_id] token_ids.extend(tokked) - offsets.append(len(token_ids)) + offsets.append(len(token_ids)-1) offsets = torch.tensor(offsets, dtype=torch.long) return token_ids, offsets @@ -153,15 +167,11 @@ def get_offsets(gold_tok: List[str], subwords: List[str], norm: bool): tok_labels = [] if norm: gold_tok = [unicodedata.normalize('NFC', unicodedata.normalize('NFKD', myutils.clean_text(tok))) for tok in gold_tok] - #print(gold_tok) - #print(subwords) for word in gold_tok: gold_char_idx += len(word) - #print(word, gold_char_idx) while subword_char_idx < gold_char_idx: # links to the last subword if there is no exact match subword_char_idx += len(subwords[subword_idx].replace(' ', '')) - #print('-', subwords[subword_idx], subword_char_idx) subword_idx += 1 if subword_char_idx < gold_char_idx: tok_labels.append('merge') @@ -380,7 +390,9 @@ def read_sequence( word_counter = 0 unk_counter = 0 subword_counter = 0 + has_unk = tokenizer.unk_token_id != None has_tok_task = 'tok' in [config['tasks'][task]['task_type'] for task in config['tasks']] + num_special_tokens = len(tokenizer.prepare_for_model([])['input_ids']) if has_tok_task: pre_tokenizer = BasicTokenizer(strip_accents=False, do_lower_case=False, tokenize_chinese_chars=True) @@ -401,7 +413,7 @@ def read_sequence( myutils.clean_text(line[word_col_idx]) for line in sent], pre_tokenizer, tokenizer) else: - token_ids, offsets = tokenize_simple(tokenizer, sent, word_col_idx) + token_ids, offsets = tokenize_simple(tokenizer, sent, word_col_idx, num_special_tokens, has_unk) no_unk_subwords = None token_ids = tokenizer.prepare_for_model(token_ids, return_tensors='pt')['input_ids'] @@ -445,13 +457,14 @@ def read_sequence( # Special handling for multiseq, as it required a different labelfield - # if task_type == 'multiseq': - # label_sequence = [] - # # For each token label, check if it is a multilabel and handle it - # for raw_label in labels: - # label_list = raw_label.split("|") - # label_sequence.append(label_list) - # instance.add_field(task, SequenceMultiLabelField(label_sequence, input_field, label_namespace=task)) + elif task_type == 'multiseq': + label_sequence = [] + for token_info in sent: + label_list = token_info[task_idx].split("|") + label_sequence.append([vocabulary.token2id(label, task, is_train) for label in label_list]) + max_labels = max([len(label) for label in label_sequence]) + padded_label_sequence = [labels + [vocabulary.UNK_ID] * (max_labels-len(labels)) for labels in label_sequence] + golds[task] = torch.tensor(padded_label_sequence, dtype=torch.long) else: golds[task] = torch.tensor( [vocabulary.token2id(token_info[task_idx], task, is_train) for token_info in sent], @@ -522,16 +535,13 @@ def read_sequence( # other tasks. no_mapping = False for task in golds: - if len(token_ids) - 2 < len(golds[task]): + if len(token_ids) - num_special_tokens < len(golds[task]): no_mapping = True if no_mapping: print('skip') # TODO - # print(len(golds['upos']), len(offsets)) - # print(offsets) - # print(token_ids) continue - - unk_counter += sum(token_ids == tokenizer.unk_token_id) + if has_unk: + unk_counter += sum(token_ids == tokenizer.unk_token_id) subword_counter += len(token_ids) - 2 word_counter += len(offsets) if max_words != -1 and word_counter > max_words and is_train: diff --git a/machamp/utils/myutils.py b/machamp/utils/myutils.py index f8ab37d..aec5837 100644 --- a/machamp/utils/myutils.py +++ b/machamp/utils/myutils.py @@ -110,11 +110,13 @@ def prep_batch( batch_offsets = None batch_eval_mask = None - # Instead of focusing on task_types, we use this dimension - # check to decide the dimensions. This should be more - # robust/easier to maintain. (Word-level tasks have lists - # of annotations) - has_word_level = True in [type(batch[0].golds[task]) in [torch.Tensor, torch.tensor, list] for task in batch[0].golds] + # Assuming here that batches are homogeneous, only checking + # the first element. + has_word_level = False + for task in batch[0].golds: + task_type = dataset.task_to_tasktype(task) + if task_type in ['seq', 'multiseq', 'seq_bio', 'tok', 'dependency', 'string2string', 'mlm']: + has_word_level = True if has_word_level: max_token_len = max([len(instance.offsets) for instance in batch]) @@ -123,12 +125,22 @@ def prep_batch( batch_subword_mask = torch.zeros((batch_size, max_subword_len), dtype=torch.bool, device=device) for task in batch[0].golds: + task_type = dataset.task_to_tasktype(task) + is_word_level = task_type in ['seq', 'multiseq', 'seq_bio', 'tok', 'dependency', 'string2string', 'mlm'] + if dataset.task2type[task.replace('-heads', '').replace('-rels', '')] == 'tok': golds[task] = torch.zeros((batch_size, max_subword_len - 2), dtype=torch.long, device=device) elif dataset.task2type[task.replace('-heads', '').replace('-rels', '')] == 'regression': golds[task] = torch.zeros(batch_size, dtype=torch.float, device=device) - elif type(batch[0].golds[task]) in [torch.Tensor, torch.tensor, list]: - golds[task] = torch.zeros((batch_size, max_token_len), dtype=torch.long, device=device) + elif is_word_level: + if len(batch[0].golds[task].shape) == 1: + golds[task] = torch.zeros((batch_size, max_token_len), dtype=torch.long, device=device) + else: # multiple annotations per token + num_labels = len(dataset.vocabulary.get_vocab(task)) + golds[task] = torch.zeros((batch_size, max_token_len, num_labels), dtype=torch.long, device=device) + elif task_type == 'multiclas': + num_labels = len(dataset.vocabulary.get_vocab(task)) + golds[task] = torch.zeros(batch_size, num_labels, dtype=torch.long, device=device) else: golds[task] = torch.zeros(batch_size, dtype=torch.long, device=device) @@ -136,16 +148,27 @@ def prep_batch( batch_tokens[instanceIdx][0:len(instance.token_ids)] = instance.token_ids batch_seg_ids[instanceIdx][0:len(instance.seg_ids)] = instance.seg_ids for task in instance.golds: - if type(batch[0].golds[task]) in [torch.Tensor, torch.tensor, list]: - golds[task][instanceIdx][0:len(instance.golds[task])] = instance.golds[task] + task_type = dataset.task_to_tasktype(task) + is_word_level = task_type in ['seq', 'multiseq', 'seq_bio', 'tok', 'dependency', 'string2string', 'mlm'] + + if is_word_level: + if len(batch[0].golds[task].shape) == 1: + golds[task][instanceIdx][0:len(instance.golds[task])] = instance.golds[task] + else: + for token_idx, token_labels in enumerate(instance.golds[task]): + for token_label in token_labels: + golds[task][instanceIdx][token_idx][token_label] = 1 + elif task_type == 'multiclas': + for sent_label in instance.golds[task]: + golds[task][instanceIdx][sent_label] = 1 else: golds[task][instanceIdx] = instance.golds[task] + if has_word_level and type(batch[0].offsets) != type(None): batch_offsets[instanceIdx][:len(instance.offsets)] = instance.offsets batch_eval_mask[instanceIdx][:len(instance.offsets)] = 1 batch_subword_mask[instanceIdx][:len(instance.token_ids)] = 1 - return {'token_ids': batch_tokens, 'seg_ids': batch_seg_ids, 'golds': golds, 'offsets': batch_offsets, 'eval_mask': batch_eval_mask, 'subword_mask': batch_subword_mask} @@ -157,7 +180,10 @@ def report_epoch( train_metrics: Dict[str, float], dev_metrics: Dict[str, float], epoch_start_time: datetime.datetime, - start_training_time: datetime.datetime): + start_training_time: datetime.datetime, + device: str, + train_loss_dict: Dict[str, float], + dev_loss_dict: Dict[str, float]): """ Reports a variety of interesting and less interesting metrics that can be tracked across epochs. These are both logged and returned. @@ -178,21 +204,35 @@ def report_epoch( The time this epoch started. start_training_time: datetime.datetime The time the training procedure started. + device: str + Used to decide whether to print GPU ram + train_loss_dict: Dict[str, float] + training losses + dev_loss_dict: Dict[str, float] + dev losses Returns ------- info: Dict[str, float] A dictionary containing all information that has just been logged """ - info = {'epoch': epoch, 'max_gpu_mem': torch.cuda.max_memory_allocated() * 1e-09} + info = {'epoch': epoch} + if 'cuda' in device: + info['max_gpu_mem'] = torch.cuda.max_memory_allocated() * 1e-09 _proc_status = '/proc/%d/status' % os.getpid() data = open(_proc_status).read() i = data.index('VmRSS:') info['cur_ram'] = int(data[i:].split(None, 3)[1]) * 1e-06 + + # Might be nice to turn into a table? + for task in train_loss_dict: + info['train_' + task + '_loss'] = train_loss_dict[task] info['train_batch_loss'] = epoch_loss for metric in train_metrics: info['train_' + metric] = train_metrics[metric] + for task in dev_loss_dict: + info['dev_' + task + '_loss'] = dev_loss_dict[task] info['dev_batch_loss'] = dev_loss for metric in dev_metrics: info['dev_' + metric] = dev_metrics[metric]