Skip to content

Commit

Permalink
Update to work with new networkx dispatching (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw authored Aug 25, 2023
1 parent 1f5ccb6 commit 3caced2
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Check with twine
run: python -m twine check --strict dist/*
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@v1.8.6
uses: pypa/gh-action-pypi-publish@v1.8.10
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
5 changes: 3 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
activate-environment: testing
- name: Install dependencies
run: |
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly pytest-mpl
# matplotlib lxml pygraphviz pydot sympy # Extra networkx deps we don't need yet
pip install git+https://github.com/networkx/networkx.git@main --no-deps
pip install -e . --no-deps
Expand All @@ -39,7 +39,8 @@ jobs:
python -c 'import sys, graphblas_algorithms; assert "networkx" not in sys.modules'
coverage run --branch -m pytest --color=yes -v --check-structure
coverage report
NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
# NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
./run_nx_tests.sh --color=yes --cov --cov-append
coverage report
coverage xml
- name: Coverage
Expand Down
30 changes: 18 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ci:
# See: https://pre-commit.ci/#configuration
autofix_prs: false
autoupdate_schedule: monthly
autoupdate_schedule: quarterly
skip: [no-commit-to-branch]
fail_fast: true
default_language_version:
Expand All @@ -17,21 +17,27 @@ repos:
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-ast
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
exclude_types: [svg]
- id: mixed-line-ending
- id: trailing-whitespace
- id: name-tests-test
args: ["--pytest-test-first"]
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.13
rev: v0.14
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
# I don't yet trust ruff to do what autoflake does
- repo: https://github.com/PyCQA/autoflake
rev: v2.1.1
rev: v2.2.0
hooks:
- id: autoflake
args: [--in-place]
Expand All @@ -40,7 +46,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -50,38 +56,38 @@ repos:
- id: auto-walrus
args: [--line-length, "100"]
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
# - id: black-jupyter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.270
rev: v0.0.285
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==6.0.0
- flake8-bugbear==23.5.9
- flake8==6.1.0
- flake8-bugbear==23.7.10
- flake8-simplify==0.20.0
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
rev: v1.5.0
hooks:
- id: yesqa
additional_dependencies: *flake8_dependencies
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
rev: v2.2.5
hooks:
- id: codespell
types_or: [python, rst, markdown]
additional_dependencies: [tomli]
files: ^(graphblas_algorithms|docs)/
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.270
rev: v0.0.285
hooks:
- id: ruff
# `pyroma` may help keep our package standards up to date if best practices change.
Expand Down
69 changes: 60 additions & 9 deletions graphblas_algorithms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,71 @@ class Dispatcher:
# End auto-generated code: dispatch

@staticmethod
def convert_from_nx(graph, weight=None, *, name=None):
def convert_from_nx(
graph,
edge_attrs=None,
node_attrs=None,
preserve_edge_attrs=False,
preserve_node_attrs=False,
preserve_graph_attrs=False,
name=None,
graph_name=None,
*,
weight=None, # For nx.__version__ <= 3.1
):
import networkx as nx

from .classes import DiGraph, Graph, MultiDiGraph, MultiGraph

if preserve_edge_attrs:
if graph.is_multigraph():
attrs = set().union(
*(
datadict
for nbrs in graph._adj.values()
for keydict in nbrs.values()
for datadict in keydict.values()
)
)
else:
attrs = set().union(
*(datadict for nbrs in graph._adj.values() for datadict in nbrs.values())
)
if len(attrs) == 1:
[attr] = attrs
edge_attrs = {attr: None}
elif attrs:
raise NotImplementedError("`preserve_edge_attrs=True` is not fully implemented")
if node_attrs:
raise NotImplementedError("non-None `node_attrs` is not yet implemented")
if preserve_node_attrs:
attrs = set().union(*(datadict for node, datadict in graph.nodes(data=True)))
if attrs:
raise NotImplementedError("`preserve_node_attrs=True` is not implemented")
if edge_attrs:
if len(edge_attrs) > 1:
raise NotImplementedError(
"Multiple edge attributes is not implemented (bad value for edge_attrs)"
)
if weight is not None:
raise TypeError("edge_attrs and weight both given")
[[weight, default]] = edge_attrs.items()
if default is not None and default != 1:
raise NotImplementedError(f"edge default != 1 is not implemented; got {default}")

if isinstance(graph, nx.MultiDiGraph):
return MultiDiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.MultiGraph):
return MultiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.DiGraph):
return DiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.Graph):
return Graph.from_networkx(graph, weight=weight)
raise TypeError(f"Unsupported type of graph: {type(graph)}")
G = MultiDiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.MultiGraph):
G = MultiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.DiGraph):
G = DiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.Graph):
G = Graph.from_networkx(graph, weight=weight)
else:
raise TypeError(f"Unsupported type of graph: {type(graph)}")
if preserve_graph_attrs:
G.graph.update(graph.graph)
return G

@staticmethod
def convert_to_nx(obj, *, name=None):
Expand Down
24 changes: 21 additions & 3 deletions graphblas_algorithms/tests/test_match_nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,29 @@
"Matching networkx namespace requires networkx to be installed", allow_module_level=True
)
else:
from networkx.classes import backends # noqa: F401
try:
from networkx.utils import backends

IS_NX_30_OR_31 = False
except ImportError: # pragma: no cover (import)
# This is the location in nx 3.1
from networkx.classes import backends # noqa: F401

IS_NX_30_OR_31 = True


def isdispatched(func):
"""Can this NetworkX function dispatch to other backends?"""
if IS_NX_30_OR_31:
return (
callable(func)
and hasattr(func, "dispatchname")
and func.__module__.startswith("networkx")
)
return (
callable(func) and hasattr(func, "dispatchname") and func.__module__.startswith("networkx")
callable(func)
and hasattr(func, "preserve_edge_attrs")
and func.__module__.startswith("networkx")
)


Expand All @@ -37,7 +53,9 @@ def dispatchname(func):
# Haha, there should be a better way to get this
if not isdispatched(func):
raise ValueError(f"Function is not dispatched in NetworkX: {func.__name__}")
return func.dispatchname
if IS_NX_30_OR_31:
return func.dispatchname
return func.name


def fullname(func):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@ ignore = [
"RET502", # Do not implicitly `return None` in function able to return non-`None` value
"RET503", # Missing explicit `return` at the end of function able to return non-`None` value
"RET504", # Unnecessary variable assignment before `return` statement
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` (Note: no annotations yet)
"S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log)
"S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log)
"SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary)
"SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster)
"SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer)
"TRY003", # Avoid specifying long messages outside the exception class (Note: why?)
"FIX001", "FIX002", "FIX003", "FIX004", # flake8-fixme (like flake8-todos)

# Ignored categories
"C90", # mccabe (Too strict, but maybe we should make things less complex)
Expand Down
7 changes: 5 additions & 2 deletions run_nx_tests.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/bin/bash
NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx "$@"
# NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx --cov --cov-report term-missing "$@"
NETWORKX_GRAPH_CONVERT=graphblas \
NETWORKX_TEST_BACKEND=graphblas \
NETWORKX_FALLBACK_TO_NX=True \
pytest --pyargs networkx "$@"
# pytest --pyargs networkx --cov --cov-report term-missing "$@"
4 changes: 2 additions & 2 deletions scripts/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

datapaths = [
Path(__file__).parent / ".." / "data",
Path("."),
Path(),
]


Expand All @@ -37,7 +37,7 @@ def find_data(dataname):
if dataname not in download_data.data_urls:
raise FileNotFoundError(f"Unable to find data file for {dataname}")
curpath = Path(download_data.main([dataname])[0])
return curpath.resolve().relative_to(Path(".").resolve())
return curpath.resolve().relative_to(Path().resolve())


def get_symmetry(file_or_mminfo):
Expand Down
2 changes: 1 addition & 1 deletion scripts/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(datanames, overwrite=False):
for name in datanames:
target = datapath / f"{name}.mtx"
filenames.append(target)
relpath = target.resolve().relative_to(Path(".").resolve())
relpath = target.resolve().relative_to(Path().resolve())
if not overwrite and target.exists():
print(f"{relpath} already exists; skipping", file=sys.stderr)
continue
Expand Down

0 comments on commit 3caced2

Please sign in to comment.