Skip to content

Commit

Permalink
Merge pull request #336 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
gviejo authored Sep 26, 2024
2 parents deca7e2 + 53c3221 commit 20f2254
Show file tree
Hide file tree
Showing 12 changed files with 1,260 additions and 850 deletions.
6 changes: 6 additions & 0 deletions docs/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ and [Edoardo Balzani](https://www.simonsfoundation.org/people/edoardo-balzani/)
of the Flatiron institute.


0.7.1 (2024-09-24)
------------------

- Fixing nan issue when computing 1d tuning curve (See issue #334).
- Refactor tuning curves and correlogram tests.
- Adding validators decorators for tuning curves and correlogram modules.

0.7.0 (2024-09-16)
------------------
Expand Down
4 changes: 1 addition & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ PYthon Neural Analysis Package.
pynapple is a light-weight python library for neurophysiological data analysis. The goal is to offer a versatile set of tools to study typical data in the field, i.e. time series (spike times, behavioral events, etc.) and time intervals (trials, brain states, etc.). It also provides users with generic functions for neuroscience such as tuning curves and cross-correlograms.

- Free software: MIT License
- __Documentation__: <https://pynapple-org.github.io/pynapple>
- __Notebooks and tutorials__ : <https://pynapple-org.github.io/pynapple/generated/gallery/>
<!-- - __Collaborative repository__: <https://github.com/PeyracheLab/pynacollada> -->
- __Documentation__: <https://pynapple.org>


> **Note**
Expand Down
2 changes: 1 addition & 1 deletion pynapple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.7.0"
__version__ = "0.7.1"
from .core import (
IntervalSet,
Ts,
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def size(self):
return self.values.size

def __array__(self, dtype=None):
return self.values.astype(dtype)
return np.asarray(self.values, dtype=dtype)

def __array_ufunc__(self, ufunc, method, *args, **kwargs):
# print("In __array_ufunc__")
Expand Down
126 changes: 86 additions & 40 deletions pynapple/process/correlograms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
"""Cross-correlograms """
"""
This module holds the functions to compute discrete cross-correlogram
for timestamps data (i.e. spike times).
| Function | Description |
|------|------|
| `nap.compute_autocorrelogram` | Autocorrelograms from a TsGroup object |
| `nap.compute_crosscorrelogram` | Crosscorrelogram from a TsGroup object |
| `nap.compute_eventcorrelogram` | Crosscorrelogram between a TsGroup object and a Ts object |
"""

import inspect
from functools import wraps
from itertools import combinations, product
from numbers import Number

import numpy as np
import pandas as pd
Expand All @@ -9,9 +22,53 @@
from .. import core as nap


#########################################################
# CORRELATION
#########################################################
def _validate_correlograms_inputs(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Validate each positional argument
sig = inspect.signature(func)
kwargs = sig.bind_partial(*args, **kwargs).arguments

# Only TypeError here
if getattr(func, "__name__") == "compute_crosscorrelogram" and isinstance(
kwargs["group"], (tuple, list)
):
if (
not all([isinstance(g, nap.TsGroup) for g in kwargs["group"]])
or len(kwargs["group"]) != 2
):
raise TypeError(
"Invalid type. Parameter group must be of type TsGroup or a tuple/list of (TsGroup, TsGroup)."
)
else:
if not isinstance(kwargs["group"], nap.TsGroup):
msg = "Invalid type. Parameter group must be of type TsGroup"
if getattr(func, "__name__") == "compute_crosscorrelogram":
msg = msg + " or a tuple/list of (TsGroup, TsGroup)."
raise TypeError(msg)

parameters_type = {
"binsize": Number,
"windowsize": Number,
"ep": nap.IntervalSet,
"norm": bool,
"time_units": str,
"reverse": bool,
"event": (nap.Ts, nap.Tsd),
}
for param, param_type in parameters_type.items():
if param in kwargs:
if not isinstance(kwargs[param], param_type):
raise TypeError(
f"Invalid type. Parameter {param} must be of type {param_type}."
)

# Call the original function with validated inputs
return func(**kwargs)

return wrapper


@jit(nopython=True)
def _cross_correlogram(t1, t2, binsize, windowsize):
"""
Expand Down Expand Up @@ -81,6 +138,7 @@ def _cross_correlogram(t1, t2, binsize, windowsize):
return C, B


@_validate_correlograms_inputs
def compute_autocorrelogram(
group, binsize, windowsize, ep=None, norm=True, time_units="s"
):
Expand Down Expand Up @@ -118,13 +176,10 @@ def compute_autocorrelogram(
RuntimeError
group must be TsGroup
"""
if type(group) is nap.TsGroup:
if isinstance(ep, nap.IntervalSet):
newgroup = group.restrict(ep)
else:
newgroup = group
if isinstance(ep, nap.IntervalSet):
newgroup = group.restrict(ep)
else:
raise RuntimeError("Unknown format for group")
newgroup = group

autocorrs = {}

Expand Down Expand Up @@ -152,6 +207,7 @@ def compute_autocorrelogram(
return autocorrs.astype("float")


@_validate_correlograms_inputs
def compute_crosscorrelogram(
group, binsize, windowsize, ep=None, norm=True, time_units="s", reverse=False
):
Expand Down Expand Up @@ -207,7 +263,24 @@ def compute_crosscorrelogram(
np.array([windowsize], dtype=np.float64), time_units
)[0]

if isinstance(group, nap.TsGroup):
if isinstance(group, tuple):
if isinstance(ep, nap.IntervalSet):
newgroup = [group[i].restrict(ep) for i in range(2)]
else:
newgroup = group

pairs = product(list(newgroup[0].keys()), list(newgroup[1].keys()))

for i, j in pairs:
spk1 = newgroup[0][i].index
spk2 = newgroup[1][j].index
auc, times = _cross_correlogram(spk1, spk2, binsize, windowsize)
if norm:
auc /= newgroup[1][j].rate
crosscorrs[(i, j)] = pd.Series(index=times, data=auc, dtype="float")

crosscorrs = pd.DataFrame.from_dict(crosscorrs)
else:
if isinstance(ep, nap.IntervalSet):
newgroup = group.restrict(ep)
else:
Expand All @@ -232,34 +305,10 @@ def compute_crosscorrelogram(
)
crosscorrs = crosscorrs / freq2

elif (
isinstance(group, (tuple, list))
and len(group) == 2
and all(map(lambda g: isinstance(g, nap.TsGroup), group))
):
if isinstance(ep, nap.IntervalSet):
newgroup = [group[i].restrict(ep) for i in range(2)]
else:
newgroup = group

pairs = product(list(newgroup[0].keys()), list(newgroup[1].keys()))

for i, j in pairs:
spk1 = newgroup[0][i].index
spk2 = newgroup[1][j].index
auc, times = _cross_correlogram(spk1, spk2, binsize, windowsize)
if norm:
auc /= newgroup[1][j].rate
crosscorrs[(i, j)] = pd.Series(index=times, data=auc, dtype="float")

crosscorrs = pd.DataFrame.from_dict(crosscorrs)

else:
raise RuntimeError("Unknown format for group")

return crosscorrs.astype("float")


@_validate_correlograms_inputs
def compute_eventcorrelogram(
group, event, binsize, windowsize, ep=None, norm=True, time_units="s"
):
Expand Down Expand Up @@ -306,10 +355,7 @@ def compute_eventcorrelogram(
else:
tsd1 = event.restrict(ep).index

if type(group) is nap.TsGroup:
newgroup = group.restrict(ep)
else:
raise RuntimeError("Unknown format for group")
newgroup = group.restrict(ep)

crosscorrs = {}

Expand Down
Loading

0 comments on commit 20f2254

Please sign in to comment.