Skip to content

Commit

Permalink
fix pyproto deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
georgios-ts committed Mar 6, 2022
1 parent 946c47c commit 58ab94e
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 204 deletions.
67 changes: 31 additions & 36 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ use indexmap::IndexSet;

use retworkx_core::dictmap::*;

use pyo3::class::PyMappingProtocol;
use pyo3::exceptions::PyIndexError;
use pyo3::gc::{PyGCProtocol, PyVisit};
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple};
use pyo3::PyTraverseError;
Expand Down Expand Up @@ -151,7 +150,7 @@ use super::dag_algo::is_directed_acyclic_graph;
/// ``PyDiGraph`` object will not be a multigraph. When ``False`` if a
/// method call is made that would add parallel edges the the weight/weight
/// from that method call will be used to update the existing edge in place.
#[pyclass(module = "retworkx", subclass, gc)]
#[pyclass(module = "retworkx", subclass)]
#[pyo3(text_signature = "(/, check_cycle=False, multigraph=True)")]
#[derive(Clone)]
pub struct PyDiGraph {
Expand Down Expand Up @@ -2638,22 +2637,20 @@ impl PyDiGraph {
pub fn copy(&self) -> PyDiGraph {
self.clone()
}
}

#[pyproto]
impl PyMappingProtocol for PyDiGraph {
/// Return the number of nodes in the graph
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {

fn __getitem__(&self, idx: usize) -> PyResult<&PyObject> {
match self.graph.node_weight(NodeIndex::new(idx as usize)) {
Some(data) => Ok(data),
None => Err(PyIndexError::new_err("No node found for index")),
}
}

fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
fn __setitem__(&mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self.graph.node_weight_mut(NodeIndex::new(idx as usize)) {
Some(node_data) => node_data,
None => return Err(PyIndexError::new_err("No node found for index")),
Expand All @@ -2662,41 +2659,15 @@ impl PyMappingProtocol for PyDiGraph {
Ok(())
}

fn __delitem__(&'p mut self, idx: usize) -> PyResult<()> {
fn __delitem__(&mut self, idx: usize) -> PyResult<()> {
match self.graph.remove_node(NodeIndex::new(idx as usize)) {
Some(_) => Ok(()),
None => Err(PyIndexError::new_err("No node found for index")),
}
}
}

fn is_cycle_check_required(dag: &PyDiGraph, a: NodeIndex, b: NodeIndex) -> bool {
let mut parents_a = dag
.graph
.neighbors_directed(a, petgraph::Direction::Incoming);
let mut children_b = dag
.graph
.neighbors_directed(b, petgraph::Direction::Outgoing);
parents_a.next().is_some() && children_b.next().is_some() && dag.graph.find_edge(a, b).is_none()
}

fn weight_transform_callable(
py: Python,
map_fn: &Option<PyObject>,
value: &PyObject,
) -> PyResult<PyObject> {
match map_fn {
Some(map_fn) => {
let res = map_fn.call1(py, (value,))?;
Ok(res.to_object(py))
}
None => Ok(value.clone_ref(py)),
}
}
// Functions to enable Python Garbage Collection

// Functions to enable Python Garbage Collection
#[pyproto]
impl PyGCProtocol for PyDiGraph {
// Function for PyTypeObject.tp_traverse [1][2] used to tell Python what
// objects the PyDiGraph has strong references to.
//
Expand Down Expand Up @@ -2732,6 +2703,30 @@ impl PyGCProtocol for PyDiGraph {
}
}

fn is_cycle_check_required(dag: &PyDiGraph, a: NodeIndex, b: NodeIndex) -> bool {
let mut parents_a = dag
.graph
.neighbors_directed(a, petgraph::Direction::Incoming);
let mut children_b = dag
.graph
.neighbors_directed(b, petgraph::Direction::Outgoing);
parents_a.next().is_some() && children_b.next().is_some() && dag.graph.find_edge(a, b).is_none()
}

fn weight_transform_callable(
py: Python,
map_fn: &Option<PyObject>,
value: &PyObject,
) -> PyResult<PyObject> {
match map_fn {
Some(map_fn) => {
let res = map_fn.call1(py, (value,))?;
Ok(res.to_object(py))
}
None => Ok(value.clone_ref(py)),
}
}

fn _from_adjacency_matrix<'p, T>(
py: Python<'p>,
matrix: PyReadonlyArray2<'p, T>,
Expand Down
21 changes: 8 additions & 13 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
use retworkx_core::dictmap::*;

use pyo3::class::PyMappingProtocol;
use pyo3::exceptions::PyIndexError;
use pyo3::gc::{PyGCProtocol, PyVisit};
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple};
use pyo3::PyTraverseError;
Expand Down Expand Up @@ -114,7 +113,7 @@ use petgraph::visit::{
/// object will not be a multigraph. When ``False`` if a method call is
/// made that would add parallel edges the the weight/weight from that
/// method call will be used to update the existing edge in place.
#[pyclass(module = "retworkx", subclass, gc)]
#[pyclass(module = "retworkx", subclass)]
#[pyo3(text_signature = "(/, multigraph=True)")]
#[derive(Clone)]
pub struct PyGraph {
Expand Down Expand Up @@ -1713,22 +1712,20 @@ impl PyGraph {
pub fn copy(&self) -> PyGraph {
self.clone()
}
}

#[pyproto]
impl PyMappingProtocol for PyGraph {
/// Return the nmber of nodes in the graph
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {

fn __getitem__(&self, idx: usize) -> PyResult<&PyObject> {
match self.graph.node_weight(NodeIndex::new(idx)) {
Some(data) => Ok(data),
None => Err(PyIndexError::new_err("No node found for index")),
}
}

fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
fn __setitem__(&mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self.graph.node_weight_mut(NodeIndex::new(idx)) {
Some(node_data) => node_data,
None => return Err(PyIndexError::new_err("No node found for index")),
Expand All @@ -1737,17 +1734,15 @@ impl PyMappingProtocol for PyGraph {
Ok(())
}

fn __delitem__(&'p mut self, idx: usize) -> PyResult<()> {
fn __delitem__(&mut self, idx: usize) -> PyResult<()> {
match self.graph.remove_node(NodeIndex::new(idx as usize)) {
Some(_) => Ok(()),
None => Err(PyIndexError::new_err("No node found for index")),
}
}
}

// Functions to enable Python Garbage Collection
#[pyproto]
impl PyGCProtocol for PyGraph {
// Functions to enable Python Garbage Collection

// Function for PyTypeObject.tp_traverse [1][2] used to tell Python what
// objects the PyGraph has strong references to.
//
Expand Down
13 changes: 5 additions & 8 deletions src/isomorphism/vf2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use std::marker;
use hashbrown::HashMap;
use retworkx_core::dictmap::*;

use pyo3::class::iter::{IterNextOutput, PyIterProtocol};
use pyo3::gc::{PyGCProtocol, PyVisit};
use pyo3::class::iter::IterNextOutput;
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::PyTraverseError;

Expand Down Expand Up @@ -977,7 +977,7 @@ where

macro_rules! vf2_mapping_impl {
($name:ident, $Ty:ty) => {
#[pyclass(module = "retworkx", gc)]
#[pyclass(module = "retworkx")]
pub struct $name {
vf2: Vf2Algorithm<$Ty, Option<PyObject>, Option<PyObject>>,
}
Expand All @@ -1001,8 +1001,8 @@ macro_rules! vf2_mapping_impl {
}
}

#[pyproto]
impl PyIterProtocol for $name {
#[pymethods]
impl $name {
fn __iter__(slf: PyRef<Self>) -> Py<$name> {
slf.into()
}
Expand All @@ -1015,10 +1015,7 @@ macro_rules! vf2_mapping_impl {
None => Ok(IterNextOutput::Return("Ended")),
})
}
}

#[pyproto]
impl PyGCProtocol for $name {
fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
for j in 0..2 {
for node in self.vf2.st[j].graph.node_weights() {
Expand Down
Loading

0 comments on commit 58ab94e

Please sign in to comment.