Skip to content

Commit

Permalink
Merge pull request #181 from holgern/fix_testrename
Browse files Browse the repository at this point in the history
Fix rename channel function
  • Loading branch information
skjerns authored Jun 24, 2022
2 parents 40a5aa6 + 11586dc commit a34382e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
31 changes: 12 additions & 19 deletions pyedflib/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def anonymize_edf(edf_file, new_file=None,
write_edf(new_file, signals, signal_headers, header, digital=True)
if verify:
compare_edf(edf_file, new_file, verbose=verbose)
return True
return new_file


def rename_channels(edf_file, mapping, new_file=None, verbose=False):
Expand All @@ -804,28 +804,21 @@ def rename_channels(edf_file, mapping, new_file=None, verbose=False):
True if successful, False if failed.
"""
header = read_edf_header(edf_file)
channels = header['channels']
if new_file is None:
file, ext = os.path.splitext(edf_file)
new_file = file + '_renamed' + ext

signals, signal_headers, header = read_edf(edf_file, digital=True)
channels = [shead['label'] for shead in signal_headers]

signal_headers = []
signals = []
for ch_nr in tqdm(range(len(channels)), disable=not verbose):
signal, signal_header, _ = read_edf(edf_file, digital=True,
ch_nrs=ch_nr, verbose=verbose)
ch = signal_header[0]['label']
if ch in mapping :
if verbose: print('{} to {}'.format(ch, mapping[ch]))
ch = mapping[ch]
signal_header[0]['label']=ch
else:
if verbose: print('no mapping for {}, leave as it is'.format(ch))
signal_headers.append(signal_header[0])
signals.append(signal.squeeze())
for ch in mapping:
assert ch in channels, f'{ch} was not found in channels: {channels}'

return write_edf(new_file, signals, signal_headers, header, digital=True)
for shead in signal_headers:
if shead['label'] in mapping:
shead['label'] = mapping[shead['label']]
write_edf(new_file, signals, signal_headers, header, digital=True)
return new_file


def change_polarity(edf_file, channels, new_file=None, verify=True,
Expand Down Expand Up @@ -869,4 +862,4 @@ def change_polarity(edf_file, channels, new_file=None, verify=True,
write_edf(new_file, signals, signal_headers, header,
digital=True, correct=False, verbose=verbose)
if verify: compare_edf(edf_file, new_file)
return True
return new_file
28 changes: 28 additions & 0 deletions pyedflib/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def setUpClass(cls):
cls.anonymized = os.path.join(data_dir, "tmp_anonymized.edf")
cls.personalized = os.path.join(data_dir, "tmp_personalized.edf")
cls.drop_from = os.path.join(data_dir, 'tmp_drop_from.edf')
cls.renamed = os.path.join(data_dir, 'tmp_renamed.edf')
cls.tmp_testfile = os.path.join(data_dir, 'tmp')

@classmethod
Expand Down Expand Up @@ -372,6 +373,33 @@ def test_annotation_bytestring(self):
highlevel.write_edf(self.edfplus_data_file, signals, signal_headers, header)
_,_,header3 = highlevel.read_edf(self.edfplus_data_file)
self.assertEqual(header2['annotations'], header3['annotations'])

def test_rename_channel(self):
signal_headers = highlevel.make_signal_headers(['ch'+str(i) for i in range(5)])
signals = np.random.rand(5, 256*300)*200 #5 minutes of eeg
signals = (signals - signals.min()) / (signals.max() - signals.min())
highlevel.write_edf(self.renamed, signals, signal_headers)

mapping = {'ch1':'channel1', 'ch2':'channel2'}
renamed = highlevel.rename_channels(self.renamed, mapping=mapping,
verbose=True)
signals2, signal_headers, header = highlevel.read_edf(renamed)
chs = highlevel.read_edf_header(renamed)['channels']
self.assertSetEqual(set(chs), set(['channel1', 'channel2', 'ch3', 'ch4', 'ch0']))

np.testing.assert_allclose(signals, signals2, atol=0.01)

other_edf = self.renamed[:-4]+'2.edf'
highlevel.rename_channels(self.renamed, new_file=other_edf, mapping=mapping)
signals2, signal_headers, header = highlevel.read_edf(other_edf)
chs = highlevel.read_edf_header(other_edf)['channels']
self.assertSetEqual(set(chs), set(['channel1', 'channel2', 'ch3', 'ch4', 'ch0']))

np.testing.assert_allclose(signals, signals2, atol=0.01)

with self.assertRaises(AssertionError):
highlevel.rename_channels(self.renamed, mapping={'doesnotexist':'test'})



if __name__ == '__main__':
Expand Down

0 comments on commit a34382e

Please sign in to comment.