Skip to content

Commit bd5a121

Browse files
committed
perf: try out new cliping numba kernel
1 parent f70e681 commit bd5a121

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

benchmarks/benchmarks/preprocessing_log.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class PreprocessingSuite: # noqa: D101
3434

3535
def setup_cache(self) -> None:
3636
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
37-
for dataset, layer in product(*self.params):
37+
for dataset, layer in product(*self.params[:2]):
3838
adata, _ = get_dataset(dataset, layer=layer)
3939
adata.write_h5ad(f"{dataset}_{layer}.h5ad")
4040

@@ -47,17 +47,6 @@ def time_pca(self, *_) -> None:
4747
def peakmem_pca(self, *_) -> None:
4848
sc.pp.pca(self.adata, svd_solver="arpack")
4949

50-
def time_highly_variable_genes(self, *_) -> None:
51-
# the default flavor runs on log-transformed data
52-
sc.pp.highly_variable_genes(
53-
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
54-
)
55-
56-
def peakmem_highly_variable_genes(self, *_) -> None:
57-
sc.pp.highly_variable_genes(
58-
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
59-
)
60-
6150
# regress_out is very slow for this dataset
6251
@skip_when(dataset={"pbmc3k"})
6352
def time_regress_out(self, *_) -> None:
@@ -72,3 +61,23 @@ def time_scale(self, *_) -> None:
7261

7362
def peakmem_scale(self, *_) -> None:
7463
sc.pp.scale(self.adata, max_value=10)
64+
65+
66+
class HVGSuite(PreprocessingSuite): # noqa: D101
67+
params = (*params, ["seurat_v3", "cell_ranger", "seurat"])
68+
param_names = (*param_names, "flavor")
69+
70+
def setup(self, dataset, layer, flavor) -> None:
71+
self.adata = ad.read_h5ad(f"{dataset}_{layer}.h5ad")
72+
self.flavor = flavor
73+
74+
def time_highly_variable_genes(self, *_) -> None:
75+
# the default flavor runs on log-transformed data
76+
sc.pp.highly_variable_genes(
77+
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor
78+
)
79+
80+
def peakmem_highly_variable_genes(self, *_) -> None:
81+
sc.pp.highly_variable_genes(
82+
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor
83+
)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ filterwarnings = [
198198
"ignore:.*'(parseAll)'.*'(parse_all)':DeprecationWarning",
199199
# igraph vs leidenalg warning
200200
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
201+
"ignore:Detected unsupported threading environment:UserWarning",
202+
"ignore:Cannot cache compiled function",
201203
]
202204

203205
[tool.coverage.run]

src/scanpy/preprocessing/_highly_variable_genes.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fast_array_utils import stats
1414

1515
from .. import logging as logg
16-
from .._compat import CSBase, CSRBase, DaskArray, old_positionals, warn
16+
from .._compat import CSBase, CSRBase, DaskArray, njit, old_positionals, warn
1717
from .._settings import Verbosity, settings
1818
from .._utils import (
1919
check_nonnegative_integers,
@@ -98,8 +98,7 @@ def _(data_batch: CSBase, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]
9898
)
9999

100100

101-
# parallel=False needed for accuracy
102-
@numba.njit(cache=True, parallel=False) # noqa: TID251
101+
@njit
103102
def _sum_and_sum_squares_clipped(
104103
indices: NDArray[np.integer],
105104
data: NDArray[np.floating],
@@ -108,13 +107,34 @@ def _sum_and_sum_squares_clipped(
108107
clip_val: NDArray[np.float64],
109108
nnz: int,
110109
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
111-
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
112-
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
110+
"""
111+
Parallel implementation using thread-local buffers to avoid race conditions.
112+
113+
Previous implementation used parallel=False due to race condition on shared arrays.
114+
This version uses explicit thread-local reduction to restore both correctness
115+
and parallelism.
116+
"""
117+
# Thread-local accumulators for parallel reduction
118+
n_threads = numba.get_num_threads()
119+
squared_local = np.zeros((n_threads, n_cols), dtype=np.float64)
120+
sum_local = np.zeros((n_threads, n_cols), dtype=np.float64)
121+
122+
# Parallel accumulation into thread-local buffers (no race condition)
113123
for i in numba.prange(nnz):
124+
tid = numba.get_thread_id()
114125
idx = indices[i]
115126
element = min(np.float64(data[i]), clip_val[idx])
116-
squared_batch_counts_sum[idx] += element**2
117-
batch_counts_sum[idx] += element
127+
squared_local[tid, idx] += element**2
128+
sum_local[tid, idx] += element
129+
130+
# Reduction phase: combine thread-local results
131+
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
132+
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
133+
134+
for t in range(n_threads):
135+
for j in range(n_cols):
136+
squared_batch_counts_sum[j] += squared_local[t, j]
137+
batch_counts_sum[j] += sum_local[t, j]
118138

119139
return squared_batch_counts_sum, batch_counts_sum
120140

0 commit comments

Comments
 (0)