Skip to content

Commit

Permalink
Upgrade PyO3 and rust-numpy to 0.21.x and Bound API (#1152)
Browse files Browse the repository at this point in the history
* Upgrade PyO3 and rust-numpy to 0.21.0 and Bound API

This commit updates rustworkx to the latest release 0.21.0 and updates
the API usage for the new Bound api. [1] Luckily in rustworkx the usage
of py references was minimal so the migration effort wasn't too complex,
it's mostly the iterators.rs module's custom traits that needed to updated
to handle the new API.

Fixes #1150

[1] https://pyo3.rs/v0.21.0/migration#from-020-to-021

* Bump to pyo3 0.21.1

---------

Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com>
  • Loading branch information
mtreinish and IvanIsCoding authored Apr 2, 2024
1 parent 1eb6508 commit 9f0646e
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 157 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fixedbitset = "0.4.2"
hashbrown = { version = ">=0.13, <0.15", features = ["rayon"] }
indexmap = { version = ">=1.9, <3", features = ["rayon"] }
num-traits = "0.2"
numpy = "0.20.0"
numpy = "0.21.0"
petgraph = "0.6.4"
rand = "0.8"
rand_pcg = "0.3"
Expand Down Expand Up @@ -59,7 +59,7 @@ serde_json = "1.0"
rustworkx-core = { path = "rustworkx-core", version = "=0.15.0" }

[dependencies.pyo3]
version = "0.20.3"
version = "0.21.1"
features = ["extension-module", "hashbrown", "num-bigint", "num-complex", "indexmap"]

[dependencies.ndarray]
Expand Down
12 changes: 6 additions & 6 deletions src/coloring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub fn graph_greedy_color(
}
None => greedy_node_color(&graph.graph),
};
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
}
Expand Down Expand Up @@ -108,7 +108,7 @@ pub fn graph_greedy_color(
#[pyo3(text_signature = "(graph, /)")]
pub fn graph_greedy_edge_color(py: Python, graph: &graph::PyGraph) -> PyResult<PyObject> {
let colors = greedy_edge_color(&graph.graph);
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
}
Expand Down Expand Up @@ -142,7 +142,7 @@ pub fn graph_greedy_edge_color(py: Python, graph: &graph::PyGraph) -> PyResult<P
#[pyo3(text_signature = "(graph, /)")]
pub fn graph_misra_gries_edge_color(py: Python, graph: &graph::PyGraph) -> PyResult<PyObject> {
let colors = misra_gries_edge_color(&graph.graph);
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
}
Expand All @@ -163,7 +163,7 @@ pub fn graph_misra_gries_edge_color(py: Python, graph: &graph::PyGraph) -> PyRes
pub fn graph_two_color(py: Python, graph: &graph::PyGraph) -> PyResult<Option<PyObject>> {
match two_color(&graph.graph) {
Some(colors) => {
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
}
Expand All @@ -187,7 +187,7 @@ pub fn graph_two_color(py: Python, graph: &graph::PyGraph) -> PyResult<Option<Py
pub fn digraph_two_color(py: Python, graph: &digraph::PyDiGraph) -> PyResult<Option<PyObject>> {
match two_color(&graph.graph) {
Some(colors) => {
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
}
Expand Down Expand Up @@ -223,7 +223,7 @@ pub fn graph_bipartite_edge_color(py: Python, graph: &graph::PyGraph) -> PyResul
Ok(colors) => colors,
Err(_) => return Err(GraphNotBipartite::new_err("Graph is not bipartite")),
};
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (node, color) in colors {
out_dict.set_item(node.index(), color)?;
}
Expand Down
15 changes: 7 additions & 8 deletions src/connectivity/johnson_simple_cycles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use petgraph::visit::IntoNodeReferences;
use petgraph::visit::NodeFiltered;
use petgraph::Directed;

use pyo3::iter::IterNextOutput;
use pyo3::prelude::*;

use crate::iterators::NodeIndices;
Expand Down Expand Up @@ -149,7 +148,7 @@ fn process_stack(
block: &mut HashMap<NodeIndex, HashSet<NodeIndex>>,
subgraph: &StableDiGraph<(), ()>,
reverse_node_map: &HashMap<NodeIndex, NodeIndex>,
) -> Option<IterNextOutput<NodeIndices, &'static str>> {
) -> Option<NodeIndices> {
while let Some((this_node, neighbors)) = stack.last_mut() {
if let Some(next_node) = neighbors.pop() {
if next_node == start_node {
Expand All @@ -159,7 +158,7 @@ fn process_stack(
out_path.push(reverse_node_map[n].index());
closed.insert(*n);
}
return Some(IterNextOutput::Yield(NodeIndices { nodes: out_path }));
return Some(NodeIndices { nodes: out_path });
} else if blocked.insert(next_node) {
path.push(next_node);
stack.push((
Expand Down Expand Up @@ -195,14 +194,14 @@ impl SimpleCycleIter {
slf.into()
}

fn __next__(mut slf: PyRefMut<Self>) -> PyResult<IterNextOutput<NodeIndices, &'static str>> {
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<NodeIndices>> {
if slf.self_cycles.is_some() {
let self_cycles = slf.self_cycles.as_mut().unwrap();
let cycle_node = self_cycles.pop().unwrap();
if self_cycles.is_empty() {
slf.self_cycles = None;
}
return Ok(IterNextOutput::Yield(NodeIndices {
return Ok(Some(NodeIndices {
nodes: vec![cycle_node.index()],
}));
}
Expand Down Expand Up @@ -237,7 +236,7 @@ impl SimpleCycleIter {
slf.subgraph = subgraph;
slf.reverse_node_map = reverse_node_map;
slf.node_map = node_map;
return Ok(res);
return Ok(Some(res));
} else {
subgraph.remove_node(slf.start_node);
slf.scc
Expand Down Expand Up @@ -290,7 +289,7 @@ impl SimpleCycleIter {
slf.subgraph = subgraph;
slf.reverse_node_map = reverse_node_map;
slf.node_map = node_map;
return Ok(res);
return Ok(Some(res));
}
subgraph.remove_node(slf.start_node);
slf.scc
Expand All @@ -306,6 +305,6 @@ impl SimpleCycleIter {
}
}));
}
Ok(IterNextOutput::Return("Ended"))
Ok(None)
}
}
8 changes: 4 additions & 4 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ pub fn digraph_adjacency_matrix(
}
}
}
Ok(matrix.into_pyarray(py).into())
Ok(matrix.into_pyarray_bound(py).into())
}

/// Return the adjacency matrix for a PyGraph class
Expand Down Expand Up @@ -442,7 +442,7 @@ pub fn graph_adjacency_matrix(
}
}
}
Ok(matrix.into_pyarray(py).into())
Ok(matrix.into_pyarray_bound(py).into())
}

/// Compute the complement of an undirected graph.
Expand Down Expand Up @@ -839,7 +839,7 @@ pub fn graph_longest_simple_path(graph: &graph::PyGraph) -> Option<NodeIndices>
#[pyo3(text_signature = "(graph, /)")]
pub fn graph_core_number(py: Python, graph: &graph::PyGraph) -> PyResult<PyObject> {
let cores = connectivity::core_number(&graph.graph);
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (k, v) in cores {
out_dict.set_item(k.index(), v)?;
}
Expand All @@ -865,7 +865,7 @@ pub fn graph_core_number(py: Python, graph: &graph::PyGraph) -> PyResult<PyObjec
#[pyo3(text_signature = "(graph, /)")]
pub fn digraph_core_number(py: Python, graph: &digraph::PyDiGraph) -> PyResult<PyObject> {
let cores = connectivity::core_number(&graph.graph);
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (k, v) in cores {
out_dict.set_item(k.index(), v)?;
}
Expand Down
6 changes: 3 additions & 3 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ pub fn layers(
next_layer = Vec::new();
}
if !index_output {
Ok(PyList::new(py, output).into())
Ok(PyList::new_bound(py, output).into())
} else {
Ok(PyList::new(py, output_indices).into())
Ok(PyList::new_bound(py, output_indices).into())
}
}

Expand Down Expand Up @@ -445,7 +445,7 @@ pub fn lexicographical_topological_sort(
}
out_list.push(&dag.graph[node])
}
Ok(PyList::new(py, out_list).into())
Ok(PyList::new_bound(py, out_list).into())
}

/// Return the topological generations of a DAG
Expand Down
44 changes: 17 additions & 27 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ impl PyDiGraph {
};
edges.push(edge);
}
let out_dict = PyDict::new(py);
let nodes_lst: PyObject = PyList::new(py, nodes).into();
let edges_lst: PyObject = PyList::new(py, edges).into();
let out_dict = PyDict::new_bound(py);
let nodes_lst: PyObject = PyList::new_bound(py, nodes).into();
let edges_lst: PyObject = PyList::new_bound(py, edges).into();
out_dict.set_item("nodes", nodes_lst)?;
out_dict.set_item("edges", edges_lst)?;
out_dict.set_item("nodes_removed", self.node_removed)?;
Expand All @@ -331,17 +331,13 @@ impl PyDiGraph {
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let dict_state = state.downcast::<PyDict>(py)?;
let nodes_lst = dict_state
.get_item("nodes")?
.unwrap()
.downcast::<PyList>()?;
let edges_lst = dict_state
.get_item("edges")?
.unwrap()
.downcast::<PyList>()?;
let dict_state = state.downcast_bound::<PyDict>(py)?;
let binding = dict_state.get_item("nodes")?.unwrap();
let nodes_lst = binding.downcast::<PyList>()?;
let binding = dict_state.get_item("edges")?.unwrap();
let edges_lst = binding.downcast::<PyList>()?;
self.graph = StablePyGraph::<Directed>::new();
let dict_state = state.downcast::<PyDict>(py)?;
let dict_state = state.downcast_bound::<PyDict>(py)?;
self.multigraph = dict_state
.get_item("multigraph")?
.unwrap()
Expand Down Expand Up @@ -381,11 +377,8 @@ impl PyDiGraph {
}
} else if nodes_lst.len() == 1 {
// graph has only one node, handle logic here to save one if in the loop later
let item = nodes_lst
.get_item(0)
.unwrap()
.downcast::<PyTuple>()
.unwrap();
let binding = nodes_lst.get_item(0).unwrap();
let item = binding.downcast::<PyTuple>().unwrap();
let node_idx: usize = item.get_item(0).unwrap().extract().unwrap();
let node_w = item.get_item(1).unwrap().extract().unwrap();

Expand All @@ -397,11 +390,8 @@ impl PyDiGraph {
self.graph.remove_node(NodeIndex::new(i));
}
} else {
let last_item = nodes_lst
.get_item(nodes_lst.len() - 1)
.unwrap()
.downcast::<PyTuple>()
.unwrap();
let binding = nodes_lst.get_item(nodes_lst.len() - 1).unwrap();
let last_item = binding.downcast::<PyTuple>().unwrap();

// list of temporary nodes that will be removed later to re-create holes
let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap();
Expand Down Expand Up @@ -1384,7 +1374,7 @@ impl PyDiGraph {
};

let have_same_weights =
source_weight.as_ref(py).compare(target_weight.as_ref(py))? == Ordering::Equal;
source_weight.bind(py).compare(target_weight.bind(py))? == Ordering::Equal;

if have_same_weights {
const DIRECTIONS: [petgraph::Direction; 2] =
Expand Down Expand Up @@ -1930,7 +1920,7 @@ impl PyDiGraph {
let mut file = Vec::<u8>::new();
build_dot(py, &self.graph, &mut file, graph_attr, node_attr, edge_attr)?;
Ok(Some(
PyString::new(py, str::from_utf8(&file)?).to_object(py),
PyString::new_bound(py, str::from_utf8(&file)?).to_object(py),
))
}
}
Expand Down Expand Up @@ -2044,7 +2034,7 @@ impl PyDiGraph {
Some(del) => pieces[2..].join(del),
None => pieces[2..].join(&' '.to_string()),
};
PyString::new(py, &weight_str).into()
PyString::new_bound(py, &weight_str).into()
} else {
py.None()
};
Expand Down Expand Up @@ -2293,7 +2283,7 @@ impl PyDiGraph {
weight.clone_ref(py),
)?;
}
let out_dict = PyDict::new(py);
let out_dict = PyDict::new_bound(py);
for (orig_node, new_node) in new_node_map.iter() {
out_dict.set_item(orig_node.index(), new_node.index())?;
}
Expand Down
2 changes: 1 addition & 1 deletion src/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ pub fn directed_complete_graph(
}

#[pymodule]
pub fn generators(_py: Python, m: &PyModule) -> PyResult<()> {
pub fn generators(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cycle_graph))?;
m.add_wrapped(wrap_pyfunction!(directed_cycle_graph))?;
m.add_wrapped(wrap_pyfunction!(path_graph))?;
Expand Down
Loading

0 comments on commit 9f0646e

Please sign in to comment.