Skip to content

Commit

Permalink
Merge pull request #250 from h-mayorquin/add_save_to_zarr
Browse files Browse the repository at this point in the history
Add save to zarr
  • Loading branch information
samuelgarcia authored Feb 5, 2024
2 parents 4d15eda + fc09d62 commit d5fe93b
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 11 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ test = [
"scipy",
"pandas",
"h5py",
]
"zarr>=2.16.0"
]

docs = [
"pillow",
Expand Down
184 changes: 176 additions & 8 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
import numpy as np
from typing import Optional

from pathlib import Path
import json

from .shank import Shank

Expand Down Expand Up @@ -197,16 +198,24 @@ def __repr__(self):
return self.get_title()

def annotate(self, **kwargs):
"""Annotates the probe object.
"""
Annotates the probe object.
Parameter
---------
**kwargs : list of keyword arguments to add to the annotations
Parameters
----------
**kwargs : list of keyword arguments to add to the annotations (e.g., brain_area="CA1")
"""
self.annotations.update(kwargs)
self.check_annotations()

def annotate_contacts(self, **kwargs):
"""
Annotates the contacts of the probe.
Parameters
----------
**kwargs : list of keyword arguments to add to the annotations (e.g., quality=["good", "bad", ...])
"""
n = self.get_contact_count()
for k, values in kwargs.items():
assert len(values) == n, (
Expand Down Expand Up @@ -506,6 +515,13 @@ def __eq__(self, other):
if not np.array_equal(self.contact_annotations[key], other.contact_annotations[key]):
return False

# planar contour
if self.probe_planar_contour is not None:
if other.probe_planar_contour is None:
return False
if not np.array_equal(self.probe_planar_contour, other.probe_planar_contour):
return False

return True

def copy(self):
Expand Down Expand Up @@ -862,7 +878,7 @@ def to_numpy(self, complete: bool = False) -> np.array:
dtype += [(f"plane_axis_{dim}_0", "float64")]
dtype += [(f"plane_axis_{dim}_1", "float64")]
for k, v in self.contact_annotations.items():
dtype += [(f"{k}", np.dtype(v[0]))]
dtype += [(f"{k}", np.array(v, copy=False).dtype)]

arr = np.zeros(self.get_contact_count(), dtype=dtype)
arr["x"] = self.contact_positions[:, 0]
Expand Down Expand Up @@ -916,8 +932,28 @@ def from_numpy(arr: np.ndarray) -> "Probe":
probe : Probe
The instantiated Probe object
"""

fields = list(arr.dtype.fields)
main_fields = [
"x",
"y",
"z",
"contact_shapes",
"shank_ids",
"contact_ids",
"device_channel_indices",
"radius",
"width",
"height",
"plane_axis_x_0",
"plane_axis_x_1",
"plane_axis_y_0",
"plane_axis_y_1",
"plane_axis_z_0",
"plane_axis_z_1",
"probe_index",
"si_units",
]
contact_annotation_fields = [f for f in fields if f not in main_fields]

if "z" in fields:
ndim = 3
Expand Down Expand Up @@ -964,14 +1000,146 @@ def from_numpy(arr: np.ndarray) -> "Probe":

if "device_channel_indices" in fields:
dev_channel_indices = arr["device_channel_indices"]
probe.set_device_channel_indices(dev_channel_indices)
if not np.all(dev_channel_indices == -1):
probe.set_device_channel_indices(dev_channel_indices)
if "shank_ids" in fields:
probe.set_shank_ids(arr["shank_ids"])
if "contact_ids" in fields:
probe.set_contact_ids(arr["contact_ids"])

# contact annotations
for k in contact_annotation_fields:
probe.annotate_contacts(**{k: arr[k]})
return probe

def add_probe_to_zarr_group(self, group: "zarr.Group") -> None:
"""
Serialize the probe's data and structure to a specified Zarr group.
This method is used to save the probe's attributes, annotations, and other
related data into a Zarr group, facilitating integration into larger Zarr
structures.
Parameters
----------
group : zarr.Group
The target Zarr group where the probe's data will be stored.
"""
probe_arr = self.to_numpy(complete=True)

# add fields and contact annotations
for field_name, (dtype, offset) in probe_arr.dtype.fields.items():
data = probe_arr[field_name]
group.create_dataset(name=field_name, data=data, dtype=dtype, chunks=False)

# Annotations as a group (special attibutes are stored as annotations)
annotations_group = group.create_group("annotations")
for key, value in self.annotations.items():
annotations_group.attrs[key] = value

# Add planar contour
if self.probe_planar_contour is not None:
group.create_dataset(
name="probe_planar_contour", data=self.probe_planar_contour, dtype="float64", chunks=False
)

def to_zarr(self, folder_path: str | Path) -> None:
"""
Serialize the Probe object to a Zarr file located at the specified folder path.
This method initializes a new Zarr group at the given folder path and calls
`add_probe_to_zarr_group` to serialize the Probe's data into this group, effectively
storing the entire Probe's state in a Zarr archive.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr data structure will be created and
where the serialized data will be stored. If the folder does not exist,
it will be created.
"""
import zarr

# Create or open a Zarr group for writing
zarr_group = zarr.open_group(folder_path, mode="w")

# Serialize this Probe object into the Zarr group
self.add_probe_to_zarr_group(zarr_group)

@staticmethod
def from_zarr_group(group: zarr.Group) -> "Probe":
"""
Load a probe instance from a given Zarr group.
Parameters
----------
group : zarr.Group
The Zarr group from which to load the probe.
Returns
-------
Probe
An instance of the Probe class initialized with data from the Zarr group.
"""
import zarr

dtype = []
# load all datasets
num_contacts = None
probe_arr_keys = []
for key in group.keys():
if key == "probe_planar_contour":
continue
if key == "annotations":
continue
dset = group[key]
if isinstance(dset, zarr.Array):
probe_arr_keys.append(key)
dtype.append((key, dset.dtype))
if num_contacts is None:
num_contacts = len(dset)

# Create a structured array from the datasets
probe_arr = np.zeros(num_contacts, dtype=dtype)

for probe_key in probe_arr_keys:
probe_arr[probe_key] = group[probe_key][:]

# Create a Probe instance from the structured array
probe = Probe.from_numpy(probe_arr)

# Load annotations
annotations_group = group.get("annotations", None)
for key in annotations_group.attrs.keys():
# Use the annotate method for each key-value pair
probe.annotate(**{key: annotations_group.attrs[key]})

if "probe_planar_contour" in group:
# Directly assign since there's no specific setter for probe_planar_contour
probe.probe_planar_contour = group["probe_planar_contour"][:]

return probe

@staticmethod
def from_zarr(folder_path: str | Path) -> "Probe":
"""
Deserialize the Probe object from a Zarr file located at the given folder path.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr file is located.
Returns
-------
Probe
An instance of the Probe class initialized with data from the Zarr file.
"""
import zarr

zarr_group = zarr.open(folder_path, mode="r")
return Probe.from_zarr_group(zarr_group)

def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
"""
Export the probe to a pandas dataframe
Expand Down
24 changes: 22 additions & 2 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from probeinterface import Probe
from probeinterface.generator import generate_dummy_probe
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -148,7 +149,7 @@ def test_probe_equality_dunder():

# Modify probe2
probe2.move([1, 1])
assert probe2 != probe1
assert probe1 != probe2


def test_set_shanks():
Expand All @@ -162,7 +163,26 @@ def test_set_shanks():
assert all(probe.shank_ids == shank_ids.astype(str))


def test_save_to_zarr(tmp_path):
# Generate a dummy probe instance
probe = generate_dummy_probe()

# Define file path in the temporary directory
folder_path = Path(tmp_path) / "probe.zarr"

# Save the probe object to Zarr format
probe.to_zarr(folder_path=folder_path)

# Reload the probe object from the saved Zarr file
reloaded_probe = Probe.from_zarr(folder_path=folder_path)

# Assert that the reloaded probe is equal to the original
assert probe == reloaded_probe, "Reloaded Probe object does not match the original"


if __name__ == "__main__":
test_probe()

test_set_shanks()
tmp_path = Path("tmp")
tmp_path.mkdir(exist_ok=True)
test_save_to_zarr(tmp_path)

0 comments on commit d5fe93b

Please sign in to comment.