diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 0d0724379..af7255d8d 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -303,91 +303,95 @@ impl<'a, 'b> ValidationContext<'a, 'b> { let op_type = &node_type.op; let flags = op_type.validity_flags(); - if self.hugr.hierarchy.child_count(node.pg_index()) > 0 { - if flags.allowed_children.is_empty() { - return Err(ValidationError::NonContainerWithChildren { + if self.hugr.hierarchy.child_count(node.pg_index()) == 0 { + return if flags.requires_children { + Err(ValidationError::ContainerWithoutChildren { node, optype: op_type.clone(), - }); - } + }) + } else { + Ok(()) + }; + } - let all_children = self.hugr.children(node); - let mut first_two_children = all_children.clone().take(2); - let first_child = self.hugr.get_optype(first_two_children.next().unwrap()); - if !flags.allowed_first_child.is_superset(first_child.tag()) { - return Err(ValidationError::InvalidInitialChild { - parent: node, - parent_optype: op_type.clone(), - optype: first_child.clone(), - expected: flags.allowed_first_child, - position: "first", - }); - } + if flags.allowed_children.is_empty() { + return Err(ValidationError::NonContainerWithChildren { + node, + optype: op_type.clone(), + }); + } - if let Some(second_child) = first_two_children - .next() - .map(|child| self.hugr.get_optype(child)) - { - if !flags.allowed_second_child.is_superset(second_child.tag()) { - return Err(ValidationError::InvalidInitialChild { - parent: node, - parent_optype: op_type.clone(), - optype: second_child.clone(), - expected: flags.allowed_second_child, - position: "second", - }); - } - } - // Additional validations running over the full list of children optypes - let children_optypes = all_children.map(|c| (c.pg_index(), self.hugr.get_optype(c))); - if let Err(source) = op_type.validate_op_children(children_optypes) { - return Err(ValidationError::InvalidChildren { + let all_children = self.hugr.children(node); + let mut first_two_children = all_children.clone().take(2); + let first_child = self.hugr.get_optype(first_two_children.next().unwrap()); + if !flags.allowed_first_child.is_superset(first_child.tag()) { + return Err(ValidationError::InvalidInitialChild { + parent: node, + parent_optype: op_type.clone(), + optype: first_child.clone(), + expected: flags.allowed_first_child, + position: "first", + }); + } + + if let Some(second_child) = first_two_children + .next() + .map(|child| self.hugr.get_optype(child)) + { + if !flags.allowed_second_child.is_superset(second_child.tag()) { + return Err(ValidationError::InvalidInitialChild { parent: node, parent_optype: op_type.clone(), - source, + optype: second_child.clone(), + expected: flags.allowed_second_child, + position: "second", }); } + } + // Additional validations running over the full list of children optypes + let children_optypes = all_children.map(|c| (c.pg_index(), self.hugr.get_optype(c))); + if let Err(source) = op_type.validate_op_children(children_optypes) { + return Err(ValidationError::InvalidChildren { + parent: node, + parent_optype: op_type.clone(), + source, + }); + } - // Additional validations running over the edges of the contained graph - if let Some(edge_check) = flags.edge_check { - for source in self.hugr.hierarchy.children(node.pg_index()) { - for target in self.hugr.graph.output_neighbours(source) { - if self.hugr.hierarchy.parent(target) != Some(node.pg_index()) { - continue; - } - let source_op = self.hugr.get_optype(source.into()); - let target_op = self.hugr.get_optype(target.into()); - for (source_port, target_port) in - self.hugr.graph.get_connections(source, target) - { - let edge_data = ChildrenEdgeData { + // Additional validations running over the edges of the contained graph + if let Some(edge_check) = flags.edge_check { + for source in self.hugr.hierarchy.children(node.pg_index()) { + for target in self.hugr.graph.output_neighbours(source) { + if self.hugr.hierarchy.parent(target) != Some(node.pg_index()) { + continue; + } + let source_op = self.hugr.get_optype(source.into()); + let target_op = self.hugr.get_optype(target.into()); + for (source_port, target_port) in + self.hugr.graph.get_connections(source, target) + { + let edge_data = ChildrenEdgeData { + source, + target, + source_port: self.hugr.graph.port_offset(source_port).unwrap(), + target_port: self.hugr.graph.port_offset(target_port).unwrap(), + source_op: source_op.clone(), + target_op: target_op.clone(), + }; + if let Err(source) = edge_check(edge_data) { + return Err(ValidationError::InvalidEdges { + parent: node, + parent_optype: op_type.clone(), source, - target, - source_port: self.hugr.graph.port_offset(source_port).unwrap(), - target_port: self.hugr.graph.port_offset(target_port).unwrap(), - source_op: source_op.clone(), - target_op: target_op.clone(), - }; - if let Err(source) = edge_check(edge_data) { - return Err(ValidationError::InvalidEdges { - parent: node, - parent_optype: op_type.clone(), - source, - }); - } + }); } } } } + } - if flags.requires_dag { - self.validate_children_dag(node, op_type)?; - } - } else if flags.requires_children { - return Err(ValidationError::ContainerWithoutChildren { - node, - optype: op_type.clone(), - }); + if flags.requires_dag { + self.validate_children_dag(node, op_type)?; } Ok(())