Skip to content

Commit

Permalink
Some linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Nov 3, 2023
1 parent 797b452 commit d890b26
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 73 deletions.
10 changes: 5 additions & 5 deletions src/omlt/neuralnet/nn_formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ def __init__(self, network_structure, activation_functions=None):
)
if activation_functions is not None:
self._activation_functions.update(activation_functions)

# If we want to do network input/output validation at initialize time instead
# of build time, as it is for FullSpaceNNFormulation:
#
#
# network_inputs = list(self.__network_definition.input_nodes)
# if len(network_inputs) != 1:
# raise ValueError("Multiple input layers are not currently supported.")
Expand Down Expand Up @@ -513,10 +513,10 @@ def layer(b, layer_id):
else:
raise ValueError("ReluPartitionFormulation supports only Dense layers")

# This check is never hit. The formulation._build_formulation() function is
# only ever called by an OmltBlock.build_formulation(), and that runs the
# This check is never hit. The formulation._build_formulation() function is
# only ever called by an OmltBlock.build_formulation(), and that runs the
# input_indexes and output_indexes first, which will catch any formulations
# with multiple input or output layers.
# with multiple input or output layers.

# setup input variables constraints
# currently only support a single input layer
Expand Down
92 changes: 61 additions & 31 deletions tests/io/test_onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,120 +107,150 @@ def test_maxpool(datadir):
for layer in layers[1:]:
assert layer.kernel_depth == 3


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_input_tensor_invalid_dims(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
model.graph.input[0].type.tensor_type.shape.dim[1].dim_value = 0
parser = NetworkParser()
with pytest.raises(ValueError) as excinfo:
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)
expected_msg = "All dimensions in graph \"tf2onnx\" input tensor have 0 value."
assert str(excinfo.value) == expected_msg


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_no_input_layers(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
model.graph.input.remove(model.graph.input[0])
parser = NetworkParser()
with pytest.raises(ValueError) as excinfo:
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)
expected_msg = "No valid input layer found in graph \"tf2onnx\"."
assert str(excinfo.value) == expected_msg


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_node_no_inputs(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
while (len(model.graph.node[0].input) > 0):
model.graph.node[0].input.pop()
parser = NetworkParser()
with pytest.raises(ValueError) as excinfo:
parser.parse_network(model.graph,None,None)
expected_msg = "Nodes must have inputs or have op_type \"Constant\". Node \"StatefulPartitionedCall/keras_linear_131/dense/MatMul\" has no inputs and op_type \"MatMul\"."
parser.parse_network(model.graph, None, None)
expected_msg = """Nodes must have inputs or have op_type \"Constant\".
Node \"StatefulPartitionedCall/keras_linear_131/dense/MatMul\" has
no inputs and op_type \"MatMul\"."""
assert str(excinfo.value) == expected_msg


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_wrong_node_type(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)

with pytest.raises(ValueError) as excinfo:
parser._consume_dense_nodes(parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
parser._consume_dense_nodes(parser._nodes[
'StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],
parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
expected_msg_dense = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only MatMul nodes can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_dense

with pytest.raises(ValueError) as excinfo:
parser._consume_gemm_dense_nodes(parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
parser._consume_gemm_dense_nodes(parser._nodes[
'StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],
parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
expected_msg_gemm = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only Gemm nodes can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_gemm

with pytest.raises(ValueError) as excinfo:
parser._consume_conv_nodes(parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
parser._consume_conv_nodes(parser._nodes[
'StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],
parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
expected_msg_conv = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only Conv nodes can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_conv

with pytest.raises(ValueError) as excinfo:
parser._consume_reshape_nodes(parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
parser._consume_reshape_nodes(parser._nodes[
'StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],
parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
expected_msg_reshape = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only Reshape nodes can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_reshape

with pytest.raises(ValueError) as excinfo:
parser._consume_pool_nodes(parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
expected_msg_pool = "StatefulPartitionedCall/keras_linear_131/dense/BiasAdd is a Add node, only MaxPool nodes can be used as starting points for consumption."
parser._consume_pool_nodes(parser._nodes[
'StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][1],
parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/BiasAdd'][2])
expected_msg_pool = """StatefulPartitionedCall/keras_linear_131/dense/BiasAdd
is a Add node, only MaxPool nodes can be used as starting points
for consumption."""
assert str(excinfo.value) == expected_msg_pool


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_dense_wrong_dims(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)

parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/MatMul'][1].input.append('abcd')
with pytest.raises(ValueError) as excinfo:
parser._consume_dense_nodes(parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/MatMul'][1],parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/MatMul'][2])
parser._consume_dense_nodes(parser._nodes[
'StatefulPartitionedCall/keras_linear_131/dense/MatMul'][1],
parser._nodes['StatefulPartitionedCall/keras_linear_131/dense/MatMul'][2])
expected_msg_dense = "StatefulPartitionedCall/keras_linear_131/dense/MatMul input has 3 dimensions, only nodes with 2 input dimensions can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_dense


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_gemm_wrong_dims(datadir):
model = onnx.load(datadir.file("gemm.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)
parser._nodes['Gemm_0'][1].input.append('abcd')
with pytest.raises(ValueError) as excinfo:
parser._consume_gemm_dense_nodes(parser._nodes['Gemm_0'][1],parser._nodes['Gemm_0'][2])
parser._consume_gemm_dense_nodes(parser._nodes['Gemm_0'][1],
parser._nodes['Gemm_0'][2])
expected_msg_gemm = "Gemm_0 input has 4 dimensions, only nodes with 3 input dimensions can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_gemm


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_conv_wrong_dims(datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)
parser._nodes['Conv_0'][1].input.append('abcd')
with pytest.raises(ValueError) as excinfo:
parser._consume_conv_nodes(parser._nodes['Conv_0'][1],parser._nodes['Conv_0'][2])
parser._consume_conv_nodes(parser._nodes['Conv_0'][1],
parser._nodes['Conv_0'][2])
expected_msg_conv = "Conv_0 input has 4 dimensions, only nodes with 2 or 3 input dimensions can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_conv


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_reshape_wrong_dims(datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)
parser._nodes['Reshape_2'][1].input.append('abcd')
with pytest.raises(ValueError) as excinfo:
parser._consume_reshape_nodes(parser._nodes['Reshape_2'][1],parser._nodes['Reshape_2'][2])
expected_msg_reshape = "Reshape_2 input has 3 dimensions, only nodes with 2 input dimensions can be used as starting points for consumption."
parser._consume_reshape_nodes(parser._nodes['Reshape_2'][1],
parser._nodes['Reshape_2'][2])
expected_msg_reshape = """Reshape_2 input has 3 dimensions, only nodes with 2 input
dimensions can be used as starting points for consumption."""
assert str(excinfo.value) == expected_msg_reshape


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_maxpool_wrong_dims(datadir):
model = onnx.load(datadir.file("maxpool_2d.onnx"))
model = onnx.load(datadir.file("maxpool_2d.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph,None,None)
parser.parse_network(model.graph, None, None)
parser._nodes['node1'][1].input.append('abcd')
with pytest.raises(ValueError) as excinfo:
parser._consume_pool_nodes(parser._nodes['node1'][1],parser._nodes['node1'][2])
expected_msg_maxpool = "node1 input has 2 dimensions, only nodes with 1 input dimension can be used as starting points for consumption."
assert str(excinfo.value) == expected_msg_maxpool
parser._consume_pool_nodes(parser._nodes['node1'][1], parser._nodes['node1'][2])
expected_msg_maxpool = """node1 input has 2 dimensions, only nodes with 1 input
dimension can be used as starting points for consumption."""
assert str(excinfo.value) == expected_msg_maxpool
Loading

0 comments on commit d890b26

Please sign in to comment.