Skip to content

Commit

Permalink
BUG: Fix bugs with split files (#597)
Browse files Browse the repository at this point in the history
* BUG: Fix bugs with split files

* FIX: Should be working

* FIX: Could be file or dir
  • Loading branch information
larsoner authored Sep 8, 2022
1 parent 63cadf1 commit 6c350e1
Show file tree
Hide file tree
Showing 14 changed files with 94 additions and 48 deletions.
30 changes: 30 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,14 @@ def gen_log_kwargs(
return kwargs


###############################################################################
# Private config vars (not to be set by user)
# -------------------------------------------

_raw_split_size = '2GB'
_epochs_split_size = '2GB'


###############################################################################
# Retrieve custom configuration options
# -------------------------------------
Expand Down Expand Up @@ -3824,3 +3832,25 @@ def save_logs(logs):
'configuration. Currently the `conditions` parameter is empty. '
'This is only allowed for resting-state analysis.')
raise ValueError(msg)


def _update_for_splits(files_dict, key, *, single=False):
if not isinstance(files_dict, dict): # fake it
assert key is None
files_dict, key = dict(x=files_dict), 'x'
bids_path = files_dict[key]
if bids_path.fpath.exists():
return bids_path # no modifications needed
bids_path = bids_path.copy().update(split='01')
assert bids_path.fpath.exists(), f'Missing file: {bids_path.fpath}'
files_dict[key] = bids_path
# if we only need the first file (i.e., when reading), quit now
if single:
return bids_path
for split in range(2, 100):
split_key = f'{split:02d}'
bids_path_next = bids_path.copy().update(split=split_key)
if not bids_path_next.fpath.exists():
break
files_dict[f'{key}_split-{split_key}'] = bids_path_next
return bids_path
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,5 @@ authors:
- EEG channels couldn't be used as "virtual" EOG channels during ICA artifact
detection. Reported by "fraenni" on the forum. Thank you! 🌻
({{ gh(572) }} by {{ authors.hoechenberger }})
- Fix bug with handling of split files during preprocessing
({{ gh(597) }} by {{ authors.larsoner }})
4 changes: 4 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import coloredlogs


# Ensure that the "scripts" that we import from is the correct one
sys.path.insert(0, str(pathlib.Path(__file__).parent))


logger = logging.getLogger(__name__)

log_level_styles = {
Expand Down
17 changes: 11 additions & 6 deletions scripts/preprocessing/_01_maxfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

import config
from config import (gen_log_kwargs, on_error, failsafe_run,
import_experimental_data, import_er_data, import_rest_data)
import_experimental_data, import_er_data, import_rest_data,
_update_for_splits)
from config import parallel_func

logger = logging.getLogger('mne-bids-pipeline')
Expand Down Expand Up @@ -81,7 +82,6 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
raise ValueError(f'You cannot set use_maxwell_filter to True '
f'if data have already processed with Maxwell-filter.'
f' Got proc={config.proc}.')

bids_path_in = in_files[f"raw_run-{run}"]
bids_path_out = bids_path_in.copy().update(
processing="sss",
Expand Down Expand Up @@ -154,13 +154,16 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
logger.info(**gen_log_kwargs(
message=msg, subject=subject, session=session, run=run))
raw_sss.save(out_files['sss_raw'], picks=picks, split_naming='bids',
overwrite=True)
overwrite=True, split_size=cfg._raw_split_size)
# we need to be careful about split files
_update_for_splits(out_files, 'sss_raw')
del raw, raw_sss

if cfg.interactive:
# Load the data we have just written, because it contains only
# the relevant channels.
raw_sss = mne.io.read_raw_fif(bids_path_out, allow_maxshield=True)
raw_sss = mne.io.read_raw_fif(
out_files['sss_raw'], allow_maxshield=True)
raw_sss.plot(n_channels=50, butterfly=True, block=True)
del raw_sss

Expand Down Expand Up @@ -211,7 +214,7 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
# copy the bad channel selection from the reference run over to
# the resting-state recording.

raw_sss = mne.io.read_raw_fif(bids_path_out)
raw_sss = mne.io.read_raw_fif(out_files['sss_raw'])
rank_exp = mne.compute_rank(raw_sss, rank='info')['meg']
rank_noise = mne.compute_rank(raw_noise_sss, rank='info')['meg']

Expand Down Expand Up @@ -248,8 +251,9 @@ def run_maxwell_filter(*, cfg, subject, session=None, run=None, in_files=None):
message=msg, subject=subject, session=session, run=run))
raw_noise_sss.save(
out_files['sss_noise'], picks=picks, overwrite=True,
split_naming='bids'
split_naming='bids', split_size=cfg._raw_split_size,
)
_update_for_splits(out_files, 'sss_noise')
del raw_noise_sss
return {key: pth.fpath for key, pth in out_files.items()}

Expand Down Expand Up @@ -290,6 +294,7 @@ def get_config(
min_break_duration=config.min_break_duration,
t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501
t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501
_raw_split_size=config._raw_split_size,
)
return cfg

Expand Down
22 changes: 13 additions & 9 deletions scripts/preprocessing/_02_frequency_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

import config
from config import (gen_log_kwargs, on_error, failsafe_run,
import_experimental_data, import_er_data, import_rest_data)
import_experimental_data, import_er_data, import_rest_data,
_update_for_splits)
from config import parallel_func


Expand Down Expand Up @@ -75,21 +76,19 @@ def get_input_fnames_frequency_filter(**kwargs):

if cfg.use_maxwell_filter:
bids_path_in.update(processing="sss")
if bids_path_in.copy().update(split='01').fpath.exists():
bids_path_in = bids_path_in.update(split='01')

in_files = dict()
in_files[f'raw_run-{run}'] = bids_path_in
_update_for_splits(in_files, f'raw_run-{run}', single=True)

if (cfg.process_er or config.noise_cov == 'rest') and run == cfg.runs[0]:
noise_task = "rest" if config.noise_cov == "rest" else "noise"
if cfg.use_maxwell_filter:
raw_noise_fname_in = bids_path_in.copy().update(
run=None, task=noise_task
)
if raw_noise_fname_in.copy().update(split='01').fpath.exists():
raw_noise_fname_in.update(split='01')
in_files["raw_noise"] = raw_noise_fname_in
_update_for_splits(in_files, "raw_noise", single=True)
else:
if config.noise_cov == 'rest':
in_files["raw_rest"] = bids_path_in.copy().update(
Expand Down Expand Up @@ -189,7 +188,7 @@ def filter_data(

out_files['raw_filt'] = bids_path.copy().update(
root=cfg.deriv_root, processing='filt', extension='.fif',
suffix='raw')
suffix='raw', split=None)
raw.load_data()
filter(
raw=raw, subject=subject, session=session, run=run,
Expand All @@ -201,7 +200,9 @@ def filter_data(
resample(raw=raw, subject=subject, session=session, run=run,
sfreq=cfg.resample_sfreq, data_type='experimental')

raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids')
raw.save(out_files['raw_filt'], overwrite=True, split_naming='bids',
split_size=cfg._raw_split_size)
_update_for_splits(out_files, 'raw_filt')
if cfg.interactive:
# Plot raw data and power spectral density.
raw.plot(n_channels=50, butterfly=True)
Expand Down Expand Up @@ -238,7 +239,7 @@ def filter_data(
out_files['raw_noise_filt'] = \
bids_path_noise.copy().update(
root=cfg.deriv_root, processing='filt', extension='.fif',
suffix='raw')
suffix='raw', split=None)

raw_noise.load_data()
filter(
Expand All @@ -252,8 +253,10 @@ def filter_data(
sfreq=cfg.resample_sfreq, data_type=data_type)

raw_noise.save(
out_files['raw_noise_filt'], overwrite=True, split_naming='bids'
out_files['raw_noise_filt'], overwrite=True, split_naming='bids',
split_size=cfg._raw_split_size,
)
_update_for_splits(out_files, 'raw_noise_filt')
if cfg.interactive:
# Plot raw data and power spectral density.
raw_noise.plot(n_channels=50, butterfly=True)
Expand Down Expand Up @@ -301,6 +304,7 @@ def get_config(
min_break_duration=config.min_break_duration,
t_break_annot_start_after_previous_event=config.t_break_annot_start_after_previous_event, # noqa:E501
t_break_annot_stop_before_next_event=config.t_break_annot_stop_before_next_event, # noqa:E501
_raw_split_size=config._raw_split_size,
)
return cfg

Expand Down
15 changes: 8 additions & 7 deletions scripts/preprocessing/_03_make_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import config
from config import make_epochs, gen_log_kwargs, on_error, failsafe_run
from config import parallel_func
from config import parallel_func, _update_for_splits

logger = logging.getLogger('mne-bids-pipeline')

Expand All @@ -44,10 +44,7 @@ def run_epochs(*, cfg, subject, session=None):
for run in cfg.runs:
raw_fname_in = bids_path.copy().update(run=run, processing='filt',
suffix='raw', check=False)

if raw_fname_in.copy().update(split='01').fpath.exists():
raw_fname_in.update(split='01')

raw_fname_in = _update_for_splits(raw_fname_in, None, single=True)
raw_fnames.append(raw_fname_in)

# Generate a unique event name -> event code mapping that can be used
Expand Down Expand Up @@ -191,7 +188,10 @@ def run_epochs(*, cfg, subject, session=None):
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))
epochs_fname = bids_path.copy().update(suffix='epo', check=False)
epochs.save(epochs_fname, overwrite=True, split_naming='bids')
epochs.save(
epochs_fname, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)
# _update_for_splits(out_files, 'epochs')

if cfg.interactive:
epochs.plot()
Expand Down Expand Up @@ -228,7 +228,8 @@ def get_config(
event_repeated=config.event_repeated,
decim=config.decim,
ch_types=config.ch_types,
eeg_reference=config.get_eeg_reference()
eeg_reference=config.get_eeg_reference(),
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
6 changes: 2 additions & 4 deletions scripts/preprocessing/_04a_run_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import config
from config import (make_epochs, gen_log_kwargs, on_error, failsafe_run,
annotations_to_events)
annotations_to_events, _update_for_splits)
from config import parallel_func


Expand Down Expand Up @@ -259,9 +259,7 @@ def run_ica(*, cfg, subject, session=None):
raw_fnames = []
for run in cfg.runs:
raw_fname.update(run=run)
if raw_fname.copy().update(split='01').fpath.exists():
raw_fname.update(split='01')

raw_fname = _update_for_splits(raw_fname, None, single=True)
raw_fnames.append(raw_fname.copy())

# Generate a unique event name -> event code mapping that can be used
Expand Down
6 changes: 2 additions & 4 deletions scripts/preprocessing/_04b_run_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import config
from config import gen_log_kwargs, on_error, failsafe_run
from config import parallel_func
from config import parallel_func, _update_for_splits


logger = logging.getLogger('mne-bids-pipeline')
Expand Down Expand Up @@ -49,9 +49,7 @@ def run_ssp(*, cfg, subject, session=None):
msg = f'Input: {raw_fname_in.basename}, Output: {proj_fname_out.basename}'
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))

if raw_fname_in.copy().update(split='01').fpath.exists():
raw_fname_in.update(split='01')
raw_fname_in = _update_for_splits(raw_fname_in, None, single=True)

raw = mne.io.read_raw_fif(raw_fname_in)
msg = 'Computing SSPs for ECG'
Expand Down
8 changes: 6 additions & 2 deletions scripts/preprocessing/_05a_apply_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def apply_ica(*, cfg, subject, session):
msg = 'Saving reconstructed epochs after ICA.'
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))
epochs_cleaned.save(fname_epo_out, overwrite=True, split_naming='bids')
epochs_cleaned.save(
fname_epo_out, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)
# _update_for_splits(out_files, 'epochs_cleaned')

# Compare ERP/ERF before and after ICA artifact rejection. The evoked
# response is calculated across ALL epochs, just like ICA was run on
Expand Down Expand Up @@ -127,7 +130,8 @@ def get_config(
deriv_root=config.get_deriv_root(),
interactive=config.interactive,
baseline=config.baseline,
ica_reject=config.get_ica_reject()
ica_reject=config.get_ica_reject(),
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
6 changes: 5 additions & 1 deletion scripts/preprocessing/_05b_apply_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def apply_ssp(*, cfg, subject, session=None):
msg = 'Saving epochs with projectors.'
logger.info(**gen_log_kwargs(message=msg, subject=subject,
session=session))
epochs_cleaned.save(fname_out, overwrite=True, split_naming='bids')
epochs_cleaned.save(
fname_out, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)
# _update_for_splits(out_files, 'epochs_cleaned')


def get_config(
Expand All @@ -76,6 +79,7 @@ def get_config(
rec=config.rec,
space=config.space,
deriv_root=config.get_deriv_root(),
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
7 changes: 5 additions & 2 deletions scripts/preprocessing/_06_ptp_reject.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def drop_ptp(*, cfg, subject, session=None):
msg = 'Saving cleaned, baseline-corrected epochs …'

epochs.apply_baseline(cfg.baseline)
epochs.save(fname_out, overwrite=True, split_naming='bids')
epochs.save(
fname_out, overwrite=True, split_naming='bids',
split_size=cfg._epochs_split_size)


def get_config(
Expand All @@ -108,7 +110,8 @@ def get_config(
spatial_filter=config.spatial_filter,
ica_reject=config.get_ica_reject(),
deriv_root=config.get_deriv_root(),
decim=config.decim
decim=config.decim,
_epochs_split_size=config._epochs_split_size,
)
return cfg

Expand Down
16 changes: 4 additions & 12 deletions scripts/report/_01_make_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import config
from config import (
gen_log_kwargs, on_error, failsafe_run, parallel_func,
get_noise_cov_bids_path
get_noise_cov_bids_path, _update_for_splits,
)


Expand All @@ -52,10 +52,7 @@ def get_events(cfg, subject, session):

for run in cfg.runs:
this_raw_fname = raw_fname.copy().update(run=run)

if this_raw_fname.copy().update(split='01').fpath.exists():
this_raw_fname.update(split='01')

this_raw_fname = _update_for_splits(this_raw_fname, None, single=True)
raw_filt = mne.io.read_raw_fif(this_raw_fname)
raws_filt.append(raw_filt)
del this_raw_fname
Expand All @@ -81,10 +78,7 @@ def get_er_path(cfg, subject, session):
datatype=cfg.datatype,
root=cfg.deriv_root,
check=False)

if raw_fname.copy().update(split='01').fpath.exists():
raw_fname.update(split='01')

raw_fname = _update_for_splits(raw_fname, None, single=True)
return raw_fname


Expand Down Expand Up @@ -482,9 +476,7 @@ def run_report_preprocessing(
run=run, processing='filt',
suffix='raw', check=False
)
if fname.copy().update(split='01').fpath.exists():
fname.update(split='01')

fname = _update_for_splits(fname, None, single=True)
fnames_raw_filt.append(fname)

fname_epo_not_clean = bids_path.copy().update(suffix='epo')
Expand Down
Loading

0 comments on commit 6c350e1

Please sign in to comment.