Skip to content

Commit c541ba0

Browse files
committed
Merge branch 'select_sorting_periods' of github.com:alejoe91/spikeinterface into select_sorting_periods
2 parents 1fd1fd4 + d0a1e66 commit c541ba0

File tree

4 files changed

+82
-92
lines changed

4 files changed

+82
-92
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension):
901901
need_backward_compatibility_on_load = False
902902
metric_list: list[BaseMetric] = None # list of BaseMetric
903903

904+
@classmethod
905+
def get_available_metric_names(cls):
906+
"""Get the available metric names.
907+
908+
Returns
909+
-------
910+
available_metric_names : list[str]
911+
List of available metric names.
912+
"""
913+
return [m.metric_name for m in cls.metric_list]
914+
904915
@classmethod
905916
def get_default_metric_params(cls):
906917
"""Get the default metric parameters.

src/spikeinterface/metrics/quality/misc_metrics.py

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,13 +1129,20 @@ def compute_drift_metrics(
11291129
unit_ids = sorting.unit_ids
11301130

11311131
spike_locations_ext = sorting_analyzer.get_extension("spike_locations")
1132-
spike_locations_array = spike_locations_ext.get_data(periods=periods)
1132+
spike_locations_by_unit_and_segments = spike_locations_ext.get_data(
1133+
outputs="by_unit", concatenated=False, periods=periods
1134+
)
11331135
spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods)
11341136

11351137
segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())]
1136-
assert direction in spike_locations_array.dtype.names, (
1137-
f"Direction {direction} is invalid. Available directions: " f"{spike_locations_array.dtype.names}"
1138+
data = spike_locations_by_unit[unit_ids[0]]
1139+
assert direction in data.dtype.names, (
1140+
f"Direction {direction} is invalid. Available directions: " f"{data.dtype.names}"
1141+
)
1142+
bin_edges_for_units = compute_bin_edges_per_unit(
1143+
sorting, segment_samples=segment_samples, periods=periods, bin_duration_s=interval_s, concatenated=False
11381144
)
1145+
failed_units = []
11391146

11401147
# we need
11411148
drift_ptps = {}
@@ -1144,62 +1151,43 @@ def compute_drift_metrics(
11441151

11451152
# reference positions are the medians across segments
11461153
reference_positions = {}
1154+
median_position_segments = {unit_id: np.array([]) for unit_id in unit_ids}
1155+
11471156
for unit_id in unit_ids:
11481157
reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction])
11491158

1150-
# now compute median positions and concatenate them over segments
1151-
spike_vector = sorting.to_spike_vector()
1152-
spike_sample_indices = spike_vector["sample_index"].copy()
1153-
# we need to add the cumulative sum of segment samples to have global sample indices
1154-
cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1])
11551159
for segment_index in range(sorting_analyzer.get_num_segments()):
1156-
segment_slice = sorting._get_spike_vector_segment_slices()[segment_index]
1157-
spike_sample_indices[segment_slice[0] : segment_slice[1]] += cumulative_segment_samples[segment_index]
1158-
1159-
bin_edges_for_units = compute_bin_edges_per_unit(
1160-
sorting,
1161-
segment_samples=segment_samples,
1162-
periods=periods,
1163-
bin_duration_s=interval_s,
1164-
)
1165-
1166-
failed_units = []
1167-
median_positions_per_unit = {}
1160+
for unit_id in unit_ids:
1161+
bins = bin_edges_for_units[unit_id][segment_index]
1162+
num_bin_edges = len(bins)
1163+
if (num_bin_edges - 1) < min_num_bins:
1164+
failed_units.append(unit_id)
1165+
continue
1166+
median_positions = np.nan * np.zeros((num_bin_edges - 1))
1167+
spikes_in_segment_of_unit = sorting.get_unit_spike_train(unit_id, segment_index)
1168+
bounds = np.searchsorted(spikes_in_segment_of_unit, bins, side="left")
1169+
for bin_index, (i0, i1) in enumerate(zip(bounds[:-1], bounds[1:])):
1170+
spike_locations_in_bin = spike_locations_by_unit_and_segments[segment_index][unit_id][i0:i1][direction]
1171+
if (i1 - i0) >= min_spikes_per_interval:
1172+
median_positions[bin_index] = np.median(spike_locations_in_bin)
1173+
median_position_segments[unit_id] = np.concatenate((median_position_segments[unit_id], median_positions))
1174+
1175+
# finally, compute deviations and drifts
11681176
for unit_id in unit_ids:
1169-
bins = bin_edges_for_units[unit_id]
1170-
num_bins = len(bins) - 1
1171-
if num_bins < min_num_bins:
1177+
# Skip units that already failed because not enough bins in at least one segment
1178+
if unit_id in failed_units:
11721179
drift_ptps[unit_id] = np.nan
11731180
drift_stds[unit_id] = np.nan
11741181
drift_mads[unit_id] = np.nan
1175-
failed_units.append(unit_id)
11761182
continue
1177-
1178-
# bin_edges are global across segments, so we have to use spike_sample_indices,
1179-
# since we offseted them to be global
1180-
bin_spike_indices = np.searchsorted(spike_sample_indices, bins)
1181-
median_positions = np.nan * np.zeros(num_bins)
1182-
for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])):
1183-
spikes_in_bin = spike_vector[i0:i1]
1184-
spike_locations_in_bin = spike_locations_array[i0:i1][direction]
1185-
1186-
unit_index = sorting_analyzer.sorting.id_to_index(unit_id)
1187-
mask = spikes_in_bin["unit_index"] == unit_index
1188-
if np.sum(mask) >= min_spikes_per_interval:
1189-
median_positions[bin_index] = np.median(spike_locations_in_bin[mask])
1190-
else:
1191-
median_positions[bin_index] = np.nan
1192-
median_positions_per_unit[unit_id] = median_positions
1193-
1194-
# now compute deviations and drifts for this unit
1195-
position_diff = median_positions - reference_positions[unit_id]
1183+
position_diff = median_position_segments[unit_id] - reference_positions[unit_id]
11961184
if np.any(np.isnan(position_diff)):
11971185
# deal with nans: if more than 50% nans --> set to nan
11981186
if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff):
1199-
failed_units.append(unit_id)
12001187
ptp_drift = np.nan
12011188
std_drift = np.nan
12021189
mad_drift = np.nan
1190+
failed_units.append(unit_id)
12031191
else:
12041192
ptp_drift = np.nanmax(position_diff) - np.nanmin(position_diff)
12051193
std_drift = np.nanstd(np.abs(position_diff))
@@ -1219,7 +1207,7 @@ def compute_drift_metrics(
12191207
)
12201208

12211209
if return_positions:
1222-
outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit
1210+
outs = res(drift_ptps, drift_stds, drift_mads), median_positions
12231211
else:
12241212
outs = res(drift_ptps, drift_stds, drift_mads)
12251213
return outs

src/spikeinterface/metrics/quality/tests/test_metrics_functions.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313

1414
from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, create_regular_periods
1515

16-
from spikeinterface.metrics.quality import (
17-
get_quality_metric_list,
18-
compute_quality_metrics,
19-
)
16+
from spikeinterface.metrics.quality import get_quality_metric_list, compute_quality_metrics, ComputeQualityMetrics
2017
from spikeinterface.metrics.quality.misc_metrics import (
2118
misc_metrics_list,
2219
compute_amplitude_cutoffs,
@@ -657,37 +654,9 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
657654

658655
# can't use _misc_metric_name_to_func as some functions compute several qms
659656
# e.g. isi_violation and synchrony
660-
quality_metrics = [
661-
"num_spikes",
662-
"firing_rate",
663-
"presence_ratio",
664-
"snr",
665-
"isi_violations_ratio",
666-
"isi_violations_count",
667-
"rp_contamination",
668-
"rp_violations",
669-
"sliding_rp_violation",
670-
"amplitude_cutoff",
671-
"amplitude_median",
672-
"amplitude_cv_median",
673-
"amplitude_cv_range",
674-
"sync_spike_2",
675-
"sync_spike_4",
676-
"sync_spike_8",
677-
"firing_range",
678-
"drift_ptp",
679-
"drift_std",
680-
"drift_mad",
681-
"sd_ratio",
682-
"isolation_distance",
683-
"l_ratio",
684-
"d_prime",
685-
"silhouette",
686-
"nn_hit_rate",
687-
"nn_miss_rate",
688-
]
689-
690-
small_sorting_analyzer.compute("quality_metrics")
657+
quality_metric_columns = ComputeQualityMetrics.get_metric_columns()
658+
all_metrics = ComputeQualityMetrics.get_available_metric_names()
659+
small_sorting_analyzer.compute("quality_metrics", metric_names=all_metrics)
691660

692661
cache_folder = create_cache_folder
693662
output_folder = cache_folder / "sorting_analyzer"
@@ -699,7 +668,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
699668
saved_metrics = csv.reader(metrics_file)
700669
metric_names = next(saved_metrics)
701670

702-
for metric_name in quality_metrics:
671+
for metric_name in quality_metric_columns:
703672
assert metric_name in metric_names
704673

705674
folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False)
@@ -708,7 +677,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
708677
saved_metrics = csv.reader(metrics_file)
709678
metric_names = next(saved_metrics)
710679

711-
for metric_name in quality_metrics:
680+
for metric_name in quality_metric_columns:
712681
assert metric_name in metric_names
713682

714683
folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True)
@@ -717,7 +686,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder):
717686
saved_metrics = csv.reader(metrics_file)
718687
metric_names = next(saved_metrics)
719688

720-
for metric_name in quality_metrics:
689+
for metric_name in quality_metric_columns:
721690
if metric_name == "snr":
722691
assert metric_name in metric_names
723692
else:

src/spikeinterface/metrics/utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from spikeinterface.core.base import unit_period_dtype
55

66

7-
def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None):
7+
def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None, concatenated=True):
88
"""
99
Compute bin edges for units, optionally taking into account periods.
1010
@@ -18,6 +18,16 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per
1818
Duration of each bin in seconds
1919
periods : array of unit_period_dtype, default: None
2020
Periods to consider for each unit
21+
concatenated : bool, default: True
22+
Wheter the bins are concatenated across segments or not.
23+
If False, the bin edges are computed per segment and the first index of each segment is 0.
24+
If True, the bin edges are computed on the concatenated segments, with the correct offsets.
25+
26+
Returns
27+
-------
28+
dict
29+
Bin edges for each unit. If concatenated is True, the bin edges are a 1D array.
30+
If False, the bin edges are a list of arrays, one per segment.
2131
"""
2232
bin_edges_for_units = {}
2333
num_segments = len(segment_samples)
@@ -31,27 +41,38 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per
3141
for seg_index in range(num_segments):
3242
seg_periods = periods_unit[periods_unit["segment_index"] == seg_index]
3343
if len(seg_periods) == 0:
44+
if not concatenated:
45+
bin_edges.append(np.array([]))
3446
continue
35-
seg_start = np.sum(segment_samples[:seg_index])
47+
seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0
48+
bin_edges_segment = []
3649
for period in seg_periods:
3750
start_sample = seg_start + period["start_sample_index"]
3851
end_sample = seg_start + period["end_sample_index"]
3952
end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin
40-
bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples))
41-
bin_edges_for_units[unit_id] = np.unique(np.array(bin_edges))
53+
bin_edges_segment.extend(np.arange(start_sample, end_sample, bin_duration_samples))
54+
bin_edges_segment = np.unique(np.array(bin_edges_segment))
55+
if concatenated:
56+
bin_edges.extend(bin_edges_segment)
57+
else:
58+
bin_edges.append(bin_edges_segment)
59+
bin_edges_for_units[unit_id] = bin_edges
4260
else:
4361
for unit_id in sorting.unit_ids:
4462
bin_edges = []
4563
for seg_index in range(num_segments):
46-
seg_start = np.sum(segment_samples[:seg_index])
64+
seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0
4765
seg_end = seg_start + segment_samples[seg_index]
4866
# for segments which are not the last, we don't need to correct the end
4967
# since the first index of the next segment will be the end of the current segment
5068
if seg_index == num_segments - 1:
5169
seg_end = seg_end // bin_duration_samples * bin_duration_samples + 1 # align to bin
52-
bins = np.arange(seg_start, seg_end, bin_duration_samples)
53-
bin_edges.extend(bins)
54-
bin_edges_for_units[unit_id] = np.array(bin_edges)
70+
bin_edges_segment = np.arange(seg_start, seg_end, bin_duration_samples)
71+
if concatenated:
72+
bin_edges.extend(bin_edges_segment)
73+
else:
74+
bin_edges.append(bin_edges_segment)
75+
bin_edges_for_units[unit_id] = bin_edges
5576
return bin_edges_for_units
5677

5778

@@ -82,7 +103,8 @@ def compute_total_samples_per_unit(sorting_analyzer, periods=None):
82103
num_samples_in_period += period["end_sample_index"] - period["start_sample_index"]
83104
total_samples[unit_id] = num_samples_in_period
84105
else:
85-
total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids}
106+
total = sorting_analyzer.get_total_samples()
107+
total_samples = {unit_id: total for unit_id in sorting_analyzer.unit_ids}
86108
return total_samples
87109

88110

0 commit comments

Comments
 (0)