Skip to content

Commit

Permalink
Added multiseq and multiclas task-types and several bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
robvanderg committed Sep 22, 2022
1 parent 51165e8 commit 9874a41
Show file tree
Hide file tree
Showing 22 changed files with 458 additions and 120 deletions.
24 changes: 23 additions & 1 deletion machamp/data/machamp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self,
num_s2s = 0
for task in self.datasets[dataset]['tasks']:
task_config = self.datasets[dataset]['tasks'][task]
is_clas = task_config['task_type'] in ['classification', 'probdistr', 'regression']
is_clas = task_config['task_type'] in ['classification', 'probdistr', 'regression', 'multiclas']
read_seq = task_config['column_idx'] == -1 if 'column_idx' in task_config else None

if is_clas and not read_seq:
Expand Down Expand Up @@ -133,6 +133,28 @@ def __init__(self,
is_train, max_sents, max_words, max_input_length):
self.data[dataset].append(instance)

def task_to_tasktype(self, task: str):
"""
Converts a task-name (str) to its type (str)
Parameters
----------
task: str
The name of the task
Returns
-------
task_type: str
The task type of the given task
"""
task_trimmed = task.replace('-heads', '').replace('-rels', '')
if task_trimmed in self.tasks:
index = self.tasks.index(task_trimmed)
else:
logger.error(task + ' not found in ' + str(self.tasks))
exit(1)
return self.task_types[index]

def __len__(self):
"""
The length is defined as the combined number of batches
Expand Down
18 changes: 8 additions & 10 deletions machamp/data/machamp_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

import os

UNK_ID = 0
UNK = '@@unkORpad@@'


class MachampVocabulary():
def __init__(self):
"""
Expand All @@ -24,6 +20,8 @@ def __init__(self):
self.namespaces = {}
self.inverse_namespaces = {}
self.hasUnk = {}
self.UNK_ID = 0
self.UNK = '@@unkORpad@@'

def load_vocab(self, vocab_path: str, name: str):
"""
Expand Down Expand Up @@ -70,7 +68,7 @@ def get_unk(self, name: str):
name in the namespace.
"""
if self.hasUnk[name]:
return UNK
return self.UNK

def get_unk_id(self, name: str):
"""
Expand All @@ -82,7 +80,7 @@ def get_unk_id(self, name: str):
name in the namespace.
"""
if self.hasUnk[name]:
return UNK_ID
return self.UNK_ID

def get_vocab(self, name: str):
"""
Expand Down Expand Up @@ -126,9 +124,9 @@ def token2id(self, token: str, namespace: str, add_if_not_present: bool):
self.inverse_namespaces[namespace].append(token)
return len(self.inverse_namespaces[namespace]) - 1
else:
return UNK_ID if self.hasUnk[namespace] else None
return self.UNK_ID if self.hasUnk[namespace] else None
if self.hasUnk[namespace]:
return self.namespaces[namespace].get(token, UNK_ID)
return self.namespaces[namespace].get(token, self.UNK_ID)
else:
return self.namespaces[namespace].get(token, None)

Expand Down Expand Up @@ -162,8 +160,8 @@ def create_vocab(self, name: str, has_unk: bool):
Whether this vocabulary should have an unknown/padding token.
"""
if name not in self.namespaces:
self.namespaces[name] = {UNK: UNK_ID} if has_unk else {}
self.inverse_namespaces[name] = [UNK] if has_unk else []
self.namespaces[name] = {self.UNK: self.UNK_ID} if has_unk else {}
self.inverse_namespaces[name] = [self.UNK] if has_unk else []
self.hasUnk[name] = has_unk

def save_vocabs(self, out_dir: str):
Expand Down
3 changes: 3 additions & 0 deletions machamp/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from machamp.metrics.multi_accuracy import MultiAccuracy
from machamp.metrics.accuracy import Accuracy
from machamp.metrics.avg_dist import AvgDist
from machamp.metrics.f1 import F1
Expand All @@ -25,6 +26,8 @@ def __init__(self, metric_name: str):
self.metrics = {}
if metric_name == 'accuracy':
self.metrics[metric_name] = Accuracy()
elif metric_name == 'multi_acc':
self.metrics[metric_name] = MultiAccuracy()
elif metric_name == 'las':
self.metrics[metric_name] = LAS()
elif metric_name == 'avg_dist':
Expand Down
32 changes: 32 additions & 0 deletions machamp/metrics/multi_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch


class MultiAccuracy():
def __init__(self):
self.cor = 0
self.total = 0
self.str = 'multi-acc.'

def score(self, preds, golds, mask, vocabulary):
# TODO: can this be done more efficient?
if len(preds.shape) == 3:
for sent_idx in range(len(mask)):
for word_idx in range(len(mask[sent_idx])):
if mask[sent_idx][word_idx]:
if torch.all(preds[sent_idx][word_idx] == golds[sent_idx][word_idx]):
self.cor += 1
self.total += 1
if len(preds.shape) == 2:
for sent_idx in range(len(preds)):
if torch.all(preds[sent_idx] == golds[sent_idx]):
self.cor += 1
self.total += 1

def reset(self):
self.cor = 0
self.total = 0

def get_score(self):
if self.total == 0:
return self.str, 0.0
return self.str, self.cor / self.total
2 changes: 1 addition & 1 deletion machamp/model/classification_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class MachampClassificationDecoder(MachampDecoder, torch.nn.Module):
def __init__(self, task, vocabulary, input_dim, device, loss_weight: float = 1.0, topn: int = 1,
metric: str = 'accuracy', **kwargs):
super().__init__(task, vocabulary, loss_weight, metric)
super().__init__(task, vocabulary, loss_weight, metric, device)

nlabels = len(self.vocabulary.get_vocab(task))
self.hidden_to_label = torch.nn.Linear(input_dim, nlabels)
Expand Down
2 changes: 1 addition & 1 deletion machamp/model/crf_label_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
topn: int = 1,
**kwargs
) -> None:
super().__init__(task, vocabulary, loss_weight, metric)
super().__init__(task, vocabulary, loss_weight, metric, device)

nlabels = len(self.vocabulary.get_vocab(task))
self.input_dim = input_dim # + dec_dataset_embeds_dim
Expand Down
3 changes: 1 addition & 2 deletions machamp/model/dependency_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def __init__(
arc_representation_dim: int = 768,
**kwargs,
) -> None:
super().__init__(task, vocabulary, loss_weight, metric)
super().__init__(task, vocabulary, loss_weight, metric, device)

self.device = device
self.input_dim = input_dim # + dec_dataset_embeds_dim
arc_representation_dim = arc_representation_dim # + dec_dataset_embeds_dim

Expand Down
Loading

0 comments on commit 9874a41

Please sign in to comment.