@@ -55,6 +55,8 @@ def _handle_backward_compatibility_on_load(self):
5555 # make compatible analyzer created between february 24 and july 24
5656 self .params ["max_lag_ms" ] = 0.0
5757 self .params ["support" ] = "union"
58+ if "lags" not in self .data :
59+ self .data ["lags" ] = np .zeros_like (self .data ["similarity" ], dtype = np .int32 )
5860
5961 def _set_params (self , method = "cosine" , max_lag_ms = 0 , support = "union" ):
6062 if method == "cosine_similarity" :
@@ -180,7 +182,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
180182 s = self .data ["similarity" ][old_ind1 , old_units_inds ]
181183 similarity [unit_ind1 , sub_units_inds ] = s
182184 similarity [sub_units_inds , unit_ind1 ] = s
183-
185+
184186 l = self .data ["lags" ][old_ind1 , old_units_inds ]
185187 lags [unit_ind1 , sub_units_inds ] = l
186188 lags [sub_units_inds , unit_ind1 ] = l
@@ -258,8 +260,8 @@ def _compute_similarity_matrix_numpy(
258260 tgt_templates = tgt_sliced_templates [overlapping_templates ]
259261 for gcount , j in enumerate (overlapping_templates ):
260262 # symmetric values are handled later
261- # if same_array and j < i:
262- # no need exhaustive looping when same template
263+ #if same_array and j < i:
264+ # no need exhaustive looping when same template
263265 # continue
264266 src = src_template [:, local_mask [j ]].reshape (1 , - 1 )
265267 tgt = (tgt_templates [gcount ][:, local_mask [j ]]).reshape (1 , - 1 )
@@ -283,7 +285,7 @@ def _compute_similarity_matrix_numpy(
283285
284286 if same_array :
285287 distances [num_shifts_both_sides - count - 1 , j , i ] = distances [count , i , j ]
286-
288+
287289 return distances
288290
289291
@@ -353,8 +355,8 @@ def _compute_similarity_matrix_numba(
353355
354356 j = overlapping_templates [gcount ]
355357 # symmetric values are handled later
356- # if same_array and j < i:
357- # no need exhaustive looping when same template
358+ #if same_array and j < i:
359+ # no need exhaustive looping when same template
358360 # continue
359361 src = src_template [:, local_mask [j ]].flatten ()
360362 tgt = (tgt_templates [gcount ][:, local_mask [j ]]).flatten ()
@@ -393,7 +395,7 @@ def _compute_similarity_matrix_numba(
393395 if same_array :
394396 distances [num_shifts_both_sides - count - 1 , j , i ] = distances [count , i , j ]
395397
396- # if same_array and num_shifts != 0:
398+ #if same_array and num_shifts != 0:
397399 # distances[num_shifts_both_sides - count - 1] = distances[count].T
398400
399401 return distances
0 commit comments