Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3dc5729
Test IBL extractors tests failing for PI update
alejoe91 Dec 29, 2025
d1a0532
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 6, 2026
33c6769
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 16, 2026
2c94bac
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 20, 2026
a412bd8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Feb 2, 2026
504e19d
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Feb 12, 2026
cd09c19
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Feb 19, 2026
a40d073
Merge branch 'main' of github.com:alejoe91/spikeinterface
alejoe91 Feb 24, 2026
a1da327
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 2, 2026
ef19a8e
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 3, 2026
a098b51
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 6, 2026
61c317a
Fix OpenEphys tests
alejoe91 Mar 6, 2026
c9ff247
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 9, 2026
3520138
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 16, 2026
f61329d
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 16, 2026
d64ae6a
Merge branch 'main' of github.com:alejoe91/spikeinterface
alejoe91 Mar 16, 2026
aef197d
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 17, 2026
e82331b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 20, 2026
710cb6f
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 23, 2026
c2f8db1
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Mar 23, 2026
f9de051
Centralize segment handling to BaseExtractors
alejoe91 Mar 23, 2026
34f6bab
add segments to base recording and sorting for typing
alejoe91 Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from pathlib import Path
import shutil
from typing import Any
Expand Down Expand Up @@ -87,6 +85,8 @@ def __init__(self, main_ids: Sequence) -> None:
self._main_ids.dtype.kind in "uiSU"
), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}"

self._segments: "list[BaseSegment]" = []

# dict at object level
self._annotations = {}

Expand Down Expand Up @@ -142,11 +142,18 @@ def name(self, value):
# we remove the annotation if it exists
_ = self._annotations.pop("name", None)

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for typing we could in BaseRecording for example:

  @property
  def segments(self) -> list[BaseRecordingSegment]:
      return self._segments  # type: ignore[return-value]

As that will enable the analysis of base recording sgements methods that we call on introspection with vscode and ohter tools. Same for bas sorting.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we make this private? what is the reason for making this public?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding it! I think it's handy to have it public. The segments are a big part of the API, so exposing them (as a property) doesn't hurt IMO

def segments(self) -> "list[BaseSegment]":
return self._segments

def add_segment(self, segment: "BaseSegment") -> None:
self._segments.append(segment)
segment.set_parent_extractor(self)

def get_num_segments(self) -> int:
# This is implemented in BaseRecording or BaseSorting
raise NotImplementedError
return len(self._segments)

def get_parent(self) -> BaseExtractor | None:
def get_parent(self) -> "BaseExtractor | None":
"""Returns parent object if it exists, otherwise None"""
return getattr(self, "_parent", None)

Expand Down Expand Up @@ -381,7 +388,7 @@ def delete_property(self, key) -> None:

def copy_metadata(
self,
other: BaseExtractor,
other: "BaseExtractor",
only_main: bool = False,
ids: Iterable | slice | None = None,
skip_properties: Iterable[str] | None = None,
Expand Down Expand Up @@ -570,7 +577,7 @@ def to_dict(
return dump_dict

@staticmethod
def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> BaseExtractor:
def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> "BaseExtractor":
"""
Instantiate extractor from dictionary

Expand Down Expand Up @@ -624,7 +631,7 @@ def save_metadata_to_folder(self, folder_metadata):
values = self.get_property(key)
np.save(prop_folder / (key + ".npy"), values)

def clone(self) -> BaseExtractor:
def clone(self) -> "BaseExtractor":
"""
Clones an existing extractor into a new instance.
"""
Expand Down Expand Up @@ -816,7 +823,7 @@ def dump_to_pickle(
file_path.write_bytes(pickle.dumps(dump_dict))

@staticmethod
def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> BaseExtractor:
def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> "BaseExtractor":
"""
Load extractor from file path (.json or .pkl)

Expand All @@ -839,7 +846,7 @@ def __reduce__(self):
return (instance_constructor, intialization_args)

@staticmethod
def load_from_folder(folder) -> BaseExtractor:
def load_from_folder(folder) -> "BaseExtractor":
return BaseExtractor.load(folder)

def _save(self, folder, **save_kwargs):
Expand All @@ -855,7 +862,7 @@ def _extra_metadata_to_folder(self, folder):
# This implemented in BaseRecording for probe
pass

def save(self, **kwargs) -> BaseExtractor:
def save(self, **kwargs) -> "BaseExtractor":
"""
Save a SpikeInterface object.

Expand Down Expand Up @@ -891,7 +898,7 @@ def save(self, **kwargs) -> BaseExtractor:

save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc)

def save_to_memory(self, sharedmem=True, **save_kwargs) -> BaseExtractor:
def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor":
save_kwargs.pop("format", None)

cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs)
Expand Down Expand Up @@ -1092,7 +1099,7 @@ def save_to_zarr(
return cached


def _load_extractor_from_dict(dic) -> BaseExtractor:
def _load_extractor_from_dict(dic) -> "BaseExtractor":
"""
Convert a dictionary into an instance of BaseExtractor or its subclass.

Expand Down
65 changes: 27 additions & 38 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import warnings
from typing import Literal
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -43,9 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
BaseRecordingSnippets.__init__(
self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype
)

self._recording_segments: list[BaseRecordingSegment] = []

# initialize main annotation and properties
self.annotate(is_filtered=False)

Expand Down Expand Up @@ -171,28 +168,20 @@ def __sub__(self, other):

return SubtractRecordings(self, other)

def get_num_segments(self) -> int:
"""
Returns the number of segments.

Returns
-------
int
Number of segments in the recording
"""
return len(self._recording_segments)
@property
def segments(self) -> list["BaseRecordingSegment"]:
"""List of recording segments."""
return self._segments

def add_recording_segment(self, recording_segment):
def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> None:
"""Adds a recording segment.

Parameters
----------
recording_segment : BaseRecordingSegment
The recording segment to add
"""
# todo: check channel count and sampling frequency
self._recording_segments.append(recording_segment)
recording_segment.set_parent_extractor(self)
super().add_segment(recording_segment)

def get_num_samples(self, segment_index: int | None = None) -> int:
"""
Expand All @@ -211,7 +200,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int:
The number of samples
"""
segment_index = self._check_segment_index(segment_index)
return int(self._recording_segments[segment_index].get_num_samples())
return int(self.segments[segment_index].get_num_samples())

get_num_frames = get_num_samples

Expand Down Expand Up @@ -305,7 +294,7 @@ def get_traces(
start_frame: int | None = None,
end_frame: int | None = None,
channel_ids: list | np.ndarray | tuple | None = None,
order: "C" | "F" | None = None,
order: Literal["C", "F"] | None = None,
return_scaled: bool | None = None,
return_in_uV: bool = False,
) -> np.ndarray:
Expand Down Expand Up @@ -343,7 +332,7 @@ def get_traces(
"""
segment_index = self._check_segment_index(segment_index)
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
start_frame = int(start_frame) if start_frame is not None else 0
num_samples = rs.get_num_samples()
end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples
Expand Down Expand Up @@ -401,7 +390,7 @@ def get_time_info(self, segment_index=None) -> dict:
"""

segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
time_kwargs = rs.get_times_kwargs()

return time_kwargs
Expand All @@ -425,7 +414,7 @@ def get_times(self, segment_index=None) -> np.ndarray:
The 1d times array
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
times = rs.get_times()
return times

Expand All @@ -443,7 +432,7 @@ def get_start_time(self, segment_index=None) -> float:
The start time in seconds
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.get_start_time()

def get_end_time(self, segment_index=None) -> float:
Expand All @@ -460,7 +449,7 @@ def get_end_time(self, segment_index=None) -> float:
The stop time in seconds
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.get_end_time()

def has_time_vector(self, segment_index: int | None = None):
Expand All @@ -477,7 +466,7 @@ def has_time_vector(self, segment_index: int | None = None):
True if the recording has time vectors, False otherwise
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
d = rs.get_times_kwargs()
return d["time_vector"] is not None

Expand All @@ -494,7 +483,7 @@ def set_times(self, times, segment_index=None, with_warning=True):
If True, a warning is printed
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]

assert times.ndim == 1, "Time must have ndim=1"
assert rs.get_num_samples() == times.shape[0], "times have wrong shape"
Expand All @@ -517,7 +506,7 @@ def reset_times(self):
segment's sampling frequency is set to the recording's sampling frequency.
"""
for segment_index in range(self.get_num_segments()):
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
if self.has_time_vector(segment_index):
rs.time_vector = None
rs.t_start = None
Expand Down Expand Up @@ -545,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
segments_to_shift = (segment_index,)

for segment_index in segments_to_shift:
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]

if self.has_time_vector(segment_index=segment_index):
rs.time_vector += shift
Expand All @@ -558,19 +547,19 @@ def sample_index_to_time(self, sample_ind, segment_index=None):
Transform sample index into time in seconds
"""
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.sample_index_to_time(sample_ind)

def time_to_sample_index(self, time_s, segment_index=None):
segment_index = self._check_segment_index(segment_index)
rs = self._recording_segments[segment_index]
rs = self.segments[segment_index]
return rs.time_to_sample_index(time_s)

def _get_t_starts(self):
# handle t_starts
t_starts = []
has_time_vectors = []
for rs in self._recording_segments:
for rs in self.segments:
d = rs.get_times_kwargs()
t_starts.append(d["t_start"])

Expand All @@ -580,7 +569,7 @@ def _get_t_starts(self):

def _get_time_vectors(self):
time_vectors = []
for rs in self._recording_segments:
for rs in self.segments:
d = rs.get_times_kwargs()
time_vectors.append(d["time_vector"])
if all(time_vector is None for time_vector in time_vectors):
Expand Down Expand Up @@ -668,7 +657,7 @@ def _extra_metadata_from_folder(self, folder):
self.set_probegroup(probegroup, in_place=True)

# load time vector if any
for segment_index, rs in enumerate(self._recording_segments):
for segment_index, rs in enumerate(self.segments):
time_file = folder / f"times_cached_seg{segment_index}.npy"
if time_file.is_file():
time_vector = np.load(time_file)
Expand All @@ -681,7 +670,7 @@ def _extra_metadata_to_folder(self, folder):
write_probeinterface(folder / "probe.json", probegroup)

# save time vector if any
for segment_index, rs in enumerate(self._recording_segments):
for segment_index, rs in enumerate(self.segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]
if time_vector is not None:
Expand Down Expand Up @@ -735,7 +724,7 @@ def _remove_channels(self, remove_channel_ids):
sub_recording = ChannelSliceRecording(self, new_channel_ids)
return sub_recording

def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording:
def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRecording":
"""
Returns a new recording with sliced frames. Note that this operation is not in place.

Expand All @@ -757,7 +746,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec
sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame)
return sub_recording

def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording:
def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseRecording":
"""
Returns a new recording object, restricted to the time interval [start_time, end_time].

Expand Down Expand Up @@ -815,7 +804,7 @@ def _select_segments(self, segment_indices):
def get_channel_locations(
self,
channel_ids: list | np.ndarray | tuple | None = None,
axes: "xy" | "yz" | "xz" | "xyz" = "xy",
axes: Literal["xy", "yz", "xz", "xyz"] = "xy",
) -> np.ndarray:
"""
Get the physical locations of specified channels.
Expand Down
Loading
Loading