Skip to content

Commit

Permalink
Fixed bugs in type specifications from Numpy 1.25 release (#1047)
Browse files Browse the repository at this point in the history
* attempted fix for matrix type erroring

* fix spec of csr_array

* add a random seed

* remove isspmatrix

* fix a type check

* black
  • Loading branch information
bdpedigo authored Jul 31, 2023
1 parent 4499f7c commit 4f91f1f
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
4 changes: 2 additions & 2 deletions graspologic/embed/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import scipy
import scipy.sparse as sp
import sklearn
from scipy.sparse import csr_array
from scipy.stats import norm
from typing_extensions import Literal

from graspologic.types import List, Tuple
from graspologic.utils import is_almost_symmetric

SvdAlgorithmType = Literal["full", "truncated", "randomized", "eigsh"]

Expand Down Expand Up @@ -107,7 +107,7 @@ def select_dimension(
pp.918-930.
"""
# Handle input data
if not isinstance(X, np.ndarray) and not sp.isspmatrix_csr(X):
if not isinstance(X, (np.ndarray, csr_array)):
msg = "X must be a numpy array or scipy.sparse.csr_array, not {}.".format(
type(X)
)
Expand Down
10 changes: 5 additions & 5 deletions graspologic/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import scipy.sparse
from beartype import beartype
from scipy.optimize import linear_sum_assignment
from scipy.sparse import csgraph, csr_array, diags, isspmatrix_csr
from scipy.sparse import csgraph, csr_array, diags
from scipy.sparse.csgraph import connected_components
from sklearn.metrics import confusion_matrix
from sklearn.utils import check_array, check_consistent_length, column_or_1d
Expand Down Expand Up @@ -43,7 +43,7 @@ def average_matrices(
"""
if isinstance(matrices[0], np.ndarray):
return np.mean(matrices, axis=0) # type: ignore
elif isspmatrix_csr(matrices[0]):
elif isinstance(matrices[0], csr_array):
return np.sum(matrices) / len(matrices)

raise TypeError(f"Unexpected type {matrices}")
Expand Down Expand Up @@ -275,11 +275,11 @@ def is_almost_symmetric(
Raises
------
TypeError
If the provided graph is not a numpy.ndarray or scipy.sparse.spmatrix
If the provided graph is not a numpy.ndarray or scipy.sparse.csr_array
"""
if (x.ndim != 2) or (x.shape[0] != x.shape[1]):
return False
if isinstance(x, (np.ndarray, scipy.sparse.spmatrix)):
if isinstance(x, (np.ndarray, csr_array)):
return abs(x - x.T).max() <= atol
else:
raise TypeError("input a correct matrix type.")
Expand Down Expand Up @@ -836,7 +836,7 @@ def augment_diagonal(
degrees = (in_degrees + out_degrees) / 2
diag = weight * degrees / divisor

graph += diags(diag) if isspmatrix_csr(graph) else np.diag(diag)
graph += diags(diag) if isinstance(graph, csr_array) else np.diag(diag)

return graph

Expand Down
6 changes: 4 additions & 2 deletions tests/partition/test_leiden.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,10 @@ def test_correct_types(self):
node_ids = partitions.keys()
for node_id in node_ids:
self.assertTrue(
isinstance(node_id, (np.int32, np.intc)),
f"{node_id} has {type(node_id)} should be an np.int32/np.intc",
isinstance(
node_id, np.integer
), # this is the preferred numpy typecheck
f"{node_id} has {type(node_id)} should be an int",
)

def test_hierarchical(self):
Expand Down
1 change: 1 addition & 0 deletions tests/test_spectral_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def test_directed_vertex_direction(self):
def test_directed_correct_latent_positions(self):
# setup
ase = AdjacencySpectralEmbed(n_components=3)
np.random.seed(8888)
P = np.array([[0.9, 0.1, 0.1], [0.3, 0.6, 0.1], [0.1, 0.5, 0.6]])
M, labels = sbm([200, 200, 200], P, directed=True, return_labels=True)

Expand Down

0 comments on commit 4f91f1f

Please sign in to comment.