Skip to content

Commit

Permalink
Merge pull request #171 from geometric-intelligence/scikitapi
Browse files Browse the repository at this point in the history
Scikitapi
  • Loading branch information
franciscoeacosta authored Aug 29, 2024
2 parents 1a7e96b + 3da50c6 commit ed70902
Show file tree
Hide file tree
Showing 61 changed files with 30,030 additions and 71 deletions.
6 changes: 3 additions & 3 deletions neurometry/datasets/piRNNs/saliency/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from sklearn.decomposition import PCA

from neurometry.datasets.piRNNs.load_rnn_grid_cells import get_scores, umap_dbscan
from neurometry.geometry.dimension.dim_reduction import (
from neurometry.estimators.dimension.dim_reduction import (
plot_2d_manifold_projections,
plot_pca_projections,
)
from neurometry.geometry.topology.persistent_homology import compute_diagrams_shuffle
from neurometry.geometry.topology.plotting import plot_all_barcodes_with_null
from neurometry.estimators.topology.persistent_homology import compute_diagrams_shuffle
from neurometry.estimators.topology.plotting import plot_all_barcodes_with_null

pretrained_run_id = "20240418-180712"
pretrained_run_dir = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions neurometry/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def synthetic_neural_manifold(
print("WARNING! Poisson spikes not generated: mean must be non-negative")
noisy_points = None

noise_level = gs.sqrt(1 / (ref_frequency * poisson_multiplier))
print(f"noise level: {100*noise_level:.2f}%")
#noise_level = gs.sqrt(1 / (ref_frequency * poisson_multiplier))
#print(f"noise level: {100*noise_level:.2f}%")

return noisy_points, manifold_points

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

from neurometry.geometry.topology.persistent_homology import (
from neurometry.estimators.topology.persistent_homology import (
cohomological_circular_coordinates,
cohomological_toroidal_coordinates,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from scipy.signal import savgol_filter

from neurometry.geometry.curvature.datasets.experimental import load_neural_activity
from neurometry.geometry.curvature.datasets.gridcells import load_grid_cells_synthetic
from neurometry.geometry.curvature.datasets.synthetic import (
from neurometry.estimators.curvature.datasets.experimental import load_neural_activity
from neurometry.estimators.curvature.datasets.gridcells import load_grid_cells_synthetic
from neurometry.estimators.curvature.datasets.synthetic import (
load_images,
load_place_cells,
load_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch

from neurometry.geometry.curvature.datasets.synthetic import (
from neurometry.estimators.curvature.datasets.synthetic import (
get_s1_synthetic_immersion,
get_s2_synthetic_immersion,
get_t2_synthetic_immersion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from torch.distributions.kl import register_kl

from neurometry.geometry.curvature.hyperspherical.distributions.hyperspherical_uniform import (
from neurometry.estimators.curvature.hyperspherical.distributions.hyperspherical_uniform import (
HypersphericalUniform,
)
from neurometry.geometry.curvature.hyperspherical.ops.ive import ive
from neurometry.estimators.curvature.hyperspherical.ops.ive import ive


class VonMisesFisher(torch.distributions.Distribution):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from neurometry.geometry.curvature.hyperspherical.distributions import (
from neurometry.estimators.curvature.hyperspherical.distributions import (
hyperspherical_uniform,
von_mises_fisher,
)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch.nn import functional as F

from neurometry.geometry.curvature.hyperspherical.distributions.von_mises_fisher import (
from neurometry.estimators.curvature.hyperspherical.distributions.von_mises_fisher import (
VonMisesFisher,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions.normal import Normal
from torch.nn import functional as F

from neurometry.geometry.curvature.hyperspherical.distributions.von_mises_fisher import (
from neurometry.estimators.curvature.hyperspherical.distributions.von_mises_fisher import (
VonMisesFisher,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch.nn import functional as F

from neurometry.geometry.curvature.hyperspherical.distributions.von_mises_fisher import (
from neurometry.estimators.curvature.hyperspherical.distributions.von_mises_fisher import (
VonMisesFisher,
)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
def _plot_bars_from_diagrams(ax, diagrams, **kwargs):
birth = diagrams[:, 0]
death = diagrams[:, 1]
inf_value = 3 * np.max(death[death != np.inf]) if np.isfinite(death).any() else 1000
death[death == np.inf] = inf_value
lifespan = death - birth
indices = np.argsort(-lifespan)[:20]

birth = birth[indices]
death = death[indices]
finite_bars = death[death != np.inf]
inf_end = 2 * max(finite_bars) if len(finite_bars) > 0 else 2
death[death == np.inf] = inf_end

offset = kwargs.get("bar_offset", 1.0)
linewidth = kwargs.get("linewidth", 5)
Expand Down
171 changes: 171 additions & 0 deletions neurometry/estimators/topology/topology_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import os

import numpy as np
from dreimac import CircularCoords, ToroidalCoords
from gtda.diagrams import PersistenceEntropy
from gtda.homology import VietorisRipsPersistence, WeightedRipsPersistence
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

import neurometry.datasets.synthetic as synthetic

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs


class TopologicalClassifier(ClassifierMixin, BaseEstimator):
def __init__(
self,
num_samples,
poisson_multiplier,
homology_dimensions=(0, 1, 2),
reduce_dim=False,
):
self.num_samples = num_samples
self.poisson_multiplier = poisson_multiplier
self.homology_dimensions = homology_dimensions
self.reduce_dim = reduce_dim
self.classifier = RandomForestClassifier()

def _generate_ref_data(self, input_data):
num_points = input_data.shape[0]
encoding_dim = input_data.shape[1]
circle_task_points = synthetic.hypersphere(1, num_points)
circle_point_clouds = []
for _ in range(self.num_samples):
circle_noisy_points, _ = synthetic.synthetic_neural_manifold(
points=circle_task_points,
encoding_dim=encoding_dim,
nonlinearity="sigmoid",
scales=gs.ones(encoding_dim),
poisson_multiplier=self.poisson_multiplier,
)
circle_point_clouds.append(circle_noisy_points)

sphere_task_points = synthetic.hypersphere(2, num_points)
sphere_point_clouds = []
for _ in range(self.num_samples):
sphere_noisy_points, _ = synthetic.synthetic_neural_manifold(
points=sphere_task_points,
encoding_dim=encoding_dim,
nonlinearity="sigmoid",
scales=gs.ones(encoding_dim),
poisson_multiplier=self.poisson_multiplier,
)
sphere_point_clouds.append(sphere_noisy_points)

torus_task_points = synthetic.hypertorus(2, num_points)
torus_point_clouds = []
for _ in range(self.num_samples):
torus_noisy_points, _ = synthetic.synthetic_neural_manifold(
points=torus_task_points,
encoding_dim=encoding_dim,
nonlinearity="sigmoid",
scales=gs.ones(encoding_dim),
poisson_multiplier=self.poisson_multiplier,
)
torus_point_clouds.append(torus_noisy_points)

circle_labels = np.zeros(self.num_samples)
sphere_labels = np.ones(self.num_samples)
torus_labels = 2 * np.ones(self.num_samples)
ref_labels = np.concatenate(
[
circle_labels,
sphere_labels,
torus_labels,
]
)

ref_point_clouds = [
*circle_point_clouds,
*sphere_point_clouds,
*torus_point_clouds,
]

return ref_point_clouds, ref_labels

def _compute_topo_features(self, diagrams):
PE = PersistenceEntropy()
return PE.fit_transform(diagrams)

def fit(self, X, y=None):
ref_point_clouds, ref_labels = self._generate_ref_data(X)
if self.reduce_dim:
pca = PCA(n_components=10)
ref_point_clouds = [
pca.fit_transform(point_cloud) for point_cloud in ref_point_clouds
]
ref_diagrams = compute_persistence_diagrams(
ref_point_clouds, homology_dimensions=self.homology_dimensions
)
ref_features = self._compute_topo_features(ref_diagrams)
X_ref_train, X_ref_valid, y_ref_train, y_ref_valid = train_test_split(
ref_features, ref_labels
)
self.classifier.fit(X_ref_train, y_ref_train)
print(f"Classifier score: {self.classifier.score(X_ref_valid, y_ref_valid)}")
return self

def predict(self, X):
if self.reduce_dim:
pca = PCA(n_components=10)
X = pca.fit_transform(X)
diagram = compute_persistence_diagrams(
[X], homology_dimensions=self.homology_dimensions
)
features = self._compute_topo_features(diagram)
return self.classifier.predict(features)





def compute_persistence_diagrams(
representations,
homology_dimensions=(0, 1, 2),
coeff=2,
metric="euclidean",
weighted=False,
n_jobs=-1
):
if weighted:
WR = WeightedRipsPersistence(
metric=metric, homology_dimensions=homology_dimensions, coeff=coeff,
)
diagrams = WR.fit_transform(representations)
else:
VR = VietorisRipsPersistence(
metric=metric, homology_dimensions=homology_dimensions, coeff=coeff, reduced_homology=False, n_jobs=n_jobs)
diagrams = VR.fit_transform(representations)
return diagrams


def _shuffle_entries(data, rng):
return np.array([rng.permutation(row) for row in data])


def compute_diagrams_shuffle(X, num_shuffles, seed=0, homology_dimensions=(0, 1)):
rng = np.random.default_rng(seed)
shuffled_Xs = [_shuffle_entries(X, rng) for _ in range(num_shuffles)]
return compute_persistence_diagrams(
[X, *shuffled_Xs], homology_dimensions=homology_dimensions
)

def cohomological_toroidal_coordinates(data):
n_landmarks = data.shape[0]
tc = ToroidalCoords(data, n_landmarks=n_landmarks)
cohomology_classes = [0,1]
toroidal_coords = tc.get_coordinates(cocycle_idxs=cohomology_classes,standard_range=False)
return toroidal_coords.T


def cohomological_circular_coordinates(data):
n_landmarks = data.shape[0]
cc = CircularCoords(data, n_landmarks=n_landmarks)
circular_coords = cc.get_coordinates(standard_range=False)
return circular_coords.T


File renamed without changes.
File renamed without changes.
51 changes: 0 additions & 51 deletions neurometry/geometry/topology/persistent_homology.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_curvature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

from neurometry.geometry.curvature.losses import latent_regularization_loss
from neurometry.estimators.curvature.losses import latent_regularization_loss


class AttrDict(dict):
Expand Down
Loading

0 comments on commit ed70902

Please sign in to comment.