Skip to content

Commit

Permalink
Merge pull request #257 from NatLibFi/issue225-feedback-online-learning
Browse files Browse the repository at this point in the history
Support for online learning
  • Loading branch information
osma authored Feb 27, 2019
2 parents 68e0f43 + 7624d21 commit a3641d5
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 34 deletions.
9 changes: 9 additions & 0 deletions annif/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ def info(self, message):
def warning(self, message):
"""Log a warning message from this backend"""
logger.warning("Backend {}: {}".format(self.backend_id, message))


class AnnifLearningBackend(AnnifBackend):
"""Base class for Annif backends that can perform online learning"""

@abc.abstractmethod
def learn(self, corpus, project):
"""further train the model on the given document or subject corpus"""
pass # pragma: no cover
19 changes: 16 additions & 3 deletions annif/backend/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@
from . import backend


class DummyBackend(backend.AnnifBackend):
class DummyBackend(backend.AnnifLearningBackend):
name = "dummy"
initialized = False
uri = 'http://example.org/dummy'
label = 'dummy'

def initialize(self):
self.initialized = True

def _analyze(self, text, project, params):
score = float(params.get('score', 1.0))
return ListAnalysisResult([AnalysisHit(uri='http://example.org/dummy',
label='dummy', score=score)],
return ListAnalysisResult([AnalysisHit(uri=self.uri,
label=self.label,
score=score)],
project.subjects)

def learn(self, corpus, project):
# in this dummy backend we "learn" by picking up the URI and label
# of the first subject of the first document in the learning set
# and using that in subsequent analysis results
for doc in corpus.documents:
if doc.uris and doc.labels:
self.uri = list(doc.uris)[0]
self.label = list(doc.labels)[0]
break
16 changes: 13 additions & 3 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import mixins


class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifLearningBackend):
"""Vorpal Wabbit multiclass/multilabel backend for Annif"""

name = "vw_multi"
Expand Down Expand Up @@ -129,15 +129,19 @@ def _inputs_to_exampletext(self, project, text):
return ' '.join(['|{} {}'.format(namespace, featurestr)
for namespace, featurestr in namespaces.items()])

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
def _create_examples(self, corpus, project):
examples = []
for doc in corpus.documents:
text = self._inputs_to_exampletext(project, doc.text)
if not text:
continue
examples.extend(self._format_examples(project, text, doc.uris))
random.shuffle(examples)
return examples

def _create_train_file(self, corpus, project):
self.info('creating VW train file')
examples = self._create_examples(corpus, project)
annif.util.atomic_save(examples,
self.datadir,
self.TRAIN_FILE,
Expand Down Expand Up @@ -184,6 +188,12 @@ def train(self, corpus, project):
self._create_train_file(corpus, project)
self._create_model(project)

def learn(self, corpus, project):
for example in self._create_examples(corpus, project):
self._model.learn(example)
modelpath = os.path.join(self.datadir, self.MODEL_FILE)
self._model.save(modelpath)

def _convert_result(self, result, project):
if self.algorithm == 'multilabel_oaa':
# result is a list of subject IDs - need to vectorize
Expand Down
13 changes: 13 additions & 0 deletions annif/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ def run_train(project_id, paths):
proj.train(documents)


@cli.command('learn')
@click_log.simple_verbosity_option(logger)
@click.argument('project_id')
@click.argument('paths', type=click.Path(), nargs=-1)
def run_learn(project_id, paths):
"""
Further train an existing project on a collection of documents.
"""
proj = get_project(project_id)
documents = open_documents(paths)
proj.learn(documents)


@cli.command('analyze')
@click_log.simple_verbosity_option(logger)
@click.argument('project_id')
Expand Down
3 changes: 2 additions & 1 deletion annif/corpus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Annif corpus operations"""


from .document import DocumentDirectory, DocumentFile
from .document import DocumentDirectory, DocumentFile, DocumentList
from .subject import Subject, SubjectDirectory, SubjectFileTSV
from .subject import SubjectIndex, SubjectSet
from .skos import SubjectFileSKOS
from .types import Document
from .combine import CombinedCorpus
14 changes: 13 additions & 1 deletion annif/corpus/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def documents(self):
with open(docfilename, errors='replace') as docfile:
text = docfile.read()
with open(keyfilename) as keyfile:
subjects = SubjectSet(keyfile.read())
subjects = SubjectSet.from_string(keyfile.read())
yield Document(text=text, uris=subjects.subject_uris,
labels=subjects.subject_labels)

Expand All @@ -66,3 +66,15 @@ def opener(path):
subjects = [annif.util.cleanup_uri(uri)
for uri in uris.split()]
yield Document(text=text, uris=subjects, labels=[])


class DocumentList(DocumentCorpus, DocumentToSubjectCorpusMixin):
"""A document corpus based on a list of other iterable of Document
objects"""

def __init__(self, documents):
self._documents = documents

@property
def documents(self):
yield from self._documents
26 changes: 12 additions & 14 deletions annif/corpus/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,20 @@ def load(cls, path):
class SubjectSet:
"""Represents a set of subjects for a document."""

def __init__(self, subj_data):
"""initialize a SubjectSet from either a string representation or a
tuple (URIs, labels)"""

if isinstance(subj_data, str):
self.subject_uris = set()
self.subject_labels = set()
self._parse(subj_data)
else:
uris, labels = subj_data
self.subject_uris = set(uris)
self.subject_labels = set(labels)
def __init__(self, subj_data=None):
"""Create a SubjectSet and optionally initialize it from a tuple
(URIs, labels)"""

uris, labels = subj_data or ([], [])
self.subject_uris = set(uris)
self.subject_labels = set(labels)

def _parse(self, subj_data):
@classmethod
def from_string(cls, subj_data):
sset = cls()
for line in subj_data.splitlines():
self._parse_line(line)
sset._parse_line(line)
return sset

def _parse_line(self, line):
vals = line.split("\t")
Expand Down
7 changes: 7 additions & 0 deletions annif/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ class ConfigurationException(AnnifException):
"""Exception raised when a project or backend is misconfigured."""

prefix = "Misconfigured"


class NotSupportedException(AnnifException):
"""Exception raised when an operation is not supported by a project or
backend."""

prefix = "Not supported"
14 changes: 13 additions & 1 deletion annif/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import annif.vocab
from annif.datadir import DatadirMixin
from annif.exception import AnnifException, ConfigurationException, \
NotInitializedException
NotInitializedException, NotSupportedException

logger = annif.logger

Expand Down Expand Up @@ -196,6 +196,18 @@ def train(self, corpus):
self._create_vectorizer(corpus)
self.backend.train(corpus, project=self)

def learn(self, corpus):
"""further train the project using documents from a metadata source"""

corpus.set_subject_index(self.subjects)
if isinstance(
self.backend,
annif.backend.backend.AnnifLearningBackend):
self.backend.learn(corpus, project=self)
else:
raise NotSupportedException("Learning not supported by backend",
project_id=self.project_id)

def dump(self):
"""return this project as a dict"""
return {'project_id': self.project_id,
Expand Down
29 changes: 29 additions & 0 deletions annif/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import connexion
import annif.project
from annif.corpus import Document, DocumentList
from annif.hit import HitFilter
from annif.exception import AnnifException
from annif.project import Access
Expand Down Expand Up @@ -64,3 +65,31 @@ def analyze(project_id, text, limit, threshold):
return server_error(err)
hits = hit_filter(result)
return {'results': [hit._asdict() for hit in hits]}


def _documents_to_corpus(documents):
corpus = [Document(text=d['text'],
uris=[subj['uri'] for subj in d['subjects']],
labels=[subj['label'] for subj in d['subjects']])
for d in documents
if 'text' in d and 'subjects' in d]
return DocumentList(corpus)


def learn(project_id, documents):
"""learn from documents and return an empty 204 response if succesful"""

try:
project = annif.project.get_project(
project_id, min_access=Access.hidden)
except ValueError:
return project_not_found_error(project_id)

corpus = _documents_to_corpus(documents)

try:
project.learn(corpus)
except AnnifException as err:
return server_error(err)

return None, 204
63 changes: 58 additions & 5 deletions annif/swagger/annif.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,38 @@ paths:
$ref: '#/definitions/Problem'
tags:
- Automatic subject indexing
'/projects/{project_id}/learn':
post:
summary: learn from manually indexed documents
operationId: annif.rest.learn
consumes:
- application/json
produces:
- application/json
- application/problem+json
parameters:
- $ref: '#/parameters/project_id'
- name: documents
in: body
description: documents to learn from
required: true
schema:
type: array
items:
$ref: '#/definitions/IndexedDocument'
responses:
'204':
description: successful operation
'404':
description: Project not found
schema:
$ref: '#/definitions/Problem'
'503':
description: Service Unavailable
schema:
$ref: '#/definitions/Problem'
tags:
- Learning from feedback
definitions:
ProjectBackend:
description: A backend of a project
Expand Down Expand Up @@ -133,7 +165,7 @@ definitions:
example: 'http://example.org/subject1'
label:
type: string
example: 'Archaeology'
example: Archaeology
score:
type: number
example: 0.85
Expand All @@ -148,6 +180,23 @@ definitions:
type: array
items:
$ref: '#/definitions/AnalysisResult'
IndexedDocument:
description: A document with attached, known good subjects
properties:
text:
type: string
example: "A quick brown fox jumped over the lazy dog."
subjects:
type: array
items:
type: object
properties:
uri:
type: string
example: 'http://example.org/subject1'
label:
type: string
example: 'Vulpes vulpes'
Problem:
type: object
properties:
Expand All @@ -169,8 +218,10 @@ definitions:
status:
type: integer
format: int32
description: |
The HTTP status code generated by the origin server for this occurrence
description: >
The HTTP status code generated by the origin server for this
occurrence
of the problem.
minimum: 100
maximum: 600
Expand All @@ -185,6 +236,8 @@ definitions:
instance:
type: string
format: uri
description: |
An absolute URI that identifies the specific occurrence of the problem.
description: >
An absolute URI that identifies the specific occurrence of the
problem.
It may or may not yield further information if dereferenced.
21 changes: 21 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import annif
import annif.backend
import annif.corpus


def test_get_backend_nonexistent():
Expand All @@ -20,3 +21,23 @@ def test_get_backend_dummy(app, project):
assert result[0].uri == 'http://example.org/dummy'
assert result[0].label == 'dummy'
assert result[0].score == 1.0


def test_learn_dummy(app, project, tmpdir):
dummy_type = annif.backend.get_backend("dummy")
dummy = dummy_type(backend_id='dummy', params={},
datadir=app.config['DATADIR'])

tmpdir.join('doc1.txt').write('doc1')
tmpdir.join('doc1.tsv').write('<http://example.org/key1>\tkey1')
tmpdir.join('doc2.txt').write('doc2')
tmpdir.join('doc2.tsv').write('<http://example.org/key2>\tkey2')
docdir = annif.corpus.DocumentDirectory(str(tmpdir))

dummy.learn(docdir, project)

result = dummy.analyze(text='this is some text', project=project)
assert len(result) == 1
assert result[0].uri == 'http://example.org/key1'
assert result[0].label == 'key1'
assert result[0].score == 1.0
Loading

0 comments on commit a3641d5

Please sign in to comment.