Skip to content

Commit b56f1e2

Browse files
rjogradycopybara-github
authored andcommitted
Minor test cleanup
PiperOrigin-RevId: 828561786 Change-Id: I23835bed10ef9869ee4be8fae3e6c45a2be41f3e
1 parent 89cf9e9 commit b56f1e2

File tree

3 files changed

+41
-47
lines changed

3 files changed

+41
-47
lines changed

fleetbench/parallel/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ py_test(
123123
":result",
124124
":run",
125125
":weights",
126+
"@com_google_absl_py//absl/flags",
126127
"@com_google_absl_py//absl/testing:absltest",
127128
"@com_google_absl_py//absl/testing:flagsaver",
128129
"@com_google_absl_py//absl/testing:parameterized",

fleetbench/parallel/parallel_bench_lib.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,6 @@
3535
from fleetbench.parallel import worker
3636

3737

38-
def _SetExtraBenchmarkFlags(
39-
benchmark_perf_counters: list[str],
40-
benchmark_repetitions: int,
41-
benchmark_min_time: str,
42-
) -> list[str]:
43-
"""Set extra benchmark flags."""
44-
benchmark_flags = []
45-
if benchmark_perf_counters:
46-
perf_counters_str = ",".join(benchmark_perf_counters)
47-
benchmark_flags.append(f"--benchmark_perf_counters={perf_counters_str}")
48-
if benchmark_min_time:
49-
benchmark_flags.append(f"--benchmark_min_time={benchmark_min_time}")
50-
if benchmark_repetitions:
51-
benchmark_flags.append(f"--benchmark_repetitions={benchmark_repetitions}")
52-
53-
return benchmark_flags
54-
55-
5638
@dataclasses.dataclass
5739
class BenchmarkMetrics:
5840
# per benchmark run total duration
@@ -177,8 +159,9 @@ def _PreRun(
177159

178160
logging.info("Initializing benchmarks and worker threads...")
179161

180-
benchmark_flags = _SetExtraBenchmarkFlags(
181-
self.perf_counters, benchmark_repetitions, benchmark_min_time
162+
benchmark_flags = self._SetExtraBenchmarkFlags(
163+
benchmark_repetitions,
164+
benchmark_min_time,
182165
)
183166

184167
if benchmark_flags:
@@ -304,6 +287,23 @@ def _ComputeBenchmarkWeights(self) -> np.ndarray:
304287

305288
return probabilities
306289

290+
def _SetExtraBenchmarkFlags(
291+
self,
292+
benchmark_repetitions: int,
293+
benchmark_min_time: str,
294+
) -> list[str]:
295+
"""Set extra benchmark flags."""
296+
benchmark_flags = []
297+
if self.perf_counters:
298+
perf_counters_str = ",".join(self.perf_counters)
299+
benchmark_flags.append(f"--benchmark_perf_counters={perf_counters_str}")
300+
if benchmark_min_time:
301+
benchmark_flags.append(f"--benchmark_min_time={benchmark_min_time}")
302+
if benchmark_repetitions:
303+
benchmark_flags.append(f"--benchmark_repetitions={benchmark_repetitions}")
304+
305+
return benchmark_flags
306+
307307
def _SelectNextBenchmarks(self, count: int) -> list[bm.Benchmark]:
308308
"""Randomly choose some benchmarks to run."""
309309

fleetbench/parallel/parallel_bench_lib_test.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import shutil
1818
from unittest import mock
1919

20+
from absl import flags
2021
from absl.testing import absltest
2122
from absl.testing import flagsaver
2223
from absl.testing import parameterized
@@ -30,41 +31,34 @@
3031
from fleetbench.parallel import run
3132
from fleetbench.parallel import weights
3233

34+
FLAGS = flags.FLAGS
35+
3336

3437
class ParallelBenchTest(parameterized.TestCase):
3538

3639
def setUp(self):
3740
super().setUp()
41+
self.temp_dir = self.create_tempdir()
3842
self.pb = parallel_bench_lib.ParallelBench(
3943
cpus=[0, 1],
4044
cpu_affinity=False,
4145
utilization=0.5,
4246
duration=0.1,
4347
repetitions=1,
44-
temp_parent_root=absltest.get_default_test_tmpdir(),
48+
temp_parent_root=self.temp_dir.full_path,
4549
keep_raw_data=True,
4650
benchmark_perf_counters="",
4751
benchmark_threads={},
4852
)
4953

50-
def tearDown(self):
51-
super().tearDown()
52-
for name in os.listdir(self.pb.temp_parent_root):
53-
if name.startswith("run_"):
54-
shutil.rmtree(os.path.join(self.pb.temp_parent_root, name))
55-
else:
56-
os.remove(os.path.join(self.pb.temp_parent_root, name))
57-
5854
@mock.patch.object(bm, "GetSubBenchmarks", autospec=True)
5955
@mock.patch.object(run.Run, "Execute", autospec=True)
6056
@mock.patch.object(cpu, "Utilization", autospec=True)
6157
@mock.patch.object(reporter, "GenerateBenchmarkReport", autospec=True)
6258
@mock.patch.object(
6359
reporter, "SaveBenchmarkResults", autospec=True, return_value=(None, None)
6460
)
65-
@flagsaver.flagsaver(
66-
benchmark_dir=absltest.get_default_test_tmpdir(),
67-
)
61+
@flagsaver.flagsaver
6862
def testRun(
6963
self,
7064
mock_save_benchmark_results,
@@ -73,6 +67,7 @@ def testRun(
7367
mock_execute,
7468
mock_get_subbenchmarks,
7569
):
70+
FLAGS.benchmark_dir = self.temp_dir.full_path
7671
mock_get_subbenchmarks.return_value = ["BM_Test1", "BM_Test2"]
7772
mock_execute.return_value = result.Result(
7873
benchmark="fake_bench (BM_Test1)",
@@ -83,9 +78,7 @@ def testRun(
8378
bm_cpu_time=0.01,
8479
result="fake_result",
8580
)
86-
self.create_tempfile(
87-
os.path.join(absltest.get_default_test_tmpdir(), "fake_bench")
88-
)
81+
self.create_tempfile(os.path.join(self.temp_dir.full_path, "fake_bench"))
8982

9083
def fake_utilization(unused_cpus):
9184
# Return 0% for the first call, then 55% for the rest.
@@ -102,7 +95,7 @@ def fake_utilization(unused_cpus):
10295
utilization=0.5,
10396
duration=0.1,
10497
repetitions=1,
105-
temp_parent_root=absltest.get_default_test_tmpdir(),
98+
temp_parent_root=self.temp_dir.full_path,
10699
keep_raw_data=True,
107100
benchmark_perf_counters="",
108101
benchmark_threads={},
@@ -116,6 +109,7 @@ def fake_utilization(unused_cpus):
116109
)
117110
self.pb.Run()
118111
mock_execute.assert_called_once()
112+
mock_save_benchmark_results.assert_called_once()
119113

120114
@mock.patch.object(parallel_bench_lib.ParallelBench, "_PreRun", autospec=True)
121115
@mock.patch.object(
@@ -166,9 +160,9 @@ def test_run_multiple_repetitions(
166160
self.assertEqual(args[3], 2)
167161

168162
def test_set_extra_benchmark_flags(self):
163+
self.pb.perf_counters = ["instructions"]
169164
self.assertEqual(
170-
parallel_bench_lib._SetExtraBenchmarkFlags(
171-
benchmark_perf_counters=["instructions"],
165+
self.pb._SetExtraBenchmarkFlags(
172166
benchmark_repetitions=10,
173167
benchmark_min_time="10s",
174168
),
@@ -179,9 +173,10 @@ def test_set_extra_benchmark_flags(self):
179173
],
180174
)
181175

176+
self.pb.perf_counters = ["instructions", "cycles"]
177+
182178
self.assertEqual(
183-
parallel_bench_lib._SetExtraBenchmarkFlags(
184-
benchmark_perf_counters=["instructions", "cycles"],
179+
self.pb._SetExtraBenchmarkFlags(
185180
benchmark_repetitions=0,
186181
benchmark_min_time="",
187182
),
@@ -197,9 +192,7 @@ def test_set_extra_benchmark_flags(self):
197192
@mock.patch.object(
198193
reporter, "SaveBenchmarkResults", autospec=True, return_value=(None, None)
199194
)
200-
@flagsaver.flagsaver(
201-
benchmark_dir=absltest.get_default_test_tmpdir(),
202-
)
195+
@flagsaver.flagsaver
203196
def testRunThreads(
204197
self,
205198
mock_save_benchmark_results,
@@ -208,6 +201,7 @@ def testRunThreads(
208201
mock_execute,
209202
mock_get_subbenchmarks,
210203
):
204+
FLAGS.benchmark_dir = self.temp_dir.full_path
211205
mock_get_subbenchmarks.return_value = ["BM_Test1"]
212206
mock_execute.return_value = result.Result(
213207
benchmark="fake_bench (BM_Test1)",
@@ -218,9 +212,7 @@ def testRunThreads(
218212
bm_cpu_time=0.01,
219213
result="fake_result",
220214
)
221-
self.create_tempfile(
222-
os.path.join(absltest.get_default_test_tmpdir(), "fake_bench")
223-
)
215+
self.create_tempfile(os.path.join(self.temp_dir.full_path, "fake_bench"))
224216

225217
def fake_utilization(unused_cpus):
226218
# Return 0% for the first call, then 55% for the rest.
@@ -237,7 +229,7 @@ def fake_utilization(unused_cpus):
237229
utilization=0.5,
238230
duration=0.1,
239231
repetitions=1,
240-
temp_parent_root=absltest.get_default_test_tmpdir(),
232+
temp_parent_root=self.temp_dir.full_path,
241233
keep_raw_data=True,
242234
benchmark_perf_counters="",
243235
benchmark_threads={"BM_Test1": 2},
@@ -251,6 +243,7 @@ def fake_utilization(unused_cpus):
251243
)
252244
self.pb.Run()
253245
mock_execute.assert_called_once()
246+
mock_save_benchmark_results.assert_called_once()
254247

255248
def test_convert_to_dataframe(self):
256249
# First entries are fake durations, the second entries are real durations.

0 commit comments

Comments
 (0)