diff --git a/neo/core/spiketrainlist.py b/neo/core/spiketrainlist.py index 04ebfd1f8..b6c0940d8 100644 --- a/neo/core/spiketrainlist.py +++ b/neo/core/spiketrainlist.py @@ -115,6 +115,11 @@ def __getitem__(self, i): else: return SpikeTrainList(items=items) + def __setitem__(self, i, value): + if self._items is None: + self._spiketrains_from_array() + self._items[i] = value + def __str__(self): """Return str(self)""" if self._items is None: diff --git a/neo/io/neomatlabio.py b/neo/io/neomatlabio.py index f2cc00af4..756869efc 100644 --- a/neo/io/neomatlabio.py +++ b/neo/io/neomatlabio.py @@ -227,6 +227,7 @@ def read_block(self, lazy=False): bl_struct = d['block'] bl = self.create_ob_from_struct( bl_struct, 'Block') + self._resolve_references(bl) bl.check_relationships() return bl @@ -242,38 +243,23 @@ def write_block(self, bl, **kargs): seg_struct = self.create_struct_from_obj(seg) bl_struct['segments'].append(seg_struct) - for anasig in seg.analogsignals: - anasig_struct = self.create_struct_from_obj(anasig) - seg_struct['analogsignals'].append(anasig_struct) - - for irrsig in seg.irregularlysampledsignals: - irrsig_struct = self.create_struct_from_obj(irrsig) - seg_struct['irregularlysampledsignals'].append(irrsig_struct) - - for ea in seg.events: - ea_struct = self.create_struct_from_obj(ea) - seg_struct['events'].append(ea_struct) - - for ea in seg.epochs: - ea_struct = self.create_struct_from_obj(ea) - seg_struct['epochs'].append(ea_struct) - - for sptr in seg.spiketrains: - sptr_struct = self.create_struct_from_obj(sptr) - seg_struct['spiketrains'].append(sptr_struct) - - for image_sq in seg.imagesequences: - image_sq_structure = self.create_struct_from_obj(image_sq) - seg_struct['imagesequences'].append(image_sq_structure) + for container_name in seg._child_containers: + for child_obj in getattr(seg, container_name): + child_struct = self.create_struct_from_obj(child_obj) + seg_struct[container_name].append(child_struct) for group in bl.groups: group_structure = self.create_struct_from_obj(group) bl_struct['groups'].append(group_structure) + for container_name in group._child_containers: + for child_obj in getattr(group, container_name): + group_structure[container_name].append(id(child_obj)) + scipy.io.savemat(self.filename, {'block': bl_struct}, oned_as='row') def create_struct_from_obj(self, ob): - struct = {} + struct = {"neo_id": id(ob)} # relationship for childname in getattr(ob, '_child_containers', []): @@ -290,11 +276,6 @@ def create_struct_from_obj(self, ob): for i, attr in enumerate(all_attrs): attrname, attrtype = attr[0], attr[1] - # ~ if attrname =='': - # ~ struct['array'] = ob.magnitude - # ~ struct['units'] = ob.dimensionality.string - # ~ continue - if (hasattr(ob, '_quantity_attr') and ob._quantity_attr == attrname): struct[attrname] = ob.magnitude @@ -320,13 +301,6 @@ def create_struct_from_obj(self, ob): def create_ob_from_struct(self, struct, classname): cl = class_by_name[classname] - # check if inherits Quantity - # ~ is_quantity = False - # ~ for attr in cl._necessary_attrs: - # ~ if attr[0] == '' and attr[1] == pq.Quantity: - # ~ is_quantity = True - # ~ break - # ~ is_quantiy = hasattr(cl, '_quantity_attr') # ~ if is_quantity: if hasattr(cl, '_quantity_attr'): @@ -374,20 +348,27 @@ def create_ob_from_struct(self, struct, classname): # check children if attrname in getattr(ob, '_child_containers', []): child_struct = getattr(struct, attrname) + child_class_name = classname_lower_to_upper[attrname[:-1]] try: # try must only surround len() or other errors are captured child_len = len(child_struct) except TypeError: # strange scipy.io behavior: if len is 1 there is no len() - child = self.create_ob_from_struct( - child_struct, - classname_lower_to_upper[attrname[:-1]]) + if classname == "Group": + child = _Ref(child_struct, child_class_name) + else: + child = self.create_ob_from_struct( + child_struct, + child_class_name) getattr(ob, attrname.lower()).append(child) else: for c in range(child_len): - child = self.create_ob_from_struct( - child_struct[c], - classname_lower_to_upper[attrname[:-1]]) + if classname == "Group": + child = _Ref(child_struct[c], child_class_name) + else: + child = self.create_ob_from_struct( + child_struct[c], + child_class_name) getattr(ob, attrname.lower()).append(child) continue @@ -432,4 +413,31 @@ def create_ob_from_struct(self, struct, classname): setattr(ob, attrname, item) + neo_id = getattr(struct, "neo_id", None) + if neo_id: + setattr(ob, "_id", neo_id) return ob + + def _resolve_references(self, bl): + if bl.groups: + obj_lookup = {} + for ob in bl.children_recur: + if hasattr(ob, "_id"): + obj_lookup[ob._id] = ob + for grp in bl.groups: + for container_name in grp._child_containers: + container = getattr(grp, container_name) + for i, ref in enumerate(container): + assert isinstance(ref, _Ref) + container[i] = obj_lookup[ref.identifier] + + +class _Ref: + + def __init__(self, identifier, target_class_name): + self.identifier = identifier + self.target_cls = class_by_name[target_class_name] + + @property + def proxy_for(self): + return self.target_cls diff --git a/neo/test/iotest/test_neomatlabio.py b/neo/test/iotest/test_neomatlabio.py index 1c39a1c76..175755ab2 100644 --- a/neo/test/iotest/test_neomatlabio.py +++ b/neo/test/iotest/test_neomatlabio.py @@ -8,7 +8,7 @@ from neo.core.analogsignal import AnalogSignal from neo.core.irregularlysampledsignal import IrregularlySampledSignal -from neo import Block, Segment, SpikeTrain, ImageSequence +from neo import Block, Segment, SpikeTrain, ImageSequence, Group from neo.test.iotest.common_io_test import BaseTestIO from neo.io.neomatlabio import NeoMatlabIO @@ -26,7 +26,7 @@ class TestNeoMatlabIO(BaseTestIO, unittest.TestCase): files_to_download = [] def test_write_read_single_spike(self): - block1 = Block() + block1 = Block(name="test_neomatlabio") seg = Segment('segment1') spiketrain1 = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz) spiketrain1.annotate(yep='yop') @@ -43,6 +43,8 @@ def test_write_read_single_spike(self): seg.irregularlysampledsignals.append(irrsig1) seg.imagesequences.append(image_sequence) + group1 = Group([spiketrain1, sig1]) + block1.groups.append(group1) # write block filename = self.get_local_path('matlabiotestfile.mat') @@ -72,6 +74,10 @@ def test_write_read_single_spike(self): assert 'yep' in spiketrain2.annotations assert spiketrain2.annotations['yep'] == 'yop' + # test group retrieval + group2 = block2.groups[0] + assert_array_equal(group1.analogsignals[0], group2.analogsignals[0]) + if __name__ == "__main__": unittest.main()