Skip to content

Commit d18be07

Browse files
authored
Merge pull request #2 from MarioniLab/devel
2 parents 9117c35 + 646d9e8 commit d18be07

20 files changed

+1441
-60
lines changed

.github/workflows/test.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
push:
55
branches: [master]
66
pull_request:
7-
branches: [master]
7+
branches: [master, devel]
88

99
jobs:
1010
test:
@@ -25,7 +25,7 @@ jobs:
2525

2626
steps:
2727
- uses: actions/checkout@v2
28-
- uses: r-lib/actions/setup-r@v1
28+
- uses: r-lib/actions/setup-r@v2
2929
- name: Set up Python ${{ matrix.python }}
3030
uses: actions/setup-python@v2
3131
with:
@@ -62,5 +62,8 @@ jobs:
6262
- name: Upload coverage
6363
env:
6464
CODECOV_NAME: ${{ matrix.python }}-${{ matrix.os }}
65-
run: |
66-
codecov --required --flags=unittests
65+
# run: |
66+
# codecov --required --flags=unittests
67+
uses: codecov/codecov-action@v3
68+
with:
69+
token: ${{secrets.CODECOV_TOKEN}}

.readthedocs.yaml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
# https://docs.readthedocs.io/en/stable/config-file/v2.html
22
version: 2
3-
4-
conda:
5-
environment: environment.yaml
6-
73
build:
84
os: ubuntu-20.04
95
tools:
106
python: "3.10"
11-
127
sphinx:
138
configuration: docs/conf.py
9+
# disable this for more lenient docs builds
1410
fail_on_warning: true
15-
1611
python:
1712
install:
18-
- requirements: docs/requirements.txt
13+
- method: pip
14+
path: .
15+
extra_requirements:
16+
- doc

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ dependencies = [
2626
"scanpy",
2727
"scvi-tools",
2828
"milopy @ git+https://github.com/emdann/milopy.git@master",
29-
"sklearn"
29+
"sklearn",
30+
"meld",
31+
"cna"
3032
]
3133

3234
[project.optional-dependencies]

src/oor_benchmark/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def check_method(adata: AnnData):
2323
assert "OOR_score" in adata.uns["sample_adata"].var
2424
assert "OOR_signif" in adata.uns["sample_adata"].var
2525
assert all(adata.uns["sample_adata"].var["OOR_signif"].isin([0, 1]))
26-
assert "groups" in adata.uns["sample_adata"].varm
27-
assert isinstance(adata.uns["sample_adata"].varm["groups"], csc_matrix)
26+
if "groups" in adata.uns["sample_adata"].varm:
27+
assert isinstance(adata.uns["sample_adata"].varm["groups"], csc_matrix)
2828
return True
2929

3030

@@ -49,8 +49,8 @@ def sample_dataset():
4949
adata.obs.loc[adata.obs["sample_id"].isin([f"S{n}" for n in range(8)]), "dataset_group"] = "atlas"
5050
adata.obs.loc[adata.obs["sample_id"].isin([f"S{n}" for n in range(8, 12)]), "dataset_group"] = "ctrl"
5151
adata.obs.loc[adata.obs["sample_id"].isin([f"S{n}" for n in range(12, 16)]), "dataset_group"] = "query"
52-
# # Make out-of-reference cell state
53-
# adata.obs["OOR_state"] = np.where(adata.obs["louvain"] == "B cells", 1, 0)
54-
# remove_cells = adata.obs_names[(adata.obs["OOR_state"] == 1) & (adata.obs["dataset_group"] != "query")]
55-
# adata = adata[~adata.obs_names.isin(remove_cells)].copy()
52+
# Make out-of-reference cell state
53+
adata.obs["OOR_state"] = np.where(adata.obs["louvain"] == "B cells", 1, 0)
54+
remove_cells = adata.obs_names[(adata.obs["OOR_state"] == 1) & (adata.obs["dataset_group"] != "query")]
55+
adata = adata[~adata.obs_names.isin(remove_cells)].copy()
5656
return adata

src/oor_benchmark/datasets/simulation.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import List, Union
22

33
import numpy as np
4+
import scanpy as sc
45
from anndata import AnnData
6+
from sklearn.neighbors import KNeighborsClassifier
57

68

79
def _split_train_test(adata: AnnData, annotation_col: str = "leiden", test_frac: float = 0.2):
@@ -23,10 +25,12 @@ def simulate_query_reference(
2325
ctrl_batch: Union[List[str], None] = None,
2426
annotation_col: str = "leiden",
2527
query_annotation: Union[List[str], None] = None,
26-
perturbation_type: str = "remove",
28+
perturbation_type: Union[str, List[str]] = "remove",
2729
test_frac: float = 0.2,
28-
DA_frac: float = 0.2,
30+
# DA_frac: float = 0.2,
31+
split_pc: int = 0,
2932
seed=42,
33+
use_rep_shift: str = "X_scVI",
3034
):
3135
"""
3236
Split single-cell dataset in a atlas, control and query dataset.
@@ -56,12 +60,18 @@ def simulate_query_reference(
5660
will be removed from the samples in ctrl_batch (the fraction specified by DA_test)
5761
if equal to 'depletion' a fraction of the cells in population specified in query_annotation
5862
will be removed from the samples in query_batch (the fraction specified by DA_test)
63+
if equal to shift, the query population will be shifted along a principal component
64+
5965
test_frac:
6066
fraction of cells in each population to be included in the query group (only used if batch_col is None)
6167
DA_frac:
6268
the fraction of cells of query_annotation to keep in control if perturbation_type is 'expansion', or in query if perturbation_type is 'depletion'
69+
split_pc:
70+
index of PC to use for splitting (default: 0, using PC1) (only used if perturbation_type=shift)
6371
seed:
6472
random seed for sampling
73+
use_rep_shift:
74+
representation to use to find neighbors in atlas dataset for shift perturbation (default: 'X_scVI')
6575
6676
Returns:
6777
--------
@@ -99,29 +109,63 @@ def simulate_query_reference(
99109
query_annotation = np.random.choice(adata.obs[annotation_col].unique(), size=1)
100110

101111
#  Apply perturbation
102-
if perturbation_type == "remove":
103-
adata.obs.loc[(adata.obs[annotation_col].isin(query_annotation)), "is_train"] = 0
104-
if ctrl_batch is not None:
105-
adata.obs.loc[(adata.obs[annotation_col].isin(query_annotation)), "is_ctrl"] = 0
106-
107-
elif perturbation_type == "expansion":
108-
for b in ctrl_batch:
109-
query_pop_cells = adata.obs_names[
110-
(adata.obs[batch_col] == b) & (adata.obs[annotation_col].isin(query_annotation))
112+
if isinstance(perturbation_type, str):
113+
perturb_types = [perturbation_type] * len(query_annotation)
114+
elif isinstance(perturbation_type, list):
115+
assert len(perturbation_type) == len(
116+
query_annotation
117+
), "If perturbation_type is a list, it should be the same length as query_annotation"
118+
perturb_types = perturbation_type.copy()
119+
else:
120+
raise TypeError(
121+
"perturbation_type should be a string or a list of strings of the same length as query_annotation"
122+
)
123+
124+
perturb_annotations = query_annotation.copy()
125+
oor_cells = []
126+
127+
for query_annotation, perturbation_type in zip(perturb_annotations, perturb_types):
128+
if perturbation_type == "remove":
129+
adata.obs.loc[(adata.obs[annotation_col] == query_annotation), "is_train"] = 0
130+
if ctrl_batch is not None:
131+
adata.obs.loc[(adata.obs[annotation_col] == query_annotation), "is_ctrl"] = 0
132+
oor_cells_p = adata.obs_names[adata.obs[annotation_col] == query_annotation].tolist()
133+
oor_cells.extend(oor_cells_p)
134+
135+
elif perturbation_type == "shift":
136+
split_pop_cells = adata.obs_names[
137+
(adata.obs[annotation_col] == query_annotation) & (adata.obs["is_train"] == 0)
111138
]
112-
cells2remove = np.random.choice(query_pop_cells, size=int(np.round(len(query_pop_cells) * (1 - DA_frac))))
113-
adata.obs.loc[cells2remove, "is_ctrl"] = 0
114-
115-
elif perturbation_type == "depletion":
116-
for b in query_batch:
117-
query_pop_cells = adata.obs_names[
118-
(adata.obs[batch_col] == b) & (adata.obs[annotation_col].isin(query_annotation))
139+
# Run PCA on perturbation population (just query dataset to avoid batch effects)
140+
split_pop_adata = adata[adata.obs_names.isin(split_pop_cells)].copy()
141+
sc.pp.normalize_per_cell(split_pop_adata)
142+
sc.pp.log1p(split_pop_adata)
143+
sc.pp.pca(split_pop_adata)
144+
pc2split = split_pop_adata.obsm["X_pca"][:, split_pc]
145+
test_size = int(np.round(len(split_pop_cells) * 0.5))
146+
idx = np.argpartition(pc2split, test_size)
147+
cells2remove = split_pop_cells[idx[:test_size]].values
148+
149+
# Find neighbors in atlas cells
150+
split_pop_adata.obs["remove"] = split_pop_adata.obs_names.isin(cells2remove).astype(int)
151+
split_pop_cells_atlas = adata.obs_names[
152+
(adata.obs[annotation_col] == query_annotation) & (adata.obs["is_train"] == 1)
119153
]
120-
cells2remove = np.random.choice(query_pop_cells, size=int(np.round(len(query_pop_cells) * (1 - DA_frac))))
121-
adata.obs.loc[cells2remove, "is_query"] = 0
154+
X_train = adata[split_pop_cells].obsm[use_rep_shift]
155+
Y_train = split_pop_adata.obs["remove"]
156+
X_atlas = adata[split_pop_cells_atlas].obsm[use_rep_shift]
122157

123-
else:
124-
raise ValueError("perturbation type should be one of 'remove' or 'perturb_pc'")
158+
neigh = KNeighborsClassifier(n_neighbors=10)
159+
neigh = neigh.fit(X_train, Y_train)
160+
atlas_cells2remove = split_pop_cells_atlas[neigh.predict(X_atlas) == 1]
161+
162+
adata.obs.loc[cells2remove, "is_ctrl"] = 0
163+
adata.obs.loc[atlas_cells2remove, "is_train"] = 0
164+
oor_cells_p = adata.obs_names[(adata.obs["is_test"] == 1) & (adata.obs_names.isin(cells2remove))].tolist()
165+
oor_cells.extend(oor_cells_p)
166+
167+
else:
168+
raise ValueError("perturbation type should be one of 'remove' or 'shift'")
125169
adata.uns["perturbation"] = {
126170
"annotation_col": annotation_col,
127171
"batch_col": batch_col,
@@ -137,7 +181,10 @@ def simulate_query_reference(
137181
adata.obs["dataset_group"] = np.where(adata.obs["is_train"] == 1, "atlas", adata.obs["dataset_group"])
138182
adata = adata[adata.obs["dataset_group"] != "exclude"].copy() # remove cells that are not in any group
139183

140-
adata.obs["OOR_state"] = (adata.obs[annotation_col].isin(query_annotation)).astype(int)
184+
# if perturbation_type == "remove":
185+
# adata.obs["OOR_state"] = (adata.obs[annotation_col].isin(query_annotation)).astype(int)
186+
# elif perturbation_type == "shift":
187+
adata.obs["OOR_state"] = (adata.obs_names.isin(oor_cells)).astype(int)
141188

142189
adata.obs["cell_annotation"] = adata.obs[annotation_col].copy()
143190
adata.obs["sample_id"] = adata.obs[batch_col].copy()

src/oor_benchmark/methods/_cna.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import cna
2+
from anndata import AnnData
3+
from multianndata import MultiAnnData
4+
5+
6+
def run_cna(adata_design: AnnData, query_group: str, reference_group: str, sample_col: str = "sample_id"):
7+
"""
8+
Run MELD to compute probability estimate per condition.
9+
10+
Following tutorial in https://nbviewer.org/github/yakirr/cna/blob/master/demo/demo.ipynb
11+
12+
Parameters:
13+
------------
14+
adata_design : AnnData
15+
AnnData object of disease and reference cells to compare
16+
query_group : str
17+
Name of query group in adata_design.obs['dataset_group']
18+
reference_group : str
19+
Name of reference group in adata_design.obs['dataset_group']
20+
sample_col : str
21+
Name of column in adata_design.obs to use as sample ID
22+
"""
23+
adata_design = MultiAnnData(adata_design, sampleid=sample_col)
24+
adata_design.obs["dataset_group"] = adata_design.obs["dataset_group"].astype("category")
25+
adata_design.obs["dataset_group_code"] = (
26+
adata_design.obs["dataset_group"].cat.reorder_categories([reference_group, query_group]).cat.codes
27+
)
28+
adata_design.obs_to_sample(["dataset_group_code"])
29+
30+
res = cna.tl.association(adata_design, adata_design.samplem.dataset_group_code)
31+
32+
adata_design.obs["CNA_ncorrs"] = res.ncorrs
33+
34+
return None

src/oor_benchmark/methods/_latent_embedding.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@ def embedding_scvi(adata_merge: AnnData, n_hvgs: int = 5000, outdir: str = None,
2727
dataset_groups = adata_merge.obs["dataset_group"].unique().tolist()
2828
dataset_groups.sort()
2929
ref_dataset = "".join(dataset_groups)
30-
# adata_merge = anndata.concat([adata_query, adata_ref])
31-
# adata_merge.layers["counts"] = adata_merge.X.copy()
3230
adata_merge_train = adata_merge.copy()
3331

3432
# Filter genes
3533
adata_merge_train.layers["counts"] = adata_merge_train.X.copy()
36-
_filter_genes_scvi(adata_merge_train)
34+
_filter_genes_scvi(adata_merge_train, n_hvgs=n_hvgs)
3735

3836
# Train scVI model
3937
if outdir is not None:
@@ -42,7 +40,6 @@ def embedding_scvi(adata_merge: AnnData, n_hvgs: int = 5000, outdir: str = None,
4240

4341
# Get latent embeddings
4442
adata_merge.obsm["X_scVI"] = model_scvi.get_latent_representation()
45-
# return adata_merge
4643

4744

4845
def embedding_scArches(
@@ -79,7 +76,7 @@ def embedding_scArches(
7976
assert ref_dataset in adata_merge.obs["dataset_group"].unique().tolist()
8077
adata_merge.layers["counts"] = adata_merge.X.copy()
8178
adata_ref_train = adata_merge[adata_merge.obs["dataset_group"] == ref_dataset].copy()
82-
_filter_genes_scvi(adata_ref_train)
79+
_filter_genes_scvi(adata_ref_train, n_hvgs=n_hvgs)
8380

8481
# Train scVI model
8582
if outdir is not None:
@@ -122,6 +119,9 @@ def _train_scVI(train_adata: AnnData, train_params: dict = None, outfile: str =
122119
\**kwargs : dict, optional
123120
Extra arguments to `scvi.model.SCVI.setup_anndata` (specifying batch etc)
124121
"""
122+
if train_params is None:
123+
train_params = {}
124+
125125
scvi.model.SCVI.setup_anndata(train_adata, layer="counts", **kwargs)
126126

127127
arches_params = {
@@ -173,7 +173,7 @@ def _fit_scVI(
173173
# --- Latent embedding utils --- #
174174

175175

176-
def _filter_genes_scvi(adata: AnnData):
176+
def _filter_genes_scvi(adata: AnnData, n_hvgs: int = 5000) -> None:
177177
"""Filter genes for latent embedding."""
178178
# Filter genes not expressed anywhere
179179
sc.pp.filter_genes(adata, min_cells=1)
@@ -183,4 +183,4 @@ def _filter_genes_scvi(adata: AnnData):
183183
sc.pp.normalize_per_cell(adata)
184184
sc.pp.log1p(adata)
185185

186-
sc.pp.highly_variable_genes(adata, n_top_genes=5000, subset=True)
186+
sc.pp.highly_variable_genes(adata, n_top_genes=n_hvgs, subset=True)

src/oor_benchmark/methods/_meld.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import meld
2+
import numpy as np
3+
import pandas as pd
4+
from anndata import AnnData
5+
6+
7+
def run_meld(
8+
adata_design: AnnData, query_group: str, reference_group: str, sample_col: str = "sample_id", n_neighbors=10
9+
):
10+
"""
11+
Run MELD to compute probability estimate per condition.
12+
13+
Parameters:
14+
------------
15+
adata_design : AnnData
16+
AnnData object of disease and reference cells to compare
17+
query_group : str
18+
Name of query group in adata_design.obs['dataset_group']
19+
reference_group : str
20+
Name of reference group in adata_design.obs['dataset_group']
21+
sample_col : str
22+
Name of column in adata_design.obs to use as sample ID
23+
n_neighbors : int
24+
Number of neighbors to use for MELD KNN graph (default: 10)
25+
"""
26+
adata_design.obs["is_query"] = adata_design.obs["dataset_group"] == query_group
27+
adata_design.uns["n_conditions"] = 2
28+
29+
# Complete the result in-place
30+
meld_op = meld.MELD(knn=n_neighbors, verbose=True)
31+
adata_design.obsm["sample_densities"] = meld_op.fit_transform(
32+
adata_design.obsm["X_scVI"], sample_labels=adata_design.obs[sample_col]
33+
).set_index(adata_design.obs_names)
34+
35+
# Normalize the probability estimates for each condition per replicate
36+
adata_design.obsm["probability_estimate"] = pd.DataFrame(
37+
np.zeros(shape=(adata_design.n_obs, adata_design.uns["n_conditions"])),
38+
index=adata_design.obs_names,
39+
columns=["query", "reference"],
40+
)
41+
42+
query_samples = adata_design.obs["sample_id"][adata_design.obs["dataset_group"] == query_group].unique().tolist()
43+
reference_samples = (
44+
adata_design.obs["sample_id"][adata_design.obs["dataset_group"] == reference_group].unique().tolist()
45+
)
46+
47+
adata_design.obsm["probability_estimate"]["query"] = adata_design.obsm["sample_densities"][query_samples].mean(
48+
axis=1
49+
)
50+
adata_design.obsm["probability_estimate"]["reference"] = adata_design.obsm["sample_densities"][
51+
reference_samples
52+
].mean(axis=1)
53+
adata_design.obsm["probability_estimate"] = meld.utils.normalize_densities(
54+
adata_design.obsm["probability_estimate"]
55+
)
56+
57+
return None

0 commit comments

Comments
 (0)