From 7d44bc3d71bd24eaabae6fb6838002b92f92c16e Mon Sep 17 00:00:00 2001 From: Mario Graff Date: Mon, 9 Sep 2024 03:38:35 +0000 Subject: [PATCH] SeqTM and DenseBoW --- dialectid/model.py | 68 ++++++++++++++++++++++++++++++- dialectid/tests/test_model.py | 18 ++++++++ dialectid/tests/test_text_repr.py | 23 +++++++++++ dialectid/text_repr.py | 48 +++++++++++++++++++--- dialectid/utils.py | 13 +++++- 5 files changed, 162 insertions(+), 8 deletions(-) diff --git a/dialectid/model.py b/dialectid/model.py index 774b4be..ba5684c 100644 --- a/dialectid/model.py +++ b/dialectid/model.py @@ -25,7 +25,7 @@ from dataclasses import dataclass import importlib import numpy as np -from dialectid.utils import BOW, load_dialectid +from dialectid.utils import BOW, load_dialectid, load_seqtm @dataclass class DialectId: @@ -87,3 +87,69 @@ def predict(self, D: List[Union[dict, list, str]]) -> np.ndarray: hy = self.decision_function(D) return self.countries[hy.argmax(axis=1)] + + +@dataclass +class DenseBoW: + """DenseBoW""" + + lang: str='es' + voc_size_exponent: int=13 + precision: int=32 + + def estimator(self, **kwargs): + """Estimator""" + + from sklearn.svm import LinearSVC + return LinearSVC(class_weight='balanced') + + @property + def bow(self): + """BoW""" + + try: + return self._bow + except AttributeError: + from dialectid.text_repr import SeqTM + self._bow = SeqTM(language=self.lang, + voc_size_exponent=self.voc_size_exponent) + return self._bow + + @property + def weights(self): + """Weights""" + try: + return self._weights + except AttributeError: + iterator = load_seqtm(self.lang, + self.voc_size_exponent, + self.precision) + precision = getattr(np, f'float{self.precision}') + weights = [] + names = [] + for data in iterator: + _ = np.frombuffer(bytes.fromhex(data['coef']), dtype=precision) + weights.append(_) + names.append(data['labels'][-1]) + self._weights = np.vstack(weights) + self._names = np.array(names) + return self._weights + + @property + def names(self): + """Vector space components""" + + return self._names + + def encode(self, text): + """Encode utterace into a matrix""" + + token2id = self.bow.token2id + seq = [] + for token in self.bow.tokenize(text): + try: + seq.append(token2id[token]) + except KeyError: + continue + W = self.weights + return np.vstack([W[:, x] for x in seq]).T diff --git a/dialectid/tests/test_model.py b/dialectid/tests/test_model.py index e2acb29..5f8a4e8 100644 --- a/dialectid/tests/test_model.py +++ b/dialectid/tests/test_model.py @@ -74,3 +74,21 @@ def test_DialectId_subwords(): dialectid = DialectId(voc_size_exponent=15) countries = dialectid.predict('comiendo tacos') assert countries[0] == 'mx' + + +def test_DenseBoW(): + """Test DenseBoW based on SeqTM""" + + from dialectid.model import DenseBoW + + dense = DenseBoW(precision=16) + assert dense.weights.shape[0] == dense.names.shape[0] + dense.weights[0, 0] > 25 + + +def test_DenseBoW_encode(): + """Test DenseBoW sentence repr""" + + from dialectid.model import DenseBoW + dense = DenseBoW(precision=16) + assert dense.encode('buenos días').shape[1] == 2 \ No newline at end of file diff --git a/dialectid/tests/test_text_repr.py b/dialectid/tests/test_text_repr.py index 55cbe6b..6f2bfe6 100644 --- a/dialectid/tests/test_text_repr.py +++ b/dialectid/tests/test_text_repr.py @@ -83,3 +83,26 @@ def test_SeqTM_seq(): res1 = seq.tokenize('mira pinche a') res2 = seq.tokenize('a pinche a') assert res1[1:] == res2[1:] + + +def test_SeqTM_seq_bug(): + """Test SeqTM seq option""" + + seq = SeqTM(language='es', sequence=True, + voc_selection='most_common', + voc_size_exponent=13) + assert seq.del_dup == False + + +def test_SeqTM_names(): + seq = SeqTM(language='es', sequence=True, + voc_selection='most_common', + voc_size_exponent=13) + assert len(seq.names) == len(seq.model.word2id) + + +def test_SeqTM_weights(): + seq = SeqTM(language='es', sequence=True, + voc_selection='most_common', + voc_size_exponent=13) + assert len(seq.weights) == len(seq.names) \ No newline at end of file diff --git a/dialectid/text_repr.py b/dialectid/text_repr.py index 1b4631e..c6174c8 100644 --- a/dialectid/text_repr.py +++ b/dialectid/text_repr.py @@ -28,6 +28,7 @@ from microtc import emoticons from microtc.utils import tweet_iterator from dialectid.utils import load_bow +import numpy as np class BoW(EvoMSABoW): @@ -103,12 +104,9 @@ class SeqTM(TextModel): def __init__(self, language='es', voc_size_exponent: int=17, - voc_selection: str='most_common_by_type', - loc: str=None, - subwords: bool=True, - sequence: bool=True, - lang=None, - **kwargs): + voc_selection: str='most_common', + loc: str=None, subwords: bool=True, + sequence: bool=True, lang=None, **kwargs): assert lang is None if sequence and subwords: loc = 'seq' @@ -129,6 +127,7 @@ def __init__(self, language='es', self.voc_selection = voc_selection self.loc = loc self.subwords = subwords + self.sequence = sequence self.__vocabulary(counter) def __vocabulary(self, counter): @@ -182,6 +181,7 @@ def voc_selection(self, value): @property def voc_size_exponent(self): """Vocabulary size :math:`2^v`; where :math:`v` is :py:attr:`voc_size_exponent` """ + return self._voc_size_exponent @voc_size_exponent.setter @@ -208,6 +208,42 @@ def subwords(self): def subwords(self, value): self._subwords = value + @property + def sequence(self): + """Vocabulary compute on sequence text-transformation""" + + return self._sequence + + @sequence.setter + def sequence(self, value): + self._sequence = value + + @property + def names(self): + """Vector space components""" + + try: + return self._names + except AttributeError: + _names = [None] * len(self.id2token) + for k, v in self.id2token.items(): + _names[k] = v + self._names = np.array(_names) + return self._names + + @property + def weights(self): + """Vector space weights""" + + try: + return self._weights + except AttributeError: + w = [None] * len(self.token_weight) + for k, v in self.token_weight.items(): + w[k] = v + self._weights = np.array(w) + return self._weights + @property def tokens(self): """Tokens""" diff --git a/dialectid/utils.py b/dialectid/utils.py index d8a9013..6d54f3f 100644 --- a/dialectid/utils.py +++ b/dialectid/utils.py @@ -160,4 +160,15 @@ def load_dialectid(lang, dim, subwords=False): if not isfile(output): Download(f'{BASEURL}/{filename}', output) _ = [Linear(**params) for params in tweet_iterator(output)] - return _ \ No newline at end of file + return _ + + +def load_seqtm(lang, dim, precision): + diroutput = join(dirname(__file__), 'models') + if not isdir(diroutput): + os.mkdir(diroutput) + filename = f'seqtm_{lang}_{precision}_{dim}.json.gz' + output = join(diroutput, filename) + if not isfile(output): + Download(f'{BASEURL}/{filename}', output) + return tweet_iterator(output) \ No newline at end of file