11from typing import List , Union
22
33import numpy as np
4+ import scanpy as sc
45from anndata import AnnData
6+ from sklearn .neighbors import KNeighborsClassifier
57
68
79def _split_train_test (adata : AnnData , annotation_col : str = "leiden" , test_frac : float = 0.2 ):
@@ -23,10 +25,12 @@ def simulate_query_reference(
2325 ctrl_batch : Union [List [str ], None ] = None ,
2426 annotation_col : str = "leiden" ,
2527 query_annotation : Union [List [str ], None ] = None ,
26- perturbation_type : str = "remove" ,
28+ perturbation_type : Union [ str , List [ str ]] = "remove" ,
2729 test_frac : float = 0.2 ,
28- DA_frac : float = 0.2 ,
30+ # DA_frac: float = 0.2,
31+ split_pc : int = 0 ,
2932 seed = 42 ,
33+ use_rep_shift : str = "X_scVI" ,
3034):
3135 """
3236 Split single-cell dataset in a atlas, control and query dataset.
@@ -56,12 +60,18 @@ def simulate_query_reference(
5660 will be removed from the samples in ctrl_batch (the fraction specified by DA_test)
5761 if equal to 'depletion' a fraction of the cells in population specified in query_annotation
5862 will be removed from the samples in query_batch (the fraction specified by DA_test)
63+ if equal to shift, the query population will be shifted along a principal component
64+
5965 test_frac:
6066 fraction of cells in each population to be included in the query group (only used if batch_col is None)
6167 DA_frac:
6268 the fraction of cells of query_annotation to keep in control if perturbation_type is 'expansion', or in query if perturbation_type is 'depletion'
69+ split_pc:
70+ index of PC to use for splitting (default: 0, using PC1) (only used if perturbation_type=shift)
6371 seed:
6472 random seed for sampling
73+ use_rep_shift:
74+ representation to use to find neighbors in atlas dataset for shift perturbation (default: 'X_scVI')
6575
6676 Returns:
6777 --------
@@ -99,29 +109,63 @@ def simulate_query_reference(
99109 query_annotation = np .random .choice (adata .obs [annotation_col ].unique (), size = 1 )
100110
101111 # Apply perturbation
102- if perturbation_type == "remove" :
103- adata .obs .loc [(adata .obs [annotation_col ].isin (query_annotation )), "is_train" ] = 0
104- if ctrl_batch is not None :
105- adata .obs .loc [(adata .obs [annotation_col ].isin (query_annotation )), "is_ctrl" ] = 0
106-
107- elif perturbation_type == "expansion" :
108- for b in ctrl_batch :
109- query_pop_cells = adata .obs_names [
110- (adata .obs [batch_col ] == b ) & (adata .obs [annotation_col ].isin (query_annotation ))
112+ if isinstance (perturbation_type , str ):
113+ perturb_types = [perturbation_type ] * len (query_annotation )
114+ elif isinstance (perturbation_type , list ):
115+ assert len (perturbation_type ) == len (
116+ query_annotation
117+ ), "If perturbation_type is a list, it should be the same length as query_annotation"
118+ perturb_types = perturbation_type .copy ()
119+ else :
120+ raise TypeError (
121+ "perturbation_type should be a string or a list of strings of the same length as query_annotation"
122+ )
123+
124+ perturb_annotations = query_annotation .copy ()
125+ oor_cells = []
126+
127+ for query_annotation , perturbation_type in zip (perturb_annotations , perturb_types ):
128+ if perturbation_type == "remove" :
129+ adata .obs .loc [(adata .obs [annotation_col ] == query_annotation ), "is_train" ] = 0
130+ if ctrl_batch is not None :
131+ adata .obs .loc [(adata .obs [annotation_col ] == query_annotation ), "is_ctrl" ] = 0
132+ oor_cells_p = adata .obs_names [adata .obs [annotation_col ] == query_annotation ].tolist ()
133+ oor_cells .extend (oor_cells_p )
134+
135+ elif perturbation_type == "shift" :
136+ split_pop_cells = adata .obs_names [
137+ (adata .obs [annotation_col ] == query_annotation ) & (adata .obs ["is_train" ] == 0 )
111138 ]
112- cells2remove = np .random .choice (query_pop_cells , size = int (np .round (len (query_pop_cells ) * (1 - DA_frac ))))
113- adata .obs .loc [cells2remove , "is_ctrl" ] = 0
114-
115- elif perturbation_type == "depletion" :
116- for b in query_batch :
117- query_pop_cells = adata .obs_names [
118- (adata .obs [batch_col ] == b ) & (adata .obs [annotation_col ].isin (query_annotation ))
139+ # Run PCA on perturbation population (just query dataset to avoid batch effects)
140+ split_pop_adata = adata [adata .obs_names .isin (split_pop_cells )].copy ()
141+ sc .pp .normalize_per_cell (split_pop_adata )
142+ sc .pp .log1p (split_pop_adata )
143+ sc .pp .pca (split_pop_adata )
144+ pc2split = split_pop_adata .obsm ["X_pca" ][:, split_pc ]
145+ test_size = int (np .round (len (split_pop_cells ) * 0.5 ))
146+ idx = np .argpartition (pc2split , test_size )
147+ cells2remove = split_pop_cells [idx [:test_size ]].values
148+
149+ # Find neighbors in atlas cells
150+ split_pop_adata .obs ["remove" ] = split_pop_adata .obs_names .isin (cells2remove ).astype (int )
151+ split_pop_cells_atlas = adata .obs_names [
152+ (adata .obs [annotation_col ] == query_annotation ) & (adata .obs ["is_train" ] == 1 )
119153 ]
120- cells2remove = np .random .choice (query_pop_cells , size = int (np .round (len (query_pop_cells ) * (1 - DA_frac ))))
121- adata .obs .loc [cells2remove , "is_query" ] = 0
154+ X_train = adata [split_pop_cells ].obsm [use_rep_shift ]
155+ Y_train = split_pop_adata .obs ["remove" ]
156+ X_atlas = adata [split_pop_cells_atlas ].obsm [use_rep_shift ]
122157
123- else :
124- raise ValueError ("perturbation type should be one of 'remove' or 'perturb_pc'" )
158+ neigh = KNeighborsClassifier (n_neighbors = 10 )
159+ neigh = neigh .fit (X_train , Y_train )
160+ atlas_cells2remove = split_pop_cells_atlas [neigh .predict (X_atlas ) == 1 ]
161+
162+ adata .obs .loc [cells2remove , "is_ctrl" ] = 0
163+ adata .obs .loc [atlas_cells2remove , "is_train" ] = 0
164+ oor_cells_p = adata .obs_names [(adata .obs ["is_test" ] == 1 ) & (adata .obs_names .isin (cells2remove ))].tolist ()
165+ oor_cells .extend (oor_cells_p )
166+
167+ else :
168+ raise ValueError ("perturbation type should be one of 'remove' or 'shift'" )
125169 adata .uns ["perturbation" ] = {
126170 "annotation_col" : annotation_col ,
127171 "batch_col" : batch_col ,
@@ -137,7 +181,10 @@ def simulate_query_reference(
137181 adata .obs ["dataset_group" ] = np .where (adata .obs ["is_train" ] == 1 , "atlas" , adata .obs ["dataset_group" ])
138182 adata = adata [adata .obs ["dataset_group" ] != "exclude" ].copy () # remove cells that are not in any group
139183
140- adata .obs ["OOR_state" ] = (adata .obs [annotation_col ].isin (query_annotation )).astype (int )
184+ # if perturbation_type == "remove":
185+ # adata.obs["OOR_state"] = (adata.obs[annotation_col].isin(query_annotation)).astype(int)
186+ # elif perturbation_type == "shift":
187+ adata .obs ["OOR_state" ] = (adata .obs_names .isin (oor_cells )).astype (int )
141188
142189 adata .obs ["cell_annotation" ] = adata .obs [annotation_col ].copy ()
143190 adata .obs ["sample_id" ] = adata .obs [batch_col ].copy ()
0 commit comments