Skip to content

Commit 8554880

Browse files
Allow to by-pass device section check and compare different devices
Fixes: #297
1 parent f651636 commit 8554880

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

scripts/nvbench_compare.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def version_tuple(v):
1818

1919
tabulate_version = version_tuple(tabulate.__version__)
2020

21-
all_devices = []
21+
all_ref_devices = []
22+
all_cmp_devices = []
2223
config_count = 0
2324
unknown_count = 0
2425
failure_count = 0
@@ -32,7 +33,7 @@ def find_matching_bench(needle, haystack):
3233
return None
3334

3435

35-
def find_device_by_id(device_id):
36+
def find_device_by_id(device_id, all_devices):
3637
for device in all_devices:
3738
if device["id"] == device_id:
3839
return device
@@ -113,7 +114,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
113114

114115
print("# %s\n" % (cmp_bench["name"]))
115116

116-
device_ids = cmp_bench["devices"]
117+
cmp_device_ids = cmp_bench["devices"]
117118
axes = cmp_bench["axes"]
118119
ref_states = ref_bench["states"]
119120
cmp_states = cmp_bench["states"]
@@ -138,7 +139,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
138139
headers.append("Status")
139140
colalign.append("center")
140141

141-
for device_id in device_ids:
142+
for cmp_device_id in cmp_device_ids:
142143
rows = []
143144
plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}}
144145

@@ -284,8 +285,21 @@ def extract_value(summary):
284285
if len(rows) == 0:
285286
continue
286287

287-
device = find_device_by_id(device_id)
288-
print("## [%d] %s\n" % (device["id"], device["name"]))
288+
cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices)
289+
ref_device = find_device_by_id(ref_state["device"], all_ref_devices)
290+
291+
if cmp_device == ref_device:
292+
print("## [%d] %s\n" % (cmp_device["id"], cmp_device["name"]))
293+
else:
294+
print(
295+
"## [%d] %s vs. [%d] %s\n"
296+
% (
297+
ref_device["id"],
298+
ref_device["name"],
299+
cmp_device["id"],
300+
cmp_device["name"],
301+
)
302+
)
289303
# colalign and github format require tabulate 0.8.3
290304
if tabulate_version >= (0, 8, 3):
291305
print(
@@ -303,7 +317,7 @@ def extract_value(summary):
303317
plt.yscale("log")
304318
plt.xlabel(plot)
305319
plt.ylabel("time [s]")
306-
plt.title(device["name"])
320+
plt.title(cmp_device["name"])
307321

308322
def plot_line(key, shape, label):
309323
x = [float(x) for x in plot_data[key][axis].keys()]
@@ -328,6 +342,13 @@ def plot_line(key, shape, label):
328342
def main():
329343
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
330344
parser = argparse.ArgumentParser(prog="nvbench_compare", usage=help_text)
345+
parser.add_argument(
346+
"--ignore-devices",
347+
dest="ignore_devices",
348+
default=False,
349+
help="Ignore differences in the device sections and compare anyway",
350+
action="store_true",
351+
)
331352
parser.add_argument(
332353
"--threshold-diff",
333354
type=float,
@@ -369,17 +390,24 @@ def main():
369390
ref_root = reader.read_file(ref)
370391
cmp_root = reader.read_file(comp)
371392

372-
global all_devices
373-
all_devices = cmp_root["devices"]
393+
global all_ref_devices
394+
global all_cmp_devices
395+
all_ref_devices = ref_root["devices"]
396+
all_cmp_devices = cmp_root["devices"]
374397

375398
if ref_root["devices"] != cmp_root["devices"]:
376-
print("Device sections do not match.")
399+
print(
400+
(Fore.YELLOW if args.ignore_devices else Fore.RED)
401+
+ "Device sections do not match:"
402+
+ Fore.RESET
403+
)
377404
print(
378405
jsondiff.diff(
379406
ref_root["devices"], cmp_root["devices"], syntax="symmetric"
380407
)
381408
)
382-
sys.exit(1)
409+
if not args.ignore_devices:
410+
sys.exit(1)
383411

384412
compare_benches(
385413
ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot

0 commit comments

Comments
 (0)