Skip to content

Commit d8d55e6

Browse files
committed
Cleanup better in tests so that we can run the full suite
Signed-off-by: John St. John <[email protected]>
1 parent 9e32dfd commit d8d55e6

File tree

2 files changed

+101
-12
lines changed

2 files changed

+101
-12
lines changed

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
# conftest.py
1818
import gc
19+
import os
20+
import random
21+
import signal
22+
import time
1923

24+
import numpy as np
2025
import pytest
2126
import torch
2227

@@ -60,13 +65,91 @@ def pytest_sessionfinish(session, exitstatus):
6065
)
6166

6267

68+
def _cleanup_child_processes():
69+
"""Kill any orphaned child processes that might be holding GPU memory.
70+
71+
This is particularly important for tests that spawn subprocesses via torchrun.
72+
"""
73+
import subprocess
74+
75+
current_pid = os.getpid()
76+
try:
77+
# Find child processes
78+
result = subprocess.run(
79+
["pgrep", "-P", str(current_pid)], check=False, capture_output=True, text=True, timeout=5
80+
)
81+
child_pids = result.stdout.strip().split("\n")
82+
for pid_str in child_pids:
83+
if pid_str:
84+
try:
85+
pid = int(pid_str)
86+
os.kill(pid, signal.SIGTERM)
87+
except (ValueError, ProcessLookupError, PermissionError):
88+
pass
89+
except (subprocess.TimeoutExpired, FileNotFoundError):
90+
pass
91+
92+
93+
def _thorough_gpu_cleanup():
94+
"""Perform thorough GPU memory cleanup."""
95+
if not torch.cuda.is_available():
96+
return
97+
98+
# Synchronize all CUDA streams to ensure all operations are complete
99+
torch.cuda.synchronize()
100+
101+
# Clear all cached memory
102+
torch.cuda.empty_cache()
103+
104+
# Reset peak memory stats
105+
torch.cuda.reset_peak_memory_stats()
106+
107+
# Run garbage collection multiple times to ensure all objects are collected
108+
for _ in range(3):
109+
gc.collect()
110+
111+
# Another sync and cache clear after gc
112+
torch.cuda.synchronize()
113+
torch.cuda.empty_cache()
114+
115+
# Small sleep to allow GPU memory to be fully released
116+
time.sleep(0.1)
117+
118+
119+
def _reset_random_seeds():
120+
"""Reset random seeds to ensure reproducibility across tests.
121+
122+
Some tests may modify global random state, which can affect subsequent tests
123+
that depend on random splitting (like dataset preprocessing).
124+
"""
125+
# Reset Python's random module
126+
random.seed(None)
127+
128+
# Reset NumPy's random state (intentionally using legacy API to reset global state)
129+
np.random.seed(None) # noqa: NPY002
130+
131+
# Reset PyTorch's random state
132+
torch.seed()
133+
if torch.cuda.is_available():
134+
torch.cuda.seed_all()
135+
136+
63137
@pytest.fixture(autouse=True)
64138
def cleanup_after_test():
65-
"""Clean up GPU memory after each test."""
139+
"""Clean up GPU memory and reset state after each test."""
140+
# Reset random seeds before the test to ensure reproducibility
141+
_reset_random_seeds()
142+
66143
yield
67-
if torch.cuda.is_available():
68-
torch.cuda.empty_cache()
69-
gc.collect()
144+
145+
# After the test, perform thorough cleanup
146+
_thorough_gpu_cleanup()
147+
148+
# Clean up any orphaned child processes (important for subprocess tests)
149+
_cleanup_child_processes()
150+
151+
# Final garbage collection
152+
gc.collect()
70153

71154

72155
def pytest_addoption(parser: pytest.Parser):

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -546,11 +546,17 @@ def test_fallback_functions_import_error_messages(self):
546546

547547
def test_einops_import_error(self):
548548
"""Test that the einops import error is raised with the correct message."""
549-
# Mock the import to fail
550-
with patch.dict("sys.modules", {"einops": None}):
551-
# Re-import the module to trigger the import error
552-
with pytest.raises(ImportError, match="einops is required by the Hyena model but cannot be imported"):
553-
import bionemo.evo2.models.megatron.hyena.hyena_utils
554-
555-
# Force a reload of the module to trigger the import error
556-
importlib.reload(bionemo.evo2.models.megatron.hyena.hyena_utils)
549+
import bionemo.evo2.models.megatron.hyena.hyena_utils
550+
551+
try:
552+
# Mock the import to fail
553+
with patch.dict("sys.modules", {"einops": None}):
554+
# Re-import the module to trigger the import error
555+
with pytest.raises(ImportError, match="einops is required by the Hyena model but cannot be imported"):
556+
# Force a reload of the module to trigger the import error
557+
importlib.reload(bionemo.evo2.models.megatron.hyena.hyena_utils)
558+
finally:
559+
# CRITICAL: Always restore the module to its proper state after the test.
560+
# The reload above leaves the module in a corrupted state, which can cause
561+
# subsequent tests to fail (especially test_infer.py tests).
562+
importlib.reload(bionemo.evo2.models.megatron.hyena.hyena_utils)

0 commit comments

Comments
 (0)