1313from fast_array_utils import stats
1414
1515from .. import logging as logg
16- from .._compat import CSBase , CSRBase , DaskArray , old_positionals , warn
16+ from .._compat import CSBase , CSRBase , DaskArray , njit , old_positionals , warn
1717from .._settings import Verbosity , settings
1818from .._utils import (
1919 check_nonnegative_integers ,
@@ -98,8 +98,7 @@ def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]
9898 )
9999
100100
101- # parallel=False needed for accuracy
102- @numba .njit (cache = True , parallel = False ) # noqa: TID251
101+ @njit
103102def _sum_and_sum_squares_clipped (
104103 indices : NDArray [np .integer ],
105104 data : NDArray [np .floating ],
@@ -108,13 +107,34 @@ def _sum_and_sum_squares_clipped(
108107 clip_val : NDArray [np .float64 ],
109108 nnz : int ,
110109) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
111- squared_batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
112- batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
110+ """
111+ Parallel implementation using thread-local buffers to avoid race conditions.
112+
113+ Previous implementation used parallel=False due to race condition on shared arrays.
114+ This version uses explicit thread-local reduction to restore both correctness
115+ and parallelism.
116+ """
117+ # Thread-local accumulators for parallel reduction
118+ n_threads = numba .get_num_threads ()
119+ squared_local = np .zeros ((n_threads , n_cols ), dtype = np .float64 )
120+ sum_local = np .zeros ((n_threads , n_cols ), dtype = np .float64 )
121+
122+ # Parallel accumulation into thread-local buffers (no race condition)
113123 for i in numba .prange (nnz ):
124+ tid = numba .get_thread_id ()
114125 idx = indices [i ]
115126 element = min (np .float64 (data [i ]), clip_val [idx ])
116- squared_batch_counts_sum [idx ] += element ** 2
117- batch_counts_sum [idx ] += element
127+ squared_local [tid , idx ] += element ** 2
128+ sum_local [tid , idx ] += element
129+
130+ # Reduction phase: combine thread-local results
131+ squared_batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
132+ batch_counts_sum = np .zeros (n_cols , dtype = np .float64 )
133+
134+ for t in range (n_threads ):
135+ for j in range (n_cols ):
136+ squared_batch_counts_sum [j ] += squared_local [t , j ]
137+ batch_counts_sum [j ] += sum_local [t , j ]
118138
119139 return squared_batch_counts_sum , batch_counts_sum
120140
0 commit comments