Skip to content

Commit a71dade

Browse files
committed
Progagate new periods dtype to SilencedPeriodsRecording with backward compatibility
1 parent ef3e820 commit a71dade

File tree

7 files changed

+166
-203
lines changed

7 files changed

+166
-203
lines changed

src/spikeinterface/core/node_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ def check_graph(nodes):
489489
Check that node list is orderd in a good (parents are before children)
490490
"""
491491

492-
node0 = nodes[0]
492+
# Do not remove this, this is to remenber that in previous version the first node needed to be
493+
# a detectot but not anymore
494+
# node0 = nodes[0]
493495
# if not isinstance(node0, PeakSource):
494496
# raise ValueError(
495497
# "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever"

src/spikeinterface/preprocessing/detect_artifacts.py

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,29 @@
1414
import numpy as np
1515

1616

17-
# artifact_dtype = [
18-
# ("start_index", "int64"),
19-
# ("stop_index", "int64"),
20-
# ("segment_index", "int64"),
21-
# ]
2217

2318
artifact_dtype = base_period_dtype
2419

2520

21+
# this will be extend with channel boundaries if needed
2622
# extended_artifact_dtype = artifact_dtype + [
2723
# # TODO
2824
# ]
2925

3026

3127

32-
3328
def detect_artifact_periods(
3429
recording,
3530
method="envelope",
3631
method_kwargs=None,
3732
job_kwargs=None,
3833
):
3934
"""
40-
35+
Detect artifacts with several possible methods:
36+
* 'saturation' using detect_artifact_periods_by_envelope()
37+
* 'envelope' using detect_saturation_periods()
38+
39+
See sub methods for more information on parameters.
4140
"""
4241

4342
if method_kwargs is None:
@@ -48,21 +47,18 @@ def detect_artifact_periods(
4847
elif method == "saturation":
4948
artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs)
5049
else:
51-
raise ValueError("")
50+
raise ValueError(f"detect_artifact_periods() method='{method}' is not valid")
5251

5352
return artifact_periods
5453

5554

5655

5756
## detect_period_artifacts_saturation Zone
5857

59-
6058
def _collapse_events(events):
6159
"""
6260
If events are detected at a chunk edge, they will be split in two.
63-
This detects such cases and collapses them in a single record instead
64-
:param events:
65-
:return:
61+
This detects such cases and collapses them in a single record instead.
6662
"""
6763
order = np.lexsort((events["start_sample_index"], events["segment_index"]))
6864
events = events[order]
@@ -87,21 +83,24 @@ class _DetectSaturation(PipelineNode):
8783
def __init__(
8884
self,
8985
recording,
90-
saturation_threshold_uV, # 1200 uV
91-
voltage_per_sec_threshold, # 1e-8 V.s-1
86+
saturation_threshold_uV,
87+
voltage_per_sec_threshold,
9288
proportion,
93-
mute_window_samples,
9489
):
9590
PipelineNode.__init__(self, recording, return_output=True)
9691

97-
self.gains = recording.get_channel_gains()
98-
self.offsets = recording.get_channel_offsets()
92+
gains = recording.get_channel_gains()
93+
offsets = recording.get_channel_offsets()
94+
num_chans = recording.get_num_channels()
9995

10096
self.voltage_per_sec_threshold = voltage_per_sec_threshold
101-
self.saturation_threshold_uV = saturation_threshold_uV
97+
thresh = np.full((num_chans, ), saturation_threshold_uV)
98+
# 0.98 is empirically determined as the true saturating point is
99+
# slightly lower than the documented saturation point of the probe
100+
self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98
101+
102102
self.sampling_frequency = recording.get_sampling_frequency()
103103
self.proportion = proportion
104-
self.mute_window_samples = mute_window_samples
105104
self._dtype = np.dtype(artifact_dtype)
106105
self.gain = recording.get_channel_gains()
107106
self.offset = recording.get_channel_offsets()
@@ -114,16 +113,7 @@ def get_dtype(self):
114113

115114
def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
116115

117-
# @olivier @joe we can avoid this by making
118-
traces = traces * self.gains[np.newaxis, :] + self.offsets[np.newaxis, :]
119-
120-
121-
# first computes the saturated samples
122-
max_voltage = np.atleast_1d(self.saturation_threshold_uV)[:, np.newaxis]
123-
124-
# 0.98 is empirically determined as the true saturating point is
125-
# slightly lower than the documented saturation point of the probe
126-
saturation = np.mean(np.abs(traces) > max_voltage * 0.98, axis=1)
116+
saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1)
127117

128118
if self.voltage_per_sec_threshold is not None:
129119
fs = self.sampling_frequency
@@ -138,25 +128,23 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
138128
else:
139129
saturation = saturation > self.proportion
140130

141-
intervals = np.where(np.diff(saturation, prepend=False, append=False))[0]
131+
intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False))
142132
n_events = len(intervals) // 2 # Number of saturation periods
143133
events = np.zeros(n_events, dtype=artifact_dtype)
144134

145135
for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])):
146136
events[i]["start_sample_index"] = start + start_frame
147137
events[i]["end_sample_index"] = stop + start_frame
148138
events[i]["segment_index"] = segment_index
149-
# events[i]["method_id"] = "saturation_detection"
150139

151140
return (events, )
152141

153142

154143
def detect_saturation_periods(
155144
recording,
156145
saturation_threshold_uV, # 1200 uV
157-
voltage_per_sec_threshold, # 1e-8 V.s-1
146+
voltage_per_sec_threshold=None, # 1e-8 V.s-1
158147
proportion=0.5,
159-
mute_window_samples=7,
160148
job_kwargs=None,
161149
):
162150
"""
@@ -174,7 +162,7 @@ def detect_saturation_periods(
174162
The recording on which to detect the saturation events.
175163
saturation_threshold_uV : float
176164
The voltage saturation threshold in volts. This will depend on the recording
177-
probe and amplifier gain settings. For NP1 the value of 1200 * 1e-6 is recommended (IBL).
165+
probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL).
178166
Note that NP2 probes are more difficult to saturate than NP1.
179167
voltage_per_sec_threshold : None | float
180168
The first-derivative threshold in volts per second. Periods of the data over which the change
@@ -207,24 +195,17 @@ def detect_saturation_periods(
207195
saturation_threshold_uV=saturation_threshold_uV,
208196
voltage_per_sec_threshold=voltage_per_sec_threshold,
209197
proportion=proportion,
210-
mute_window_samples=mute_window_samples,
211198
)
212199

213-
saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation events")
200+
saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts")
214201

215202
return _collapse_events(saturation_periods)
216203

217204

218205

219206
## detect_artifact_periods_by_envelope Zone
220207

221-
# _internal_dtype = [
222-
# ("sample_index", "int64"),
223-
# ("segment_index", "int64"),
224-
# ("front", "bool")
225-
# ]
226-
227-
class DetectThresholdCrossing(PeakDetector):
208+
class _DetectThresholdCrossing(PeakDetector):
228209

229210
name = "threshold_crossings"
230211
preferred_mp_context = None
@@ -243,6 +224,7 @@ def __init__(
243224
random_slices_kwargs["seed"] = seed
244225
noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs)
245226
self.abs_thresholds = noise_levels * detect_threshold
227+
# internal dtype
246228
self._dtype = np.dtype([
247229
("sample_index", "int64"),
248230
("segment_index", "int64"),
@@ -278,7 +260,7 @@ def detect_artifact_periods_by_envelope(
278260
random_slices_kwargs=None,
279261
):
280262
"""
281-
Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of
263+
Function to detect putative artifact periods as threshold crossings of
282264
a global envelope of the channels.
283265
284266
Parameters
@@ -300,8 +282,6 @@ def detect_artifact_periods_by_envelope(
300282
envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max)
301283
envelope = CommonReferenceRecording(envelope)
302284

303-
304-
# _, job_kwargs = split_job_kwargs(noise_levels_kwargs)
305285
job_kwargs = fix_job_kwargs(job_kwargs)
306286
if random_slices_kwargs is None:
307287
random_slices_kwargs = {}
@@ -310,28 +290,25 @@ def detect_artifact_periods_by_envelope(
310290
random_slices_kwargs["seed"] = seed
311291
noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs)
312292

313-
node0 = DetectThresholdCrossing(
293+
node0 = _DetectThresholdCrossing(
314294
recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed,
315295
)
316296

317297
threshold_crossings = run_node_pipeline(
318298
envelope,
319299
[node0],
320300
job_kwargs,
321-
job_name="detect threshold crossings",
301+
job_name="detect artifact on envelope",
322302
)
323303

324304
order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"]))
325305
threshold_crossings = threshold_crossings[order]
326306

327307
artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording)
328308

329-
330309
return artifacts, envelope
331310

332311

333-
# tools
334-
335312
def _transform_internal_dtype_to_artifact_dtype(artifacts, recording):
336313

337314
num_seg = recording.get_num_segments()

src/spikeinterface/preprocessing/silence_artifacts.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

0 commit comments

Comments
 (0)