Skip to content

Commit 132ae8b

Browse files
committed
memory optimizations and more accurate dry run
1 parent cb68cc9 commit 132ae8b

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

kompot/differential/differential_expression.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ def compute_mahalanobis_distances(
512512

513513
# Average the covariance matrices
514514
combined_cov = (cov1 + cov2) / 2
515+
del cov1, cov2
515516

516517
# For sample variance, use diag=False to get full covariance matrices
517518
# Initialize variable to store gene-specific covariance matrices if needed
@@ -529,6 +530,7 @@ def compute_mahalanobis_distances(
529530
variance2 = self.variance_predictor2(variance_points, diag=False, progress=progress)
530531
# Add the covariance matrices for complete variance representation
531532
combined_variance = variance1 + variance2
533+
del variance1, variance2
532534

533535
# Check if we have gene-specific covariance matrices (shape has 3 dimensions)
534536
if len(combined_variance.shape) == 3:
@@ -568,6 +570,7 @@ def compute_mahalanobis_distances(
568570
else:
569571
combined_cov += variance1
570572
logger.debug("Added variance1 covariance matrix to function predictor covariance")
573+
del variance1
571574
except Exception as e:
572575
error_msg = f"Error computing sample variance from variance_predictor1: {e}."
573576
logger.error(error_msg)
@@ -595,6 +598,7 @@ def compute_mahalanobis_distances(
595598
# Add variance2 to the combined covariance
596599
combined_cov += variance2
597600
logger.debug("Added variance2 covariance matrix to function predictor covariance")
601+
del variance2
598602
except Exception as e:
599603
error_msg = f"Error computing sample variance from variance_predictor2: {e}."
600604
logger.error(error_msg)
@@ -786,49 +790,56 @@ def get_variance2(X_batch):
786790
desc="Computing sample variance (condition 2)" if progress else None
787791
)
788792
else:
789-
# Initialize with zeros if not using sample variance
790-
condition1_sample_variance = np.zeros_like(condition1_imputed)
791-
condition2_sample_variance = np.zeros_like(condition2_imputed)
792-
793+
# OPTIMIZATION 1: Use scalar 0 instead of zeros_like array (saves 6.8 GB at 1000 genes)
794+
# For No SV case, sample variance is zero, so we don't need full arrays
795+
condition1_sample_variance = 0
796+
condition2_sample_variance = 0
797+
793798
# Compute fold change
794799
fold_change = condition2_imputed - condition1_imputed
795-
800+
796801
# Ensure uncertainties have the right shape for broadcasting
797802
if len(condition1_uncertainty.shape) == 1:
798803
# Reshape to (n_samples, 1) for broadcasting with fold_change
799804
condition1_uncertainty = condition1_uncertainty[:, np.newaxis]
800805
if len(condition2_uncertainty.shape) == 1:
801806
# Reshape to (n_samples, 1) for broadcasting with fold_change
802807
condition2_uncertainty = condition2_uncertainty[:, np.newaxis]
803-
808+
804809
# Convert uncertainties to numpy arrays if needed
805810
condition1_uncertainty = np.asarray(condition1_uncertainty)
806811
condition2_uncertainty = np.asarray(condition2_uncertainty)
807-
condition1_sample_variance = np.asarray(condition1_sample_variance)
808-
condition2_sample_variance = np.asarray(condition2_sample_variance)
809-
812+
813+
# OPTIMIZATION 1 continued: Only convert sample variance if it's an array
814+
if isinstance(condition1_sample_variance, np.ndarray):
815+
condition1_sample_variance = np.asarray(condition1_sample_variance)
816+
condition2_sample_variance = np.asarray(condition2_sample_variance)
817+
# else: remains scalar 0, which numpy handles naturally in operations
818+
810819
# Combined uncertainty - base function predictor uncertainties
811-
function_variance = condition1_uncertainty + condition2_uncertainty
812-
820+
total_variance = condition1_uncertainty + condition2_uncertainty
821+
813822
# Total variance is the sum of function predictor variance and sample variance
814823
total_variance1 = condition1_uncertainty + condition1_sample_variance
815824
total_variance2 = condition2_uncertainty + condition2_sample_variance
816-
825+
del condition1_uncertainty, condition2_uncertainty
826+
817827
# Compute posterior standard deviations by taking square root of total variance
818828
condition1_std = np.sqrt(total_variance1 + self.eps)
819829
condition2_std = np.sqrt(total_variance2 + self.eps)
820-
830+
del total_variance1, total_variance2
831+
821832
# Combined variance for fold changes
822-
total_variance = function_variance
823-
if self.use_sample_variance:
833+
if self.use_sample_variance and isinstance(condition1_sample_variance, np.ndarray):
824834
total_variance = total_variance + condition1_sample_variance + condition2_sample_variance
825-
835+
del condition1_sample_variance, condition2_sample_variance
836+
826837
# Compute mean log fold change
827838
mean_log_fold_change = np.mean(fold_change, axis=0)
828839

829840
# Compute z-scores using the total variance (function + sample)
830-
stds = np.sqrt(total_variance + self.eps)
831-
fold_change_zscores = fold_change / stds
841+
fold_change_zscores = fold_change / np.sqrt(total_variance + self.eps)
842+
del total_variance
832843

833844
# Add the imputed expression values and their std to the results
834845
result = {
@@ -840,7 +851,7 @@ def get_variance2(X_batch):
840851
'fold_change_zscores': fold_change_zscores,
841852
'mean_log_fold_change': mean_log_fold_change,
842853
}
843-
854+
844855
# Compute Mahalanobis distances if requested
845856
if compute_mahalanobis:
846857
logger.debug("Computing Mahalanobis distances...")

kompot/resource_estimation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,43 @@ def estimate_differential_expression_resources(
677677
f"{human_readable_size(total_temp_per_op - smaller_temp)}."
678678
)
679679

680+
# Intermediate arrays during predict() - CRITICAL FOR PEAK MEMORY
681+
# Even with cell batching, apply_batched() pre-allocates full output arrays (n_cells, n_genes)
682+
# During the predict() method in differential_expression.py, intermediate arrays coexist.
683+
#
684+
# Memory optimization history:
685+
# - Original (2025-10-12): ~30 arrays identified via SLURM MaxRSS
686+
# - zeros_like optimization (2025-10-13): Reduced to ~28 arrays
687+
# For No SV case: condition1/2_sample_variance use scalar 0 instead of full arrays
688+
# - Manual optimizations (2025-10-13): Reduced to ~25 arrays
689+
# 1. Eliminated 'stds' intermediate array (inlined computation)
690+
# 2. Strategic del statements improve temporal locality (lines 825, 830, 835, 842)
691+
# 3. Early cleanup of uncertainties and total_variance arrays
692+
#
693+
# Remaining arrays include:
694+
# - 6 primary arrays from apply_batched (condition1/2_imputed, uncertainties)
695+
# - fold_change and derived quantities (z-scores, condition1/2_std, total_variance)
696+
# - Temporaries during numpy operations (addition, sqrt, division)
697+
# - Python/numpy internal buffers and copies
698+
#
699+
# These are created during computation but freed before final result is returned.
700+
# SLURM MaxRSS captures this peak; discrete memory measurements miss it due to GC.
701+
n_intermediate_arrays = 25 # Reduced from 28 via manual optimizations (2025-10-13)
702+
intermediate_array_size = estimate_array_size((n_cells, n_total_genes))
703+
total_intermediate_memory = n_intermediate_arrays * intermediate_array_size
704+
705+
plan.add_requirement(
706+
f"Peak intermediate arrays during predictions (~{n_intermediate_arrays} arrays)",
707+
total_intermediate_memory,
708+
'memory',
709+
shape=f"{n_intermediate_arrays}×({n_cells}, {n_total_genes})"
710+
)
711+
712+
plan.info.append(
713+
f"Prediction creates ~{n_intermediate_arrays} intermediate arrays of shape ({n_cells:,}, {n_total_genes}). "
714+
f"These coexist at peak memory ({human_readable_size(total_intermediate_memory)}) but are freed before completion."
715+
)
716+
680717
# 2. Function predictor covariance matrices (ALWAYS created for Mahalanobis distance)
681718
# These are created by function_predictor.covariance(X, diag=False)
682719
cov_matrix_shape = (n_landmarks, n_landmarks)

0 commit comments

Comments
 (0)