Skip to content

Commit

Permalink
Updating when the label is present
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Jul 21, 2023
1 parent c98c238 commit 266f3a4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
6 changes: 3 additions & 3 deletions IngeoDash/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def has_label(mem: Config, x):
return False


def model(mem: Config, data: dict, select: bool=True):
def model(mem: Config, data: dict):
lang = mem[mem.lang]
if lang not in CONFIG.denseBoW:
dense = DenseBoW(lang=lang, voc_size_exponent=mem.voc_size_exponent,
Expand All @@ -41,7 +41,7 @@ def model(mem: Config, data: dict, select: bool=True):
n_jobs=mem.n_jobs,
dataset=False, emoji=False, keyword=False)
dense.text_representations_extend(CONFIG.denseBoW[lang])
if select:
if mem.dense_select:
dense.select(D=data)
_ = np.unique([x[mem.label_header] for x in data],
return_counts=True)[1]
Expand Down Expand Up @@ -78,7 +78,7 @@ def active_learning_selection(mem: Config):
data = []
for cnt, i in enumerate(index):
ele = D.pop(i - cnt)
ele[mem.label_header] = klasses[cnt]
ele[mem.label_header] = ele.get(mem.label_header, klasses[cnt])
data.append(ele)
db[mem.original] = D
db[mem.data] = data
Expand Down
4 changes: 3 additions & 1 deletion IngeoDash/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Config:
voc_selection: str='most_common_by_type'
estimator_class: object=LinearSVC
decision_function_name: str='decision_function'
dense_select: bool=True


def __getitem__(self, key):
Expand All @@ -67,7 +68,8 @@ def __call__(self, value):
cls.mem = json.loads(value) if isinstance(value, str) else value
for key in ['label_header', 'text', 'n_value',
'voc_size_exponent', 'voc_selection',
'estimator_class', 'decision_function_name']:
'estimator_class', 'decision_function_name',
'dense_select']:
if key in cls.mem:
setattr(cls, key, cls.mem[key])
return cls
Expand Down
22 changes: 12 additions & 10 deletions IngeoDash/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def test_Config():
voc_size_exponent=15,
voc_selection='most_common_by_type',
estimator_class=LinearSVC,
decision_function_name='decision_function')
decision_function_name='decision_function',
dense_select=True)
for k, v in default.items():
assert v == getattr(conf, k)

Expand All @@ -75,15 +76,16 @@ def test_Config_call():


def test_Config_call2():
mem = CONFIG(dict(label_header='label',
text='texto', n_value=12,
voc_size_exponent=15,
voc_selection='most_common_by_type',
estimator_class=LinearSVC,
decision_function_name='decision_function'))
assert mem.label_header == 'label'
assert mem.text == 'texto'
assert mem.n_value == 12
kwargs = dict(label_header='label',
text='texto', n_value=12,
voc_size_exponent=15,
voc_selection='most_common_by_type',
estimator_class=LinearSVC,
decision_function_name='decision_function',
dense_select=True)
mem = CONFIG(kwargs)
for k, v in kwargs.items():
assert getattr(mem, k) == v


def test_CONFIG():
Expand Down

0 comments on commit 266f3a4

Please sign in to comment.