Skip to content

Commit 3a6e954

Browse files
committed
WIP
1 parent 8fdf96d commit 3a6e954

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/spikeinterface/postprocessing/template_similarity.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def _handle_backward_compatibility_on_load(self):
5555
# make compatible analyzer created between february 24 and july 24
5656
self.params["max_lag_ms"] = 0.0
5757
self.params["support"] = "union"
58+
if "lags" not in self.data:
59+
self.data["lags"] = np.zeros_like(self.data["similarity"], dtype=np.int32)
5860

5961
def _set_params(self, method="cosine", max_lag_ms=0, support="union"):
6062
if method == "cosine_similarity":
@@ -180,7 +182,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
180182
s = self.data["similarity"][old_ind1, old_units_inds]
181183
similarity[unit_ind1, sub_units_inds] = s
182184
similarity[sub_units_inds, unit_ind1] = s
183-
185+
184186
l = self.data["lags"][old_ind1, old_units_inds]
185187
lags[unit_ind1, sub_units_inds] = l
186188
lags[sub_units_inds, unit_ind1] = l
@@ -258,8 +260,8 @@ def _compute_similarity_matrix_numpy(
258260
tgt_templates = tgt_sliced_templates[overlapping_templates]
259261
for gcount, j in enumerate(overlapping_templates):
260262
# symmetric values are handled later
261-
# if same_array and j < i:
262-
# no need exhaustive looping when same template
263+
#if same_array and j < i:
264+
# no need exhaustive looping when same template
263265
# continue
264266
src = src_template[:, local_mask[j]].reshape(1, -1)
265267
tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)
@@ -283,7 +285,7 @@ def _compute_similarity_matrix_numpy(
283285

284286
if same_array:
285287
distances[num_shifts_both_sides - count - 1, j, i] = distances[count, i, j]
286-
288+
287289
return distances
288290

289291

@@ -353,8 +355,8 @@ def _compute_similarity_matrix_numba(
353355

354356
j = overlapping_templates[gcount]
355357
# symmetric values are handled later
356-
# if same_array and j < i:
357-
# no need exhaustive looping when same template
358+
#if same_array and j < i:
359+
# no need exhaustive looping when same template
358360
# continue
359361
src = src_template[:, local_mask[j]].flatten()
360362
tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten()
@@ -393,7 +395,7 @@ def _compute_similarity_matrix_numba(
393395
if same_array:
394396
distances[num_shifts_both_sides - count - 1, j, i] = distances[count, i, j]
395397

396-
# if same_array and num_shifts != 0:
398+
#if same_array and num_shifts != 0:
397399
# distances[num_shifts_both_sides - count - 1] = distances[count].T
398400

399401
return distances

0 commit comments

Comments
 (0)