Skip to content

Commit

Permalink
3.0 (#135)
Browse files Browse the repository at this point in the history
* Corrected permissions

* Bugfix

* Added GPU support at runtime

* Wrong config package

* Refactoring

* refactoring

* add lightning to dependencies

* Dummy test

* Dummy test

* Tweak

* Tweak

* Update test

* Test

* Finished loading for UD CONLL-U format

* Working on tagger

* Work on tagger

* tagger training

* tagger training

* tagger training

* Sync

* Sync

* Sync

* Sync

* Tagger working

* Better weight for aux loss

* Better weight for aux loss

* Added save and printing for tagger and shared options class

* Multilanguage evaluation

* Saving multiple models

* Updated ignore list

* Added XLM-Roberta support

* Using custom ro model

* Score update

* Bugfixing

* Code refactor

* Refactor

* Added option to load external config

* Added option to select LM-model from CLI or config

* added option to overwrite config lm from CLI

* Bugfix

* Working on parser

* Sync work on parser

* Parser working

* Removed load limit

* Bugfix in evaluation

* Added bi-affine attention

* Added experimental ChuLiuEdmonds tree decoding

* Better config for parser and bugfix

* Added residuals to tagging

* Model update

* Switched to AdamW optimizer

* Working on tokenizer

* Working on tokenizer

* Training working - validation to do

* Bugfix in language id

* Working on tokenization validation

* Tokenizer working

* YAML update

* Bug in LMHelper

* Tagger is working

* Tokenizer is working

* bfix

* bfix

* Bugfix for bugfix :)

* Sync

* Tokenizer worker

* Tagger working

* Trainer updates

* Trainer process now working

* Added .DS_Store

* Added datasets for Compound Word Expander and Lemmatizer

* Added collate function for lemma+compound

* Added training and validation step

* Updated config for Lemmatizer

* Minor fixes

* Removed duplicate entries from lemma and cwe

* Added training support for lemmatizer

* Removed debug directives

* Lemmatizer in testing phase

* removed unused line

* Bugfix in Lemma dataset

* Corrected validation issue with gs labels being sent to the forward method and removed loss computation during testing

* Lemmatizier training done

* Compound word expander ready

* Sync

* Added support for FastText, Transformers and Languasito LM models

* Added multi-lm support for tokenizer

* Added support for multiword tokens

* Sync

* Bugfix in evaluation

* Added Languasito as a subpackage

* Added path to local Languasito

* Bugfixing all around

* Removed debug printing

* Bugfix for no-space languages that actually contain spaces :)

* Bugfix for no-space languages that actually contain spaces :)

* Fixed GPU support

* Biaffine transform for LAS and relative head location (RHL) for UAS

* Bugfix

* Tweaks

* moved rhl to lower layer

* Added configurable option for RHL

* Safenet for spaces in languages that should use no spaces

* Better defaults

* Sync

* Cleanup parser

* Bilinear xpos and attrs

* Added Biaffine module from Stanza

* Tagger with reduced number of parameters:

* Parser with conditional attrs

* Working on tokenizer runtime

* Tokenizer process 90% done

* Added runtime for parser, tokenizer and tagger

* Added quick test for runtime

* Test for e2e

* Added support for multiple word embeddings at the same time

* Bugfix

* Added multiple word representations for tokenizer

* moved mask_concat to utils.py

* Added XPOS prediction to pipeline

* Bugfix in tokenizer shifted word embeddings

* Using Languasito tokenizer for HF tokenization

* Bugfix

* Bugfixing

* Bugfixing

* Bugfix

* Runtime fixing

* Sync

* Added spa for FT and Languasito

* Added spa for FT and Languasito

* Minor tweaks

* Added configuration for RNN layers

* Bugfix for spa

* HF runtime fix

* Mixed test fasttext+transformer

* Added word reconstruction and MHA

* Sync

* Bugfix

* bugfix

* Added masked attention

* Sync

* Added test for runtime

* Bugfix in mask values

* Updated test

* Added full mask dropout

* Added resume option

* Removed useless printouts

* Removed useless printouts

* Switched to eval at runtime

* multiprocessing added

* Added full mask dropout for word decoder

* Bugfix

* Residual

* Added lexical-contextual cosine loss

* Removed full mask dropout from WordDecoder

* Bugfix

* Training script generation update

* Added residual

* Updated languasito to pickle tokenized lines

* Updated languasito to pickle tokenized lines

* Updated languasito to pickle tokenized lines

* Not training for seq len > max_seq_len

* Added seq limmits for collates

* Passing seq limits from collate to tokenizer

* Skipping complex parsing

* Working on word decomposer

* Model update

* Sync

* Bugfix

* Bugfix

* Bugfix

* Using all reprs

* Dropped immediate context

* Multi train script added

* Changed gpu parameter type to string, for multiple gpus int failed

* Updated pytorch_lightning callback method to work with newer version

* Updated pytorch_lightning callback method to work with newer version

* Transparently pass PL args from the command line; skip over empty compound word datasets

* Fix typo

* Refactoring and on the way to working API

* API load working

* Partial _call_ working

* Partial _call_ working

* Added partly working api and refactored everything back to cube/. Compound not working yet and tokenizer needs retraining.

* api is working

* Fixing api

* Updated readme

* Update Readme to include flavours

* Device support

* api update

* Updated package

* Tweak + results

* Clarification

* Test update

* Update

* Sync

* Update README

* Bugfixing

* Bugfix and api update

* Fixed compound

* Evaluation update

* Bugfix

* Package update

* Bugfix for large sentences

* Pip package update

* Corrected spanish evaluation

* Package version update

* Fixed tokenization issues on transformers

* Removed pinned memory

* Bugfix for GPU tensors

* Update package version

* Automatically detecting hidden state size

* Automatically detecting hidden state size

* Automatically detecting hidden state size

* Sync

* Evaluation update

* Package update

* Bugfix

* Bugfixing

* Package version update

* Bugfix

* Package version update

* Update evaluation for Italian

* tentative support torchtext>=0.9.0 (#127)

as mentioned in Lightning-AI/pytorch-lightning#6211 and #100

* Update package dependencies

* Dummy word embeddings

* Update params

* Better dropout values

* Skipping long words

* Skipping long words

* dummy we -> float

* Added gradient clipping

* Update tokenizer

* Update tokenizer

* Sync

* DCWE

* Working on DCWE

---------

Co-authored-by: Stefan Dumitrescu <sdumitre@adobe.com>
Co-authored-by: Tiberiu Boros <boros@adobe.com>
Co-authored-by: Koichi Yasuoka <yasuoka@kanji.zinbun.kyoto-u.ac.jp>
  • Loading branch information
4 people committed Feb 17, 2023
1 parent 0bb4fa2 commit cc0c34c
Show file tree
Hide file tree
Showing 14 changed files with 293 additions and 31 deletions.
5 changes: 1 addition & 4 deletions cube/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ def __call__(self, text: Union[str, Document], flavour: Optional[str] = None):
self._lm_helper.apply(doc)
self._parser.process(doc, self._parser_collate, num_workers=0)
self._lemmatizer.process(doc, self._lemmatizer_collate, num_workers=0)
for seq in doc.sentences:
for w in seq.words:
if w.upos =='PUNCT':
w.lemma = w.word

return doc


Expand Down
30 changes: 27 additions & 3 deletions cube/io_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def __init__(self, filename=None, verbose=False):
self.cnn_filter = 512
self.lang_emb_size = 100
self.cnn_layers = 5
self.external_proj_size = 300
self.rnn_size = 50
self.rnn_layers = 2
self.external_proj_size = 2

self.no_space_lang = False

if filename is None:
Expand Down Expand Up @@ -139,9 +142,10 @@ def __init__(self, filename=None, verbose=False):
self.head_size = 100
self.label_size = 200
self.lm_model = 'xlm-roberta-base'
self.external_proj_size = 300
self.external_proj_size = 2
self.rhl_win_size = 2
self.rnn_size = 50
self.rnn_size = 200

self.rnn_layers = 3

self._valid = True
Expand Down Expand Up @@ -275,6 +279,26 @@ def __init__(self, filename=None, verbose=False):
self.load(filename)


class DCWEConfig(Config):
def __init__(self, filename=None, verbose=False):
super().__init__()
self.char_emb_size = 256
self.case_emb_size = 32
self.num_filters = 512
self.kernel_size = 5
self.lang_emb_size = 32
self.num_layers = 8
self.output_size = 300 # this will be automatically updated at training time, so do not change

if filename is None:
if verbose:
sys.stdout.write("No configuration file supplied. Using default values.\n")
else:
if verbose:
sys.stdout.write("Reading configuration file " + filename + " \n")
self.load(filename)


class GDBConfig(Config):
def __init__(self, filename=None, verbose=False):
super().__init__()
Expand Down
83 changes: 83 additions & 0 deletions cube/networks/dcwe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
from typing import *
import sys

sys.path.append('')
from cube.networks.modules import WordGram, LinearNorm
from cube.io_utils.encodings import Encodings
from cube.io_utils.config import DCWEConfig


class DCWE(pl.LightningModule):
encodings: Encodings
config: DCWEConfig

def __init__(self, config: DCWEConfig, encodings: Encodings):
super(DCWE, self).__init__()
self._config = config
self._encodings = encodings
self._wg = WordGram(num_chars=len(encodings.char2int),
num_langs=encodings.num_langs,
num_layers=config.num_layers,
num_filters=config.num_filters,
char_emb_size=config.lang_emb_size,
case_emb_size=config.case_emb_size,
lang_emb_size=config.lang_emb_size
)
self._output_proj = LinearNorm(config.num_filters // 2, config.output_size, w_init_gain='linear')
self._improve = 0
self._best_loss = 9999

def forward(self, x_char, x_case, x_lang, x_mask, x_word_len):
pre_proj = self._wg(x_char, x_case, x_lang, x_mask, x_word_len)
proj = self._output_proj(pre_proj)
return proj

def _get_device(self):
if self._output_proj.linear_layer.weight.device.type == 'cpu':
return 'cpu'
return '{0}:{1}'.format(self._output_proj.linear_layer.weight.device.type,
str(self._output_proj.linear_layer.weight.device.index))

def configure_optimizers(self):
return torch.optim.AdamW(self.parameters())

def training_step(self, batch, batch_idx):
x_char = batch['x_char']
x_case = batch['x_case']
x_lang = batch['x_lang']
x_word_len = batch['x_word_len']
x_mask = batch['x_mask']
y_target = batch['y_target']
y_pred = self.forward(x_char, x_case, x_lang, x_mask, x_word_len)
loss = torch.mean((y_pred - y_target) ** 2)
return loss

def validation_step(self, batch, batch_idx):
x_char = batch['x_char']
x_case = batch['x_case']
x_lang = batch['x_lang']
x_word_len = batch['x_word_len']
x_mask = batch['x_mask']
y_target = batch['y_target']
y_pred = self.forward(x_char, x_case, x_lang, x_mask, x_word_len)
loss = torch.mean((y_pred - y_target) ** 2)
return {'loss': loss.detach().cpu().numpy()[0]}

def validation_epoch_end(self, outputs: List[Any]) -> None:
mean_loss = sum([output['loss'] for output in outputs])
mean_loss /= len(outputs)
self.log('val/loss', mean_loss)
self.log('val/early_meta', self._improve)

def save(self, path):
torch.save(self.state_dict(), path)

def load(self, model_path: str, device: str = 'cpu'):
self.load_state_dict(torch.load(model_path, map_location='cpu')['state_dict'])
self.to(device)



21 changes: 21 additions & 0 deletions cube/networks/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,27 @@ def apply_raw(self, batch):
pass


class LMHelperDummy(LMHelper):
def __init__(self, device: str = 'cpu', model: str = None):
pass

def get_embedding_size(self):
return [1]

def apply(self, document: Document):
for ii in tqdm.tqdm(range(len(document.sentences)), desc="Pre-computing embeddings", unit="sent"):
for jj in range(len(document.sentences[ii].words)):
document.sentences[ii].words[jj].emb = [[1.0]]

def apply_raw(self, batch):
embeddings = []
for ii in range(len(batch)):
c_emb = []
for jj in range(len(batch[ii])):
c_emb.append([1.0])
embeddings.append(c_emb)
return embeddings

if __name__ == "__main__":
from ipdb import set_trace

Expand Down
7 changes: 4 additions & 3 deletions cube/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,10 @@ def __init__(self, num_chars: int, num_langs: int, num_filters=512, char_emb_siz
super(WordGram, self).__init__()
NUM_FILTERS = num_filters
self._num_filters = NUM_FILTERS
self._lang_emb = nn.Embedding(num_langs + 1, lang_emb_size)
self._tok_emb = nn.Embedding(num_chars + 1, char_emb_size)
self._case_emb = nn.Embedding(4, case_emb_size)
self._lang_emb = nn.Embedding(num_langs + 1, lang_emb_size, padding_idx=0)
self._tok_emb = nn.Embedding(num_chars + 3, char_emb_size, padding_idx=0)
self._case_emb = nn.Embedding(4, case_emb_size, padding_idx=0)

self._num_layers = num_layers
convolutions_char = []
cs_inp = char_emb_size + lang_emb_size + case_emb_size
Expand Down
22 changes: 14 additions & 8 deletions cube/networks/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(self, config: ParserConfig, encodings: Encodings, language_codes: [
self._upos_emb = nn.Embedding(len(encodings.upos2int), 64)

self._rnn = nn.LSTM(NUM_FILTERS // 2 + config.lang_emb_size + config.external_proj_size, config.rnn_size,
num_layers=config.rnn_layers, batch_first=True, bidirectional=True, dropout=0.33)
num_layers=config.rnn_layers, batch_first=True, bidirectional=True, dropout=0.1)


self._pre_out = LinearNorm(config.rnn_size * 2 + config.lang_emb_size, config.pre_parser_size)
# self._head_r1 = LinearNorm(config.pre_parser_size, config.head_size)
Expand Down Expand Up @@ -137,9 +138,10 @@ def forward(self, X):
for ii in range(len(x_word_emb_packed)):
we = unpack(x_word_emb_packed[ii], sl, x_sents.shape[1], self._get_device())
if word_emb_ext is None:
word_emb_ext = self._ext_proj[ii](we.float())
word_emb_ext = self._ext_proj[ii](we)
else:
word_emb_ext = word_emb_ext + self._ext_proj[ii](we.float())
word_emb_ext = word_emb_ext + self._ext_proj[ii](we)


word_emb_ext = word_emb_ext / len(x_word_emb_packed)
word_emb_ext = torch.tanh(word_emb_ext)
Expand All @@ -153,7 +155,8 @@ def forward(self, X):

word_emb = self._word_emb(x_sents)

x = mask_concat([word_emb, char_emb, word_emb_ext], 0.33, self.training, self._get_device())
x = mask_concat([word_emb, char_emb, word_emb_ext], 0.1, self.training, self._get_device())


x = torch.cat([x, lang_emb[:, 1:, :]], dim=-1)
# prepend root
Expand All @@ -172,7 +175,8 @@ def forward(self, X):
res = tmp
else:
res = res + tmp
x = torch.dropout(tmp, 0.2, self.training)
x = torch.dropout(tmp, 0.1, self.training)

cnt += 1
if cnt == self._config.aux_softmax_location:
hidden = torch.cat([x + res, lang_emb], dim=1)
Expand All @@ -184,7 +188,8 @@ def forward(self, X):
# aux tagging
lang_emb = lang_emb.permute(0, 2, 1)
hidden = hidden.permute(0, 2, 1)[:, 1:, :]
pre_morpho = torch.dropout(torch.tanh(self._pre_morpho(hidden)), 0.33, self.training)
pre_morpho = torch.dropout(torch.tanh(self._pre_morpho(hidden)), 0.1, self.training)

pre_morpho = torch.cat([pre_morpho, lang_emb[:, 1:, :]], dim=2)
upos = self._upos(pre_morpho)
if gs_upos is None:
Expand All @@ -200,11 +205,12 @@ def forward(self, X):
word_emb_ext = torch.cat(
[torch.zeros((word_emb_ext.shape[0], 1, self._config.external_proj_size), device=self._get_device(),
dtype=torch.float), word_emb_ext], dim=1)
x = mask_concat([x_parse, word_emb_ext], 0.33, self.training, self._get_device())
x = torch.cat([x_parse, word_emb_ext], dim=-1) #mask_concat([x_parse, word_emb_ext], 0.1, self.training, self._get_device())
x = torch.cat([x, lang_emb], dim=-1)
output, _ = self._rnn(x)
output = torch.cat([output, lang_emb], dim=-1)
pre_parsing = torch.dropout(torch.tanh(self._pre_out(output)), 0.33, self.training)
pre_parsing = torch.dropout(torch.tanh(self._pre_out(output)), 0.1, self.training)

# h_r1 = torch.tanh(self._head_r1(pre_parsing))
# h_r2 = torch.tanh(self._head_r2(pre_parsing))
# l_r1 = torch.tanh(self._label_r1(pre_parsing))
Expand Down
7 changes: 6 additions & 1 deletion cube/networks/tagger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sys

sys.path.append('')
import os, yaml


os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pytorch_lightning as pl
import torch.nn as nn
Expand All @@ -14,6 +17,7 @@
from cube.networks.utils import MorphoCollate, MorphoDataset, unpack, mask_concat
from cube.networks.modules import WordGram


class Tagger(pl.LightningModule):
def __init__(self, config: TaggerConfig, encodings: Encodings, language_codes: [] = None, ext_word_emb=0):
super().__init__()
Expand Down Expand Up @@ -276,7 +280,8 @@ def validation_epoch_end(self, outputs):
# print("\n\n\n", upos_ok / total, xpos_ok / total, attrs_ok / total,
# aupos_ok / total, axpos_ok / total, aattrs_ok / total, "\n\n\n")

def load(self, model_path:str, device: str = 'cpu'):
def load(self, model_path: str, device: str = 'cpu'):

self.load_state_dict(torch.load(model_path, map_location='cpu')['state_dict'])
self.to(device)

Expand Down
27 changes: 22 additions & 5 deletions cube/networks/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(self, config: TokenizerConfig, encodings: Encodings, language_codes
conv_layer = nn.Sequential(
ConvNorm(cs_inp,
NUM_FILTERS,
kernel_size=5, stride=1,
padding=2,
kernel_size=3, stride=1,
padding=1,
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(NUM_FILTERS))
conv_layers.append(conv_layer)
Expand All @@ -49,7 +49,13 @@ def __init__(self, config: TokenizerConfig, encodings: Encodings, language_codes
self._wg = WordGram(len(encodings.char2int), num_langs=encodings.num_langs)
self._lang_emb = nn.Embedding(encodings.num_langs + 1, config.lang_emb_size, padding_idx=0)
self._spa_emb = nn.Embedding(3, 16, padding_idx=0)
self._output = LinearNorm(NUM_FILTERS // 2 + config.lang_emb_size, 5)
self._rnn = nn.LSTM(NUM_FILTERS // 2 + config.lang_emb_size,
config.rnn_size,
num_layers=config.rnn_layers,
bidirectional=True,
batch_first=True)
self._output = LinearNorm(config.rnn_size * 2, 5)


ext2int = []
for input_size in self._ext_word_emb:
Expand Down Expand Up @@ -103,20 +109,29 @@ def forward(self, batch):
half = self._config.cnn_filter // 2
res = None
cnt = 0

skip = None
for conv in self._convs:
conv_out = conv(x)
tmp = torch.tanh(conv_out[:, :half, :]) * torch.sigmoid((conv_out[:, half:, :]))
if res is None:
res = tmp
else:
res = res + tmp
x = torch.dropout(tmp, 0.2, self.training)
x = torch.dropout(tmp, 0.1, self.training)
cnt += 1
if cnt != self._config.cnn_layers:
if skip is not None:
x = x + skip
skip = x

x = torch.cat([x, x_lang], dim=1)
x = x + res
x = torch.cat([x, x_lang], dim=1)
x = x.permute(0, 2, 1)

x, _ = self._rnn(x)

return self._output(x)

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -297,7 +312,9 @@ def process(self, raw_text, collate: TokenCollate, batch_size=32, num_workers: i
return d

def configure_optimizers(self):
return torch.optim.AdamW(self.parameters())
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4)
return optimizer


def _compute_early_stop(self, res):
for lang in res:
Expand Down
2 changes: 2 additions & 0 deletions cube/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self, document: Document, for_training=True):
word = w.word
lemma = w.lemma
upos = w.upos
if len(word) > 25:
continue

key = (word, lang_id, upos)
if key not in lookup or for_training is False:
Expand Down
Loading

0 comments on commit cc0c34c

Please sign in to comment.