Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Support `reference=None` in `fine_tune` (random initialization) and used as default
- Added `fast` argument in `novae.plot.domains` to have a quick (but less accurate) rendering of domains.
- Added `novae.settings.scale_to_microns` if the coordinates are not in microns
- Use [`fast-array-utils`](https://github.com/scverse/fast-array-utils) to support multiple backends in `adata.X`

### Breaking changes
- Remove support for `python==3.10`
Expand Down
11 changes: 10 additions & 1 deletion docs/advice.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@ If you have a rare tissue or a tissue that was not used in our large dataset, yo
For the zero-shot and fine-tuning modes, you can provide a `reference` slide (or multiple slides). This allows to recompute the model prototypes (i.e., the centroids of the spatial domains) based on the chosen slides.

- For [zero-shot](../api/Novae/#novae.Novae.compute_representations), we use `reference="all"` by default, meaning we use all slides to recompute the prototypes. Depending on your use case, you may consider specifying one or multiple **representative** slides.
- For [fine-tuning](../api/Novae/#novae.Novae.fine_tune), we use `reference=None` by default, meaning we will initialize the prototypes randomly, and re-train them. If you have only one slide, it may be worth trying `reference="all"`.
- For [fine-tuning](../api/Novae/#novae.Novae.fine_tune), we use `reference=None` by default, meaning we will initialize the prototypes randomly, and re-train them. **If you have only one slide**, it may be worth trying `reference="all"`.

### Handling large datasets

Novae uses lazy-loading for the model training (i.e., you don't need a lot of GPU memory), but you still need to be able to load your dataset on CPUs. We recommend using sparse `csr_matrix` in `adata.X` by default, but, if your dataset becomes too large, sparse matrices may not be enough anymore.

In that case, you can use other backends, such as Dask (see the [AnnData tutorials](https://anndata.readthedocs.io/en/stable/tutorials/index.html)). You don't need to change anything in your code, `novae` will handle the Dask backend!

!!! info "Chunk sizes"
The chunk size will influence how fast the mini-batches are created. We will soon perform some benchmarks to see how best to choose the chunk size.

### Hyperparameters
We recommend using the default Novae hyperparameters, which should work great in most cases. Yet, if you confortable with Novae you might consider updating them. In that case, here are some of the most important hyperparameters in [`fit`](../api/Novae/#novae.Novae.fit) or [`fine_tune`](../api/Novae/#novae.Novae.fine_tune):
Expand Down
22 changes: 7 additions & 15 deletions novae/data/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import torch
from anndata import AnnData
from scipy.sparse import csr_matrix
from fast_array_utils import stats
from fast_array_utils.conv import to_dense
from sklearn.preprocessing import LabelEncoder
from torch import Tensor

Expand Down Expand Up @@ -47,14 +48,11 @@ def _compute_means_stds(self) -> tuple[Tensor, Tensor, LabelEncoder]:
for slide_id in slide_ids.cat.categories:
adata_slide = adata[adata.obs[Keys.SLIDE_ID] == slide_id, self._keep_var(adata)]

mean = adata_slide.X.mean(0)
mean = mean.A1 if isinstance(mean, np.matrix) else mean
means[slide_id] = mean.astype(np.float32)
mean, var = stats.mean_var(adata_slide.X, axis=0)
mean, var = to_dense(mean, to_cpu_memory=True), to_dense(var, to_cpu_memory=True)

std = (
adata_slide.X.std(0) if isinstance(adata_slide.X, np.ndarray) else _sparse_std(adata_slide.X, 0).A1
)
stds[slide_id] = std.astype(np.float32)
means[slide_id] = mean.astype(np.float32)
stds[slide_id] = np.sqrt(var).astype(np.float32)

label_encoder = LabelEncoder()
label_encoder.fit(list(means.keys()))
Expand Down Expand Up @@ -84,8 +82,7 @@ def to_tensor(self, adata: AnnData) -> Tensor:
mean = torch.stack([self.means[i] for i in slide_id_indices]) # TODO: avoid stack (only if not fast enough)
std = torch.stack([self.stds[i] for i in slide_id_indices])

X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray()
X = torch.tensor(X, dtype=torch.float32)
X = torch.tensor(to_dense(adata.X, to_cpu_memory=True), dtype=torch.float32)
X = (X - mean) / (std + Nums.EPS)

return X
Expand All @@ -108,8 +105,3 @@ def __getitem__(self, item: tuple[int, slice]) -> tuple[Tensor, Tensor]:
adata_view = adata[obs_indices]

return self.to_tensor(adata_view), self.genes_indices_list[adata_index]


def _sparse_std(a: csr_matrix, axis=None) -> np.matrix:
a_squared = a.multiply(a)
return np.sqrt(a_squared.mean(axis) - np.square(a.mean(axis)))
6 changes: 4 additions & 2 deletions novae/plot/_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import scanpy as sc
import seaborn as sns
from anndata import AnnData
from fast_array_utils import stats
from fast_array_utils.conv import to_dense
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D
from scanpy._utils import sanitize_anndata
Expand Down Expand Up @@ -192,8 +194,8 @@ def spatially_variable_genes(
axis=1,
)

where = (adata.X > 0).mean(0) > min_positive_ratio
valid_vars = adata.var_names[where.A1 if isinstance(where, np.matrix) else where]
positive_ratio = to_dense(stats.mean(adata.X > 0, axis=0), to_cpu_memory=True)
valid_vars = adata.var_names[positive_ratio > min_positive_ratio]
assert len(valid_vars) >= top_k, (
f"Only {len(valid_vars)} genes are available. Please decrease `top_k` or `min_positive_ratio`."
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"safetensors>=0.4.3",
"pandas>=2.0.0",
"igraph>=0.11.8",
"fast-array-utils>=1.3.1",
]

[project.optional-dependencies]
Expand Down
53 changes: 53 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import dask.array as da
import numpy as np
from fast_array_utils import stats
from fast_array_utils.conv import to_dense
from scipy.sparse import csr_matrix

x_np = np.array([
[1, 0, 0, 3],
[0, 5, 6, 0],
[7, 0, -1, 0],
])

x_sparse = csr_matrix(x_np)

x_dask = da.from_array(x_np)


def test_backends_mean_var():
mean_np, std_np = np.mean(x_np, axis=0), np.std(x_np, axis=0)

fau_mean_np, fau_var_np = stats.mean_var(x_np, axis=0)
fau_mean_sparse, fau_var_sparse = stats.mean_var(x_sparse, axis=0)
fau_mean_dask, fau_var_dask = stats.mean_var(x_dask, axis=0)

assert np.allclose(mean_np, fau_mean_np)
assert np.allclose(std_np, fau_var_np**0.5)
assert np.allclose(mean_np, fau_mean_sparse)
assert np.allclose(std_np, fau_var_sparse**0.5)
assert np.allclose(mean_np, fau_mean_dask)
assert np.allclose(std_np, fau_var_dask**0.5)


def test_backends_min_max():
min_np, max_np = np.min(x_np, axis=0), np.max(x_np, axis=0)

fau_min_np = stats.min(x_np, axis=0)
fau_min_sparse = stats.min(x_sparse, axis=0)
fau_min_dask = stats.min(x_dask, axis=0)
fau_max_np = stats.max(x_np, axis=0)
fau_max_sparse = stats.max(x_sparse, axis=0)
fau_max_dask = stats.max(x_dask, axis=0)

assert np.allclose(min_np, fau_min_np)
assert np.allclose(max_np, fau_max_np)
assert np.allclose(min_np, fau_min_sparse)
assert np.allclose(max_np, fau_max_sparse)
assert np.allclose(min_np, fau_min_dask)
assert np.allclose(max_np, fau_max_dask)


def test_spatially_variable_genes_backend():
assert (to_dense(stats.mean(x_np > 0, axis=0), to_cpu_memory=True) > 0.5).sum() == 1
assert (to_dense(stats.mean(x_sparse > 0, axis=0), to_cpu_memory=True) > 0.5).sum() == 1
Loading
Loading