diff --git a/annif/backend/nn_ensemble.py b/annif/backend/nn_ensemble.py index e1daab00a..f24fc69ee 100644 --- a/annif/backend/nn_ensemble.py +++ b/annif/backend/nn_ensemble.py @@ -21,7 +21,11 @@ import annif.corpus import annif.parallel import annif.util -from annif.exception import NotInitializedException, NotSupportedException +from annif.exception import ( + NotInitializedException, + NotSupportedException, + OperationFailedException, +) from annif.suggestion import SuggestionBatch, vector_to_suggestions from . import backend, ensemble @@ -129,9 +133,14 @@ def initialize(self, parallel: bool = False) -> None: backend_id=self.backend_id, ) self.debug("loading Keras model from {}".format(model_filename)) - self._model = load_model( - model_filename, custom_objects={"MeanLayer": MeanLayer} - ) + try: + self._model = load_model( + model_filename, custom_objects={"MeanLayer": MeanLayer} + ) + except ValueError: + md = annif.util.get_keras_model_metadata(model_filename) + message = f"loading model from {model_filename}; model metadata: {md}" + raise OperationFailedException(message, backend_id=self.backend_id) def _merge_source_batches( self, diff --git a/annif/util.py b/annif/util.py index b03c63ec2..9a3b0ca76 100644 --- a/annif/util.py +++ b/annif/util.py @@ -3,10 +3,12 @@ from __future__ import annotations import glob +import json import logging import os import os.path import tempfile +import zipfile from typing import Any, Callable from annif import logger @@ -50,6 +52,18 @@ def atomic_save( os.rename(fn, newname) +def get_keras_model_metadata(model_file_path: str) -> dict: + """Read metadata from Keras model files.""" + try: + with zipfile.ZipFile(model_file_path, "r") as zip: + with zip.open("metadata.json") as metadata_file: + metadata_str = metadata_file.read().decode("utf-8") + metadata = json.loads(metadata_str) + return metadata + except Exception: + return dict() + + def cleanup_uri(uri: str) -> str: """remove angle brackets from a URI, if any""" if uri.startswith("<") and uri.endswith(">"): diff --git a/tests/dummy-nn-model.keras b/tests/dummy-nn-model.keras new file mode 100644 index 000000000..966b93937 Binary files /dev/null and b/tests/dummy-nn-model.keras differ diff --git a/tests/test_util.py b/tests/test_util.py index 183333117..4caa10ba1 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -17,3 +17,9 @@ def test_metric_code(): for input, output in zip(inputs, outputs): assert annif.util.metric_code(input) == output + + +def test_get_keras_model_metadata(): + model_file_path = "tests/dummy-nn-model.keras" # nn-ensemble-model.zip" + expected_md = {"keras_version": "xx.yy.zz", "date_saved": "2024-04-11@01:01:01"} + assert annif.util.get_keras_model_metadata(model_file_path) == expected_md