Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bugs in type specifications from Numpy 1.25 release #1047

Merged
merged 6 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graspologic/embed/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import scipy
import scipy.sparse as sp
import sklearn
from scipy.sparse import csr_array
from scipy.stats import norm
from typing_extensions import Literal

from graspologic.types import List, Tuple
from graspologic.utils import is_almost_symmetric

SvdAlgorithmType = Literal["full", "truncated", "randomized", "eigsh"]

Expand Down Expand Up @@ -107,7 +107,7 @@ def select_dimension(
pp.918-930.
"""
# Handle input data
if not isinstance(X, np.ndarray) and not sp.isspmatrix_csr(X):
if not isinstance(X, (np.ndarray, csr_array)):
msg = "X must be a numpy array or scipy.sparse.csr_array, not {}.".format(
type(X)
)
Expand Down
10 changes: 5 additions & 5 deletions graspologic/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import scipy.sparse
from beartype import beartype
from scipy.optimize import linear_sum_assignment
from scipy.sparse import csgraph, csr_array, diags, isspmatrix_csr
from scipy.sparse import csgraph, csr_array, diags
from scipy.sparse.csgraph import connected_components
from sklearn.metrics import confusion_matrix
from sklearn.utils import check_array, check_consistent_length, column_or_1d
Expand Down Expand Up @@ -43,7 +43,7 @@ def average_matrices(
"""
if isinstance(matrices[0], np.ndarray):
return np.mean(matrices, axis=0) # type: ignore
elif isspmatrix_csr(matrices[0]):
elif isinstance(matrices[0], csr_array):
return np.sum(matrices) / len(matrices)

raise TypeError(f"Unexpected type {matrices}")
Expand Down Expand Up @@ -275,11 +275,11 @@ def is_almost_symmetric(
Raises
------
TypeError
If the provided graph is not a numpy.ndarray or scipy.sparse.spmatrix
If the provided graph is not a numpy.ndarray or scipy.sparse.csr_array
"""
if (x.ndim != 2) or (x.shape[0] != x.shape[1]):
return False
if isinstance(x, (np.ndarray, scipy.sparse.spmatrix)):
if isinstance(x, (np.ndarray, csr_array)):
return abs(x - x.T).max() <= atol
else:
raise TypeError("input a correct matrix type.")
Expand Down Expand Up @@ -836,7 +836,7 @@ def augment_diagonal(
degrees = (in_degrees + out_degrees) / 2
diag = weight * degrees / divisor

graph += diags(diag) if isspmatrix_csr(graph) else np.diag(diag)
graph += diags(diag) if isinstance(graph, csr_array) else np.diag(diag)

return graph

Expand Down
6 changes: 4 additions & 2 deletions tests/partition/test_leiden.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,10 @@ def test_correct_types(self):
node_ids = partitions.keys()
for node_id in node_ids:
self.assertTrue(
isinstance(node_id, (np.int32, np.intc)),
f"{node_id} has {type(node_id)} should be an np.int32/np.intc",
isinstance(
node_id, np.integer
), # this is the preferred numpy typecheck
f"{node_id} has {type(node_id)} should be an int",
)

def test_hierarchical(self):
Expand Down
1 change: 1 addition & 0 deletions tests/test_spectral_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def test_directed_vertex_direction(self):
def test_directed_correct_latent_positions(self):
# setup
ase = AdjacencySpectralEmbed(n_components=3)
np.random.seed(8888)
P = np.array([[0.9, 0.1, 0.1], [0.3, 0.6, 0.1], [0.1, 0.5, 0.6]])
M, labels = sbm([200, 200, 200], P, directed=True, return_labels=True)

Expand Down