|
28 | 28 |
|
29 | 29 | import os |
30 | 30 | import subprocess |
31 | | -from pathlib import Path |
32 | 31 |
|
33 | 32 | import pytest |
34 | 33 | import torch |
35 | 34 |
|
36 | 35 | from bionemo.evo2.models.evo2_provider import HyenaInferenceContext |
37 | 36 |
|
38 | 37 |
|
| 38 | +def find_free_network_port(address: str = "localhost") -> int: |
| 39 | + """Find a free port on localhost for distributed testing.""" |
| 40 | + import socket |
| 41 | + |
| 42 | + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| 43 | + s.bind((address, 0)) |
| 44 | + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 45 | + return s.getsockname()[1] |
| 46 | + |
| 47 | + |
39 | 48 | @pytest.fixture(scope="module") |
40 | 49 | def mbridge_checkpoint_path(tmp_path_factory): |
41 | 50 | """Create or use an MBridge checkpoint for testing. |
@@ -101,16 +110,15 @@ def test_infer_runs(mbridge_checkpoint_path, tmp_path): |
101 | 110 |
|
102 | 111 | env = os.environ.copy() |
103 | 112 | env["MASTER_ADDR"] = "localhost" |
104 | | - env["MASTER_PORT"] = "29501" |
| 113 | + env["MASTER_PORT"] = str(find_free_network_port()) |
105 | 114 |
|
106 | 115 | result = subprocess.run( |
107 | 116 | cmd, |
108 | 117 | check=False, |
109 | 118 | capture_output=True, |
110 | 119 | text=True, |
111 | | - timeout=300, |
| 120 | + timeout=300, # 5 minutes |
112 | 121 | env=env, |
113 | | - cwd=str(Path(__file__).parent.parent.parent.parent.parent), |
114 | 122 | ) |
115 | 123 |
|
116 | 124 | assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" |
@@ -148,16 +156,15 @@ def test_infer_temperature(mbridge_checkpoint_path, tmp_path, temperature): |
148 | 156 |
|
149 | 157 | env = os.environ.copy() |
150 | 158 | env["MASTER_ADDR"] = "localhost" |
151 | | - env["MASTER_PORT"] = "29502" |
| 159 | + env["MASTER_PORT"] = str(find_free_network_port()) |
152 | 160 |
|
153 | 161 | result = subprocess.run( |
154 | 162 | cmd, |
155 | 163 | check=False, |
156 | 164 | capture_output=True, |
157 | 165 | text=True, |
158 | | - timeout=300, |
| 166 | + timeout=300, # 5 minutes |
159 | 167 | env=env, |
160 | | - cwd=str(Path(__file__).parent.parent.parent.parent.parent), |
161 | 168 | ) |
162 | 169 |
|
163 | 170 | assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" |
@@ -189,16 +196,15 @@ def test_infer_top_k(mbridge_checkpoint_path, tmp_path): |
189 | 196 |
|
190 | 197 | env = os.environ.copy() |
191 | 198 | env["MASTER_ADDR"] = "localhost" |
192 | | - env["MASTER_PORT"] = "29503" |
| 199 | + env["MASTER_PORT"] = str(find_free_network_port()) |
193 | 200 |
|
194 | 201 | result = subprocess.run( |
195 | 202 | cmd, |
196 | 203 | check=False, |
197 | 204 | capture_output=True, |
198 | 205 | text=True, |
199 | | - timeout=300, |
| 206 | + timeout=300, # 5 minutes |
200 | 207 | env=env, |
201 | | - cwd=str(Path(__file__).parent.parent.parent.parent.parent), |
202 | 208 | ) |
203 | 209 |
|
204 | 210 | assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" |
@@ -245,16 +251,15 @@ def test_infer_phylogenetic_prompt(mbridge_checkpoint_path, tmp_path): |
245 | 251 |
|
246 | 252 | env = os.environ.copy() |
247 | 253 | env["MASTER_ADDR"] = "localhost" |
248 | | - env["MASTER_PORT"] = "29504" |
| 254 | + env["MASTER_PORT"] = str(find_free_network_port()) |
249 | 255 |
|
250 | 256 | result = subprocess.run( |
251 | 257 | cmd, |
252 | 258 | check=False, |
253 | 259 | capture_output=True, |
254 | 260 | text=True, |
255 | | - timeout=300, |
| 261 | + timeout=300, # 5 minutes |
256 | 262 | env=env, |
257 | | - cwd=str(Path(__file__).parent.parent.parent.parent.parent), |
258 | 263 | ) |
259 | 264 |
|
260 | 265 | assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" |
|
0 commit comments