Skip to content

Commit f0d0ba7

Browse files
committed
Speed up function which was already fast but Sam didn't like it
1 parent c541ba0 commit f0d0ba7

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/spikeinterface/metrics/utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,12 @@ def compute_total_samples_per_unit(sorting_analyzer, periods=None):
9393
Total number of samples for each unit.
9494
"""
9595
if periods is not None:
96-
total_samples = {}
96+
total_samples_array = np.zeros(len(sorting_analyzer.unit_ids), dtype="int64")
9797
sorting = sorting_analyzer.sorting
98-
for unit_id in sorting.unit_ids:
99-
unit_index = sorting.id_to_index(unit_id)
100-
periods_unit = periods[periods["unit_index"] == unit_index]
101-
num_samples_in_period = 0
102-
for period in periods_unit:
103-
num_samples_in_period += period["end_sample_index"] - period["start_sample_index"]
104-
total_samples[unit_id] = num_samples_in_period
98+
for period in periods:
99+
unit_index = period["unit_index"]
100+
total_samples_array[unit_index] += period["end_sample_index"] - period["start_sample_index"]
101+
total_samples = dict(zip(sorting.unit_ids, total_samples_array))
105102
else:
106103
total = sorting_analyzer.get_total_samples()
107104
total_samples = {unit_id: total for unit_id in sorting_analyzer.unit_ids}

0 commit comments

Comments
 (0)