1414import numpy as np
1515
1616
17- # artifact_dtype = [
18- # ("start_index", "int64"),
19- # ("stop_index", "int64"),
20- # ("segment_index", "int64"),
21- # ]
2217
2318artifact_dtype = base_period_dtype
2419
2520
21+ # this will be extend with channel boundaries if needed
2622# extended_artifact_dtype = artifact_dtype + [
2723# # TODO
2824# ]
2925
3026
3127
32-
3328def detect_artifact_periods (
3429 recording ,
3530 method = "envelope" ,
3631 method_kwargs = None ,
3732 job_kwargs = None ,
3833):
3934 """
40-
35+ Detect artifacts with several possible methods:
36+ * 'saturation' using detect_artifact_periods_by_envelope()
37+ * 'envelope' using detect_saturation_periods()
38+
39+ See sub methods for more information on parameters.
4140 """
4241
4342 if method_kwargs is None :
@@ -48,21 +47,18 @@ def detect_artifact_periods(
4847 elif method == "saturation" :
4948 artifact_periods = detect_saturation_periods (recording , ** method_kwargs , job_kwargs = job_kwargs )
5049 else :
51- raise ValueError (" " )
50+ raise ValueError (f"detect_artifact_periods() method=' { method } ' is not valid " )
5251
5352 return artifact_periods
5453
5554
5655
5756## detect_period_artifacts_saturation Zone
5857
59-
6058def _collapse_events (events ):
6159 """
6260 If events are detected at a chunk edge, they will be split in two.
63- This detects such cases and collapses them in a single record instead
64- :param events:
65- :return:
61+ This detects such cases and collapses them in a single record instead.
6662 """
6763 order = np .lexsort ((events ["start_sample_index" ], events ["segment_index" ]))
6864 events = events [order ]
@@ -87,21 +83,24 @@ class _DetectSaturation(PipelineNode):
8783 def __init__ (
8884 self ,
8985 recording ,
90- saturation_threshold_uV , # 1200 uV
91- voltage_per_sec_threshold , # 1e-8 V.s-1
86+ saturation_threshold_uV ,
87+ voltage_per_sec_threshold ,
9288 proportion ,
93- mute_window_samples ,
9489 ):
9590 PipelineNode .__init__ (self , recording , return_output = True )
9691
97- self .gains = recording .get_channel_gains ()
98- self .offsets = recording .get_channel_offsets ()
92+ gains = recording .get_channel_gains ()
93+ offsets = recording .get_channel_offsets ()
94+ num_chans = recording .get_num_channels ()
9995
10096 self .voltage_per_sec_threshold = voltage_per_sec_threshold
101- self .saturation_threshold_uV = saturation_threshold_uV
97+ thresh = np .full ((num_chans , ), saturation_threshold_uV )
98+ # 0.98 is empirically determined as the true saturating point is
99+ # slightly lower than the documented saturation point of the probe
100+ self .saturation_threshold_unscaled = (thresh - offsets ) / gains * 0.98
101+
102102 self .sampling_frequency = recording .get_sampling_frequency ()
103103 self .proportion = proportion
104- self .mute_window_samples = mute_window_samples
105104 self ._dtype = np .dtype (artifact_dtype )
106105 self .gain = recording .get_channel_gains ()
107106 self .offset = recording .get_channel_offsets ()
@@ -114,16 +113,7 @@ def get_dtype(self):
114113
115114 def compute (self , traces , start_frame , end_frame , segment_index , max_margin ):
116115
117- # @olivier @joe we can avoid this by making
118- traces = traces * self .gains [np .newaxis , :] + self .offsets [np .newaxis , :]
119-
120-
121- # first computes the saturated samples
122- max_voltage = np .atleast_1d (self .saturation_threshold_uV )[:, np .newaxis ]
123-
124- # 0.98 is empirically determined as the true saturating point is
125- # slightly lower than the documented saturation point of the probe
126- saturation = np .mean (np .abs (traces ) > max_voltage * 0.98 , axis = 1 )
116+ saturation = np .mean (np .abs (traces ) > self .saturation_threshold_unscaled , axis = 1 )
127117
128118 if self .voltage_per_sec_threshold is not None :
129119 fs = self .sampling_frequency
@@ -138,25 +128,23 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
138128 else :
139129 saturation = saturation > self .proportion
140130
141- intervals = np .where (np .diff (saturation , prepend = False , append = False ))[ 0 ]
131+ intervals = np .flatnonzero (np .diff (saturation , prepend = False , append = False ))
142132 n_events = len (intervals ) // 2 # Number of saturation periods
143133 events = np .zeros (n_events , dtype = artifact_dtype )
144134
145135 for i , (start , stop ) in enumerate (zip (intervals [::2 ], intervals [1 ::2 ])):
146136 events [i ]["start_sample_index" ] = start + start_frame
147137 events [i ]["end_sample_index" ] = stop + start_frame
148138 events [i ]["segment_index" ] = segment_index
149- # events[i]["method_id"] = "saturation_detection"
150139
151140 return (events , )
152141
153142
154143def detect_saturation_periods (
155144 recording ,
156145 saturation_threshold_uV , # 1200 uV
157- voltage_per_sec_threshold , # 1e-8 V.s-1
146+ voltage_per_sec_threshold = None , # 1e-8 V.s-1
158147 proportion = 0.5 ,
159- mute_window_samples = 7 ,
160148 job_kwargs = None ,
161149):
162150 """
@@ -174,7 +162,7 @@ def detect_saturation_periods(
174162 The recording on which to detect the saturation events.
175163 saturation_threshold_uV : float
176164 The voltage saturation threshold in volts. This will depend on the recording
177- probe and amplifier gain settings. For NP1 the value of 1200 * 1e-6 is recommended (IBL).
165+ probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL).
178166 Note that NP2 probes are more difficult to saturate than NP1.
179167 voltage_per_sec_threshold : None | float
180168 The first-derivative threshold in volts per second. Periods of the data over which the change
@@ -207,24 +195,17 @@ def detect_saturation_periods(
207195 saturation_threshold_uV = saturation_threshold_uV ,
208196 voltage_per_sec_threshold = voltage_per_sec_threshold ,
209197 proportion = proportion ,
210- mute_window_samples = mute_window_samples ,
211198 )
212199
213- saturation_periods = run_node_pipeline (recording , [node0 ], job_kwargs = job_kwargs , job_name = "detect saturation events " )
200+ saturation_periods = run_node_pipeline (recording , [node0 ], job_kwargs = job_kwargs , job_name = "detect saturation artifacts " )
214201
215202 return _collapse_events (saturation_periods )
216203
217204
218205
219206## detect_artifact_periods_by_envelope Zone
220207
221- # _internal_dtype = [
222- # ("sample_index", "int64"),
223- # ("segment_index", "int64"),
224- # ("front", "bool")
225- # ]
226-
227- class DetectThresholdCrossing (PeakDetector ):
208+ class _DetectThresholdCrossing (PeakDetector ):
228209
229210 name = "threshold_crossings"
230211 preferred_mp_context = None
@@ -243,6 +224,7 @@ def __init__(
243224 random_slices_kwargs ["seed" ] = seed
244225 noise_levels = get_noise_levels (recording , return_in_uV = False , random_slices_kwargs = random_slices_kwargs )
245226 self .abs_thresholds = noise_levels * detect_threshold
227+ # internal dtype
246228 self ._dtype = np .dtype ([
247229 ("sample_index" , "int64" ),
248230 ("segment_index" , "int64" ),
@@ -278,7 +260,7 @@ def detect_artifact_periods_by_envelope(
278260 random_slices_kwargs = None ,
279261):
280262 """
281- Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of
263+ Function to detect putative artifact periods as threshold crossings of
282264 a global envelope of the channels.
283265
284266 Parameters
@@ -300,8 +282,6 @@ def detect_artifact_periods_by_envelope(
300282 envelope = GaussianFilterRecording (envelope , freq_min = None , freq_max = freq_max )
301283 envelope = CommonReferenceRecording (envelope )
302284
303-
304- # _, job_kwargs = split_job_kwargs(noise_levels_kwargs)
305285 job_kwargs = fix_job_kwargs (job_kwargs )
306286 if random_slices_kwargs is None :
307287 random_slices_kwargs = {}
@@ -310,28 +290,25 @@ def detect_artifact_periods_by_envelope(
310290 random_slices_kwargs ["seed" ] = seed
311291 noise_levels = get_noise_levels (envelope , return_in_uV = False , random_slices_kwargs = random_slices_kwargs )
312292
313- node0 = DetectThresholdCrossing (
293+ node0 = _DetectThresholdCrossing (
314294 recording , detect_threshold = detect_threshold , noise_levels = noise_levels , seed = seed ,
315295 )
316296
317297 threshold_crossings = run_node_pipeline (
318298 envelope ,
319299 [node0 ],
320300 job_kwargs ,
321- job_name = "detect threshold crossings " ,
301+ job_name = "detect artifact on envelope " ,
322302 )
323303
324304 order = np .lexsort ((threshold_crossings ["sample_index" ], threshold_crossings ["segment_index" ]))
325305 threshold_crossings = threshold_crossings [order ]
326306
327307 artifacts = _transform_internal_dtype_to_artifact_dtype (threshold_crossings , recording )
328308
329-
330309 return artifacts , envelope
331310
332311
333- # tools
334-
335312def _transform_internal_dtype_to_artifact_dtype (artifacts , recording ):
336313
337314 num_seg = recording .get_num_segments ()
0 commit comments