diff --git a/.gitignore b/.gitignore index ca109865..c4b401c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.py[co] *.pyd *.so +.DS_Store # Packages *.egg @@ -24,6 +25,7 @@ pip-log.txt .wpr *.wpu *.wpr +*-checkpoint.ipynb .pypirc # Project working files diff --git a/pyedflib/highlevel.py b/pyedflib/highlevel.py index e71457ce..b28ccaf8 100644 --- a/pyedflib/highlevel.py +++ b/pyedflib/highlevel.py @@ -15,6 +15,7 @@ - Comparing EDFs - Renaming Channels from EDF files - Dropping Channels from EDF files + - Cropping EDFs @author: skjerns """ @@ -24,7 +25,7 @@ import warnings import pyedflib from copy import deepcopy -from datetime import datetime +from datetime import datetime, timedelta # from . import EdfWriter # from . import EdfReader @@ -766,6 +767,124 @@ def anonymize_edf(edf_file, new_file=None, return True +def crop_edf( + edf_file, + *, + new_file=None, + start=None, + stop=None, + start_format="datetime", + stop_format="datetime", + verbose=True, +): + """Crop an EDF file to desired start/stop times. + + The new start/end times can be either specified as a datetime.datetime or + as seconds from the beginning of the recording. + For example, using `crop_edf(..., start=10, start_format="seconds") will + remove the first 10-seconds of the recording. + + Parameters + ---------- + edf_file : str + The path to the EDF file. + new_file : str | None + The path to the new cropped file. If None (default), the input + filename appended with '_cropped' is used. + start : datetime.datetime | int | float | None + The new start. Can be None to keep the original start time of + the recording. + stop : datetime.datetime | int | float | None + The new stop. Can be None to keep the original end time of the + recording. + start_format : str + The format of ``start``: "datetime" (default) or "seconds". + stop_format : str + The format of ``stop``: "datetime" (default) or "seconds". + verbose : bool + If True (default), print some details about the original and cropped + file. + """ + # Check input + assert start_format in ["datetime", "seconds"] + assert stop_format in ["datetime", "seconds"] + if start_format == "datetime": + assert isinstance(start, (datetime, type(None))) + else: + assert isinstance(start, (int, float, type(None))) + if stop_format == "datetime": + assert isinstance(stop, (datetime, type(None))) + else: + assert isinstance(start, (int, float, type(None))) + + # Open the original EDF file + edf = pyedflib.EdfReader(edf_file) + signals_headers = edf.getSignalHeaders() + header = edf.getHeader() + + # Define new start time + current_start = edf.getStartdatetime() + if start is None: + start = current_start + else: + if start_format == "seconds": + start = current_start + timedelta(seconds=start) + else: + pass + assert current_start <= start, 'start must not be before current start of recording' + start_diff_from_start = (start - current_start).total_seconds() + + # Define new stop time + current_stop = current_start + timedelta(seconds=edf.getFileDuration()) + current_duration = current_stop - current_start + if stop is None: + stop = current_stop + else: + if stop_format == "seconds": + stop = current_start + timedelta(seconds=stop) + else: + pass + assert stop <= current_stop, 'new stop value must not be after current end of recording' + + assert start < current_stop, 'new start value must not be after current end of recording' + assert stop > current_start, 'new stop value must not be before current start of recording' + stop_diff_from_start = (stop - current_start).total_seconds() + + # Crop each signal + signals = [] + for i in range(len(edf.getSignalHeaders())): + sf = edf.getSampleFrequency(i) + # Convert from seconds to samples + start_idx = int(np.round(start_diff_from_start * sf)) + stop_idx = int(np.round(stop_diff_from_start * sf)) + # We use digital=True in reading and writing to avoid precision loss + signals.append( + edf.readSignal(i, start=start_idx, n=stop_idx - start_idx, digital=True) + ) + edf.close() + + # Update header startdate and save file + header["startdate"] = start + if new_file is None: + file, ext = os.path.splitext(edf_file) + new_file = file + "_cropped" + ext + write_edf(new_file, signals, signals_headers, header, digital=True) + + # Safety check: are we able to load the new EDF file? + # Get new EDF start, stop and duration + with pyedflib.EdfReader(new_file) as edf: + start = edf.getStartdatetime() + stop = start + timedelta(seconds=edf.getFileDuration()) + duration = stop - start + edf.close() + + # Verbose + if verbose: + print(f"Original: {current_start} to {current_stop} ({current_duration})") + print(f"Truncated: {start} to {stop} ({duration})") + print(f"Succesfully written file: {new_file}") + + def rename_channels(edf_file, mapping, new_file=None, verbose=False): """ A convenience function to rename channels in an EDF file. diff --git a/pyedflib/tests/test_highlevel.py b/pyedflib/tests/test_highlevel.py index 739f5699..e2985a5e 100644 --- a/pyedflib/tests/test_highlevel.py +++ b/pyedflib/tests/test_highlevel.py @@ -1,7 +1,7 @@ # Copyright (c) 2019 - 2020 Simon Kern # Copyright (c) 2015 Holger Nahrstaedt -import os, sys +import os import shutil import gc import numpy as np @@ -11,7 +11,30 @@ from pyedflib import highlevel from pyedflib.edfwriter import EdfWriter from pyedflib.edfreader import EdfReader -from datetime import datetime, date +from datetime import datetime, timedelta + + +def _compare_cropped_edf(path_orig_edf, path_cropped_edf): + # Load original EDF + orig_signals, orig_signal_headers, orig_header = highlevel.read_edf(path_orig_edf) # noqa: E501 + orig_start = orig_header["startdate"] + + # Load cropped EDF + signals, signal_headers, header = highlevel.read_edf(path_cropped_edf) # noqa: E501 + start = header["startdate"] + duration = signals[0].size / signal_headers[0]["sample_frequency"] + stop = start + timedelta(seconds=duration) + + # Compare signal headers + assert signal_headers == orig_signal_headers + + # Compare signal values + for i in range(signals.shape[0]): + sf_sig = signal_headers[i]["sample_frequency"] + idx_start = int(np.round((start - orig_start).seconds * sf_sig)) + idx_stop = int(np.round((stop - orig_start).seconds * sf_sig)) + assert (signals[i] == orig_signals[i, idx_start:idx_stop]).all() + class TestHighLevel(unittest.TestCase): @@ -300,6 +323,58 @@ def test_anonymize(self): new_values=['x', '', 'xx', 'xxx'], verify=True) + def test_crop_edf(self): + data_dir = os.path.join(os.path.dirname(__file__), 'data') + edf_file = os.path.join(data_dir, 'test_generator.edf') + outfile = os.path.join(data_dir, 'tmp_test_generator_cropped.edf') + orig_header = highlevel.read_edf_header(edf_file) # noqa: E501 + orig_start = orig_header["startdate"] + new_start = datetime(2011, 4, 4, 12, 58, 0) + new_stop = datetime(2011, 4, 4, 13, 0, 0) + + # Test 1: no cropping + # The output file should be the same as input. + highlevel.crop_edf( + edf_file, new_file=outfile, start=None, stop=None) + assert highlevel.compare_edf(edf_file, outfile) + + # Test 2: crop using datetimes (default) + # .. both start and stop + highlevel.crop_edf( + edf_file, new_file=outfile, start=new_start, + stop=new_stop + ) + # Test that the signal values are correctly cropped + _compare_cropped_edf(edf_file, outfile) + # .. only start + highlevel.crop_edf(edf_file, new_file=outfile, start=new_start) + _compare_cropped_edf(edf_file, outfile) + # .. only stop + highlevel.crop_edf(edf_file, new_file=outfile, stop=new_stop) + _compare_cropped_edf(edf_file, outfile) + + # Test 3: crop using seconds + new_start_sec = (new_start - orig_start).seconds + new_stop_sec = (new_stop - orig_start).seconds + # .. both start and stop + highlevel.crop_edf( + edf_file, new_file=outfile, start=new_start_sec, + stop=new_stop_sec, start_format="seconds", stop_format="seconds" + ) + _compare_cropped_edf(edf_file, outfile) + # .. only start + highlevel.crop_edf( + edf_file, new_file=outfile, + start=new_start_sec, start_format="seconds" + ) + _compare_cropped_edf(edf_file, outfile) + # .. only stop + highlevel.crop_edf( + edf_file, new_file=outfile, stop=new_stop_sec, + stop_format="seconds" + ) + _compare_cropped_edf(edf_file, outfile) + def test_drop_channel(self): signal_headers = highlevel.make_signal_headers(['ch'+str(i) for i in range(5)]) signals = np.random.rand(5, 256*300)*200 #5 minutes of eeg