Skip to content

Commit

Permalink
Adapt xtransformer backend to new vocab model.
Browse files Browse the repository at this point in the history
  • Loading branch information
mo-fu committed Sep 2, 2022
1 parent 4a82ea2 commit 367e493
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions annif/backend/xtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ class XTransformerBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
'imbalanced_depth': 100,
'max_match_clusters': 32768,
'do_fine_tune': True,
# 'model_shortcut': 'distilbert-base-multilingual-cased',
'model_shortcut': 'bert-base-multilingual-uncased',
'model_shortcut': 'distilbert-base-multilingual-cased',
'beam_size': 20,
'limit': 100,
'post_processor': 'sigmoid',
Expand Down Expand Up @@ -135,22 +134,18 @@ def _create_train_files(self, veccorpus, corpus):
txt_pth = osp.join(self.datadir, self.train_txt_file)
with open(txt_pth, 'w', encoding='utf-8') as txt_file:
for doc, vector in zip(corpus.documents, veccorpus):
subject_ids = [
self.project.subjects.by_uri(uri)
for uri
in doc.uris]
subject_ids = [s_id for s_id in subject_ids if s_id]
if not (subject_ids and doc.text):
subject_set = doc.subject_set
if not (subject_set and doc.text):
continue # noqa
print(' '.join(doc.text.split()), file=txt_file)
Xs.append(
sp.csr_matrix(vector, dtype=np.float32).sorted_indices())
ys.append(
sp.csr_matrix((
np.ones(len(subject_ids)),
np.ones(len(subject_set)),
(
np.zeros(len(subject_ids)),
subject_ids)),
np.zeros(len(subject_set)),
[s for s in subject_set])),
shape=(1, len(self.project.subjects)),
dtype=np.float32
).sorted_indices())
Expand Down Expand Up @@ -239,11 +234,8 @@ def _suggest(self, text, params):
post_processor=new_params['post_processor'])
results = []
for idx, score in zip(prediction.indices, prediction.data):
subject = self.project.subjects[idx]
results.append(SubjectSuggestion(
uri=subject[0],
label=subject[1],
notation=subject[2],
subject_id=idx,
score=score
))
return ListSuggestionResult(results)

0 comments on commit 367e493

Please sign in to comment.