Skip to content

Commit

Permalink
SeqTM and DenseBoW
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Sep 9, 2024
1 parent de656a0 commit 7d44bc3
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 8 deletions.
68 changes: 67 additions & 1 deletion dialectid/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dataclasses import dataclass
import importlib
import numpy as np
from dialectid.utils import BOW, load_dialectid
from dialectid.utils import BOW, load_dialectid, load_seqtm

@dataclass
class DialectId:
Expand Down Expand Up @@ -87,3 +87,69 @@ def predict(self, D: List[Union[dict, list, str]]) -> np.ndarray:

hy = self.decision_function(D)
return self.countries[hy.argmax(axis=1)]


@dataclass
class DenseBoW:
"""DenseBoW"""

lang: str='es'
voc_size_exponent: int=13
precision: int=32

def estimator(self, **kwargs):
"""Estimator"""

from sklearn.svm import LinearSVC
return LinearSVC(class_weight='balanced')

@property
def bow(self):
"""BoW"""

try:
return self._bow
except AttributeError:
from dialectid.text_repr import SeqTM
self._bow = SeqTM(language=self.lang,
voc_size_exponent=self.voc_size_exponent)
return self._bow

@property
def weights(self):
"""Weights"""
try:
return self._weights
except AttributeError:
iterator = load_seqtm(self.lang,
self.voc_size_exponent,
self.precision)
precision = getattr(np, f'float{self.precision}')
weights = []
names = []
for data in iterator:
_ = np.frombuffer(bytes.fromhex(data['coef']), dtype=precision)
weights.append(_)
names.append(data['labels'][-1])
self._weights = np.vstack(weights)
self._names = np.array(names)
return self._weights

@property
def names(self):
"""Vector space components"""

return self._names

def encode(self, text):
"""Encode utterace into a matrix"""

token2id = self.bow.token2id
seq = []
for token in self.bow.tokenize(text):
try:
seq.append(token2id[token])
except KeyError:
continue
W = self.weights
return np.vstack([W[:, x] for x in seq]).T
18 changes: 18 additions & 0 deletions dialectid/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,21 @@ def test_DialectId_subwords():
dialectid = DialectId(voc_size_exponent=15)
countries = dialectid.predict('comiendo tacos')
assert countries[0] == 'mx'


def test_DenseBoW():
"""Test DenseBoW based on SeqTM"""

from dialectid.model import DenseBoW

dense = DenseBoW(precision=16)
assert dense.weights.shape[0] == dense.names.shape[0]
dense.weights[0, 0] > 25


def test_DenseBoW_encode():
"""Test DenseBoW sentence repr"""

from dialectid.model import DenseBoW
dense = DenseBoW(precision=16)
assert dense.encode('buenos días').shape[1] == 2
23 changes: 23 additions & 0 deletions dialectid/tests/test_text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,26 @@ def test_SeqTM_seq():
res1 = seq.tokenize('mira pinche a')
res2 = seq.tokenize('a pinche a')
assert res1[1:] == res2[1:]


def test_SeqTM_seq_bug():
"""Test SeqTM seq option"""

seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
assert seq.del_dup == False


def test_SeqTM_names():
seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
assert len(seq.names) == len(seq.model.word2id)


def test_SeqTM_weights():
seq = SeqTM(language='es', sequence=True,
voc_selection='most_common',
voc_size_exponent=13)
assert len(seq.weights) == len(seq.names)
48 changes: 42 additions & 6 deletions dialectid/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from microtc import emoticons
from microtc.utils import tweet_iterator
from dialectid.utils import load_bow
import numpy as np


class BoW(EvoMSABoW):
Expand Down Expand Up @@ -103,12 +104,9 @@ class SeqTM(TextModel):

def __init__(self, language='es',
voc_size_exponent: int=17,
voc_selection: str='most_common_by_type',
loc: str=None,
subwords: bool=True,
sequence: bool=True,
lang=None,
**kwargs):
voc_selection: str='most_common',
loc: str=None, subwords: bool=True,
sequence: bool=True, lang=None, **kwargs):
assert lang is None
if sequence and subwords:
loc = 'seq'
Expand All @@ -129,6 +127,7 @@ def __init__(self, language='es',
self.voc_selection = voc_selection
self.loc = loc
self.subwords = subwords
self.sequence = sequence
self.__vocabulary(counter)

def __vocabulary(self, counter):
Expand Down Expand Up @@ -182,6 +181,7 @@ def voc_selection(self, value):
@property
def voc_size_exponent(self):
"""Vocabulary size :math:`2^v`; where :math:`v` is :py:attr:`voc_size_exponent` """

return self._voc_size_exponent

@voc_size_exponent.setter
Expand All @@ -208,6 +208,42 @@ def subwords(self):
def subwords(self, value):
self._subwords = value

@property
def sequence(self):
"""Vocabulary compute on sequence text-transformation"""

return self._sequence

@sequence.setter
def sequence(self, value):
self._sequence = value

@property
def names(self):
"""Vector space components"""

try:
return self._names
except AttributeError:
_names = [None] * len(self.id2token)
for k, v in self.id2token.items():
_names[k] = v
self._names = np.array(_names)
return self._names

@property
def weights(self):
"""Vector space weights"""

try:
return self._weights
except AttributeError:
w = [None] * len(self.token_weight)
for k, v in self.token_weight.items():
w[k] = v
self._weights = np.array(w)
return self._weights

@property
def tokens(self):
"""Tokens"""
Expand Down
13 changes: 12 additions & 1 deletion dialectid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,15 @@ def load_dialectid(lang, dim, subwords=False):
if not isfile(output):
Download(f'{BASEURL}/{filename}', output)
_ = [Linear(**params) for params in tweet_iterator(output)]
return _
return _


def load_seqtm(lang, dim, precision):
diroutput = join(dirname(__file__), 'models')
if not isdir(diroutput):
os.mkdir(diroutput)
filename = f'seqtm_{lang}_{precision}_{dim}.json.gz'
output = join(diroutput, filename)
if not isfile(output):
Download(f'{BASEURL}/{filename}', output)
return tweet_iterator(output)

0 comments on commit 7d44bc3

Please sign in to comment.