Skip to content

Commit

Permalink
Merge pull request #115 from HDI-Project/issue-114-load-from-json
Browse files Browse the repository at this point in the history
Issue 114 load from json
  • Loading branch information
csala authored Dec 24, 2019
2 parents 949c8b1 + be684dd commit a9f261e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
44 changes: 17 additions & 27 deletions mlblocks/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def get_pipelines_paths():
return _PIPELINES_PATHS + _load_entry_points('pipelines')


def _load_json(json_path):
with open(json_path, 'r') as json_file:
LOGGER.debug('Loading %s', json_path)
return json.load(json_file)


def _load(name, paths):
"""Locate and load the JSON annotation in any of the given paths.
Expand All @@ -206,15 +212,17 @@ def _load(name, paths):
Args:
name (str):
name of the JSON to look for. The name should not contain the
``.json`` extension, as it will be added dynamically.
Path to a JSON file or name of the JSON to look for withouth the ``.json`` extension.
paths (list):
list of paths where the primitives will be looked for.
Returns:
dict:
The content of the JSON annotation file loaded into a dict.
"""
if os.path.isfile(name):
return _load_json(name)

for base_path in paths:
parts = name.split('.')
number_of_parts = len(parts)
Expand All @@ -225,12 +233,7 @@ def _load(name, paths):
json_path = os.path.join(folder, filename)

if os.path.isfile(json_path):
with open(json_path, 'r') as json_file:
LOGGER.debug('Loading %s from %s', name, json_path)
return json.load(json_file)


_PRIMITIVES = dict()
return _load_json(json_path)


def load_primitive(name):
Expand All @@ -241,8 +244,7 @@ def load_primitive(name):
Args:
name (str):
name of the JSON to look for. The name should not contain the
``.json`` extension, as it will be added dynamically.
Path to a JSON file or name of the JSON to look for withouth the ``.json`` extension.
Returns:
dict:
Expand All @@ -252,20 +254,13 @@ def load_primitive(name):
ValueError:
A ``ValueError`` will be raised if the primitive cannot be found.
"""
primitive = _PRIMITIVES.get(name)
primitive = _load(name, get_primitives_paths())
if primitive is None:
primitive = _load(name, get_primitives_paths())
if primitive is None:
raise ValueError("Unknown primitive: {}".format(name))

_PRIMITIVES[name] = primitive
raise ValueError("Unknown primitive: {}".format(name))

return primitive


_PIPELINES = dict()


def load_pipeline(name):
"""Locate and load the pipeline JSON annotation.
Expand All @@ -274,8 +269,7 @@ def load_pipeline(name):
Args:
name (str):
name of the JSON to look for. The name should not contain the
``.json`` extension, as it will be added dynamically.
Path to a JSON file or name of the JSON to look for withouth the ``.json`` extension.
Returns:
dict:
Expand All @@ -285,13 +279,9 @@ def load_pipeline(name):
ValueError:
A ``ValueError`` will be raised if the pipeline cannot be found.
"""
pipeline = _PIPELINES.get(name)
pipeline = _load(name, get_pipelines_paths())
if pipeline is None:
pipeline = _load(name, get_pipelines_paths())
if pipeline is None:
raise ValueError("Unknown pipeline: {}".format(name))

_PIPELINES[name] = pipeline
raise ValueError("Unknown pipeline: {}".format(name))

return pipeline

Expand Down
11 changes: 11 additions & 0 deletions mlblocks/mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import re
import warnings
from collections import Counter, OrderedDict, defaultdict
from copy import deepcopy

Expand Down Expand Up @@ -814,6 +815,11 @@ def from_dict(cls, metadata):
A new MLPipeline instance with the details found in the
given specification dictionary.
"""
warnings.warn(
'MLPipeline.form_dict(pipeline_dict) is deprecated and will be removed in a '
'later release. Please use MLPipeline(dict) instead,',
DeprecationWarning
)
return cls(metadata)

@classmethod
Expand All @@ -831,6 +837,11 @@ def load(cls, path):
A new MLPipeline instance with the specification found
in the JSON file.
"""
warnings.warn(
'MLPipeline.load(path) is deprecated and will be removed in a later release. '
'Please use MLPipeline(path) instead,',
DeprecationWarning
)
with open(path, 'r') as in_file:
metadata = json.load(in_file)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ def test__load_success():
assert primitive == loaded


def test__load_json_path():
primitive = {
'name': 'temp.primitive',
'primitive': 'temp.primitive'
}

with tempfile.TemporaryDirectory() as tempdir:
paths = [tempdir]
primitive_path = os.path.join(tempdir, 'temp.primitive.json')
with open(primitive_path, 'w') as primitive_file:
json.dump(primitive, primitive_file, indent=4)

loaded = discovery._load(primitive_path, paths)

assert primitive == loaded


@patch('mlblocks.discovery.get_primitives_paths')
@patch('mlblocks.discovery._load')
def test__load_primitive_value_error(load_mock, gpp_mock):
Expand Down

0 comments on commit a9f261e

Please sign in to comment.