Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from probeinterface import ProbeGroup

from .base import minimum_spike_dtype
from .base import minimum_spike_dtype, _get_class_from_string
from .baserecording import BaseRecording, BaseRecordingSegment
from .basesorting import BaseSorting, SpikeVectorSortingSegment
from .core_tools import define_function_from_class, check_json
from .core_tools import define_function_from_class, check_json, retrieve_importing_provenance
from .job_tools import split_job_kwargs
from .core_tools import is_path_remote

Expand Down Expand Up @@ -212,6 +212,7 @@ def write_recording(
recording: BaseRecording, folder_path: str | Path, storage_options: dict | None = None, **kwargs
):
zarr_root = zarr.open(str(folder_path), mode="w", storage_options=storage_options)
zarr_root.attrs["zarr_class_info"] = retrieve_importing_provenance(ZarrRecordingExtractor)
add_recording_to_zarr_group(recording, zarr_root, **kwargs)


Expand Down Expand Up @@ -320,6 +321,7 @@ def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options
Write a sorting extractor to zarr format.
"""
zarr_root = zarr.open(str(folder_path), mode="w", storage_options=storage_options)
zarr_root.attrs["zarr_class_info"] = retrieve_importing_provenance(ZarrSortingExtractor)
add_sorting_to_zarr_group(sorting, zarr_root, **kwargs)


Expand All @@ -345,15 +347,22 @@ def read_zarr(
extractor : ZarrExtractor
The loaded extractor
"""
# TODO @alessio : we should have something more explicit in our zarr format to tell which object it is.
# for the futur SortingAnalyzer we will have this 2 fields!!!
root = super_zarr_open(folder_path, mode="r", storage_options=storage_options)
if "channel_ids" in root.keys():
return read_zarr_recording(folder_path, storage_options=storage_options)
elif "unit_ids" in root.keys():
return read_zarr_sorting(folder_path, storage_options=storage_options)
zarr_class_info = root.attrs.get("zarr_class_info", None)
if zarr_class_info is not None:
class_name = zarr_class_info["class"]
extractor_class = _get_class_from_string(class_name)
return extractor_class(folder_path, storage_options=storage_options)
else:
raise ValueError("Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format")
# For version<0.105.0 zarr files, revert to old way of loading based on the presence of "channel_ids"/"unit_ids"
if "channel_ids" in root.keys():
return read_zarr_recording(folder_path, storage_options=storage_options)
elif "unit_ids" in root.keys():
return read_zarr_sorting(folder_path, storage_options=storage_options)
else:
raise ValueError(
"Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format"
)


### UTILITY FUNCTIONS ###
Expand Down
Loading