@@ -512,6 +512,7 @@ def compute_mahalanobis_distances(
512512
513513 # Average the covariance matrices
514514 combined_cov = (cov1 + cov2 ) / 2
515+ del cov1 , cov2
515516
516517 # For sample variance, use diag=False to get full covariance matrices
517518 # Initialize variable to store gene-specific covariance matrices if needed
@@ -529,6 +530,7 @@ def compute_mahalanobis_distances(
529530 variance2 = self .variance_predictor2 (variance_points , diag = False , progress = progress )
530531 # Add the covariance matrices for complete variance representation
531532 combined_variance = variance1 + variance2
533+ del variance1 , variance2
532534
533535 # Check if we have gene-specific covariance matrices (shape has 3 dimensions)
534536 if len (combined_variance .shape ) == 3 :
@@ -568,6 +570,7 @@ def compute_mahalanobis_distances(
568570 else :
569571 combined_cov += variance1
570572 logger .debug ("Added variance1 covariance matrix to function predictor covariance" )
573+ del variance1
571574 except Exception as e :
572575 error_msg = f"Error computing sample variance from variance_predictor1: { e } ."
573576 logger .error (error_msg )
@@ -595,6 +598,7 @@ def compute_mahalanobis_distances(
595598 # Add variance2 to the combined covariance
596599 combined_cov += variance2
597600 logger .debug ("Added variance2 covariance matrix to function predictor covariance" )
601+ del variance2
598602 except Exception as e :
599603 error_msg = f"Error computing sample variance from variance_predictor2: { e } ."
600604 logger .error (error_msg )
@@ -786,49 +790,56 @@ def get_variance2(X_batch):
786790 desc = "Computing sample variance (condition 2)" if progress else None
787791 )
788792 else :
789- # Initialize with zeros if not using sample variance
790- condition1_sample_variance = np .zeros_like (condition1_imputed )
791- condition2_sample_variance = np .zeros_like (condition2_imputed )
792-
793+ # OPTIMIZATION 1: Use scalar 0 instead of zeros_like array (saves 6.8 GB at 1000 genes)
794+ # For No SV case, sample variance is zero, so we don't need full arrays
795+ condition1_sample_variance = 0
796+ condition2_sample_variance = 0
797+
793798 # Compute fold change
794799 fold_change = condition2_imputed - condition1_imputed
795-
800+
796801 # Ensure uncertainties have the right shape for broadcasting
797802 if len (condition1_uncertainty .shape ) == 1 :
798803 # Reshape to (n_samples, 1) for broadcasting with fold_change
799804 condition1_uncertainty = condition1_uncertainty [:, np .newaxis ]
800805 if len (condition2_uncertainty .shape ) == 1 :
801806 # Reshape to (n_samples, 1) for broadcasting with fold_change
802807 condition2_uncertainty = condition2_uncertainty [:, np .newaxis ]
803-
808+
804809 # Convert uncertainties to numpy arrays if needed
805810 condition1_uncertainty = np .asarray (condition1_uncertainty )
806811 condition2_uncertainty = np .asarray (condition2_uncertainty )
807- condition1_sample_variance = np .asarray (condition1_sample_variance )
808- condition2_sample_variance = np .asarray (condition2_sample_variance )
809-
812+
813+ # OPTIMIZATION 1 continued: Only convert sample variance if it's an array
814+ if isinstance (condition1_sample_variance , np .ndarray ):
815+ condition1_sample_variance = np .asarray (condition1_sample_variance )
816+ condition2_sample_variance = np .asarray (condition2_sample_variance )
817+ # else: remains scalar 0, which numpy handles naturally in operations
818+
810819 # Combined uncertainty - base function predictor uncertainties
811- function_variance = condition1_uncertainty + condition2_uncertainty
812-
820+ total_variance = condition1_uncertainty + condition2_uncertainty
821+
813822 # Total variance is the sum of function predictor variance and sample variance
814823 total_variance1 = condition1_uncertainty + condition1_sample_variance
815824 total_variance2 = condition2_uncertainty + condition2_sample_variance
816-
825+ del condition1_uncertainty , condition2_uncertainty
826+
817827 # Compute posterior standard deviations by taking square root of total variance
818828 condition1_std = np .sqrt (total_variance1 + self .eps )
819829 condition2_std = np .sqrt (total_variance2 + self .eps )
820-
830+ del total_variance1 , total_variance2
831+
821832 # Combined variance for fold changes
822- total_variance = function_variance
823- if self .use_sample_variance :
833+ if self .use_sample_variance and isinstance (condition1_sample_variance , np .ndarray ):
824834 total_variance = total_variance + condition1_sample_variance + condition2_sample_variance
825-
835+ del condition1_sample_variance , condition2_sample_variance
836+
826837 # Compute mean log fold change
827838 mean_log_fold_change = np .mean (fold_change , axis = 0 )
828839
829840 # Compute z-scores using the total variance (function + sample)
830- stds = np .sqrt (total_variance + self .eps )
831- fold_change_zscores = fold_change / stds
841+ fold_change_zscores = fold_change / np .sqrt (total_variance + self .eps )
842+ del total_variance
832843
833844 # Add the imputed expression values and their std to the results
834845 result = {
@@ -840,7 +851,7 @@ def get_variance2(X_batch):
840851 'fold_change_zscores' : fold_change_zscores ,
841852 'mean_log_fold_change' : mean_log_fold_change ,
842853 }
843-
854+
844855 # Compute Mahalanobis distances if requested
845856 if compute_mahalanobis :
846857 logger .debug ("Computing Mahalanobis distances..." )
0 commit comments