-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcode_quality.py
More file actions
864 lines (707 loc) · 30.2 KB
/
code_quality.py
File metadata and controls
864 lines (707 loc) · 30.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
import argparse
import concurrent.futures
import json
import math
import os
import pathlib
import pickle
import random
import re
import time
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Callable, Dict, List
import datasets
import json5
import numpy as np
import pandas as pd
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.json import _parse_json, parse_partial_json
from langchain_openai import ChatOpenAI
from tqdm import tqdm
quality = '''\
As an experienced and professional reverse engineer, you possess the ability to evaluate code generated by two different decompilers in an objective and impartial manner.
I will provide you with the source code that needs evaluation, followed by two separate decompiled versions of that code.
Your task is to assess each decompiler's performance across various aspects.
For some aspects, you will have three scoring options (-1, 0, 1), while for others, only two options (0, 1) will be available. The specific meanings of these scores can be found in the *Evaluation Choices* section.
When evaluating the decompiled code, you should compare it to the original source code, focusing primarily on the these 12 aspects.
## **Readability**
1. **Typecast Issues**
- Example:
- Source code:
```c
if (n < 1 || m < 1)
error_exit("Invalid size");
```
- Decompiled code:
```c
if (n < 1 || m < 1)
error_exit((long long) "Invalid size");
```
- **Explanation**: The decompiler introduces an unnecessary and incorrect `(long long)` cast, adding redundancy and confusing the reader. This additional cast serves no purpose and reduces the clarity of the code, making the code harder to read.
- Evaluation Choices:
- -1: The decompiled code contains an incorrect typecast.
- 0: The decompiled code has an unnecessary typecast.
- 1: The decompiled code does not contain any incorrect or unnecessary typecasts.
2. **Non-idiomatic Literal Representation**
- Example:
- Source code:
```c
strcat(buffer, "}\\n");
```
- Decompiled code:
```c
*( (WORD *) (v3)) = 2685;
```
- **Explanation**: Non-idiomatic representations of literals, such as turning `"\\\\n"` into `2685`, obscure the original meaning and make the logic harder to follow.
- Evaluation Choices:
- 0: The decompiled code has an non-idiomatic representation of literals.
- 1: The decompiled code does not contain any non-idomatic representations of literals.
3. **Obfuscated Control Flow**
- Example:
- Source code:
```c
while (pack->next_object != obj) {
pack = pack->next_object;
}
```
- Decompiled code:
```c
for(i=a2; a1 != (*((_QWORD *) (i + 64))); i = *((_QWORD *) (i + 64)));
```
- **Explanation**: Overly complex pointer dereferencing in loops, such as `(*((_QWORD *) (i + 64)))`. This complicates understanding what was originally a simple `while` loop, diminishes readability and makes it hard to reconstruct the original control flow.
- Evaluation Choices:
- 0: The decompiled code has an obfuscated control flow.
- 1: The decompiled code does not contain any obfuscated control flow.
4. **Use of Decompiler-Specific Macros**
- Example:
- Decompiled code:
```c
LOWWORD(v5)
```
- **Explanation**: The introduction of decompiler-specific macros (e.g., `LOWWORD(v5)`) deviates from standard C, reducing the readability and portability of the decompiled code.
- Evaluation Choices:
- 0: The decompiled code has an use of decompiler-specific macros.
- 1: The decompiled code does not contain any use of decompiler-specific macros.
5. **Incorrect Return Behavior**
- Example:
- Source code:
```c
// No return statement
```
- Decompiled code:
```c
return _readfsqword(0x28u) ^ v3;
```
- **Explanation**: The inclusion of incorrect return statements, such as `return _readfsqword(0x28u) ^ v3`, introduces erroneous behavior that wasn’t part of the original logic.
- Evaluation Choices:
- 0: The decompiled code has an incorrect return behavior.
- 1: The decompiled code does not contain any incorrect return behavior.
## **Helpfulness**
1. **Meaningless Identifier Names**
- Example:
- Source code:
```c
return buffer;
```
- Decompiled code:
```c
return v4;
```
- **Explanation**: Generic identifier names like `v4` instead of meaningful names like `buffer` significantly reduce the helpfulness of the decompiled code.
- Evaluation Choices:
- 0: The decompiled code has meaningless identifier names, like `v4` instead of `buffer`.
- 1: The decompiled code does not contain any semantic-wrong or meaningless identifier names.
2. **Incorrect Identifier Names**
- Example:
- Source code:
```c
int total_count;
```
- Decompiled code:
```c
int error_flag;
```
- **Explanation**: Incorrect identifier names (e.g., `error_flag` instead of `total_count`) make the decompiled code misleading and harder to reason about.
- Evaluation Choices:
- -1: The decompiled code has incorrect identifier names, like `success_flag` instead of `total_count`.
- 0: The decompiled code has confusing identifier names, like `number` instead of `total_count`.
- 1: The decompiled code does not contain any incorrect or confusing identifier names.
3. **Expanded Symbols**
- Example:
- Source code:
```c
sizeof(int *)
```
- Decompiled code:
```c
8
```
- **Explanation**: Replacing `sizeof` expressions with hardcoded numbers like `8` can be misleading and makes the code less readable and less portable.
- Evaluation Choices:
- -1: The decompiled code has misleading symbols, like `0xFFFFFFFF` instead of `sizeof(int *)`.
- 0: The decompiled code has expanded symbols, like `8` instead of `sizeof(int *)`.
- 1: The decompiled code does not contain any expanded symbols.
4. **Overall Function Correctness**
- **Explanation**: This aspect evaluates whether the decompiled code captures the core functionality of the original source code. For example, if the original source code implements an MD5 hashing function, the decompiled code should make this clear, even if the syntax or identifiers are somewhat altered. If the overall logic aligns but is hard to identify, it's less helpful.
- Evaluation Choices:
- -1: The decompiled code significantly deviates from the functionality of the original source code.
- 0: The decompiled code captures some functionality.
- 1: The decompiled code captures the core functionality of the original source code.
5. **Overall Functionality Precision**
- **Explanation**: This goes a step beyond correctness, evaluating how clearly the decompiled code captures the *exact* functionality. If the original source code implements MD5 but the decompiled version suggests a different hash function (such as SHA-256), then the code is not considered precise. Being able to pinpoint specific implementations adds to the helpfulness.
- Evaluation Choices:
- 0: The decompiled code does not capture the exact functionality of the original source code.
- 1: The decompiled code captures the exact functionality of the original source code.
## **Both**
1. **Non-Idiomatic Dereferencing**
- **Example**:
- Source code:
```c
current->next = malloc(sizeof(Node));
current = current->next;
current->x = 0;
current->y = 0;
```
- Decompiled code:
```c
*((_QWORD *)v5 + 8) = malloc(24LL);
v5 = *((_QWORD *)(v5 + 8));
*((_QWORD *)v5) = 0;
```
- **Explanation**: The decompiled code uses cryptic pointer arithmetic and memor layout, such as `((_QWORD *)v5 + 8)`, instead of reflecting the natural usage of structured data and object dereferencing (`current->next`). This obscures the underlying logic of the code, reducing both its readability and helpfulness to a reverse engineer, who now has to decode not just the logic but also the data structure's layout.
- Evaluation Choices:
- 1: The decompiled code does not contain any non-idiomatic dereferencing.
- 0: The decompiled code uses some pointer arithmetic and memory layout, such as `((_QWORD *)v5 + 8)`, instead of reflecting the natural usage of structured data and object dereferencing (`current->next`).
2. **Abuse of Memory Layout**
- **Example**:
- Decompiled code:
```c
(*(void (__stdcall **)(int, _DWORD, _DWORD, _DWORD, _DWORD))(*(_DWORD *)lpD3DDevice_1 + 68))(
lpD3DDevice_1,
0,
0,
0,
0);
```
- **Explanation**: Here, the decompiled code doesn't recover the original function structure, resorting to manual dereferencing and extensive type casting. This leads to over-complicated explicit type manipulations, making it hard to identify what is being invoked without further investigation into the memory layout or the device object itself.
- Evaluation Choices:
- 1: The decompiled code does not exhibit any abuse of memory layout, as it correctly reflects structured data usage and object dereferencing.
- 0: The decompiled code demonstrates abuse of memory layout, as it relies on complex pointer arithmetic and manual dereferencing instead of using straightforward object-oriented access patterns.
You should consider the above points comprehensively, summarize and categorize them to evaluate different decompilers. **Think step by step** and output the evaluation results in a clear and structured way.
'''
evaluation_prompt_1 = quality + '''\
**First**, evaluate the performance of each decompiler compared to the source code for each aspect and represent this with a score:
The scoring options for different aspects vary, so please refer to the evaluation choices above to score each aspect accordingly.
**Then**, give result of which decompiler is better according to the scores for each aspect.
**Finally**, you should output a json file to collect the scores of decompilers and the winner for every criterion, following the format below.
```json
{
"Readability":{
"Typecast Issues": {
"A_score": -1,
"B_score": 1,
"winner": "B"
},
},
"Helpfulness":{
"Meaningless Identifier Names": {
"A_score": 0,
"B_score": 0,
"winner": "Tie"
},
}
"Both":{
"Non-idiomatic dereference": {
"A_score": 1,
"B_score": -1,
"winner": "A"
},
}
}
```
'''
parser = argparse.ArgumentParser()
parser.add_argument('--run', action='store_true')
parser.add_argument('--rate', action='store_true')
parser.add_argument('--model', type=str,
default='Qwen/Qwen2.5-Coder-32B-Instruct')
parser.add_argument('--dataset', type=str, default='./decompiled_ds_all')
parser.add_argument('--output', type=str, default='./code_quality')
parser.add_argument('--calibrate_elo', type=int, default=1141.86)
args = parser.parse_args()
def json5_loads(x, *args, **kwargs):
try:
return json5.loads(x)
except ValueError as e:
raise JSONDecodeError("Expecting value", x, 0) from e
json.loads = json5_loads
def enforce_prefix_parse_json_markdown(
json_string: str, *args, parser: Callable[[str], Any] = parse_partial_json, require_prefix=True,
) -> dict:
try:
return _parse_json(json_string, parser=parser)
except json.JSONDecodeError:
if require_prefix:
pattern = r"```json(.*)```"
else:
pattern = r"```(.*)```"
match = re.search(pattern, json_string, re.DOTALL)
if match is None:
json_str = json_string
else:
json_str = match.group(1)
try:
return _parse_json(json_str, parser=parser)
except json.JSONDecodeError as e:
if require_prefix is True:
return enforce_prefix_parse_json_markdown(json_string, *args, parser, require_prefix=False)
else:
raise e
class EnforcePrefixJsonOutputParser(JsonOutputParser):
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
text = result[0].text
text = text.strip()
if partial:
try:
return enforce_prefix_parse_json_markdown(text)
except JSONDecodeError:
return None
else:
try:
return enforce_prefix_parse_json_markdown(text)
except JSONDecodeError as e:
msg = f"Invalid json output: {text}"
raise OutputParserException(msg, llm_output=text) from e
llm = ChatOpenAI(
model=args.model,
max_completion_tokens=8192,
timeout=60 * 60,
)
json_output_parser = EnforcePrefixJsonOutputParser()
str_output_parser = StrOutputParser()
def parse_generation(text, partial: bool = False):
text = text.strip()
if partial:
try:
return enforce_prefix_parse_json_markdown(text)
except JSONDecodeError:
return None
else:
try:
return enforce_prefix_parse_json_markdown(text)
except JSONDecodeError as e:
msg = f"Invalid json output: {text}"
raise OutputParserException(msg, llm_output=text) from e
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=evaluation_prompt_1),
HumanMessagePromptTemplate.from_template(template='''\
Source Code:
{source_code}
Decompiled Code A:
{decompile_code_a}
Decompiled Code B:
{decompile_code_b}
''')
])
chain = (prompt | llm | {
"raw_output": str_output_parser,
"parsed_output": json_output_parser,
})
decompilers = ["angr", "binja", "dewolf", "ghidra",
"hexrays", "mlm", "deepseek", "qwen", "retdec", 'llm4decompile', 'gpt-4o-mini', 'gpt-4o']
INIT_RATING = 1000
aspects = [
"Typecast Issues",
"Non-idiomatic Literal Representation",
"Obfuscated Control Flow",
"Use of Decompiler-Specific Macros",
"Incorrect Return Behavior",
"Meaningless Identifier Names",
"Incorrect Identifier Names",
"Expanded Symbols",
"Overall Function Correctness",
"Overall Functionality Precision",
"Non-Idiomatic Dereferencing",
"Abuse of Memory Layout"
]
rating = {decompiler: INIT_RATING for decompiler in decompilers}
def format_message(message, role):
return f"<s>{role}\n{message}</s>\n"
def prompt_format_decompile(src, dec_a, dec_b=None) -> str:
prompt_filled = prompt.format(
source_code=src, decompile_code_a=dec_a, decompile_code_b=dec_b)
return prompt_filled
def invoke(src, dec_a, dec_b, metadata):
for i in range(1):
try:
ret = chain.invoke({
'source_code': src,
'decompile_code_a': dec_a,
'decompile_code_b': dec_b,
})
break
except Exception as e:
print(i, e)
ret = None
return {
'ret': ret,
'metadata': metadata,
}
def execute_from_generator(
generator,
fn,
max_workers=8,
parallel=False,
):
futures = {}
finished = False
pbar = tqdm()
ret_list = []
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
while True:
if futures:
done, not_done = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_COMPLETED)
else:
done, not_done = set(), set()
not_done.add(None)
available_workers = max(
0, max_workers * 2 - len(not_done))
for _ in range(available_workers):
try:
task = next(generator)
if parallel:
future = executor.submit(fn, **task)
futures[future] = task
else:
ret = fn(**task)
ret_list.append(ret)
except StopIteration:
finished = True
break
for future in done:
task = futures[future]
task_exception = future.exception()
pbar.update()
if task_exception is not None:
print(f'Task failed: {task_exception}')
else:
result = future.result()
# print(f"---Result: {result}")
yield result
del futures[future]
pbar.set_description(
f"Remaining Tasks: {len(not_done)}, Finished: {finished}") # average 1min+ per item
if len(not_done) == 0 and finished:
break
if available_workers == 0 and len(done) == 0:
time.sleep(1)
def choose_by_elo(rating: Dict[str, int], model_a):
selection_probs = []
rating_diff_list = []
model_list = [model for model in rating.keys() if model != model_a]
for model_b in model_list:
rating_diff = abs(rating[model_a] - rating[model_b])
rating_diff_list.append(rating_diff)
rating_diff_min = min(rating_diff_list) + 1e-6
for rating_diff in rating_diff_list:
selection_prob = 1 / (1 + rating_diff/rating_diff_min)
selection_probs.append(selection_prob)
total_prob = sum(selection_probs)
normalized_probs = [prob / total_prob for prob in selection_probs]
selected_b = np.random.choice(len(model_list), p=normalized_probs)
return model_list[selected_b]
def get_tasks(df: pd.DataFrame, rating: Dict[str, int]):
decompilers = list(rating.keys())
for idx, row in df.iterrows():
dec_a: str = np.random.choice(decompilers)
source_code = row['func']
dec_b: str = choose_by_elo(rating, dec_a)
if np.random.random() < 0.5:
dec_a, dec_b = dec_b, dec_a
dec_a_res = row[dec_a].item()
dec_b_res = row[dec_b].item()
if not dec_a_res or not dec_b_res or dec_a_res == 'none' or dec_b_res == 'none':
continue
yield {
"src": source_code,
"dec_a": dec_a_res,
"dec_b": dec_b_res,
"metadata": {
"idx": idx,
"a": dec_a,
"b": dec_b,
}
}
def run(df: pd.DataFrame, rating: Dict[str, int], max_workers: int):
result = []
print("="*15)
for ret in execute_from_generator(
get_tasks(df, rating),
invoke,
max_workers=max_workers,
parallel=True,
):
result.append(ret)
print(f"result: {len(result)}")
return result
def main():
save_dir = args.output + '/all_chunks'
if not pathlib.Path(save_dir).exists():
pathlib.Path(save_dir).mkdir(parents=True)
pkl_files = [f for f in os.listdir(save_dir) if f.endswith('.pkl')]
for pkl_file in pkl_files:
try:
with open(os.path.join(save_dir, pkl_file), 'rb') as file:
data = pickle.load(file)
if len(data) == 0:
print(f"Removing empty file: {pkl_file}")
os.remove(os.path.join(save_dir, pkl_file))
idx = int(pkl_file.split('_')[-1].split('.')[0])
result_path = f'ratings_{idx}.json'
if os.path.exists(os.path.join(save_dir, result_path)):
os.remove(os.path.join(save_dir, result_path))
except Exception as e:
print(f"Error loading {pkl_file}: {e}")
def compute_online_elo(battles, calibrate_model, K=4, SCALE=400, BASE=10):
for model_a, model_b, winner in battles:
ra = rating[model_a]
rb = rating[model_b]
ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
if winner == "a" or winner == "A":
sa = 1
elif winner == "b" or winner == "B":
sa = 0
elif winner == "Tie" or winner == "tie (bothbad)" or winner == "tie":
sa = 0.5
else:
sa = 0.5
rating[model_a] += K * (sa - ea)
rating[model_b] += K * (1 - sa - eb)
delta = (800-rating[calibrate_model])
model_a_set = set([model_a for model_a, model_b, winner in battles])
for model in model_a_set:
rating[model] += delta
return rating
calibrate_model = 'ghidra'
INIT_RATING = 1000
rating: Dict[str, int] = defaultdict(lambda: INIT_RATING)
for model in decompilers:
rating[model] = INIT_RATING
needed_decompiler_ds = datasets.load_from_disk(args.dataset)
assert isinstance(needed_decompiler_ds, datasets.Dataset)
df = needed_decompiler_ds.to_pandas()
assert isinstance(df, pd.DataFrame)
left_idx = 0
right_idx = len(df)
result_list = os.listdir(save_dir)
result_list = [
result for result in result_list if result.startswith('ratings_')]
result_list = sorted(result_list, key=lambda x: int(
x.split('_')[-1].split('.')[0]))
chunk_size = 100
chunks = [df[i:i + chunk_size] for i in range(0, df.shape[0], chunk_size)]
random.shuffle(chunks)
chunk_idx = 0
for chunk in chunks:
with open(f'{save_dir}/ratings_{chunk_idx}.json', 'w') as file:
json.dump(rating, file, indent=2)
save_path = f"{save_dir}/chunk_{left_idx}_{chunk_size}_{chunk_idx}.pkl"
chunk_idx += 1
if os.path.exists(save_path):
continue
ret = run(chunk, rating, chunk_size//2)
battles = []
for _ret in ret:
try:
eval_ret = _ret['ret']['parsed_output']
idx = _ret['metadata']['idx']
dec_a = _ret['metadata']['a']
dec_b = _ret['metadata']['b']
for k, v in eval_ret.items():
for kk, vv in v.items():
winner = vv['winner']
battles.append((dec_a, dec_b, winner))
except Exception as e:
print(f"no eval_ret: {e}")
continue
print(f"battles: {battles}")
rating = compute_online_elo(battles, calibrate_model)
print(f"chunk_idx: {chunk_idx}, rating: {rating}")
with open(save_path, "wb") as f:
pickle.dump(ret, f)
def compute_elo(save_dir='./'):
def get_bootstrap_result(battles, func_compute_elo, num_round):
rows = []
for i in tqdm(range(num_round), desc="bootstrap"):
try:
rows.append(func_compute_elo(
battles.sample(frac=1.0, replace=True)))
except Exception as e:
print(e)
df = pd.DataFrame(rows)
return df[df.median().sort_values(ascending=False).index]
def compute_mle_elo(
df, SCALE=800, BASE=10, INIT_RATING=1000
):
from sklearn.linear_model import LogisticRegression
ptbl_a_win = pd.pivot_table(
df[df["winner"] == "model_a"],
index="model_a",
columns="model_b",
aggfunc="size",
fill_value=0,
)
if sum(df["winner"].isin(["tie", "tie (bothbad)"])) == 0:
ptbl_tie = pd.DataFrame(0, index=ptbl_a_win.index,
columns=ptbl_a_win.columns)
else:
ptbl_tie = pd.pivot_table(
df[df["winner"].isin(["tie", "tie (bothbad)"])],
index="model_a",
columns="model_b",
aggfunc="size",
fill_value=0,
)
ptbl_tie = ptbl_tie + ptbl_tie.T
ptbl_b_win = pd.pivot_table(
df[df["winner"] == "model_b"],
index="model_a",
columns="model_b",
aggfunc="size",
fill_value=0,
)
ptbl_win = ptbl_a_win * 2 + ptbl_b_win.T * 2 + ptbl_tie
models = pd.Series(np.arange(len(ptbl_win.index)),
index=ptbl_win.index)
p = len(models)
X = np.zeros([p * (p - 1) * 2, p])
Y = np.zeros(p * (p - 1) * 2)
cur_row = 0
sample_weights = []
for m_a in ptbl_win.index:
for m_b in ptbl_win.columns:
if m_a == m_b:
continue
# if nan skip
if math.isnan(ptbl_win.loc[m_a, m_b]) or math.isnan(ptbl_win.loc[m_b, m_a]):
continue
X[cur_row, models[m_a]] = +math.log(BASE)
X[cur_row, models[m_b]] = -math.log(BASE)
Y[cur_row] = 1.0
sample_weights.append(ptbl_win.loc[m_a, m_b])
X[cur_row + 1, models[m_a]] = math.log(BASE)
X[cur_row + 1, models[m_b]] = -math.log(BASE)
Y[cur_row + 1] = 0.0
sample_weights.append(ptbl_win.loc[m_b, m_a])
cur_row += 2
X = X[:cur_row]
Y = Y[:cur_row]
lr = LogisticRegression(fit_intercept=False, penalty=None, tol=1e-6)
lr.fit(X, Y, sample_weight=sample_weights)
elo_scores = SCALE * lr.coef_[0] + INIT_RATING
if calibrate_model in models.index:
elo_scores += args.calibrate_elo - \
elo_scores[models[calibrate_model]]
return pd.Series(elo_scores, index=models.index).sort_values(ascending=False)
BOOTSTRAP_ROUNDS = 100 # 1000
calibrate_model = 'ghidra'
chunk_dir = f'{save_dir}/all_chunks'
pkl_list = os.listdir(chunk_dir)
df_list = []
os.makedirs(f"{save_dir}/all_scores", exist_ok=True)
print(f"pkl_list: {len(pkl_list)}")
for pkl in pkl_list:
if not pkl.endswith("pkl"):
continue
print(f"pkl: {pkl}")
with open(f"{save_dir}/{pkl}", "rb") as f:
ret = pickle.load(f)
for battle in ret:
metadata = battle['metadata']
eval_ret = battle['ret']
if not eval_ret:
continue
parsed_output = eval_ret['parsed_output']
if not parsed_output:
continue
model_a = metadata['a']
model_b = metadata['b']
if type(parsed_output) is not dict:
print(f"parsed_output: {parsed_output}")
continue
for k, v in parsed_output.items():
try:
for kk, vv in v.items():
if 'winner' not in vv:
print(vv.keys())
continue
winner_value = vv['winner']
if not winner_value:
winner = 'tie'
elif winner_value.lower() == 'a':
winner = 'model_a'
elif winner_value.lower() == 'b':
winner = 'model_b'
else:
winner = 'tie'
df_list.append({
'model_a': model_a,
'model_b': model_b,
"ds_idx": metadata['idx'],
'aspect': kk,
"winner": winner,
})
except Exception as e:
print(f"parse output error: {e}")
continue
print(f"df_list: {len(df_list)}")
df = pd.DataFrame(df_list)
bootstrap_elo_lu = get_bootstrap_result(
df, compute_mle_elo, BOOTSTRAP_ROUNDS)
bootstrap_elo_lu.to_csv(f"{save_dir}/all_scores/elo_scores.csv")
global aspects
for aspect in aspects:
df_aspect = df[df['aspect'] == aspect]
bootstrap_elo_lu_aspect = get_bootstrap_result(
df_aspect, compute_mle_elo, BOOTSTRAP_ROUNDS)
bootstrap_elo_lu_aspect.to_csv(
f"{save_dir}/all_scores/elo_scores_{aspect}.csv")
def compute_aspect_scores(save_dir='arena_general'):
global aspects
model_scores = {}
for aspect in aspects:
try:
df = pd.read_csv(f"{save_dir}/all_scores/elo_scores_{aspect}.csv")
df = df.drop(columns=["Unnamed: 0"])
median_scores = df.median()
for model in median_scores.index:
if model not in model_scores:
model_scores[model] = {}
model_scores[model][aspect] = median_scores[model]
except FileNotFoundError:
print(f"No scores found for aspect: {aspect}")
continue
scores_df = pd.DataFrame(model_scores).T
scores_df.to_csv(f"{save_dir}/all_scores/model_aspect_scores.csv")
scores_dict = scores_df.to_dict('index')
with open(f"{save_dir}/all_scores/model_aspect_scores.json", 'w') as f:
json.dump(scores_dict, f, indent=2)
return scores_df
def save_scores(save_dir='./code_quality'):
for aspect in aspects:
compute_elo(save_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
compute_aspect_scores(save_dir)
if __name__ == "__main__":
if args.run:
main()
if args.rate:
save_scores(args.output)