Skip to content

Commit

Permalink
Merge pull request ME-ICA#1 from awstanton/basic.simpleoptcom
Browse files Browse the repository at this point in the history
Basic.simpleoptcom
  • Loading branch information
awstanton authored Nov 21, 2022
2 parents 3c152a4 + c4020ac commit 6957ee5
Show file tree
Hide file tree
Showing 13 changed files with 376 additions and 338 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
.gitignore
tedana/resources/gendata/*.gz
Notebooks/
*.nii.gz
*.tsv
*.txt
*.json
figures/

.DS_Store
docs/generated/
Expand Down
131 changes: 68 additions & 63 deletions tedana/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions to optimally combine data across echoes.
"""
import logging
import copy

import numpy as np

Expand Down Expand Up @@ -231,7 +232,8 @@ def make_optcom(data, tes, adaptive_mask, t2s=None, combmode="t2s", verbose=True
return combined


def make_optcom_sage(data, tes, t2star_map, s0_I_map, t2_map, s0_II_map):
def make_optcom_sage(data, tes, t2star_map, s0_I_map, t2_map, s0_II_map, mask):

if data.ndim != 3:
raise ValueError("Input data must be 3D (S x E x T)")

Expand All @@ -242,95 +244,98 @@ def make_optcom_sage(data, tes, t2star_map, s0_I_map, t2_map, s0_II_map):
"{1}".format(len(tes), data.shape[1])
)

# t2star_map = t2star_map[..., np.newaxis] # add singleton
# t2_map = t2_map[..., np.newaxis] # add singleton
data = data[mask]

alpha_t2star_I, alpha_t2_I = weights_sage_I(tes, t2star_map, s0_I_map)
alpha_t2star_II, alpha_t2_II = weights_sage_II(tes, t2star_map, t2_map, s0_II_map)

alpha_t2star_I = np.expand_dims(alpha_t2star_I, axis=2)
alpha_t2star_II = np.expand_dims(alpha_t2star_II, axis=2)
alpha_t2_I = np.expand_dims(alpha_t2_I, axis=2)
alpha_t2_II = np.expand_dims(alpha_t2_II, axis=2)

idx_I = tes < (tes[-1] / 2)
idx_II = tes >= (tes[-1] / 2)

com1 = copy.deepcopy(data)
com2 = copy.deepcopy(data)

com1[:, idx_I, :] = (alpha_t2star_I / 2) * data[:, idx_I, :]
com1[:, idx_II, :] = (alpha_t2star_II / 2) * data[:, idx_II, :]

combined_t2star_I = combine_sage_I(
data,
tes,
t2star_map,
s0_I_map,
report=True,
)
com2[:, idx_I, :] = (alpha_t2_I / 2) * data[:, idx_I, :]
com2[:, idx_II, :] = (alpha_t2_II / 2) * data[:, idx_II, :]

combined_t2star_II, combined_t2_II = combine_sage_II(
data,
tes,
t2star_map,
t2_map,
s0_II_map,
report=True,
)
optcom_t2star = np.sum(com1, axis=1)
optcom_t2 = np.sum(com2, axis=1)

return combined_t2star_I, combined_t2star_II, combined_t2_II
return optcom_t2star, optcom_t2


def combine_sage_I(data, echo_times, t2star_map, s0_I_map, report=True):
echo_times = np.expand_dims(echo_times, axis=0)
idx_I = echo_times < echo_times[0, -1]
def weights_sage_I(tes, t2star_map, s0_I_map):
tese = tes[-1]
idx_I = tes < tese / 2
tes = np.expand_dims(tes, axis=0)
s0_I_map = np.expand_dims(s0_I_map, axis=1)
t2star_map = np.expand_dims(t2star_map, axis=1)

alpha_I = (s0_I_map * (-1 * echo_times[idx_I])) * np.exp(
(-1) * echo_times[idx_I] * (1 / t2star_map)
)
alpha_I = (s0_I_map * (-1 * tes[0, idx_I])) * np.exp((1 / t2star_map) * (-1 * tes[0, idx_I]))

# If all values across echos are 0, set to 1 to avoid
# divide-by-zero errors
ax0_idx = np.where(np.all(alpha_I == 0, axis=1))
alpha_I[ax0_idx, :] = 1
alpha_I[np.where(np.all(alpha_I == 0, axis=1)), :] = 1

# normalize
alpha_I = alpha_I / np.nansum(alpha_I)
alpha_I = alpha_I / np.expand_dims(np.sum(alpha_I, axis=1), axis=1)

combined_t2star_I = np.zeros((data.shape[0], data.shape[2]))
# combined_t2star_I = np.zeros((data.shape[0], data.shape[2]))

for samp_idx in range(data.shape[0]):
combined_t2star_I[samp_idx, :] = np.average(
data[samp_idx, idx_I[0, :], :], axis=0, weights=alpha_I[samp_idx, :]
)
# for samp_idx in range(data.shape[0]):
# combined_t2star_I[samp_idx, :] = np.average(
# data[samp_idx, idx_I[0, :], :], axis=0, weights=alpha_I[samp_idx, :]
# )

# derivative with respect to t2 is 0, so corresponding weight is always all zeros
# # derivative with respect to t2 is 0, so corresponding weight is always all zeros
# combined_t2_I = np.tile([0], combined_t2star_I.shape)

return combined_t2star_I
return alpha_I, np.zeros(alpha_I.shape)


def combine_sage_II(data, echo_times, t2star_map, t2_map, s0_II_map, report=True):
mid_echo_time = echo_times[-1] / 2
echo_times = np.expand_dims(echo_times, axis=0)
def weights_sage_II(tes, t2star_map, t2_map, s0_II_map):
tese = tes[-1]
idx_II = tes >= (tese / 2)

idx_II = echo_times > mid_echo_time
tes = np.expand_dims(tes, axis=0)
s0_II_map = np.expand_dims(s0_II_map, axis=1)
t2_map = np.expand_dims(t2_map, axis=1)
t2star_map = np.expand_dims(t2star_map, axis=1)

alpha_t2star_II = (s0_II_map * ((-1 * echo_times[0, -1]) + echo_times[idx_II])) * np.exp(
((-1) * echo_times[0, -1] * ((1 / t2star_map) - (1 / t2_map)))
- ((echo_times[idx_II]) * ((2 * (1 / t2_map)) - (1 / t2star_map)))
)
alpha_t2_II = (s0_II_map * ((echo_times[0, -1] - (2 * echo_times[idx_II])))) * np.exp(
((-1 * echo_times[0, -1]) * ((1 / t2star_map) - (1 / t2_map)))
- (echo_times[idx_II] * ((2 * (1 / t2_map)) - (1 / t2star_map)))
)
const1 = s0_II_map * ((-1 * tese) + tes[0, idx_II])
const2 = s0_II_map * ((tese - (2 * tes[0, idx_II])))
exp1 = ((1 / t2star_map) - (1 / t2_map)) * (-1 * tese)
exp2 = ((2 * (1 / t2_map)) - (1 / t2star_map)) * (tes[0, idx_II])

alpha_t2star_II = const1 * np.exp(exp1 - exp2)
alpha_t2_II = const2 * np.exp(exp1 - exp2)

ax0_idx = np.where(np.all(alpha_t2star_II == 0, axis=1))
alpha_t2star_II[ax0_idx, :] = 1
ax0_idx = np.where(np.all(alpha_t2_II == 0, axis=1))
alpha_t2_II[ax0_idx, :] = 1
# If all values across echos are 0, set to 1 to avoid
# divide-by-zero errors
alpha_t2star_II[np.where(np.all(alpha_t2star_II == 0, axis=1)), :] = 1
alpha_t2_II[np.where(np.all(alpha_t2_II == 0, axis=1)), :] = 1

# normalize
alpha_t2star_II = alpha_t2star_II / np.sum(alpha_t2star_II)
alpha_t2_II = alpha_t2_II / np.sum(alpha_t2_II)
alpha_t2star_II = alpha_t2star_II / np.expand_dims(np.sum(alpha_t2star_II, axis=1), axis=1)
alpha_t2_II = alpha_t2_II / np.expand_dims(np.sum(alpha_t2_II, axis=1), axis=1)

combined_t2star_II = np.zeros((data.shape[0], data.shape[2]))
combined_t2_II = np.zeros((data.shape[0], data.shape[2]))
# combined_t2star_II = np.zeros((data.shape[0], data.shape[2]))
# combined_t2_II = np.zeros((data.shape[0], data.shape[2]))

for samp_idx in range(data.shape[0]):
combined_t2star_II[samp_idx, :] = np.average(
data[samp_idx, idx_II[0, :], :], axis=0, weights=alpha_t2star_II[samp_idx, :]
)
combined_t2_II[samp_idx, :] = np.average(
data[samp_idx, idx_II[0, :], :], axis=0, weights=alpha_t2_II[samp_idx, :]
)
# for samp_idx in range(data.shape[0]):
# combined_t2star_II[samp_idx, :] = np.average(
# data[samp_idx, idx_II[0, :], :], axis=0, weights=alpha_t2star_II[samp_idx, :]
# )
# combined_t2_II[samp_idx, :] = np.average(
# data[samp_idx, idx_II[0, :], :], axis=0, weights=alpha_t2_II[samp_idx, :]
# )

return combined_t2star_II, combined_t2_II
return alpha_t2star_II, alpha_t2_II
96 changes: 48 additions & 48 deletions tedana/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,61 +477,59 @@ def fit_decay_ts(data, tes, mask, adaptive_mask, fittype):
######################################################################################


def fit_decay_sage(data, tes, fittype, report=True):
def fit_decay_sage(data, tes, mask, fittype, report=True):

if data.shape[1] != len(tes):
raise ValueError(
"Second dimension of data ({0}) does not match number "
"of echoes provided (tes; {1})".format(data.shape[1], len(tes))
)

if len(tes) != 5:
raise ValueError("Five echos are required for computing SAGE T2*, T2, and S0 maps")

data = data.copy()
if data.ndim == 2:
data = data[:, :, None]

if fittype == "loglin":
t2star_map, s0_I_map, t2_map, s0_II_map = fit_loglinear_sage(data, tes, report=report)
t2star_map, s0_I_map, t2_map, s0_II_map = fit_loglinear_sage(
data, tes, mask, report=report
)
elif fittype == "curvefit":
t2star_map, s0_I_map, t2_map, s0_II_map = fit_monoexponential_sage(
data, tes, report=report
data, tes, mask, report=report
)
else:
raise ValueError("Unknown fittype option: {}".format(fittype))

# t2s_limited[np.isinf(t2s_limited)] = 500.0 # why 500?
# # let's get rid of negative values, but keep zeros where limited != full
# t2s_limited[(adaptive_mask_masked > 1) & (t2s_limited <= 0)] = 1.0
# t2s_limited = _apply_t2s_floor(t2s_limited, tes)
# s0_limited[np.isnan(s0_limited)] = 0.0 # why 0?
# t2s_full[np.isinf(t2s_full)] = 500.0 # why 500?
# t2s_full[t2s_full <= 0] = 1.0 # let's get rid of negative values!
# t2s_full = _apply_t2s_floor(t2s_full, tes)
# s0_full[np.isnan(s0_full)] = 0.0 # why 0?

# t2s_limited = utils.unmask(t2s_limited, mask)
# s0_limited = utils.unmask(s0_limited, mask)
# t2s_full = utils.unmask(t2s_full, mask)
# s0_full = utils.unmask(s0_full, mask)
t2star_map[np.isinf(t2star_map)] = 500.0 # why 500?
t2_map[np.isinf(t2_map)] = 500.0 # why 500?
t2star_map = _apply_t2s_floor(t2star_map, tes)
t2_map = _apply_t2s_floor(t2_map, tes)
s0_I_map[np.isnan(s0_I_map)] = 0.0 # why 0?
s0_II_map[np.isnan(s0_II_map)] = 0.0 # why 0?

# set a hard cap for the T2* map
# set a hard cap for the T2* and T2 maps
# anything that is 10x higher than the 99.5 %ile will be reset to 99.5 %ile
# cap_t2s = stats.scoreatpercentile(t2s_limited.flatten(), 99.5, interpolation_method="lower")
# LGR.debug("Setting cap on T2* map at {:.5f}".format(cap_t2s * 10))
# t2s_limited[t2s_limited > cap_t2s * 10] = cap_t2s
cap_t2star = stats.scoreatpercentile(t2star_map.flatten(), 99.5, interpolation_method="lower")
t2star_map[t2star_map > cap_t2star * 10] = cap_t2star
cap_t2 = stats.scoreatpercentile(t2_map.flatten(), 99.5, interpolation_method="lower")
t2_map[t2_map > cap_t2 * 10] = cap_t2

return t2star_map, s0_I_map, t2_map, s0_II_map


def fit_monoexponential_sage(data_cat, echo_times, report=True):
def fit_monoexponential_sage(data_cat, echo_times, mask, report=True):
n_samp, _, n_vols = data_cat.shape
tese = echo_times[-1]

t2star_map, s0_I_map, t2_map, s0_II_map = fit_loglinear_sage(
data_cat, echo_times, report=False
data_cat, echo_times, mask, report=False
)

mid_echo_time = echo_times[-1] / 2
echo_times_idx_t2star = echo_times[echo_times > mid_echo_time]
echo_times_idx_t2 = echo_times[echo_times < echo_times[-1]]
echo_times_idx_t2star = echo_times[echo_times > tese / 2]
echo_times_idx_t2 = echo_times[echo_times < tese]

data_2d_t2star = data_cat[:, echo_times_idx_t2star, :].reshape(n_samp, -1).T
data_2d_t2 = data_cat[:, echo_times_idx_t2, :].reshape(n_samp, -1).T
Expand Down Expand Up @@ -584,49 +582,51 @@ def fit_monoexponential_sage(data_cat, echo_times, report=True):
return t2star_map, s0_I_map, t2_map, s0_II_map


def fit_loglinear_sage(data_cat, echo_times, report=True):
def fit_loglinear_sage(data_cat, echo_times, mask, report=True):
# exclude samples that have no nonzero values
data_cat = data_cat[mask, :, :]
n_samp, _, n_vols = data_cat.shape
tese = echo_times[-1]

mid_echo_time = echo_times[-1] / 2
te_idx_t2star = echo_times < echo_times[-1]
te_idx_t2 = echo_times > mid_echo_time
te_idx_I = echo_times < tese
te_idx_II = echo_times > tese / 2

# 1) Find T2* and S0_I values for all voxels
# data_2d: ((E x T) x S)
data_2d = data_cat[:, te_idx_t2star, :].reshape(n_samp, -1).T
log_data = np.log(np.abs(data_2d) + 1)
# 1) Find T2* and S0_I values across te and volume for each voxel independently
Y = data_cat[:, te_idx_I, :].reshape(n_samp, -1).T
Y = np.log(np.abs(Y) + 1) # why take absolute value and why add 1?

x = np.column_stack([np.ones(np.sum(te_idx_t2star)), -1 * echo_times[te_idx_t2star]])
x = np.column_stack([np.ones(np.sum(te_idx_I)), -1 * echo_times[te_idx_I]])
X = np.repeat(x, n_vols, axis=0)

# Log-linear fit
betas = np.linalg.lstsq(X, log_data, rcond=None)[0]
betas = np.linalg.lstsq(X, Y, rcond=None)[0]
t2star_map = 1 / betas[1, :].T
s0_I_map = np.exp(betas[0, :]).T

# 2) Find T2 values and S0_II values for all voxels using T2* values
data_2d = data_cat[:, te_idx_t2, :].reshape(n_samp, -1).T
# 2) Find T2 values and S0_II values for all voxels using computed T2* values
Y = data_cat[:, te_idx_II, :].reshape(n_samp, -1).T
constant = np.repeat(
(np.expand_dims(t2star_map, axis=1) * ((2 * echo_times[te_idx_t2]) - echo_times[-1])).T,
(
np.expand_dims(1 / t2star_map, axis=1)
* ((2 * np.expand_dims(echo_times[te_idx_II], axis=0)) - tese)
),
n_vols,
axis=0,
)
log_data = np.log(np.abs(data_2d) + 1) - constant
axis=1,
).T
Y = np.log(np.abs(Y) + 1) - constant

x = np.column_stack(
[np.ones(np.sum(te_idx_t2)), (-2 * echo_times[te_idx_t2]) + echo_times[-1]]
)
x = np.column_stack([np.ones(np.sum(te_idx_II)), (-2 * echo_times[te_idx_II]) + tese])
X = np.repeat(x, n_vols, axis=0)

betas = np.linalg.lstsq(X, log_data, rcond=None)[0]
betas = np.linalg.lstsq(X, Y, rcond=None)[0]

t2_map = 1 / betas[1, :].T
s0_II_map = np.exp(betas[0, :]).T

return t2star_map, s0_I_map, t2_map, s0_II_map


def fit_decay_ts_sage(data, tes, fittype):
def fit_decay_ts_sage(data, tes, mask, fittype):
n_samples, _, n_vols = data.shape
tes = np.array(tes)

Expand All @@ -638,7 +638,7 @@ def fit_decay_ts_sage(data, tes, fittype):
report = True
for vol in range(n_vols):
t2star_map, s0_I_map, t2_map, s0_II_map = fit_decay_sage(
data[:, :, vol][:, :, None], tes, fittype, report=report
data[:, :, vol][:, :, None], tes, fittype, mask, report=report
)
t2star_map_vols[:, vol] = t2star_map
s0_I_map_vols[:, vol] = s0_I_map
Expand Down
20 changes: 20 additions & 0 deletions tedana/resources/config/outputs.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,34 @@
"orig": "t2svG",
"bidsv1.5.0": "T2starmap"
},
"t2 img": {
"orig": "t2vG",
"bidsv1.5.0": "T2map"
},
"s0 img": {
"orig": "s0vG",
"bidsv1.5.0": "S0map"
},
"s0_I img": {
"orig": "s0IvG",
"bidsv1.5.0": "S0map"
},
"s0_II img": {
"orig": "s0IIvG",
"bidsv1.5.0": "S0map"
},
"combined img": {
"orig": "ts_OC",
"bidsv1.5.0": "desc-optcom_bold"
},
"combined T2* img": {
"orig": "t2star_OC",
"bidsv1.5.0": "desc-optcom_bold"
},
"combined T2 img": {
"orig": "t2_OC",
"bidsv1.5.0": "desc-optcom"
},
"ICA components img": {
"orig": "ica_components",
"bidsv1.5.0": "desc-ICA_components"
Expand Down
Loading

0 comments on commit 6957ee5

Please sign in to comment.