Skip to content

Commit

Permalink
Merge pull request #117 from erica-chiu/pipeline-inputs
Browse files Browse the repository at this point in the history
Issue 112 Get Pipeline Inputs
  • Loading branch information
csala authored Jan 14, 2020
2 parents a9f261e + 1dd0f37 commit 390606a
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 7 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Contributors
* William Xue <wgxue@mit.edu>
* Akshay Ravikumar <akshayr@mit.edu>
* Laura Gustafson <lgustaf@mit.edu>
* Erica Chiu <ejchiu@mit.edu>
83 changes: 76 additions & 7 deletions mlblocks/mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,37 @@ def _get_pipeline_dict(pipeline, primitives):

def _get_block_outputs(self, block_name):
"""Get the list of output variables for the given block."""
block = self.blocks[block_name]
outputs = deepcopy(block.produce_output)
output_names = self.output_names.get(block_name, dict())
for output in outputs:
name = output['name']
context_name = output_names.get(name, name)
outputs = self._get_block_variables(
block_name,
'produce_output',
self.output_names.get(block_name, dict())
)
for context_name, output in outputs.items():
output['variable'] = '{}.{}'.format(block_name, context_name)

return outputs
return list(outputs.values())

def _get_block_variables(self, block_name, variables_attr, names):
"""Get dictionary of variable names to the variable for a given block
Args:
block_name (str):
Name of the block for which to get the specification
variables_attr (str):
Name of the attribute that has the variables list. It can be
`fit_args`, `produce_args` or `produce_output`.
names (dict):
Dictionary used to translate the variable names.
"""
block = self.blocks[block_name]
variables = deepcopy(getattr(block, variables_attr))
variable_dict = {}
for variable in variables:
name = variable['name']
context_name = names.get(name, name)
variable_dict[context_name] = variable

return variable_dict

def _get_outputs(self, pipeline, outputs):
"""Get the output definitions from the pipeline dictionary.
Expand Down Expand Up @@ -225,6 +247,53 @@ def _get_str_output(self, output):

raise ValueError('Invalid Output Specification: {}'.format(output))

def get_inputs(self, fit=True):
"""Get a relation of all the input variables required by this pipeline.
The result is a dictionary that maps each variable name with their
specified information.
Optionally include the fit arguments.
Args:
fit (bool):
Optional argument to include fit arguments or not. Defaults to ``True``.
Returns:
dictionary:
A dictionary mapping every input variable's name to a dictionary
specifying the information corresponding to that input variable.
Each dictionary contains the entry ``name``, as
well as any other metadata that may have been included in the
pipeline inputs specification.
"""
inputs = dict()
for block_name in reversed(self.blocks.keys()): # iterates through pipeline backwards
produce_outputs = self._get_block_variables(
block_name,
'produce_output',
self.output_names.get(block_name, dict())
)

for produce_output_name in produce_outputs.keys():
inputs.pop(produce_output_name, None)

produce_inputs = self._get_block_variables(
block_name,
'produce_args',
self.input_names.get(block_name, dict())
)
inputs.update(produce_inputs)

if fit:
fit_inputs = self._get_block_variables(
block_name,
'fit_args',
self.input_names.get(block_name, dict())
)
inputs.update(fit_inputs)

return inputs

def get_outputs(self, outputs='default'):
"""Get the list of output variables that correspond to the specified outputs.
Expand Down
132 changes: 132 additions & 0 deletions tests/test_mlpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,138 @@ def test_get_output_variables(self):

assert names == ['a_variable']

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test__get_block_variables(self):
expected = {
'name_output': {
'name': 'output',
'type': 'whatever',
}
}

pipeline = MLPipeline(['a_primitive'])

pipeline.blocks['a_primitive#1'].produce_outputs = [
{
'name': 'output',
'type': 'whatever'
}
]

outputs = pipeline._get_block_variables(
'a_primitive#1',
'produce_outputs',
{'output': 'name_output'}
)
assert outputs == expected

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_get_inputs_fit(self):
expected = {
'input': {
'name': 'input',
'type': 'whatever',
},
'fit_input': {
'name': 'fit_input',
'type': 'whatever',
},
'another_input': {
'name': 'another_input',
'type': 'another_whatever',
}

}

pipeline = MLPipeline(['a_primitive', 'another_primitive'])

pipeline.blocks['a_primitive#1'].produce_args = [
{
'name': 'input',
'type': 'whatever'
}
]

pipeline.blocks['a_primitive#1'].fit_args = [
{
'name': 'fit_input',
'type': 'whatever'
}
]

pipeline.blocks['a_primitive#1'].produce_output = [
{
'name': 'output',
'type': 'another_whatever'
}
]

pipeline.blocks['another_primitive#1'].produce_args = [
{
'name': 'output',
'type': 'another_whatever'
},
{
'name': 'another_input',
'type': 'another_whatever'
}
]

inputs = pipeline.get_inputs()
assert inputs == expected

@patch('mlblocks.mlpipeline.MLBlock', new=get_mlblock_mock)
def test_get_inputs_no_fit(self):
expected = {
'input': {
'name': 'input',
'type': 'whatever',
},
'another_input': {
'name': 'another_input',
'type': 'another_whatever',
}

}

pipeline = MLPipeline(['a_primitive', 'another_primitive'])

pipeline.blocks['a_primitive#1'].produce_args = [
{
'name': 'input',
'type': 'whatever'
}
]

pipeline.blocks['a_primitive#1'].fit_args = [
{
'name': 'fit_input',
'type': 'whatever'
}
]

pipeline.blocks['a_primitive#1'].produce_output = [
{
'name': 'output',
'type': 'another_whatever'
}
]

pipeline.blocks['another_primitive#1'].produce_args = [
{
'name': 'output',
'type': 'another_whatever'
},
{
'name': 'another_input',
'type': 'another_whatever'
}
]

inputs = pipeline.get_inputs(fit=False)

assert inputs == expected

def test_fit(self):
pass

Expand Down

0 comments on commit 390606a

Please sign in to comment.