diff --git a/annif/backend/xtransformer.py b/annif/backend/xtransformer.py index 1d92a7f25..983403fd6 100644 --- a/annif/backend/xtransformer.py +++ b/annif/backend/xtransformer.py @@ -80,8 +80,7 @@ class XTransformerBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend): 'imbalanced_depth': 100, 'max_match_clusters': 32768, 'do_fine_tune': True, - # 'model_shortcut': 'distilbert-base-multilingual-cased', - 'model_shortcut': 'bert-base-multilingual-uncased', + 'model_shortcut': 'distilbert-base-multilingual-cased', 'beam_size': 20, 'limit': 100, 'post_processor': 'sigmoid', @@ -135,22 +134,18 @@ def _create_train_files(self, veccorpus, corpus): txt_pth = osp.join(self.datadir, self.train_txt_file) with open(txt_pth, 'w', encoding='utf-8') as txt_file: for doc, vector in zip(corpus.documents, veccorpus): - subject_ids = [ - self.project.subjects.by_uri(uri) - for uri - in doc.uris] - subject_ids = [s_id for s_id in subject_ids if s_id] - if not (subject_ids and doc.text): + subject_set = doc.subject_set + if not (subject_set and doc.text): continue # noqa print(' '.join(doc.text.split()), file=txt_file) Xs.append( sp.csr_matrix(vector, dtype=np.float32).sorted_indices()) ys.append( sp.csr_matrix(( - np.ones(len(subject_ids)), + np.ones(len(subject_set)), ( - np.zeros(len(subject_ids)), - subject_ids)), + np.zeros(len(subject_set)), + [s for s in subject_set])), shape=(1, len(self.project.subjects)), dtype=np.float32 ).sorted_indices()) @@ -239,11 +234,8 @@ def _suggest(self, text, params): post_processor=new_params['post_processor']) results = [] for idx, score in zip(prediction.indices, prediction.data): - subject = self.project.subjects[idx] results.append(SubjectSuggestion( - uri=subject[0], - label=subject[1], - notation=subject[2], + subject_id=idx, score=score )) return ListSuggestionResult(results)