@@ -18,7 +18,8 @@ def version_tuple(v):
1818
1919tabulate_version = version_tuple (tabulate .__version__ )
2020
21- all_devices = []
21+ all_ref_devices = []
22+ all_cmp_devices = []
2223config_count = 0
2324unknown_count = 0
2425failure_count = 0
@@ -32,7 +33,7 @@ def find_matching_bench(needle, haystack):
3233 return None
3334
3435
35- def find_device_by_id (device_id ):
36+ def find_device_by_id (device_id , all_devices ):
3637 for device in all_devices :
3738 if device ["id" ] == device_id :
3839 return device
@@ -113,7 +114,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
113114
114115 print ("# %s\n " % (cmp_bench ["name" ]))
115116
116- device_ids = cmp_bench ["devices" ]
117+ cmp_device_ids = cmp_bench ["devices" ]
117118 axes = cmp_bench ["axes" ]
118119 ref_states = ref_bench ["states" ]
119120 cmp_states = cmp_bench ["states" ]
@@ -138,7 +139,7 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
138139 headers .append ("Status" )
139140 colalign .append ("center" )
140141
141- for device_id in device_ids :
142+ for cmp_device_id in cmp_device_ids :
142143 rows = []
143144 plot_data = {"cmp" : {}, "ref" : {}, "cmp_noise" : {}, "ref_noise" : {}}
144145
@@ -284,8 +285,21 @@ def extract_value(summary):
284285 if len (rows ) == 0 :
285286 continue
286287
287- device = find_device_by_id (device_id )
288- print ("## [%d] %s\n " % (device ["id" ], device ["name" ]))
288+ cmp_device = find_device_by_id (cmp_device_id , all_cmp_devices )
289+ ref_device = find_device_by_id (ref_state ["device" ], all_ref_devices )
290+
291+ if cmp_device == ref_device :
292+ print ("## [%d] %s\n " % (cmp_device ["id" ], cmp_device ["name" ]))
293+ else :
294+ print (
295+ "## [%d] %s vs. [%d] %s\n "
296+ % (
297+ ref_device ["id" ],
298+ ref_device ["name" ],
299+ cmp_device ["id" ],
300+ cmp_device ["name" ],
301+ )
302+ )
289303 # colalign and github format require tabulate 0.8.3
290304 if tabulate_version >= (0 , 8 , 3 ):
291305 print (
@@ -303,7 +317,7 @@ def extract_value(summary):
303317 plt .yscale ("log" )
304318 plt .xlabel (plot )
305319 plt .ylabel ("time [s]" )
306- plt .title (device ["name" ])
320+ plt .title (cmp_device ["name" ])
307321
308322 def plot_line (key , shape , label ):
309323 x = [float (x ) for x in plot_data [key ][axis ].keys ()]
@@ -328,6 +342,13 @@ def plot_line(key, shape, label):
328342def main ():
329343 help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
330344 parser = argparse .ArgumentParser (prog = "nvbench_compare" , usage = help_text )
345+ parser .add_argument (
346+ "--ignore-devices" ,
347+ dest = "ignore_devices" ,
348+ default = False ,
349+ help = "Ignore differences in the device sections and compare anyway" ,
350+ action = "store_true" ,
351+ )
331352 parser .add_argument (
332353 "--threshold-diff" ,
333354 type = float ,
@@ -369,17 +390,24 @@ def main():
369390 ref_root = reader .read_file (ref )
370391 cmp_root = reader .read_file (comp )
371392
372- global all_devices
373- all_devices = cmp_root ["devices" ]
393+ global all_ref_devices
394+ global all_cmp_devices
395+ all_ref_devices = ref_root ["devices" ]
396+ all_cmp_devices = cmp_root ["devices" ]
374397
375398 if ref_root ["devices" ] != cmp_root ["devices" ]:
376- print ("Device sections do not match." )
399+ print (
400+ (Fore .YELLOW if args .ignore_devices else Fore .RED )
401+ + "Device sections do not match:"
402+ + Fore .RESET
403+ )
377404 print (
378405 jsondiff .diff (
379406 ref_root ["devices" ], cmp_root ["devices" ], syntax = "symmetric"
380407 )
381408 )
382- sys .exit (1 )
409+ if not args .ignore_devices :
410+ sys .exit (1 )
383411
384412 compare_benches (
385413 ref_root ["benchmarks" ], cmp_root ["benchmarks" ], args .threshold , args .plot
0 commit comments