Skip to content

Commit

Permalink
figure for submission
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaitanya CHINTALURI committed Aug 4, 2023
1 parent de7fd0c commit 0a88f76
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 66 deletions.
3 changes: 2 additions & 1 deletion figures/kCSD_properties/L_curve_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main_loop(src_width, total_ele, inpos, lpos, nm, noise=0, srcs=1):
t_csd_x, t_csd_y, true_csd = generate_csd_1D(src_width, nm, srcs=srcs,
start_x=0, end_x=1.,
start_y=0, end_y=1,
res_x=100, res_y=100)
res_x=101, res_y=101)
if type(noise) == float: n_spec = [noise]
else: n_spec = noise
for i, noise in enumerate(n_spec):
Expand Down Expand Up @@ -142,6 +142,7 @@ def main_loop(src_width, total_ele, inpos, lpos, nm, noise=0, srcs=1):
'pots':pots, 'estm_x':k.estm_x, 'est_pot':est_pot,
'est_csd':est_csd, 'noreg_csd':noreg_csd, 'errsy':errsy}
np.savez('data_fig4_and_fig13_'+save_as, **vals_to_save)
print(true_csd.shape, est_csd[:,0].shape)
RMS_wek[0, i] = np.linalg.norm(true_csd/np.linalg.norm(true_csd) - est_csd[:,0]/np.linalg.norm(est_csd[:,0]))
RMS_wek[1, i] = np.linalg.norm(true_csd/np.linalg.norm(true_csd) - est_csd_cv[:,0]/np.linalg.norm(est_csd_cv[:,0]))

Expand Down
10 changes: 5 additions & 5 deletions figures/kCSD_properties/figure_LC.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def make_plots(title, m_norm, m_resi, true_csd, curveseq, ele_y,
# os.chdir("./LCurve/LC2")
noises = 3
noise_lvl = np.linspace(0, 0.5, noises)
# df = np.load('data_fig4_and_fig13_lc_noise25.0.npz')
#df = np.load('data_fig4_and_fig13_lc_noise25.0.npz')
Rs = np.linspace(0.025, 8*0.025, 8)
title = ['nazwa_pliku']
save_as = 'noise'
# make_plots(title, df['m_norm'], df['m_resi'], df['true_csd'],
# df['curve_surf'], df['ele_y'], df['pots_n'],
# df['pots'], df['estm_x'], df['est_pot'], df['est_csd'],
# df['noreg_csd'], save_as)
make_plots(title, df['m_norm'], df['m_resi'], df['true_csd'],
df['curve_surf'], df['ele_y'], df['pots_n'],
df['pots'], df['estm_x'], df['est_pot'], df['est_csd'],
df['noreg_csd'], save_as)
45 changes: 26 additions & 19 deletions figures/kCSD_properties/figure_LCandCV.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,58 @@ def set_axis(ax, x, y, letter=None):
def plot_surface(curve_surf, errsy, save_as):
fsize = 18
lambdas = np.logspace(-7, -3, 50)
fig = plt.figure(figsize = (20,9), dpi = 300)
gs = gridspec.GridSpec(16, 12, hspace=2, wspace=2)
fig = plt.figure(figsize = (15, 6), dpi = 300)
gs = gridspec.GridSpec(16, 12, hspace=1, wspace=1)
ax = plt.subplot(gs[0:16, 0:6])
set_axis(ax, -0.05, 1.05, letter='A')
plt.pcolormesh(lambdas, np.arange(9), curve_surf,
cmap = 'BrBG', vmin = -2, vmax=2)

plt.pcolormesh(lambdas, np.arange(8), curve_surf,
cmap = 'BrBG', vmin = -3, vmax=3)
plt.colorbar()
for i,m in enumerate(curve_surf.argmax(axis=1)):
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red', alpha = 0.7)
plt.scatter([lambdas[m]], [i], s=50, color='red', alpha = 0.7)
if i==7:
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red',
label = 'Maximum Curvature', alpha = 0.7)
plt.scatter([lambdas[m]], [i], s=50, color='red',
label = 'Maximum \nCurvature', alpha = 0.7)
plt.xlim(lambdas[1],lambdas[-1])
plt.title('L-curve regularization', fontsize = fsize)
plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
# plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
# frameon = False, fontsize = fsize)
plt.legend(loc='upper left', ncol=1,
frameon = False, fontsize = fsize)
plt.yticks(np.arange(8)+0.5, [str(x)+'x' for x in range(1,9)])
plt.yticks(np.arange(8), [str(x)+'x' for x in range(1,9)])
plt.xscale('log')
plt.ylabel('Parameter $R$ in electrode distance', fontsize=fsize, labelpad = 15)
plt.xlabel('$\lambda$',fontsize=fsize)
ax = plt.subplot(gs[0:16, 6:12])
set_axis(ax, -0.05, 1.05, letter='B')
plt.pcolormesh(lambdas, np.arange(9), errsy, cmap = 'Greys')
plt.pcolormesh(lambdas, np.arange(8), errsy, cmap='Greys', vmin=0.01, vmax=0.02)
plt.colorbar()
for i,m in enumerate(errsy.argmin(axis=1)):
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red', alpha = 0.7)
plt.scatter([lambdas[m]], [i], s=50, color='red', alpha = 0.7)
if i==7:
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red',
label = 'Minimum Error', alpha = 0.7)
plt.scatter([lambdas[m]], [i], s=50, color='red',
label = 'Minimum \nError', alpha = 0.7)
plt.xlim(lambdas[1],lambdas[-1])
plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
# plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
# frameon = False, fontsize = fsize)
plt.legend(loc='upper left', ncol=1,
frameon = False, fontsize = fsize)
plt.title('Cross-validation regularization', fontsize = fsize)
plt.yticks(np.arange(8)+0.5, [str(x)+'x' for x in range(1,9)])
plt.yticks(np.arange(8), [str(x)+'x' for x in range(1,9)])
plt.xscale('log')
plt.xlabel('$\lambda$', fontsize=fsize)
fig.savefig(save_as+'.png')

if __name__=='__main__':
# os.chdir("./LCurve/")
os.chdir("./LCurve/")
noises = 3
noise_lvl = np.linspace(0, 0.5, noises)
# df = np.load('LC2/data_fig4_and_fig13_lc_noise25.0.npz')
print(os.getcwd())
df = np.load(os.path.join('LC2', 'data_fig4_and_fig13_LC_noise25.0.npz'))

Rs = np.linspace(0.025, 8*0.025, 8)
title = ['nazwa_pliku']
save_as = 'noise'
# plot_surface(df['curve_surf'], df['errsy'], save_as+'surf')
plt.close('all')
plot_surface(df['curve_surf'], df['errsy'], save_as+'surf')
plt.close('all')
20 changes: 11 additions & 9 deletions figures/kCSD_properties/figure_LCandCVperformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def make_plot_perf(sim_results):
lam_lc = sim_results[0, 0]
rms_cv = sim_results[1, 2]
lam_cv = sim_results[1, 0]
fig = plt.figure(figsize = (9,12), dpi = 300)
widths = [10]
heights = [1, 1]
gs = gridspec.GridSpec(2, 1, height_ratios=heights, width_ratios=widths,
fig = plt.figure(figsize = (12,7), dpi = 300)
widths = [1, 1]
heights = [1]
gs = gridspec.GridSpec(1, 2, height_ratios=heights, width_ratios=widths,
hspace=0.45, wspace=0.3)
ax1 = plt.subplot(gs[0])
if np.min(rms_cv) < np.min(rms_lc):
Expand All @@ -50,7 +50,8 @@ def make_plot_perf(sim_results):
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)
set_axis(ax1, -0.05, 1.05, letter='A')
plt.title('Performance of regularization methods')
ax1.legend(loc='upper left', frameon=False)
# plt.title('Performance of regularization methods')

'''second plot'''
ax2 = plt.subplot(gs[1])
Expand All @@ -69,14 +70,15 @@ def make_plot_perf(sim_results):
plt.xlabel('Relative Noise Level', labelpad = 15)
set_axis(ax2, -0.05, 1.05, letter='B')
ht, lh = ax2.get_legend_handles_labels()
fig.legend(ht, lh, loc='lower center', ncol=2, frameon=False)
#fig.legend(ht, lh, loc='upper center', ncol=2, frameon=False)
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.legend(loc='upper left', frameon=False)
fig.savefig('stats.png')

if __name__=='__main__':
# os.chdir("./LCurve/")
os.chdir("./LCurve/")
noises = 9
noise_lvl = np.linspace(0, 0.5, noises)
# sim_results = np.load('sim_results.npy')
# make_plot_perf(sim_results)
sim_results = np.load('sim_results.npy')
make_plot_perf(sim_results)
2 changes: 2 additions & 0 deletions figures/kCSD_properties/figure_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
})


def cm_to_inches(vals):
return [0.393701*ii for ii in vals]
Binary file modified figures/kCSD_properties/sources_electrodes.odg
Binary file not shown.
12 changes: 7 additions & 5 deletions figures/kCSD_properties/tutorial_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def grid(x, y, z):
x = x.flatten()
y = y.flatten()
z = z.flatten()
xi, yi = np.mgrid[min(x):max(x):np.complex(0, 100),
min(y):max(y):np.complex(0, 100)]
xi, yi = np.mgrid[min(x):max(x):complex(0, 100),
min(y):max(y):complex(0, 100)]
zi = griddata((x, y), z, (xi, yi), method='linear')
return xi, yi, zi

def set_axis(ax, letter=None):
ax.text(
-0.05,
1.05,
1.10,
letter,
fontsize=20,
weight='bold',
Expand All @@ -63,7 +63,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None, xl
if ylabel:
ax.set_ylabel('Y (mm)')
if title is not None:
ax.set_title(title)
ax.set_title(title, pad=10)
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ticks = np.linspace(-1 * t_max, t_max, 3, endpoint=True)
Expand Down Expand Up @@ -128,6 +128,7 @@ def generate_figure(small_seed, large_seed):
cax = plt.subplot(gs[1, 0])
t_max_1 = 0.50
make_subplot(ax, 'csd', csd_x, csd_y, true_csd, cax, 'True CSD', xlabel=True, ylabel=True, letter='A', t_max=t_max_1)
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'pot', pot_X, pot_Y, pot_Z, cax, 'Interpolated potentials', xlabel=True, ele_pos=ele_pos, letter='B')
Expand All @@ -145,9 +146,10 @@ def generate_figure(small_seed, large_seed):
cax = plt.subplot(gs[1, 0])
t_max_2 = 0.52
make_subplot(ax, 'csd', csd_x, csd_y, true_csd, cax, ylabel=True, xlabel=True, letter='E', t_max=t_max_2)
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'pot', pot_X, pot_Y, pot_Z, cax, xlabel=True, ele_pos=ele_pos, letter='F')
make_subplot(ax, 'pot', pot_X, pot_Y, pot_Z, cax, xlabel=True, ele_pos=ele_pos, letter='F', t_max=1)
ax = plt.subplot(gs[0, 2])
cax = plt.subplot(gs[1, 2])
make_subplot(ax, 'csd', k.estm_x, k.estm_y, est_csd_pre_cv[:, :, 0], cax, xlabel=True, letter='G', t_max=t_max_2)
Expand Down
6 changes: 4 additions & 2 deletions figures/kCSD_properties/tutorial_broken_electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_files(folderpaths, seeds):
def set_axis(ax, letter=None):
ax.text(
-0.05,
1.05,
1.1,
letter,
fontsize=20,
weight='bold',
Expand Down Expand Up @@ -121,7 +121,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None, xl
if ylabel:
ax.set_ylabel('Y (mm)')
if title is not None:
ax.set_title(title)
ax.set_title(title, pad=10)
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ticks = np.linspace(0, t_max, 3, endpoint=True)
Expand Down Expand Up @@ -154,6 +154,7 @@ def generate_figure():
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
cax=cax, title='Error CSD', xlabel=True, ylabel=True, letter='A',
t_max=err_max)
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'err', csd_x, csd_y, errs[1], ele_pos=electrode_positions(missing_ele=5),
Expand All @@ -179,6 +180,7 @@ def generate_figure():
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
cax=cax, xlabel=True, ylabel=True, letter='E',
t_max=err_max)
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'err', csd_x, csd_y, errs[1], ele_pos=electrode_positions(missing_ele=5),
Expand Down
14 changes: 9 additions & 5 deletions figures/kCSD_properties/tutorial_broken_electrodes_diff_err.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def load_files(folderpaths, seeds):
def set_axis(ax, letter=None):
ax.text(
-0.05,
1.05,
1.10,
letter,
fontsize=20,
weight='bold',
Expand Down Expand Up @@ -145,7 +145,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None,
if ylabel:
ax.set_ylabel('Y (mm)')
if title is not None:
ax.set_title(title)
ax.set_title(title, pad=10)
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ticks = np.linspace(0, t_max, 3, endpoint=True)
Expand Down Expand Up @@ -277,20 +277,22 @@ def generate_figure2():
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
cax=cax, title='Error CSD', xlabel=True, ylabel=True, letter='A',
t_max=.2)
ax.text(-0.4, 0.5, 'Small+Large sources', fontsize=20, rotation=90, va='center')

ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[1] - errs[0]), ele_pos=electrode_positions(missing_ele=5),
cax=cax, title='Error Diff CSD 5 broken', xlabel=True, letter='B',
cax=cax, title='5 broken - Error CSD ', xlabel=True, letter='B',
t_max=err_max)
ax = plt.subplot(gs[0, 2])
cax = plt.subplot(gs[1, 2])
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[2] - errs[0]), ele_pos=electrode_positions(missing_ele=10),
cax=cax, title='Error Diff CSD 10 broken', xlabel=True, letter='C',
cax=cax, title='10 broken - Error CSD', xlabel=True, letter='C',
t_max=err_max)
ax = plt.subplot(gs[0, 3])
cax = plt.subplot(gs[1, 3])
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[3] - errs[0]), ele_pos=electrode_positions(missing_ele=20),
cax=cax, title='Error Diff CSD 20 broken', xlabel=True, letter='D',
cax=cax, title='20 broken - Error CSD', xlabel=True, letter='D',
t_max=err_max)

errs = fetch_values('small')
Expand All @@ -302,6 +304,7 @@ def generate_figure2():
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
cax=cax, xlabel=True, ylabel=True, letter='E',
t_max=.2)
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[1] - errs[0]), ele_pos=electrode_positions(missing_ele=5),
Expand All @@ -327,6 +330,7 @@ def generate_figure2():
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
cax=cax, xlabel=True, ylabel=True, letter='I',
t_max=.2)
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[1] - errs[0]), ele_pos=electrode_positions(missing_ele=5),
Expand Down
21 changes: 10 additions & 11 deletions figures/kCSD_properties/tutorial_noisy_electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def grid(x, y, z):
x = x.flatten()
y = y.flatten()
z = z.flatten()
xi, yi = np.mgrid[min(x):max(x):np.complex(0, 100),
min(y):max(y):np.complex(0, 100)]
xi, yi = np.mgrid[min(x):max(x):complex(0, 100),
min(y):max(y):complex(0, 100)]
zi = griddata((x, y), z, (xi, yi), method='linear')
return xi, yi, zi

def set_axis(ax, letter=None):
ax.text(
-0.05,
1.05,
1.1,
letter,
fontsize=20,
weight='bold',
Expand Down Expand Up @@ -77,7 +77,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None, xl
if ylabel:
ax.set_ylabel('Y (mm)')
if title is not None:
ax.set_title(title)
ax.set_title(title, pad=10)
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ticks = np.linspace(0, t_max, 3, endpoint=True)
Expand All @@ -97,12 +97,11 @@ def do_kcsd(CSD_PROFILE, csd_seed, noise_level):
# R_final = np.linspace(0.1, 1.5, 15)
R_final = np.linspace(0.05, 1., 20)
# True CSD_PROFILE
csd_at = np.mgrid[0.:1.:100j,
0.:1.:100j]
csd_at = np.mgrid[0.:1.:101j,
0.:1.:101j]
csd_x, csd_y = csd_at
# Small source
true_csd = CSD_PROFILE(csd_at, seed=csd_seed)

# Electrode positions
ele_x, ele_y = np.mgrid[0.05: 0.95: 10j,
0.05: 0.95: 10j]
Expand Down Expand Up @@ -153,12 +152,12 @@ def generate_figure(small_seed, large_seed):
cax = plt.subplot(gs[1, 0])
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
cax=cax, title='Error CSD', xlabel=True, ylabel=True, letter='A', t_max=t_max_1)

ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
csd_x, csd_y, err, ele_pos = do_kcsd(CSD.gauss_2d_small, csd_seed=small_seed, noise_level=5)
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
cax=cax, title='Error CSD 5% noise', xlabel=True, ylabel=True, letter='B', t_max=t_max_1)
cax=cax, title='Error CSD 5% noise', xlabel=True, letter='B', t_max=t_max_1)

csd_x, csd_y, err, ele_pos = do_kcsd(CSD.gauss_2d_small, csd_seed=small_seed, noise_level=10)
ax = plt.subplot(gs[0, 2])
Expand All @@ -180,7 +179,7 @@ def generate_figure(small_seed, large_seed):
cax = plt.subplot(gs[1, 0])
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
cax=cax, xlabel=True, ylabel=True, letter='E', t_max=0.55)

ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
csd_x, csd_y, err, ele_pos = do_kcsd(CSD.gauss_2d_large, csd_seed=large_seed, noise_level=5)
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
Expand All @@ -199,7 +198,7 @@ def generate_figure(small_seed, large_seed):
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
cax=cax, xlabel=True, letter='H', t_max=0.55)
plt.savefig('tutorial_noise.png', dpi=300)
plt.show()
# plt.show()

if __name__ == '__main__':
small_seed = 15
Expand Down
Loading

0 comments on commit 0a88f76

Please sign in to comment.