Skip to content

Commit

Permalink
Params in file
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jun 12, 2024
1 parent bf9b097 commit 63699e9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
19 changes: 17 additions & 2 deletions dialectid/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ def test_load_bow():
from microtc.utils import Counter

c = utils.load_bow()
assert isinstance(c, Counter)
assert isinstance(c['counter'], Counter)
c2 = utils.load_bow(loc='mx')
assert c.most_common(n=1)[0][1] != c2.most_common(n=1)[0][1]
assert c['counter'].most_common(n=1)[0][1] != c2['counter'].most_common(n=1)[0][1]


def test_BOW():
"""Test BOW"""
import importlib

BOW = utils.BOW
for lang in ['ar', 'de', 'en',
'es', 'fr', 'nl',
'pt', 'ru', 'tr', 'zh']:
assert lang in BOW
path = BOW[lang].split('.')
module = '.'.join(path[:-1])
text_repr = importlib.import_module(module)
instance = getattr(text_repr, path[-1])
10 changes: 5 additions & 5 deletions dialectid/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ def bow(self):

if self._bow is not None:
return self._bow
freq = load_bow(lang=self.lang,
data = load_bow(lang=self.lang,
d=self.voc_size_exponent,
func=self.voc_selection,
loc=self._loc)
params = b4msa_params(lang=self.lang,
dim=self._voc_size_exponent)
params = data['params']
counter = data['counter']
params.update(self.b4msa_kwargs)
bow = TextModel(**params)
tfidf = TFIDF()
tfidf.N = freq.update_calls
tfidf.word2id, tfidf.wordWeight = tfidf.counter2weight(freq)
tfidf.N = counter.update_calls
tfidf.word2id, tfidf.wordWeight = tfidf.counter2weight(counter)
bow.model = tfidf
self._bow = bow
return bow
22 changes: 19 additions & 3 deletions dialectid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@
]
}

BOW = {'es': 'dialectid.text_repr.BoW',
'en': 'dialectid.text_repr.BoW',
'ar': 'dialectid.text_repr.BoW',
'de': 'EvoMSA.text_repr.BoW',
'fr': 'dialectid.text_repr.BoW',
'nl': 'EvoMSA.text_repr.BoW',
'pt': 'dialectid.text_repr.BoW',
'ru': 'dialectid.text_repr.BoW',
'tr': 'EvoMSA.text_repr.BoW',
'zh': 'EvoMSA.text_repr.BoW'
}


def load_bow(lang: str='es',
d: int=17,
Expand All @@ -108,9 +120,10 @@ def load_bow(lang: str='es',
"""Load BoW model from dialectid"""

def load(filename):
from microtc.utils import tweet_iterator

try:
with gzip.open(filename, 'rb') as fpt:
return str(fpt.read(), encoding='utf-8')
return next(tweet_iterator(filename))
except Exception:
os.unlink(filename)
raise Exception(filename)
Expand All @@ -127,5 +140,8 @@ def load(filename):
output = join(diroutput, filename)
if not isfile(output):
Download(url, output)
return Counter.fromjson(load(output))
data = load(output)
_ = data['counter']
data['counter'] = Counter(_["dict"], _["update_calls"])
return data

0 comments on commit 63699e9

Please sign in to comment.