Skip to content

Commit

Permalink
andrew's improvements
Browse files Browse the repository at this point in the history
Implement support for saving/loading Groups in NeoMatlabIO.
  • Loading branch information
zm711 authored Jan 2, 2024
2 parents 8e7e67c + c51b2ce commit 5ef4f01
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 44 deletions.
5 changes: 5 additions & 0 deletions neo/core/spiketrainlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def __getitem__(self, i):
else:
return SpikeTrainList(items=items)

def __setitem__(self, i, value):
if self._items is None:
self._spiketrains_from_array()
self._items[i] = value

def __str__(self):
"""Return str(self)"""
if self._items is None:
Expand Down
92 changes: 50 additions & 42 deletions neo/io/neomatlabio.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def read_block(self, lazy=False):
bl_struct = d['block']
bl = self.create_ob_from_struct(
bl_struct, 'Block')
self._resolve_references(bl)
bl.check_relationships()
return bl

Expand All @@ -242,38 +243,23 @@ def write_block(self, bl, **kargs):
seg_struct = self.create_struct_from_obj(seg)
bl_struct['segments'].append(seg_struct)

for anasig in seg.analogsignals:
anasig_struct = self.create_struct_from_obj(anasig)
seg_struct['analogsignals'].append(anasig_struct)

for irrsig in seg.irregularlysampledsignals:
irrsig_struct = self.create_struct_from_obj(irrsig)
seg_struct['irregularlysampledsignals'].append(irrsig_struct)

for ea in seg.events:
ea_struct = self.create_struct_from_obj(ea)
seg_struct['events'].append(ea_struct)

for ea in seg.epochs:
ea_struct = self.create_struct_from_obj(ea)
seg_struct['epochs'].append(ea_struct)

for sptr in seg.spiketrains:
sptr_struct = self.create_struct_from_obj(sptr)
seg_struct['spiketrains'].append(sptr_struct)

for image_sq in seg.imagesequences:
image_sq_structure = self.create_struct_from_obj(image_sq)
seg_struct['imagesequences'].append(image_sq_structure)
for container_name in seg._child_containers:
for child_obj in getattr(seg, container_name):
child_struct = self.create_struct_from_obj(child_obj)
seg_struct[container_name].append(child_struct)

for group in bl.groups:
group_structure = self.create_struct_from_obj(group)
bl_struct['groups'].append(group_structure)

for container_name in group._child_containers:
for child_obj in getattr(group, container_name):
group_structure[container_name].append(id(child_obj))

scipy.io.savemat(self.filename, {'block': bl_struct}, oned_as='row')

def create_struct_from_obj(self, ob):
struct = {}
struct = {"neo_id": id(ob)}

# relationship
for childname in getattr(ob, '_child_containers', []):
Expand All @@ -290,11 +276,6 @@ def create_struct_from_obj(self, ob):
for i, attr in enumerate(all_attrs):
attrname, attrtype = attr[0], attr[1]

# ~ if attrname =='':
# ~ struct['array'] = ob.magnitude
# ~ struct['units'] = ob.dimensionality.string
# ~ continue

if (hasattr(ob, '_quantity_attr') and
ob._quantity_attr == attrname):
struct[attrname] = ob.magnitude
Expand All @@ -320,13 +301,6 @@ def create_struct_from_obj(self, ob):

def create_ob_from_struct(self, struct, classname):
cl = class_by_name[classname]
# check if inherits Quantity
# ~ is_quantity = False
# ~ for attr in cl._necessary_attrs:
# ~ if attr[0] == '' and attr[1] == pq.Quantity:
# ~ is_quantity = True
# ~ break
# ~ is_quantiy = hasattr(cl, '_quantity_attr')

# ~ if is_quantity:
if hasattr(cl, '_quantity_attr'):
Expand Down Expand Up @@ -374,20 +348,27 @@ def create_ob_from_struct(self, struct, classname):
# check children
if attrname in getattr(ob, '_child_containers', []):
child_struct = getattr(struct, attrname)
child_class_name = classname_lower_to_upper[attrname[:-1]]
try:
# try must only surround len() or other errors are captured
child_len = len(child_struct)
except TypeError:
# strange scipy.io behavior: if len is 1 there is no len()
child = self.create_ob_from_struct(
child_struct,
classname_lower_to_upper[attrname[:-1]])
if classname == "Group":
child = _Ref(child_struct, child_class_name)
else:
child = self.create_ob_from_struct(
child_struct,
child_class_name)
getattr(ob, attrname.lower()).append(child)
else:
for c in range(child_len):
child = self.create_ob_from_struct(
child_struct[c],
classname_lower_to_upper[attrname[:-1]])
if classname == "Group":
child = _Ref(child_struct[c], child_class_name)
else:
child = self.create_ob_from_struct(
child_struct[c],
child_class_name)
getattr(ob, attrname.lower()).append(child)
continue

Expand Down Expand Up @@ -432,4 +413,31 @@ def create_ob_from_struct(self, struct, classname):

setattr(ob, attrname, item)

neo_id = getattr(struct, "neo_id", None)
if neo_id:
setattr(ob, "_id", neo_id)
return ob

def _resolve_references(self, bl):
if bl.groups:
obj_lookup = {}
for ob in bl.children_recur:
if hasattr(ob, "_id"):
obj_lookup[ob._id] = ob
for grp in bl.groups:
for container_name in grp._child_containers:
container = getattr(grp, container_name)
for i, ref in enumerate(container):
assert isinstance(ref, _Ref)
container[i] = obj_lookup[ref.identifier]


class _Ref:

def __init__(self, identifier, target_class_name):
self.identifier = identifier
self.target_cls = class_by_name[target_class_name]

@property
def proxy_for(self):
return self.target_cls
10 changes: 8 additions & 2 deletions neo/test/iotest/test_neomatlabio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from neo.core.analogsignal import AnalogSignal
from neo.core.irregularlysampledsignal import IrregularlySampledSignal
from neo import Block, Segment, SpikeTrain, ImageSequence
from neo import Block, Segment, SpikeTrain, ImageSequence, Group
from neo.test.iotest.common_io_test import BaseTestIO
from neo.io.neomatlabio import NeoMatlabIO

Expand All @@ -26,7 +26,7 @@ class TestNeoMatlabIO(BaseTestIO, unittest.TestCase):
files_to_download = []

def test_write_read_single_spike(self):
block1 = Block()
block1 = Block(name="test_neomatlabio")
seg = Segment('segment1')
spiketrain1 = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz)
spiketrain1.annotate(yep='yop')
Expand All @@ -43,6 +43,8 @@ def test_write_read_single_spike(self):
seg.irregularlysampledsignals.append(irrsig1)
seg.imagesequences.append(image_sequence)

group1 = Group([spiketrain1, sig1])
block1.groups.append(group1)

# write block
filename = self.get_local_path('matlabiotestfile.mat')
Expand Down Expand Up @@ -72,6 +74,10 @@ def test_write_read_single_spike(self):
assert 'yep' in spiketrain2.annotations
assert spiketrain2.annotations['yep'] == 'yop'

# test group retrieval
group2 = block2.groups[0]
assert_array_equal(group1.analogsignals[0], group2.analogsignals[0])


if __name__ == "__main__":
unittest.main()

0 comments on commit 5ef4f01

Please sign in to comment.