diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9dc270d38d..8d149a7c49 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from pathlib import Path import shutil from typing import Any @@ -87,6 +85,8 @@ def __init__(self, main_ids: Sequence) -> None: self._main_ids.dtype.kind in "uiSU" ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" + self._segments: "list[BaseSegment]" = [] + # dict at object level self._annotations = {} @@ -142,11 +142,18 @@ def name(self, value): # we remove the annotation if it exists _ = self._annotations.pop("name", None) + @property + def segments(self) -> "list[BaseSegment]": + return self._segments + + def add_segment(self, segment: "BaseSegment") -> None: + self._segments.append(segment) + segment.set_parent_extractor(self) + def get_num_segments(self) -> int: - # This is implemented in BaseRecording or BaseSorting - raise NotImplementedError + return len(self._segments) - def get_parent(self) -> BaseExtractor | None: + def get_parent(self) -> "BaseExtractor | None": """Returns parent object if it exists, otherwise None""" return getattr(self, "_parent", None) @@ -381,7 +388,7 @@ def delete_property(self, key) -> None: def copy_metadata( self, - other: BaseExtractor, + other: "BaseExtractor", only_main: bool = False, ids: Iterable | slice | None = None, skip_properties: Iterable[str] | None = None, @@ -570,7 +577,7 @@ def to_dict( return dump_dict @staticmethod - def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> BaseExtractor: + def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> "BaseExtractor": """ Instantiate extractor from dictionary @@ -624,7 +631,7 @@ def save_metadata_to_folder(self, folder_metadata): values = self.get_property(key) np.save(prop_folder / (key + ".npy"), values) - def clone(self) -> BaseExtractor: + def clone(self) -> "BaseExtractor": """ Clones an existing extractor into a new instance. """ @@ -816,7 +823,7 @@ def dump_to_pickle( file_path.write_bytes(pickle.dumps(dump_dict)) @staticmethod - def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> BaseExtractor: + def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> "BaseExtractor": """ Load extractor from file path (.json or .pkl) @@ -839,7 +846,7 @@ def __reduce__(self): return (instance_constructor, intialization_args) @staticmethod - def load_from_folder(folder) -> BaseExtractor: + def load_from_folder(folder) -> "BaseExtractor": return BaseExtractor.load(folder) def _save(self, folder, **save_kwargs): @@ -855,7 +862,7 @@ def _extra_metadata_to_folder(self, folder): # This implemented in BaseRecording for probe pass - def save(self, **kwargs) -> BaseExtractor: + def save(self, **kwargs) -> "BaseExtractor": """ Save a SpikeInterface object. @@ -891,7 +898,7 @@ def save(self, **kwargs) -> BaseExtractor: save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) - def save_to_memory(self, sharedmem=True, **save_kwargs) -> BaseExtractor: + def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor": save_kwargs.pop("format", None) cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) @@ -1092,7 +1099,7 @@ def save_to_zarr( return cached -def _load_extractor_from_dict(dic) -> BaseExtractor: +def _load_extractor_from_dict(dic) -> "BaseExtractor": """ Convert a dictionary into an instance of BaseExtractor or its subclass. diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 75bd47597b..f23b524271 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,5 +1,5 @@ -from __future__ import annotations import warnings +from typing import Literal from pathlib import Path import numpy as np @@ -43,9 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): BaseRecordingSnippets.__init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype ) - - self._recording_segments: list[BaseRecordingSegment] = [] - # initialize main annotation and properties self.annotate(is_filtered=False) @@ -171,18 +168,12 @@ def __sub__(self, other): return SubtractRecordings(self, other) - def get_num_segments(self) -> int: - """ - Returns the number of segments. - - Returns - ------- - int - Number of segments in the recording - """ - return len(self._recording_segments) + @property + def segments(self) -> list["BaseRecordingSegment"]: + """List of recording segments.""" + return self._segments - def add_recording_segment(self, recording_segment): + def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> None: """Adds a recording segment. Parameters @@ -190,9 +181,7 @@ def add_recording_segment(self, recording_segment): recording_segment : BaseRecordingSegment The recording segment to add """ - # todo: check channel count and sampling frequency - self._recording_segments.append(recording_segment) - recording_segment.set_parent_extractor(self) + super().add_segment(recording_segment) def get_num_samples(self, segment_index: int | None = None) -> int: """ @@ -211,7 +200,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int: The number of samples """ segment_index = self._check_segment_index(segment_index) - return int(self._recording_segments[segment_index].get_num_samples()) + return int(self.segments[segment_index].get_num_samples()) get_num_frames = get_num_samples @@ -305,7 +294,7 @@ def get_traces( start_frame: int | None = None, end_frame: int | None = None, channel_ids: list | np.ndarray | tuple | None = None, - order: "C" | "F" | None = None, + order: Literal["C", "F"] | None = None, return_scaled: bool | None = None, return_in_uV: bool = False, ) -> np.ndarray: @@ -343,7 +332,7 @@ def get_traces( """ segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] start_frame = int(start_frame) if start_frame is not None else 0 num_samples = rs.get_num_samples() end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples @@ -401,7 +390,7 @@ def get_time_info(self, segment_index=None) -> dict: """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] time_kwargs = rs.get_times_kwargs() return time_kwargs @@ -425,7 +414,7 @@ def get_times(self, segment_index=None) -> np.ndarray: The 1d times array """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] times = rs.get_times() return times @@ -443,7 +432,7 @@ def get_start_time(self, segment_index=None) -> float: The start time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.get_start_time() def get_end_time(self, segment_index=None) -> float: @@ -460,7 +449,7 @@ def get_end_time(self, segment_index=None) -> float: The stop time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.get_end_time() def has_time_vector(self, segment_index: int | None = None): @@ -477,7 +466,7 @@ def has_time_vector(self, segment_index: int | None = None): True if the recording has time vectors, False otherwise """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] d = rs.get_times_kwargs() return d["time_vector"] is not None @@ -494,7 +483,7 @@ def set_times(self, times, segment_index=None, with_warning=True): If True, a warning is printed """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] assert times.ndim == 1, "Time must have ndim=1" assert rs.get_num_samples() == times.shape[0], "times have wrong shape" @@ -517,7 +506,7 @@ def reset_times(self): segment's sampling frequency is set to the recording's sampling frequency. """ for segment_index in range(self.get_num_segments()): - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] if self.has_time_vector(segment_index): rs.time_vector = None rs.t_start = None @@ -545,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N segments_to_shift = (segment_index,) for segment_index in segments_to_shift: - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] if self.has_time_vector(segment_index=segment_index): rs.time_vector += shift @@ -558,19 +547,19 @@ def sample_index_to_time(self, sample_ind, segment_index=None): Transform sample index into time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.sample_index_to_time(sample_ind) def time_to_sample_index(self, time_s, segment_index=None): segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.time_to_sample_index(time_s) def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for rs in self._recording_segments: + for rs in self.segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) @@ -580,7 +569,7 @@ def _get_t_starts(self): def _get_time_vectors(self): time_vectors = [] - for rs in self._recording_segments: + for rs in self.segments: d = rs.get_times_kwargs() time_vectors.append(d["time_vector"]) if all(time_vector is None for time_vector in time_vectors): @@ -668,7 +657,7 @@ def _extra_metadata_from_folder(self, folder): self.set_probegroup(probegroup, in_place=True) # load time vector if any - for segment_index, rs in enumerate(self._recording_segments): + for segment_index, rs in enumerate(self.segments): time_file = folder / f"times_cached_seg{segment_index}.npy" if time_file.is_file(): time_vector = np.load(time_file) @@ -681,7 +670,7 @@ def _extra_metadata_to_folder(self, folder): write_probeinterface(folder / "probe.json", probegroup) # save time vector if any - for segment_index, rs in enumerate(self._recording_segments): + for segment_index, rs in enumerate(self.segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] if time_vector is not None: @@ -735,7 +724,7 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording - def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording: + def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRecording": """ Returns a new recording with sliced frames. Note that this operation is not in place. @@ -757,7 +746,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording - def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording: + def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseRecording": """ Returns a new recording object, restricted to the time interval [start_time, end_time]. @@ -815,7 +804,7 @@ def _select_segments(self, segment_indices): def get_channel_locations( self, channel_ids: list | np.ndarray | tuple | None = None, - axes: "xy" | "yz" | "xz" | "xyz" = "xy", + axes: Literal["xy", "yz", "xz", "xyz"] = "xy", ) -> np.ndarray: """ Get the physical locations of specified channels. diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index d49065e28d..cb68f3d455 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,4 +1,3 @@ -from __future__ import annotations import warnings from copy import deepcopy @@ -17,7 +16,6 @@ class BaseSorting(BaseExtractor): def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) self._sampling_frequency = float(sampling_frequency) - self._sorting_segments: list[BaseSortingSegment] = [] # this weak link is to handle times from a recording object self._recording = None self._sorting_info = None @@ -62,6 +60,11 @@ def _repr_html_(self, display_name=True): html_repr = html_header + html_unit_ids + html_extra return html_repr + @property + def segments(self) -> list["BaseSortingSegment"]: + """List of sorting segments.""" + return self._segments + @property def unit_ids(self): return self._main_ids @@ -76,16 +79,12 @@ def get_unit_ids(self) -> list: def get_num_units(self) -> int: return len(self.get_unit_ids()) - def add_sorting_segment(self, sorting_segment): - self._sorting_segments.append(sorting_segment) - sorting_segment.set_parent_extractor(self) + def add_sorting_segment(self, sorting_segment: "BaseSortingSegment") -> None: + super().add_segment(sorting_segment) def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self) -> int: - return len(self._sorting_segments) - def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. @@ -200,7 +199,7 @@ def get_unit_spike_train( end = np.searchsorted(spike_frames, end_frame) spike_frames = spike_frames[:end] else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] spike_frames = segment.get_unit_spike_train( unit_id=unit_id, start_frame=start_frame, end_frame=end_frame ).astype("int64") @@ -244,7 +243,7 @@ def get_unit_spike_train_in_seconds( Spike times in seconds """ segment_index = self._check_segment_index(segment_index) - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording # Note that this take into account the segment start time of the recording @@ -497,7 +496,7 @@ def count_total_num_spikes(self) -> int: """ return self.to_spike_vector().size - def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: + def select_units(self, unit_ids, renamed_unit_ids=None) -> "BaseSorting": """ Returns a new sorting object which contains only a selected subset of units. @@ -519,7 +518,7 @@ def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids) return sub_sorting - def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: + def rename_units(self, new_unit_ids: np.ndarray | list) -> "BaseSorting": """ Returns a new sorting object with renamed units. @@ -540,7 +539,7 @@ def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids) return sub_sorting - def remove_units(self, remove_unit_ids) -> BaseSorting: + def remove_units(self, remove_unit_ids) -> "BaseSorting": """ Returns a new sorting object with contains only a selected subset of units. @@ -613,7 +612,7 @@ def frame_slice(self, start_frame, end_frame, check_spike_frames=True): ) return sub_sorting - def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSorting: + def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseSorting": """ Returns a new sorting object, restricted to the time interval [start_time, end_time]. @@ -705,7 +704,7 @@ def time_to_sample_index(self, time, segment_index=0): if self.has_recording(): sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index) else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 sample_index = round((time - t_start) * self.get_sampling_frequency()) @@ -721,7 +720,7 @@ def sample_index_to_time( if self.has_recording(): return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 return (sample_index / self.get_sampling_frequency()) + t_start @@ -754,7 +753,7 @@ def _compute_and_cache_spike_vector(self) -> None: sample_indices = [] unit_indices = [] for u, unit_id in enumerate(self.unit_ids): - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype( "int64" ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 62de7e8fde..b3eaa099ed 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -159,11 +159,10 @@ def __del__(self): Closes any open file handles in the recording segments. """ # Close all recording segments - if hasattr(self, "_recording_segments"): - for segment in self._recording_segments: - # This will trigger the __del__ method of the BinaryRecordingSegment - # which will close the file handle - del segment + for segment in self.segments: + # This will trigger the __del__ method of the BinaryRecordingSegment + # which will close the file handle + del segment BinaryRecordingExtractor.write_recording.__doc__ = BinaryRecordingExtractor.write_recording.__doc__.format( diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 0da4797440..697aab875e 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -127,7 +127,7 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record ch_id += 1 for i_seg in range(num_segments): - parent_segments = [rec._recording_segments[i_seg] for rec in recording_list] + parent_segments = [rec.segments[i_seg] for rec in recording_list] sub_segment = ChannelsAggregationRecordingSegment(channel_map, parent_segments) self.add_recording_segment(sub_segment) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 67d25b2925..de693d5c26 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -53,7 +53,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent_channel_indices = parent_recording.ids_to_indices(self._channel_ids) # link recording segment - for parent_segment in parent_recording._recording_segments: + for parent_segment in parent_recording.segments: sub_segment = ChannelSliceRecordingSegment(parent_segment, self._parent_channel_indices) self.add_recording_segment(sub_segment) diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 5cc4daa7ed..513c8b3dfb 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -46,7 +46,7 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): ) # link recording segment - parent_segment = parent_recording._recording_segments[0] + parent_segment = parent_recording.segments[0] sub_segment = FrameSliceRecordingSegment(parent_segment, start_frame=int(start_frame), end_frame=int(end_frame)) self.add_recording_segment(sub_segment) diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 0d1f307d2e..a337e83707 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -75,7 +75,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike BaseSorting.__init__(self, sampling_frequency=parent_sorting.get_sampling_frequency(), unit_ids=unit_ids) # link sorting segment - parent_segment = parent_sorting._sorting_segments[0] + parent_segment = parent_sorting.segments[0] sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9b40a23dbd..35116a9e4c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1993,9 +1993,7 @@ def __init__( amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None if parent_recording is None else parent_recording._recording_segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectTemplatesRecordingSegment( self.sampling_frequency, self.dtype, diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 2c38248c1a..43cdd30c87 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -657,7 +657,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c nodes = worker_ctx["nodes"] skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] - recording_segment = recording._recording_segments[segment_index] + recording_segment = recording.segments[segment_index] retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) # get peak slices once for all retrievers peak_slice_by_retriever = {} diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 1200612864..31a3a8831d 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -198,7 +198,7 @@ def __init__( } def __del__(self): - self._recording_segments = [] + self._segments = [] for shm in self.shms: shm.close() if self.main_shm_owner: diff --git a/src/spikeinterface/core/operatorrecordings.py b/src/spikeinterface/core/operatorrecordings.py index 6ffb7d9fa3..63332bffa1 100644 --- a/src/spikeinterface/core/operatorrecordings.py +++ b/src/spikeinterface/core/operatorrecordings.py @@ -25,7 +25,7 @@ def __init__(self, recording1, recording2, operator: str): BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) - for segment1, segment2 in zip(recording1._recording_segments, recording2._recording_segments): + for segment1, segment2 in zip(recording1.segments, recording2.segments): add_segment = OperatorRecordingSegment(segment1, segment2, operator) self.add_recording_segment(add_segment) @@ -35,8 +35,8 @@ def are_times_kwargs_compatible(self, recording1, recording2) -> bool: import numpy as np for segment_index in range(recording1.get_num_segments()): - time_kwargs1 = recording1._recording_segments[segment_index].get_times_kwargs() - time_kwargs2 = recording2._recording_segments[segment_index].get_times_kwargs() + time_kwargs1 = recording1.segments[segment_index].get_times_kwargs() + time_kwargs2 = recording2.segments[segment_index].get_times_kwargs() for key in time_kwargs1.keys(): val1 = time_kwargs1[key] val2 = time_kwargs2[key] diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 6b563ff1d7..3d99fd23c4 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -63,7 +63,7 @@ def __init__(self, recording_list, sampling_frequency_max_diff=0): rec0.copy_metadata(self) for rec in recording_list: - for parent_segment in rec._recording_segments: + for parent_segment in rec.segments: rec_seg = ProxyAppendRecordingSegment(parent_segment) self.add_recording_segment(rec_seg) @@ -119,7 +119,7 @@ def __init__(self, recording_list, ignore_times=True, sampling_frequency_max_dif parent_segments = [] for rec in recording_list: - for parent_segment in rec._recording_segments: + for parent_segment in rec.segments: time_kwargs = parent_segment.get_times_kwargs() if not ignore_times: assert time_kwargs["time_vector"] is None, ( @@ -240,7 +240,7 @@ def __init__(self, recording: BaseRecording, segment_indices: int | list[int]): ), f"'segment_index' must be between 0 and {num_segments - 1}" for segment_index in segment_indices: - rec_seg = recording._recording_segments[segment_index] + rec_seg = recording.segments[segment_index] self.add_recording_segment(rec_seg) self._parent = recording @@ -302,7 +302,7 @@ def __init__(self, sorting_list, sampling_frequency_max_diff=0): sorting0.copy_metadata(self) for sorting in sorting_list: - for parent_segment in sorting._sorting_segments: + for parent_segment in sorting.segments: sorting_seg = ProxyAppendSortingSegment(parent_segment) self.add_sorting_segment(sorting_seg) @@ -384,7 +384,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam parent_segments = [] parent_num_samples = [] for sorting_i, sorting in enumerate(sorting_list): - for segment_i, parent_segment in enumerate(sorting._sorting_segments): + for segment_i, parent_segment in enumerate(sorting.segments): # Check t_start is not assigned segment_t_start = parent_segment._t_start if not ignore_times: @@ -438,7 +438,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam def get_num_samples(self, segment_index=None): """Overrides the BaseSorting method, which requires a recording.""" segment_index = self._check_segment_index(segment_index) - n_samples = self._sorting_segments[segment_index].get_num_samples() + n_samples = self.segments[segment_index].get_num_samples() if self.has_recording(): # Sanity check assert n_samples == self._recording.get_num_samples(segment_index) return n_samples @@ -554,7 +554,7 @@ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None num_samples = [0] for recording in recording_list: - for recording_segment in recording._recording_segments: + for recording_segment in recording.segments: num_samples.append(recording_segment.get_num_samples()) cumsum_num_samples = np.cumsum(num_samples) @@ -562,7 +562,7 @@ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None sliced_parent_sorting = parent_sorting.frame_slice( start_frame=cumsum_num_samples[idx], end_frame=cumsum_num_samples[idx + 1] ) - sliced_segment = sliced_parent_sorting._sorting_segments[0] + sliced_segment = sliced_parent_sorting.segments[0] self.add_sorting_segment(sliced_segment) self._parent = parent_sorting @@ -597,7 +597,7 @@ def __init__(self, sorting: BaseSorting, segment_indices: int | list[int]): ), f"'segment_index' must be between 0 and {num_segments - 1}" for segment_index in segment_indices: - sort_seg = sorting._sorting_segments[segment_index] + sort_seg = sorting.segments[segment_index] self.add_sorting_segment(sort_seg) self._kwargs = {"sorting": sorting, "segment_indices": [int(s) for s in segment_indices]} diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 91eb7df864..05963520cd 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import numpy as np @@ -619,13 +619,15 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: "radius" | "best_channels" | "closest_channels" | "snr" | "amplitude" | "energy" | "by_property" = "radius", - peak_sign: "neg" | "pos" | "both" = "neg", + method: Literal[ + "radius", "best_channels", "closest_channels", "snr", "amplitude", "energy", "by_property" + ] = "radius", + peak_sign: Literal["neg", "pos", "both"] = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + amplitude_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", ) -> ChannelSparsity: """ Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. @@ -718,12 +720,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "closest_channels" | "amplitude" | "snr" | "by_property" = "radius", - peak_sign: "neg" | "pos" | "both" = "neg", + method: Literal["radius", "best_channels", "closest_channels", "amplitude", "snr", "by_property"] = "radius", + peak_sign: Literal["neg", "pos", "both"] = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, - amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + amplitude_mode: Literal["extremum", "peak_to_peak"] = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, **job_kwargs, diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3e5a517b0a..67ba1179b0 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,4 +1,3 @@ -from __future__ import annotations import numpy as np import json from dataclasses import dataclass, field, astuple, replace @@ -140,7 +139,7 @@ def __repr__(self): return repr_str - def select_units(self, unit_ids) -> Templates: + def select_units(self, unit_ids) -> "Templates": """ Return a new Templates object with only the selected units. @@ -162,7 +161,7 @@ def select_units(self, unit_ids) -> Templates: check_for_consistent_sparsity=False, ) - def select_channels(self, channel_ids) -> Templates: + def select_channels(self, channel_ids) -> "Templates": """ Return a new Templates object with only the selected channels. This operation can be useful to remove bad channels for hybrid recording diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 0293c23876..b6a6b90bb2 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import numpy as np @@ -62,8 +62,8 @@ def _get_nbefore(one_object): def get_template_amplitudes( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", return_in_uV: bool = True, abs_value: bool = True, operator: str = "average", @@ -135,9 +135,9 @@ def get_template_amplitudes( def get_template_extremum_channel( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", - outputs: "id" | "index" = "id", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + outputs: Literal["id", "index"] = "id", operator: str = "average", ): """ @@ -202,7 +202,9 @@ def get_template_extremum_channel( def get_template_extremum_channel_peak_shift( - templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", operator: str = "average" + templates_or_sorting_analyzer, + peak_sign: Literal["neg", "pos", "both"] = "neg", + operator: Literal["average", "median"] = "average", ): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. @@ -228,7 +230,9 @@ def get_template_extremum_channel_peak_shift( channel_ids = templates_or_sorting_analyzer.channel_ids nbefore = _get_nbefore(templates_or_sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign) + extremum_channels_ids = get_template_extremum_channel( + templates_or_sorting_analyzer, peak_sign=peak_sign, operator=operator + ) shifts = {} @@ -265,10 +269,10 @@ def get_template_extremum_channel_peak_shift( def get_template_extremum_amplitude( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "at_index", abs_value: bool = True, - operator: str = "average", + operator: Literal["average", "median"] = "average", ): """ Computes amplitudes on the best channel. diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 02798099ec..405a2ecccf 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -251,7 +251,7 @@ def test_get_noise_levels_output(): def test_get_chunk_with_margin(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0]) - rec_seg = rec._recording_segments[0] + rec_seg = rec.segments[0] length = rec_seg.get_num_samples() #  rec_segment, start_frame, end_frame, channel_indices, sample_margin diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index f22939c33c..e03096ce14 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -64,7 +64,7 @@ def _get_time_vector_recording(self, raw_recording): times_recording.set_times(times=time_vector, segment_index=segment_index) assert np.array_equal( - times_recording._recording_segments[segment_index].time_vector, + times_recording.segments[segment_index].time_vector, time_vector, ), "time_vector was not properly set during test setup" @@ -84,7 +84,7 @@ def _get_t_start_recording(self, raw_recording): t_start = (segment_index + 1) * 100 all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) - t_start_recording._recording_segments[segment_index].t_start = t_start + t_start_recording.segments[segment_index].t_start = t_start return (raw_recording, t_start_recording, all_t_starts) @@ -442,6 +442,6 @@ def test_shift_times_with_None_as_t_start(): """Ensures we can shift times even when t_stat is None which is interpeted as zero""" recording = generate_recording(num_channels=4, durations=[10]) - assert recording._recording_segments[0].t_start is None + assert recording.segments[0].t_start is None recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error assert recording.get_start_time() == 1.0 diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 84d2c06e59..32040f8f61 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -134,7 +134,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_d # add segments for i_seg in range(num_segments): - parent_segments = [sort._sorting_segments[i_seg] for sort in sorting_list] + parent_segments = [sort.segments[i_seg] for sort in sorting_list] sub_segment = UnitsAggregationSortingSegment(unit_map, parent_segments) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index b0f3b19472..59356db976 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -33,7 +33,7 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): BaseSorting.__init__(self, sampling_frequency, self._renamed_unit_ids) - for parent_segment in self._parent_sorting._sorting_segments: + for parent_segment in self._parent_sorting.segments: sub_segment = UnitsSelectionSortingSegment(parent_segment, ids_conversion) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e58ef4ee68..8f266e0123 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -500,7 +500,7 @@ def add_recording_to_zarr_group( # save time vector if any t_starts = np.zeros(recording.get_num_segments(), dtype="float64") * np.nan - for segment_index, rs in enumerate(recording._recording_segments): + for segment_index, rs in enumerate(recording.segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 888f5964ca..9e0b17632e 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -8,8 +8,6 @@ non_soma: Non-somatic units (axonal) """ -from __future__ import annotations - import operator from pathlib import Path import json @@ -87,7 +85,7 @@ def bombcell_label_units( thresholds: dict | str | Path | None = None, label_non_somatic: bool = True, split_non_somatic_good_mua: bool = False, - external_metrics: "pd.DataFrame | list[pd.DataFrame]" | None = None, + external_metrics: "pd.DataFrame | list[pd.DataFrame] | None" = None, ) -> "pd.DataFrame": """ Label units based on quality metrics and template metrics using Bombcell logic: diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index 31ce825c7f..ff2c32d07f 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import numpy as np @@ -61,7 +61,7 @@ def _find_duplicated_spikes_numpy( spike_train: np.ndarray, censored_period: int, seed: int | None = None, - method: "keep_first" | "random" | "keep_last" = "keep_first", + method: Literal["keep_first", "random", "keep_last"] = "keep_first", ) -> np.ndarray: (indices_of_duplicates,) = np.where(np.diff(spike_train) <= censored_period) @@ -138,7 +138,7 @@ def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period): def find_duplicated_spikes( spike_train, censored_period: int, - method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random" = "random", + method: Literal["keep_first", "keep_last", "keep_first_iterative", "keep_last_iterative", "random"] = "random", seed: int | None = None, ) -> np.ndarray: """ diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index ff83edaca2..9d9e10e75f 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -68,7 +68,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy rm_dup_delta = None else: rm_dup_delta = int(delta_time_ms / 1000 * sampling_frequency) - for parent_segment in self._parent_sorting._sorting_segments: + for parent_segment in self._parent_sorting.segments: sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 2ff3456822..33d342ff14 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -37,7 +37,7 @@ def __init__(self, sorting: BaseSorting, censored_period_ms: float = 0.3, method censored_period = int(round(censored_period_ms * 1e-3 * sorting.get_sampling_frequency())) seed = np.random.randint(low=0, high=np.iinfo(np.int32).max) - for segment in sorting._sorting_segments: + for segment in sorting.segments: self.add_sorting_segment( RemoveDuplicatedSpikesSortingSegment(segment, censored_period, sorting.unit_ids, method, seed) ) diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 020037b2b7..04169808f5 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -32,7 +32,7 @@ def __init__(self, sorting: BaseSorting, recording: BaseRecording) -> None: self._parent_sorting = sorting self._num_samples = np.empty(sorting.get_num_segments(), dtype=np.int64) for segment_index in range(sorting.get_num_segments()): - sorting_segment = sorting._sorting_segments[segment_index] + sorting_segment = sorting.segments[segment_index] self._num_samples[segment_index] = recording.get_num_samples(segment_index=segment_index) self.add_sorting_segment( RemoveExcessSpikesSortingSegment(sorting_segment, self._num_samples[segment_index]) diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index f5a548113d..c09f57df5a 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -78,7 +78,7 @@ def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, prop np.isin(unchanged_units, self.unit_ids) ), "new_unit_ids should have a compatible format with the parent ids" - for si, parent_segment in enumerate(self._parent_sorting._sorting_segments): + for si, parent_segment in enumerate(self._parent_sorting.segments): sub_segment = SplitSortingUnitSegment(parent_segment, split_unit_id, indices_zero_based[si], new_unit_ids) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 676b2bceac..7a7bdd45a6 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -210,7 +210,7 @@ def __init__(self, file_path, sampling_frequency): # Every spike assigned to a unit (label) has the same max channel # ref: https://github.com/SpikeInterface/spikeinterface/issues/3695#issuecomment-2663329006 max_channels = [] - segment = self._sorting_segments[0] + segment = self.segments[0] for unit_id in self.unit_ids: label_mask = segment._labels == unit_id # since all max channels are the same, we can just grab the first occurrence for the unit diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index dec65404e9..0e5dd2694d 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -355,7 +355,7 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse # kilosort occasionally contains a few spikes just beyond the recording end point, which can lead # to errors later. To avoid this, we pad the recording with an extra second of blank time. - duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 + duration = sorting.segments[0]._all_spikes[-1] / sampling_frequency + 1 if (phy_path / "probe.prb").is_file(): probegroup = read_prb(phy_path / "probe.prb") diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 97ac263cc5..1800138dae 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -463,9 +463,7 @@ def __init__( amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None # upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None if parent_recording is None else parent_recording._recording_segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectDriftingTemplatesRecordingSegment( self.dtype, self.spike_vector[start:end], diff --git a/src/spikeinterface/postprocessing/alignsorting.py b/src/spikeinterface/postprocessing/alignsorting.py index c2b23ba83e..cf4189a3c7 100644 --- a/src/spikeinterface/postprocessing/alignsorting.py +++ b/src/spikeinterface/postprocessing/alignsorting.py @@ -25,7 +25,7 @@ class AlignSortingExtractor(BaseSorting): def __init__(self, sorting, unit_peak_shifts): super().__init__(sorting.get_sampling_frequency(), sorting.unit_ids) - for segment in sorting._sorting_segments: + for segment in sorting.segments: self.add_sorting_segment(AlignSortingSegment(segment, unit_peak_shifts)) sorting.copy_metadata(self, only_main=False) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 448be8d055..f8b5e16dc0 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import warnings import importlib.util @@ -661,8 +661,9 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + operator: Literal["average", "median"] = "average", ) -> np.ndarray: """ Localize a unit using max channel. @@ -690,7 +691,7 @@ def compute_location_max_channel( 2d """ extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index", operator=operator ) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index 41f88ce858..26d45dd711 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -42,7 +42,7 @@ def __init__( if round is None: round = np.issubdtype(dtype, np.integer) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = AstypeRecordingSegment( parent_segment, dtype, diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 113d1e22f1..8d1c4475cd 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -75,7 +75,7 @@ def __init__( self.parent_recording = parent_recording self.num_channels = n_pos_unique - for segment in parent_recording._recording_segments: + for segment in parent_recording.segments: recording_segment = AverageAcrossDirectionRecordingSegment( segment, self.num_channels, diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 30ed2a7b5e..dd7676fd23 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -33,7 +33,7 @@ def __init__(self, recording, a_min=None, a_max=None): value_max = a_max BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) self.add_recording_segment(rec_segment) @@ -130,7 +130,7 @@ def __init__( value_max = fill_value BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index b1469a0250..5a3a9b0043 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -154,7 +154,7 @@ def __init__( else: ref_channel_indices = None - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = CommonReferenceRecordingSegment( parent_segment, reference, diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 1c7566ab20..da66cd9c3f 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -65,7 +65,7 @@ def __init__( BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment( DecimateRecordingSegment( parent_segment, diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 90863a8df7..07be76d47d 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -89,7 +89,7 @@ def __init__( self.model = model # add segment - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = DeepInterpolatedRecordingSegment( segment, self.model, diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index adcd1d80f8..92b07b8f35 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Literal import numpy as np diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index 124a6e2744..f302708055 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -50,7 +50,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = DirectionalDerivativeRecordingSegment( parent_segment, parent_channel_locations, diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 2eb8d7cdf8..1fc289f937 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -125,7 +125,7 @@ def __init__( f"chunking. Consider increasing the chunk_size or chunk_duration to minimize margin overhead." ) self.margin_samples = margin - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment( FilterRecordingSegment( parent_segment, diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index b51e9603f5..1cf6873a7a 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -47,7 +47,7 @@ def __init__( if freq_min is None and freq_max is None: raise ValueError("At least one of `freq_min`,`freq_max` should be specified.") - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: # Sampling frequency is taken from recording since segments may not have it set (in case of time_vector) self.add_recording_segment( GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd, self.sampling_frequency) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 6db0c3d642..d0bffda0e0 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -72,7 +72,7 @@ def __init__( dtype = "float32" executor = OpenCLFilterExecutor(coefficients, num_channels, dtype, margin) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment(FilterOpenCLRecordingSegment(parent_segment, executor, margin)) self._kwargs = dict( diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index f64e553980..497bbdd482 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -139,7 +139,7 @@ def __init__( dtype = fix_dtype(recording, dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = HighPassSpatialFilterSegment( parent_segment, n_channel_pad, diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 17275f7a23..0e7e5f9950 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -65,7 +65,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non locations_bad = locations[self._bad_channel_idxs] weights = preprocessing_tools.get_kriging_channel_weights(locations_good, locations_bad, sigma_um, p) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = InterpolateBadChannelsSegment( parent_segment, self._good_channel_idxs, self._bad_channel_idxs, weights ) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7319e2994e..641d6af0d9 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -103,7 +103,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) self.add_recording_segment(rec_segment) @@ -166,7 +166,7 @@ def __init__(self, recording, gain=1.0, offset=0.0, dtype="float32"): BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, self._dtype) self.add_recording_segment(rec_segment) @@ -211,7 +211,7 @@ def __init__(self, recording, mode="median", dtype="float32", **random_chunk_kwa BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) self.add_recording_segment(rec_segment) @@ -313,7 +313,7 @@ def __init__( self.set_property(key="gain_to_uV", values=np.ones(num_chans, dtype="float32")) self.set_property(key="offset_to_uV", values=np.zeros(num_chans, dtype="float32")) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 4131f912f3..5648d689dd 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -61,7 +61,7 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non tmp_dtype = None BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = PhaseShiftRecordingSegment(parent_segment, sample_shifts, margin, dtype, tmp_dtype) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/rectify.py b/src/spikeinterface/preprocessing/rectify.py index 3b622149d1..7bd91a16d9 100644 --- a/src/spikeinterface/preprocessing/rectify.py +++ b/src/spikeinterface/preprocessing/rectify.py @@ -9,7 +9,7 @@ class RectifyRecording(BasePreprocessor): def __init__(self, recording): BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = RectifyRecordingSegment(parent_segment) self.add_recording_segment(rec_segment) self._kwargs = dict(recording=recording) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 0863522fd8..3fc5449ff2 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -197,7 +197,7 @@ def __init__( time_pad = None BasePreprocessor.__init__(self, recording) - for seg_index, parent_segment in enumerate(recording._recording_segments): + for seg_index, parent_segment in enumerate(recording.segments): triggers = list_triggers[seg_index] labels = list_labels[seg_index] rec_segment = RemoveArtifactsRecordingSegment( diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 773b68b977..902bd6d176 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -65,7 +65,7 @@ def __init__( margin = int(margin_ms * recording.get_sampling_frequency() / 1000) BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment( ResampleRecordingSegment( parent_segment, diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 189b97ec87..393c712919 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -103,7 +103,7 @@ def __init__( BasePreprocessor.__init__(self, recording) seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) - for seg_index, parent_segment in enumerate(recording._recording_segments): + for seg_index, parent_segment in enumerate(recording.segments): i0 = seg_limits[seg_index] i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index cea15722a0..96020692a1 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -41,7 +41,7 @@ def test_blank_saturation(): traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure - a_min = rec1._recording_segments[0].a_min + a_min = rec1.segments[0].a_min assert np.all(traces1 >= a_min) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 3fbc260b5f..e19cad59ba 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -96,7 +96,7 @@ def test_common_reference_channel_slicing(recording): start_frame = 0 end_frame = 10 - recording_segment_cmr = recording_cmr._recording_segments[0] + recording_segment_cmr = recording_cmr.segments[0] traces_cmr_all = recording_segment_cmr.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) @@ -106,7 +106,7 @@ def test_common_reference_channel_slicing(recording): assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub) - recording_segment_car = recording_car._recording_segments[0] + recording_segment_car = recording_car.segments[0] traces_car_all = recording_segment_car.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) @@ -116,7 +116,7 @@ def test_common_reference_channel_slicing(recording): assert np.all(traces_car_all[:, indices] == traces_car_sub) - recording_segment_local = recording_local_car._recording_segments[0] + recording_segment_local = recording_local_car.segments[0] traces_local_all = recording_segment_local.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index e9493145a6..141345ca46 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -66,7 +66,7 @@ def test_decimate_with_times(): # test with t_start rec = generate_recording(durations=[5, 10]) t_starts = [10, 20] - for t_start, rec_segment in zip(t_starts, rec._recording_segments): + for t_start, rec_segment in zip(t_starts, rec.segments): rec_segment.t_start = t_start decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) for segment_index in range(rec.get_num_segments()): diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 75279bcae0..35f398f985 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -262,7 +262,7 @@ def reduce_high_freq_power_in_non_noisy_channels(recording, is_noisy, not_noisy) """ from scipy.signal import welch - for iseg, __ in enumerate(recording._recording_segments): + for iseg, __ in enumerate(recording.segments): data = recording.get_traces(iseg).T num_samples = recording.get_num_samples(iseg) @@ -291,7 +291,7 @@ def add_dead_channels(recording, is_dead): data[:, is_dead] = np.random.normal( mean, std * 0.1, size=(is_dead.size, recording.get_num_samples(segment_index)) ).T - recording._recording_segments[segment_index]._traces = data + recording.segments[segment_index]._traces = data if __name__ == "__main__": diff --git a/src/spikeinterface/preprocessing/tests/test_resample.py b/src/spikeinterface/preprocessing/tests/test_resample.py index 7e1d173fdb..c53b7b42bd 100644 --- a/src/spikeinterface/preprocessing/tests/test_resample.py +++ b/src/spikeinterface/preprocessing/tests/test_resample.py @@ -216,10 +216,10 @@ def test_resample_preserves_t_start(): t_start = 100.5 traces = np.random.randn(sampling_frequency * 2, 2).astype(np.float32) parent_rec = NumpyRecording(traces, sampling_frequency) - parent_rec._recording_segments[0].t_start = t_start + parent_rec.segments[0].t_start = t_start resampled = resample(parent_rec, 500) - assert resampled._recording_segments[0].t_start == t_start + assert resampled.segments[0].t_start == t_start assert not resampled.has_time_vector() assert np.isclose(resampled.get_times()[0], t_start) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 7c414df738..f4c0e4d166 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -366,8 +366,8 @@ def test_passed_W_and_M(self): whitened_recording = whiten(recording, W=test_W, M=test_M) for seg_idx in [0, 1]: - assert np.array_equal(whitened_recording._recording_segments[seg_idx].W, test_W) - assert np.array_equal(whitened_recording._recording_segments[seg_idx].M, test_M) + assert np.array_equal(whitened_recording.segments[seg_idx].W, test_W) + assert np.array_equal(whitened_recording.segments[seg_idx].M, test_M) assert whitened_recording._kwargs["W"] == test_W.tolist() assert whitened_recording._kwargs["M"] == test_M.tolist() diff --git a/src/spikeinterface/preprocessing/unsigned_to_signed.py b/src/spikeinterface/preprocessing/unsigned_to_signed.py index 62107155ee..ae1ce12281 100644 --- a/src/spikeinterface/preprocessing/unsigned_to_signed.py +++ b/src/spikeinterface/preprocessing/unsigned_to_signed.py @@ -31,7 +31,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_signed) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = UnsignedToSignedRecordingSegment(parent_segment, dtype_signed, bit_depth) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index d5f26d9b01..1d723b63a0 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -101,7 +101,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = WhitenRecordingSegment(parent_segment, W, M, dtype_, int_scale) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 35b984449d..45d4809cd8 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -32,7 +32,7 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end self.padding_start = padding_start self.padding_end = padding_end self.fill_value = fill_value - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = TracePaddedRecordingSegment( segment, recording.get_num_channels(), @@ -164,7 +164,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: self.parent_recording = recording self.num_channels = num_channels - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping) self.add_recording_segment(recording_segment) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 37c13b9395..947eaf391f 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -202,7 +202,7 @@ def __init__( # interpolation bins edges self.interpolation_time_bins_s = [] self.interpolation_time_bin_edges_s = [] - for segment_index, parent_segment in enumerate(recording._recording_segments): + for segment_index, parent_segment in enumerate(recording.segments): # in this case, interpolation_time_bin_size_s is set. s_end = parent_segment.get_num_samples() t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a50b9609b9..7c4c4b166e 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -422,7 +422,7 @@ def __init__( interpolation_time_bin_centers_s, interpolation_time_bin_edges_s ) - for segment_index, parent_segment in enumerate(recording._recording_segments): + for segment_index, parent_segment in enumerate(recording.segments): # finish the per-segment part of the time bin logic if interpolation_time_bin_centers_s is None: # in this case, interpolation_time_bin_size_s is set. diff --git a/src/spikeinterface/widgets/bombcell_curation.py b/src/spikeinterface/widgets/bombcell_curation.py index 1a1212ba5c..8c7c275ad1 100644 --- a/src/spikeinterface/widgets/bombcell_curation.py +++ b/src/spikeinterface/widgets/bombcell_curation.py @@ -1,7 +1,5 @@ """Widgets for visualizing unit labeling results.""" -from __future__ import annotations - import warnings import numpy as np diff --git a/src/spikeinterface/widgets/unit_labels.py b/src/spikeinterface/widgets/unit_labels.py index c5b55041c1..348f0e3b8d 100644 --- a/src/spikeinterface/widgets/unit_labels.py +++ b/src/spikeinterface/widgets/unit_labels.py @@ -1,7 +1,5 @@ """Widgets for visualizing unit labeling results.""" -from __future__ import annotations - import numpy as np from spikeinterface.curation.curation_tools import is_threshold_disabled