diff --git a/IngeoDash/annotate.py b/IngeoDash/annotate.py index 96c797f..797fca6 100644 --- a/IngeoDash/annotate.py +++ b/IngeoDash/annotate.py @@ -27,7 +27,7 @@ def has_label(mem: Config, x): return False -def model(mem: Config, data: dict, select: bool=True): +def model(mem: Config, data: dict): lang = mem[mem.lang] if lang not in CONFIG.denseBoW: dense = DenseBoW(lang=lang, voc_size_exponent=mem.voc_size_exponent, @@ -41,7 +41,7 @@ def model(mem: Config, data: dict, select: bool=True): n_jobs=mem.n_jobs, dataset=False, emoji=False, keyword=False) dense.text_representations_extend(CONFIG.denseBoW[lang]) - if select: + if mem.dense_select: dense.select(D=data) _ = np.unique([x[mem.label_header] for x in data], return_counts=True)[1] @@ -78,7 +78,7 @@ def active_learning_selection(mem: Config): data = [] for cnt, i in enumerate(index): ele = D.pop(i - cnt) - ele[mem.label_header] = klasses[cnt] + ele[mem.label_header] = ele.get(mem.label_header, klasses[cnt]) data.append(ele) db[mem.original] = D db[mem.data] = data diff --git a/IngeoDash/config.py b/IngeoDash/config.py index 5a00f6f..b47c2e5 100644 --- a/IngeoDash/config.py +++ b/IngeoDash/config.py @@ -53,6 +53,7 @@ class Config: voc_selection: str='most_common_by_type' estimator_class: object=LinearSVC decision_function_name: str='decision_function' + dense_select: bool=True def __getitem__(self, key): @@ -67,7 +68,8 @@ def __call__(self, value): cls.mem = json.loads(value) if isinstance(value, str) else value for key in ['label_header', 'text', 'n_value', 'voc_size_exponent', 'voc_selection', - 'estimator_class', 'decision_function_name']: + 'estimator_class', 'decision_function_name', + 'dense_select']: if key in cls.mem: setattr(cls, key, cls.mem[key]) return cls diff --git a/IngeoDash/tests/test_config.py b/IngeoDash/tests/test_config.py index 31f5b87..b0210b9 100644 --- a/IngeoDash/tests/test_config.py +++ b/IngeoDash/tests/test_config.py @@ -50,7 +50,8 @@ def test_Config(): voc_size_exponent=15, voc_selection='most_common_by_type', estimator_class=LinearSVC, - decision_function_name='decision_function') + decision_function_name='decision_function', + dense_select=True) for k, v in default.items(): assert v == getattr(conf, k) @@ -75,15 +76,16 @@ def test_Config_call(): def test_Config_call2(): - mem = CONFIG(dict(label_header='label', - text='texto', n_value=12, - voc_size_exponent=15, - voc_selection='most_common_by_type', - estimator_class=LinearSVC, - decision_function_name='decision_function')) - assert mem.label_header == 'label' - assert mem.text == 'texto' - assert mem.n_value == 12 + kwargs = dict(label_header='label', + text='texto', n_value=12, + voc_size_exponent=15, + voc_selection='most_common_by_type', + estimator_class=LinearSVC, + decision_function_name='decision_function', + dense_select=True) + mem = CONFIG(kwargs) + for k, v in kwargs.items(): + assert getattr(mem, k) == v def test_CONFIG():