1+ # ' Quick group-aware zero-inflation check (Negative Binomial baseline via edgeR)
2+ # '
3+ # ' @description
4+ # ' Computes a **sample group-aware Zero-Inflation (ZI) index** for each gene using a
5+ # ' negative-binomial (NB) baseline fitted with **edgeR**. For each group
6+ # ' (e.g., drug condition), the function:
7+ # ' 1) estimates gene-wise tagwise dispersions with edgeR (using all selected groups),
8+ # ' 2) builds NB-expected zero probabilities from TMMwsp-scaled means, and
9+ # ' 3) returns per-gene ZI (observed zeros minus NB-expected zeros) and
10+ # ' per-group summaries (e.g., % genes with ZI > 0.05). ZI-cutoffs are user-defined.
11+ # '
12+ # ' This is intended as a **fast screening diagnostic** to decide whether
13+ # ' standard NB GLM methods (edgeR/DESeq2) are adequate or whether a
14+ # ' zero-aware workflow (e.g., ZINB-WaVE) might be warranted.
15+ # '
16+ # ' This function **relies on edgeR** to estimate dispersion. The current
17+ # ' implementation requires **≥2 groups** in the design so that edgeR can
18+ # ' stabilize gene-wise dispersions across groups. If you only have a single
19+ # ' group and still want a design-aware baseline for expected zeros, fit a
20+ # ' Gamma–Poisson/NB GLM and compute the
21+ # ' expected zero probabilities from its fitted means and over-dispersion.
22+ # '
23+ # ' @param data Seurat object.
24+ # ' @param group_by Character, column in `[email protected] ` that defines groups25+ # ' (default: `"combined_id"`).
26+ # ' @param samples Character vector of group labels/patterns to include. If
27+ # ' `NULL` or if none match, all groups in `group_by` are used.
28+ # ' @param batch Optional batch indicator; if length 1, an intercept-free design
29+ # ' is used with group dummies.
30+ # ' @param cutoffs Numeric vector of user-supply ZI thresholds for summary statistics
31+ # '
32+ # ' @returns A list with:
33+ # ' * `gene_metrics_by_group`: long data frame (group × gene) with `p0_obs`,
34+ # ' `p0_nb`, `ZI`, and counts.
35+ # ' * `summary_by_group`: one row per group with medians and % ZI thresholds,
36+ # ' plus observed/expected zero **counts** for the group.
37+ # '
38+ # ' @note
39+ # ' - This is a **screening** tool; it is not a replacement for fitting a full
40+ # ' GLM with your actual design. If strong covariates exist, a GLM baseline
41+ # ' (e.g., `glmGamPoi::glm_gp`) will yield more faithful expected-zero rates.
42+ # ' - For single-group experiments, consider either adding a reference group or
43+ # ' switching to a GLM-based baseline that does not require multiple groups.
44+ # ' @export
45+ # ' @examples
46+ # ' data(mini_mac)
47+ # ' check_zeroinflation(mini_mac, group_by = "combined_id",
48+ # ' samples = c("DMSO_0","Staurosporine_10"))
49+
50+
51+ check_zeroinflation <- function (data = NULL ,
52+ group_by = NULL ,
53+ samples = NULL ,
54+ batch = 1 ,
55+ cutoffs = c(0.1 , 0.20 )
56+ ){
57+ validate_inputs <- function (data , group_by , samples , cutoffs ) {
58+ if (! inherits(data , " Seurat" )) {
59+ stop(" argument 'data' must be a Seurat or TidySeurat object." )
60+ }
61+ group_by <- if (is.null(group_by )) " combined_id" else group_by
62+
63+ # check samples in combined_id column
64+ meta_groups <- as.character(data @ meta.data [[group_by ]])
65+ matched_groups <- samples %in% meta_groups
66+ if (is.null(samples )){
67+ # all samples included
68+ samples <- unique(data @ meta.data [[group_by ]])
69+ cat(" All samples will be included in the combined_id column." )
70+ } else if (length(samples ) == 1 || length(which(matched_groups == TRUE )) < 2 ) {
71+ # need at least two groups for edgeR dispersion estimation
72+ stop(" Two treatment groups are needed to calculate dispersion using edgeR." )
73+ }
74+ # check cutoffs
75+ if (any(cutoffs < = 0 ) || any(cutoffs > = 1 )) {
76+ stop(" cutoffs must be between 0 and 1." )
77+ }
78+ return (list (data = data , group_by = group_by , samples = samples , cutoffs = cutoffs ))
79+ }
80+ validated <- validate_inputs(data , group_by , samples ,cutoffs )
81+ data <- validated $ data
82+ group_by <- validated $ group_by
83+ samples <- validated $ samples
84+ cutoffs <- validated $ cutoffs
85+ mac_data <- subset(data , subset = combined_id %in% samples )
86+ count_matrix <- GetAssayData(mac_data , assay = " RNA" , layer = " counts" )
87+ count_matrix <- Matrix :: Matrix(count_matrix , sparse = TRUE )
88+ obs_zero <- Matrix :: rowMeans(count_matrix == 0 )
89+ # Negative binomial expected zeros
90+ # using edgeR for dispersion estimation
91+ dge <- edgeR :: DGEList(counts = count_matrix )
92+ dge <- edgeR :: calcNormFactors(dge , method = " TMMwsp" )
93+ # design matrix
94+ combined_id <- mac_data $ combined_id
95+ # make up batch parameter
96+ model_matrix <- if (length(batch ) == 1 ) model.matrix(~ 0 + combined_id ) else
97+ model.matrix(~ 0 + combined_id + batch )
98+ # tagwise dispersion
99+ dge <- edgeR :: estimateDisp(dge , design = model_matrix )
100+ phi <- dge $ tagwise.dispersion # NB variance: mu + phi * mu^2 (phi >= 0)
101+ # Build per-sample NB mean mu_gj using TMMwsp-scaled library sizes
102+ # Effective library sizes
103+ eff_lib <- dge $ samples $ lib.size * dge $ samples $ norm.factors
104+ per_group_gene_metrics <- lapply(samples , function (g ){
105+ idx <- which(combined_id == g )
106+ n_wells <- length(idx )
107+ # sub count matrix for group g
108+ count_matrix_g <- count_matrix [, idx , drop = FALSE ]
109+ # Observed zeros within group g
110+ p0_obs_g <- Matrix :: rowMeans(count_matrix_g == 0 )
111+ # count zeros per gene within group g
112+ # sum later for summary
113+ obs_zero_num_g <- Matrix :: rowSums(count_matrix_g == 0 )
114+ # Group-specific q_{g,g} using only wells in group g
115+ eff_lib_g <- eff_lib [idx ]
116+ total_eff_lib_g <- sum(eff_lib_g )
117+ total_counts_per_gene_g <- Matrix :: rowSums(count_matrix_g )
118+ q_g_g <- as.numeric(total_counts_per_gene_g ) / total_eff_lib_g
119+ # NB-expected zeros within group g (average over wells in g)
120+ eps <- 1e-12
121+ phi_safe <- pmax(phi , eps )
122+ inv_phi <- 1 / phi_safe
123+ # Fast loop over wells in g, no GxJ materialization
124+ p0_nb_sum_g <- numeric (nrow(count_matrix ))
125+ for (j in seq_along(idx )) {
126+ Lj <- eff_lib_g [j ]
127+ mu_gj <- q_g_g * Lj
128+ p0_nb_sum_g <- p0_nb_sum_g + (1 + phi_safe * mu_gj )^ (- inv_phi )
129+ }
130+ p0_nb_g <- p0_nb_sum_g / length(idx )
131+ # Poisson fallback where phi ~ 0
132+ poi_idx <- which(phi < 1e-8 )
133+ if (length(poi_idx )) {
134+ mu_bar_g <- q_g_g * mean(eff_lib_g )
135+ p0_nb_g [poi_idx ] <- exp(- mu_bar_g [poi_idx ])
136+ }
137+ # ZI within group g
138+ zi_g <- p0_obs_g - p0_nb_g
139+ data.frame (
140+ group = g ,
141+ gene = rownames(count_matrix ),
142+ mean_count_group = total_counts_per_gene_g / length(idx ),
143+ dispersion = phi ,
144+ p0_obs = p0_obs_g ,
145+ obs_zeros_num = obs_zero_num_g ,
146+ p0_nb = p0_nb_g ,
147+ expected_zeros_num = p0_nb_g * n_wells ,
148+ ZI = zi_g ,
149+ stringsAsFactors = FALSE
150+ )
151+ })
152+ gene_metrics_by_group <- do.call(rbind , per_group_gene_metrics )
153+ # if there are more than one cutoffs, calculate pct_ZI_gt_ for each cutoff
154+ # Per-group summaries (one row per group)
155+ summary_by_group <- do.call(rbind , lapply(split(gene_metrics_by_group , gene_metrics_by_group $ group ), function (df ){
156+ list_a <- list (
157+ group = unique(df $ group ),
158+ n_genes = nrow(df ),
159+ n_wells = sum(combined_id == unique(df $ group )),
160+ median_p0_obs = median(df $ p0_obs ),
161+ median_p0_nb = median(df $ p0_nb ),
162+ median_ZI = median(df $ ZI ),
163+ observed_zeros_num = sum(df $ obs_zeros_num ),
164+ expected_zeros_num = sum(df $ expected_zeros_num )
165+ )
166+ list_b <- lapply(cutoffs , function (cutoff ){
167+ pct_name <- paste0(" pct_ZI_gt_" , cutoff )
168+ pct_value <- mean(df $ ZI > cutoff )
169+ setNames(list (pct_value ), pct_name )
170+ })
171+ as.data.frame(c(list_a , list_b ))
172+ }))
173+ # Return just the selected groups' indices instead of plate-level
174+ list (
175+ gene_metrics_by_group = gene_metrics_by_group , # long format: group × gene
176+ summary_by_group = summary_by_group
177+ )
178+
179+ }
0 commit comments