Skip to content

Commit

Permalink
Merge pull request #59 from aopy/master
Browse files Browse the repository at this point in the history
Feature specific Information Transfer (FIT - function + tests + doc + example)
  • Loading branch information
EtienneCmb authored Jun 22, 2023
2 parents 20acf95 + eecc7f0 commit b0c7e40
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api/api_connectivity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Connectivity metrics
conn_te
conn_ii
conn_pid
conn_fit

Utility functions
+++++++++++++++++
Expand Down
1 change: 1 addition & 0 deletions docs/source/overview/ovw_whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ New Features
* New function :func:`frites.core.ent_nd_g` to compute entropy on tensors (:commit:`17587a15`)
* New function :func:`frites.conn.conn_ii` to estimate the interaction information (:commit:`10938b46`)
* New function :func:`frites.conn.conn_pid` to estimate the partial information decomposition (:commit:`ac9798dd`)
- New function :func:`frites.conn.conn_fit` to estimate the feature specific information transfer (:PR:`59`) - :author:`aopy`

Bug fixes
+++++++++
Expand Down
56 changes: 56 additions & 0 deletions examples/conn/plot_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
FIT: Feature specific information transfer
==========================================
This example illustrates how to discover a feature (e.g. sensory stimuli) in
information transfer between brain regions.
"""

import numpy as np
import xarray as xr

from frites.simulations import StimSpecAR
from frites.conn import conn_fit

from frites import set_mpl_style

import matplotlib.pyplot as plt
set_mpl_style()

###############################################################################
# Data simulation
# ---------------
#
# Here, we use an auto-regressive simulating a gamma increase.

net = False
avg_delay = False
ar_type = 'hga'
n_stim = 3
n_epochs = 400

ss = StimSpecAR()
ar = ss.fit(ar_type=ar_type, n_epochs=n_epochs, n_stim=n_stim, random_state=0)

print(ar)

plt.figure(figsize=(7, 8))
ss.plot(cmap='bwr')
plt.tight_layout()
plt.show()

###############################################################################
# Compute Feature specific information transfer
# -----------------------------------------
#
# Now we can use the simulated data to estimate the FIT.

# Compute the FIT

fit = conn_fit(ar, y='trials', roi='roi', times='times', mi_type='cd',
max_delay=.3, net=net, verbose=False, avg_delay=avg_delay)

# Plot the results
fit.plot(x='times', col='roi') # net = False

plt.show()
1 change: 1 addition & 0 deletions frites/conn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
from .conn_te import conn_te
from .conn_ii import conn_ii # noqa
from .conn_pid import conn_pid # noqa
from .conn_fit import conn_fit
257 changes: 257 additions & 0 deletions frites/conn/conn_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""Feature specific information transfer (Numba compliant)."""
import numpy as np
import xarray as xr

from frites.conn import conn_io, _conn_mi
from frites.core import mi_nd_gg, copnorm_nd
from frites.io import logger, check_attrs
from frites.config import CONFIG


def conn_fit(data, y, roi=None, times=None, mi_type='cc', gcrn=True,
max_delay=.3, avg_delay=False, net=False, sfreq=None,
verbose=None, **kw_links):
"""Feature-specific information transfer.
Parameters
----------
data : array_like
Electrophysiological data. Several input types are supported :
* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)
y : array_like
The feature of shape (n_trials,). This feature vector can either be
categorical and in that case, the mutual information type has to 'cd'
or y can also be a continuous regressor and in that case the mutual
information type has to be 'cc'
roi : array_like | None
Array of region of interest name of shape (n_roi,)
times : array_like | None
Array of time points of shape (n_times,)
mi_type : {'cc', 'cd'}
Mutual information type. Switch between :
* 'cc' : if the y input is a continuous regressor
* 'cd' : if the y input is a discret vector with categorical
integers inside
gcrn : bool | True
Specify if the Gaussian Copula Rank Normalization should be applied.
Default is True.
max_delay : float | .3
Maximum delay for past conditioning
avg_delay : bool | False
If False (default) the returned FIT is aggregated across delays. If
True, the returned FIT is going to contained the additional dimension
corresponding to the number of delays used.
sfreq : float | None
The sampling frequency.
kw_links : dict | {}
Additional arguments for selecting links to compute are passed to the
function :func:`frites.conn.conn_links`
Returns
-------
fit : array_like
The feature specific information transfer of shape (n_pairs, n_times)
if avg_delay is False or (n_pairs, n_delays, n_times) if avg_delay is
True.
See also
--------
conn_links
"""
# _________________________________ INPUTS ________________________________
# inputs conversion
kw_links.update({'directed': True, 'net': False})
data, cfg = conn_io(
data, y=y, times=times, roi=roi, agg_ch=False, win_sample=None,
name='FIT', verbose=verbose, sfreq=sfreq, kw_links=kw_links
)

# extract variables
x, attrs = data.data, cfg['attrs']
y, roi, times = data['y'].data, data['roi'].data, data['times'].data

# indices for the souces and targets
i_s, i_t = cfg['x_s'], cfg['x_t']
roi_p, n_pairs = cfg['roi_p'], len(i_s)

# delay conversion
if isinstance(max_delay, int):
max_delay = max_delay / cfg['sfreq']
n_delays = int(np.round(max_delay * cfg['sfreq']))

# build the indices when using multi-variate mi
n_roi, n_times = len(roi), len(times)

logger.info(f"Compute FIT on {n_pairs} connectivity pairs "
f"(max_delay={max_delay})")
# gcrn
if gcrn:
logger.info(" Apply the Gaussian Copula Rank Normalization")
x = copnorm_nd(x, axis=0)
if mi_type == 'cc':
y = copnorm_nd(y, axis=0)

# transpose the data to be (n_roi, n_times, 1, n_trials)
x = np.transpose(x, (1, 2, 0))[..., np.newaxis, :]

# __________________________ MUTUAL INFORMATION ___________________________

# compute mi between each node x (brain data) and y (task-related var)
mi_xy = np.zeros((n_roi, n_times), dtype=float)
for n_r in range(n_roi):
mi_xy[n_r, :] = _conn_mi(x[n_r, :, :], y, mi_type)
mi_xy_s = mi_xy[i_s, :]
mi_xy_t = mi_xy[i_t, :]

# compute mi between past and present of sources and targets
cfg_mi = CONFIG["KW_GCMI"]
mi_x_sptf = np.zeros((n_pairs, n_delays, n_times), dtype=float)
mi_x_tptf = np.zeros((n_pairs, n_delays, n_times), dtype=float)

for n_d in range(n_delays):
# define indices
idx_past = slice(n_d, n_d + n_times - n_delays - 1)
idx_pres = slice(n_delays + 1, n_times)

for n_l in range(n_pairs):
# source past; target past; target present
_sp = x[i_s[n_l], idx_past, :]
_tp = x[i_t[n_l], idx_past, :]
_tf = x[i_t[n_l], idx_pres, :]

# I(source_{past}; target_{pres})
mi_x_sptf[n_l, n_d, idx_pres] = mi_nd_gg(
_sp, _tf, **cfg_mi
).squeeze()

# I(target_{past}; target_{pres})
mi_x_tptf[n_l, n_d, idx_pres] = mi_nd_gg(
_tp, _tf, **cfg_mi
).squeeze()

# __________________________________ FIT __________________________________
# time indices for target roi
t_start = list(range(n_delays, n_times))

# Compute FIT on original MI values
if avg_delay:
fit_sh = (n_pairs, n_delays, n_times - n_delays)
else:
fit_sh = (n_pairs, n_times - n_delays)
fit = np.zeros(fit_sh, dtype=np.float32)

# I(target_pres; cue)
mi_xy_t_pres = mi_xy_t[..., t_start]

# I(source_past; target_pres)
mi_x_st_pres = mi_x_sptf[..., t_start]

# I(target_past; target_pres) = mi_x_t
mi_x_t_pres = mi_x_tptf[..., t_start]

# Loop over delays for past of target and sources
for n_d in range(n_delays):

# Delay indices
delays = list(range(n_d, n_times - n_delays + n_d))

# PID with cue as target var
# I(target_{past}; cue)
mi_xy_t_past = mi_xy_t[..., delays]
# I(source_{past}; cue)
mi_xy_s_past = mi_xy_s[..., delays]

# redundancy between sources and target about S (MMI-based)
red_s_t = np.minimum(mi_xy_s_past, mi_xy_t_pres)
# redundancy between sources, target present and target past about S
red_all = np.minimum(red_s_t, mi_xy_t_past)
# first term of FIT with the cue as target var
fit_cue = red_s_t - red_all

# PID with target pres as target var
# redundancy between sources and target about target pres (MMI-based)
red_x_t = np.minimum(mi_xy_t_pres, mi_x_st_pres[:, n_d, :])
# redundancy between sources, target present and target past about S
red_all = np.minimum(red_x_t, mi_x_t_pres[:, n_d, :])
# second term of FIT with x pres as target var
fit_t_pres = red_x_t - red_all

if avg_delay:
fit[:, n_d, :] = np.minimum(fit_cue, fit_t_pres)
else:
fit += np.minimum(fit_cue, fit_t_pres)

# ________________________________ OUTPUTS ________________________________
# rebuild time vector
times = times[n_delays:]
delay = np.arange(1, n_delays + 1)[::-1] / cfg['sfreq']

# net transfer
if net:
roi_po = roi_p.copy()
done, roi_p, i_st = [], [], []
for n_s, (s, t) in enumerate(zip(i_s, i_t)):
# ignore if indices have already been stored
if ([s, t] in done) or ([t, s] in done): continue # noqa
# find [source, target] and [target, source]
result = np.where((i_s == t) & (i_t == s))[0]
assert len(result) == 1
n_t = result[0]
# decide if (x->y - y->x) or (y->x - x->y)
r_s, r_t = roi_po[n_s], roi_po[n_t]
if not np.all(np.array([r_s, r_t]) == np.sort([r_s, r_t])):
n_s, n_t = n_t, n_s
r_s, r_t = roi_po[n_s].split('->')[0], roi_po[n_t].split('->')[0]
# store results
i_st.append([n_s, n_t])
roi_p.append(f"{r_s}-{r_t}")
done.append([s, t])
# computes net transfer
i_s, i_t = np.array(i_st).T
fit = fit[i_s, :] - fit[i_t, :]

# xarray conversion
if avg_delay:
dims, coords = ('roi', 'delays', 'times'), (roi_p, delay, times)
else:
dims, coords = ('roi', 'times'), (roi_p, times)
attrs['max_delay'] = max_delay
attrs['mi_type'] = mi_type
attrs['net'] = net
attrs['avg_delay'] = avg_delay
attrs['unit'] = 'Bits'
fit = xr.DataArray(
fit, name='FIT', dims=dims, coords=coords, attrs=check_attrs(attrs)
)

return fit


if __name__ == '__main__':
import matplotlib.pyplot as plt

net = False
avg_delay = False

from frites.simulations import StimSpecAR

ar_type = 'hga'
n_stim = 3
n_epochs = 400

ss = StimSpecAR()
x = ss.fit(ar_type=ar_type, n_epochs=n_epochs, n_stim=n_stim,
random_state=0)

fit = conn_fit(x, y='trials', roi='roi', times='times', mi_type='cd',
max_delay=.3, net=net, verbose=False, avg_delay=avg_delay)
if net:
fit.plot(x='times')
else:
fit.plot(x='times', col='roi')

plt.show()
29 changes: 27 additions & 2 deletions frites/conn/tests/test_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from frites.simulations import StimSpecAR

from frites.conn import (conn_covgc, conn_te, conn_dfc, conn_ccf, conn_ii,
conn_pid)
conn_pid, conn_fit)


class TestConn(object):
Expand Down Expand Up @@ -155,6 +155,31 @@ def test_conn_pid(self):
verbose=False, gcrn=True
)

def test_conn_fit(self):
"""Test function conn_fit."""
ar_type = 'hga'
n_stim = 2
n_epochs = 100
ss = StimSpecAR()
x = ss.fit(ar_type=ar_type, n_epochs=n_epochs, n_stim=n_stim)

for m in ['cc', 'cd']: # mi_type
for n in [True, False]: # net
for ad in [True, False]: # avg_delay
fit = conn_fit(x, y='trials', roi='roi', times='times',
mi_type=m, max_delay=.3, net=n,
verbose=False, avg_delay=ad)
print(fit.shape)
if n:
assert len(fit['roi']) == 1
else:
assert len(fit['roi']) == 2

if ad:
assert fit.ndim == 3
else:
assert fit.ndim == 2


if __name__ == '__main__':
TestConn().test_conn_pid()
TestConn().test_conn_fit()

0 comments on commit b0c7e40

Please sign in to comment.