Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features: spectral clustering #518

Merged
merged 63 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
7cf400f
Implement Utilities
Cdebus Jan 24, 2020
477d840
Run Hooks
Cdebus Jan 24, 2020
3670728
Implementation of metrics class and spectral clustering
Cdebus Jan 27, 2020
5411b45
Rework draft:
Cdebus Jan 28, 2020
9537f65
Merge branch 'features/469-KmeansRework' into features/spectral
Cdebus Jan 30, 2020
440fc81
Implementation of heat cdist(X,y,metric)
Cdebus Jan 31, 2020
560fbba
Removed diag function, was already available under ht.manipulations
Cdebus Feb 3, 2020
6165b3c
Some cosmetics and input sanitation
Cdebus Feb 3, 2020
a3f82b8
Bugfixes in the Linalg Utilites
Cdebus Feb 3, 2020
0c7cc4a
Renaming of distance module; implementation of rbf kernel
Cdebus Feb 5, 2020
bb2588a
Merged changes from 469-KmeansRework for refactorizing spatial.distan…
Cdebus Feb 5, 2020
b4cddd9
Unit tests for norm and projection
Cdebus Feb 5, 2020
ecfc9c1
Testing for CG Algorithm
Cdebus Feb 5, 2020
9af8bed
Extended Unit tests for spatial/distance.py module
Cdebus Feb 6, 2020
b32e852
Bugfixes after first full testrun
Cdebus Feb 6, 2020
2caf243
Bugfix
Cdebus Feb 7, 2020
ffe60b6
Add support of float64 and typecasting to distance calculation
Cdebus Feb 7, 2020
0d3762b
Add support of float64 and typecasting to distance calculation
Cdebus Feb 7, 2020
c7c8c2e
Support of int types and typecasting
Cdebus Feb 7, 2020
cb5bdc6
Cosmetics and initial unit tests
Cdebus Feb 7, 2020
d360ebe
Fixed typecasting in ht.array, and adjusted Unit tests
Cdebus Feb 7, 2020
c9530d5
Added testing for different types in cdist
Cdebus Feb 7, 2020
9286d44
Removed old spatial module
Cdebus Feb 7, 2020
affb2ce
Merged latest changes from master
Cdebus Feb 7, 2020
a3312c3
Cleanup
Cdebus Feb 7, 2020
12e0027
merge master
Cdebus Feb 12, 2020
338401b
Finalize spectral clustering API
Cdebus Feb 14, 2020
67f45f5
Bugfixes and Unit tests
Cdebus Feb 14, 2020
dd04279
Fixes in Lanczos Algorithm: Reorthogonalization to make Eigenvalues c…
Cdebus Feb 17, 2020
6adf52e
Testing on HDFML
Feb 19, 2020
ec131d9
Merged device fix in distance and others from master
Feb 19, 2020
b413b5f
Benchmarking Spectral Clustering. Cleanup needed
Feb 19, 2020
1b2e7a7
Enhanced Lanczos-algorithm by speeding up reorthogonalization
Mar 4, 2020
4454bb9
Tagged #494 for rethinking Reorthogonalization
Mar 4, 2020
8f5163e
Finale changes from benchmarks
Apr 1, 2020
4dbe9bb
Merge current master
Apr 1, 2020
dff6c19
Merge branch 'master' into features/spectral
Apr 1, 2020
0dec50d
prep for master merge
coquelin77 Apr 1, 2020
4d8aea4
Merge branch 'master' into features/spectral
coquelin77 Apr 1, 2020
9ce6218
Running hooks
Apr 1, 2020
e1b44cd
Adjusted Tests to new Spectral Design
Apr 1, 2020
c69da43
Re-design Laplacian API for intrinsic similarity matrix calculation
Apr 2, 2020
cdb82aa
Implemented 'fill_diagonal' function for dndarrays
Apr 2, 2020
6de377e
Refactorization of linalg.solver and implementation of graph.laplacia…
Apr 3, 2020
17635c6
Final Changes from PR Review and Bugfixes
Apr 3, 2020
d10a432
Bugfixes and testing for spectral clustering
Apr 6, 2020
bad9a8f
bugfix in CG algorithm
Apr 6, 2020
1327921
added to resplit_, removed need for this fix in mm split=none
coquelin77 Apr 6, 2020
20deaba
Temporary Fixe for dndarray fill_diagonal
Apr 6, 2020
f5c84a5
minor comment change to restart tests
coquelin77 Apr 6, 2020
29ea72e
updated Changelog
Apr 6, 2020
73b562e
Merge branch 'features/spectral' of https://github.com/Cdebus/heat in…
Apr 6, 2020
f609302
Appeasing Markus about the order of changelog
Apr 6, 2020
270dfbb
Merge remote-tracking branch 'upstream/master' into features/spectral
Apr 6, 2020
b4b0dce
Merge branch 'master' into features/spectral
coquelin77 Apr 6, 2020
a3bf6d3
Merge branch 'features/spectral' of https://github.com/Cdebus/heat in…
Apr 6, 2020
f5d4012
Update CHANGELOG.md
Markus-Goetz Apr 6, 2020
28b4f38
PR review changes
Apr 6, 2020
02ef024
Merge branch 'features/spectral' of https://github.com/Cdebus/heat in…
Apr 6, 2020
4829c03
Documentation
Apr 6, 2020
8ffb794
Merge 'master' into features/spectral
Apr 6, 2020
77ea885
Cosmetics
Apr 7, 2020
a862198
Bugfix tests cdist
Apr 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Pending Additions

- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Create submodule for Linear Algebra functions
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implementated QR
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implementated a tiling class to create Square tiles along the diagonal of a 2D matrix
Expand All @@ -12,6 +11,8 @@
- [#499](https://github.com/helmholtz-analytics/heat/pull/499) Bugfix: MPI datatype mapping: `torch.int16` now maps to `MPI.SHORT` instead of `MPI.SHORT_INT`
- [#507](https://github.com/helmholtz-analytics/heat/pull/507) Bugfix: sanitize_axis changes axis of 0-dim scalars to None
- [#515](https://github.com/helmholtz-analytics/heat/pull/515) ht.var() now returns the unadjusted sample variance by default, Bessel's correction can be applied by setting ddof=1.
- [#518](https://github.com/helmholtz-analytics/heat/pull/518) Implemenation of Spectral Clustering
- [#518](https://github.com/helmholtz-analytics/heat/pull/518) Implemenation of Spectral Clustering.
- [#519](https://github.com/helmholtz-analytics/heat/pull/519) Bugfix: distributed slicing with empty list or scalar as input; distributed nonzero() of empty (local) tensor.
- [#521](https://github.com/helmholtz-analytics/heat/pull/521) Add documentation for the generic reduce_op in Heat's core
- [#522](https://github.com/helmholtz-analytics/heat/pull/522) Added CUDA-aware MPI detection for MVAPICH, MPICH and ParaStation.
Expand Down
1 change: 1 addition & 0 deletions heat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import core
from . import cluster
from . import graph
from . import naive_bayes
from . import regression
from . import spatial
Expand Down
1 change: 1 addition & 0 deletions heat/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .kmeans import *
from .spectral import *
14 changes: 7 additions & 7 deletions heat/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def fit(self, X):
if self.tol is not None and self._inertia <= self.tol:
break

self._labels = matching_centroids.squeeze()
self._labels = matching_centroids

return self

Expand Down Expand Up @@ -338,7 +338,7 @@ def predict(self, X):
raise ValueError("input needs to be a ht.DNDarray, but was {}".format(type(X)))

# determine the centroids
return self._fit_to_cluster(X.expand_dims(axis=2)).squeeze()
return self._fit_to_cluster(X)

def set_params(self, **params):
"""
Expand All @@ -354,10 +354,10 @@ def set_params(self, **params):
self : ht.ml.KMeans
This estimator instance for chaining.
"""
self.init = params.get(params["init"], self.init)
self.max_iter = params.get(params["max_iter"], self.max_iter)
self.n_clusters = params.get(params["n_clusters"], self.n_clusters)
self.random_state = params.get(params["random_state"], self.random_state)
self.tol = params.get(params["tol"], self.tol)
self.init = params.get("init", self.init)
self.max_iter = params.get("max_iter", self.max_iter)
self.n_clusters = params.get("n_clusters", self.n_clusters)
self.random_state = params.get("random_state", self.random_state)
self.tol = params.get("tol", self.tol)

return self
250 changes: 250 additions & 0 deletions heat/cluster/spectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import torch
import math
import heat as ht


class Spectral:
def __init__(
self,
n_clusters=None,
gamma=1.0,
metric="rbf",
laplacian="fully_connected",
threshold=1.0,
boundary="upper",
n_lanczos=300,
assign_labels="kmeans",
**params
):
"""
Spectral clustering

Parameters
----------
n_clusters : int, optional
gamma : float, default=1.0
Kernel coefficient sigma for 'rbf', ignored for affinity='euclidean'
metric : string
How to construct the similarity matrix.
'rbf' : construct the similarity matrix using a radial basis function (RBF) kernel.
'euclidean' : construct the similarity matrix as only euclidean distance
laplacian : string
How to calculate the graph laplacian (affinity)
Currently supported : 'fully_connected', 'eNeighbour'
theshold : float
Threshold for affinity matrix if laplacian='eNeighbour'
Ignorded for laplacian='fully_connected'
boundary : string
How to interpret threshold: 'upper', 'lower'
Ignorded for laplacian='fully_connected'
n_lanczos : int
number of Lanczos iterations for Eigenvalue decomposition
assign_labels: str, default = 'kmeans'
The strategy to use to assign labels in the embedding space.
'kmeans'
**params: dict
Parameter dictionary for the assign_labels estimator
"""
self.n_clusters = n_clusters
self.n_lanczos = n_lanczos
if metric == "rbf":
sig = math.sqrt(1 / (2 * gamma))
self.laplacian = ht.graph.Laplacian(
lambda x: ht.spatial.rbf(x, sigma=sig, quadratic_expansion=True),
definition="norm_sym",
mode=laplacian,
threshold_key=boundary,
threshold_value=threshold,
)

elif metric == "euclidean":
self.laplacian = ht.graph.Laplacian(
lambda x: ht.spatial.cdist(x, quadratic_expansion=True),
definition="norm_sym",
mode=laplacian,
threshold_key=boundary,
threshold_value=threshold,
)
else:
raise NotImplementedError("Other kernels currently not supported")

if assign_labels == "kmeans":
self._cluster = ht.cluster.KMeans(params)
else:
raise NotImplementedError(
"Other Label Assignment Algorithms are currently not available"
)

# in-place properties
self._labels = None

@property
def labels_(self):
"""
Returns
-------
ht.DNDarray, shape=(n_points):
Labels of each point.
"""
return self._labels

def _spectral_embedding(self, X):
"""
Helper function to embed the dataset X into the eigenvectors of the graph Laplacian matrix
Returns
-------
ht.DNDarray, shape=(m_lanczos):
Eigenvalues of the graph's Laplacian matrix.
ht.DNDarray, shape=(n, m_lanczos):
Eigenvectors of the graph's Laplacian matrix.
"""
L = self.laplacian.construct(X)
# 3. Eigenvalue and -vector calculation via Lanczos Algorithm
v0 = ht.ones((L.shape[0],), dtype=L.dtype, split=0, device=L.device) / math.sqrt(L.shape[0])
V, T = ht.lanczos(L, self.n_lanczos, v0)

# 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T
eval, evec = torch.eig(T._DNDarray__array, eigenvectors=True)
# If x is an Eigenvector of T, then y = V@x is the corresponding Eigenvector of L
eval, idx = torch.sort(eval[:, 0], dim=0)
eigenvalues = ht.array(eval)
eigenvectors = ht.matmul(V, ht.array(evec))[:, idx]

return eigenvalues, eigenvectors

def fit(self, X):
Cdebus marked this conversation as resolved.
Show resolved Hide resolved
"""
Computes the low-dim representation by calculation of eigenspectrum (eigenvalues and eigenvectors) of the graph laplacian from the similarity matrix and fits the eigenvectors that correspond to the k lowest eigenvalues with a seperate clustering algorithm (currently only kemans is supported)
Similarity metrics for adjacency calculations are supported via spatial.distance. The eigenvalues and eigenvectors are computed by reducing the Laplacian via lanczos iterations and using the torch eigenvalue solver on this smaller matrix. If other eigenvalue decompostion methods are supported, this will be expanded.
Parameters
----------
X : ht.DNDarray, shape=(n_samples, n_features)
Training instances to cluster.
"""
# 1. input sanitation
if not isinstance(X, ht.DNDarray):
raise ValueError("input needs to be a ht.DNDarray, but was {}".format(type(X)))
if X.split is not None and X.split != 0:
raise NotImplementedError("Not implemented for other splitting-axes")
# 2. Embed Dataset into lower-dimensional Eigenvector space
eigenvalues, eigenvectors = self._spectral_embedding(X)

# 3. Find the spectral gap, if number of clusters is not defined from the outside
if self.n_clusters is None:
diff = eigenvalues[1:] - eigenvalues[:-1]
tmp = ht.where(diff == diff.max()).item()
self.n_clusters = tmp + 1

components = eigenvectors[:, : self.n_clusters].copy()

params = self._cluster.get_params()
params["n_clusters"] = self.n_clusters
self._cluster.set_params(**params)
self._cluster.fit(components)
self._labels = self._cluster.labels_
self._cluster_centers = self._cluster.cluster_centers_

return self

def predict(self, X):
"""
Predict the closest cluster each sample in X belongs to.

X is transformed to the low-dim representation by calculation of eigenspectrum (eigenvalues and eigenvectors) of the graph laplacian from the similarity matrix.
Inference of lables is done by extraction of the closest centroid of the n_clusters eigenvectors from the previously fitted clustering algorithm (kmeans)
Caution: Calculation of the low-dim representation requires some time!

Parameters
----------
X : ht.DNDarray, shape = [n_samples, n_features]
New data to predict.

Returns
-------
labels : ht.DNDarray, shape = [n_samples,]
Index of the cluster each sample belongs to.
"""
# input sanitation
if not isinstance(X, ht.DNDarray):
raise ValueError("input needs to be a ht.DNDarray, but was {}".format(type(X)))
if X.split is not None and X.split != 0:
raise NotImplementedError("Not implemented for other splitting-axes")

_, eigenvectors = self._spectral_embedding(X)

components = eigenvectors[:, : self.n_clusters].copy()

return self._cluster.predict(components)

def fit_predict(self, X):
"""
Compute cluster centers and predict cluster index for each sample.

This method should be preferred to to calling fit(X) followed by predict(X), since predict(X) requires recomputation of the low-dim eigenspectrum representation of X

Parameters
----------
X : ht.DNDarray, shape = [n_samples, n_features]
Input data to be clustered.

Returns
-------
labels : ht.DNDarray, shape [n_samples,]
Index of the cluster each sample belongs to.
"""
self.fit(X)
return self._labels

def get_params(self, deep=True):
"""
Get parameters for this estimator.

Parameters
----------
deep : boolean, optional
If True, will return the parameters for this estimator and contained sub-objects that are estimators.
Defaults to true.

Returns
-------
params : dict of string to any
Parameter names mapped to their values.
"""
# unused
_ = deep

return {
"n_clusters": self.n_clusters,
"gamma": self.gamma,
"metric": self.metric,
"laplacian": self.laplacian,
"threshold": self.threshold,
"boundary": self.boundary,
"n_lanczos": self.n_lanczos,
"assign_labels": self.assign_labels,
}

def set_params(self, **params):
"""
Set the parameters of this estimator.

Parameters
----------
params : dict
The parameters of the estimator to be modified.

Returns
-------
self : ht.ml.KMeans
This estimator instance for chaining.
"""
self.n_clusters = params.get("n_clusters", self.n_clusters)
self.gamma = params.get("gamma", self.gamma)
self.metric = params.get("metric", self.metric)
self.laplacian = params.get("laplacian", self.laplacian)
self.threshold = params.get("thresholdtol", self.threshold)
self.boundary = params.get("boundary", self.boundary)
self.n_lanczos = params.get("n_lanczos", self.n_lanczos)
self.assign_labels = params.get("assign_labels", self.assign_labels)

return self
63 changes: 63 additions & 0 deletions heat/cluster/tests/test_spectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import unittest

import heat as ht

if os.environ.get("DEVICE") == "gpu" and ht.torch.cuda.is_available():
Cdebus marked this conversation as resolved.
Show resolved Hide resolved
ht.use_device("gpu")
ht.torch.cuda.set_device(ht.torch.device(ht.get_device().torch_device))
else:
ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and ht.torch.cuda.is_available():
device = ht.gpu.torch_device
ht_device = ht.gpu
ht.torch.cuda.set_device(device)


class TestSpectral(unittest.TestCase):
def test_fit_iris(self):
# get some test data
iris = ht.load("heat/datasets/data/iris.csv", sep=";", split=0)
m = 10

# fit the clusters
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m
)
spectral.fit(iris)
self.assertIsInstance(spectral.labels_, ht.DNDarray)

spectral = ht.cluster.Spectral(
metric="euclidean", laplacian="eNeighbour", theshold=0.5, boundary="upper", n_lanczos=m
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

spectral = ht.cluster.Spectral(
gamma=0.1,
metric="rbf",
laplacian="eNeighbour",
theshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1}
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

# Errors
with self.assertRaises(NotImplementedError):
spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m)

iris_split = ht.load("heat/datasets/data/iris.csv", sep=";", split=1)
spectral = ht.cluster.Spectral(n_lanczos=20)
with self.assertRaises(NotImplementedError):
spectral.fit(iris_split)
Loading