2020from typing import List
2121import requests
2222import time
23- from tqdm import tqdm
2423
2524def _clean_json_block (item : str ) -> str :
2625 return item .strip ().removeprefix ("```json" ).removeprefix ("```" ).removesuffix ("```" ).strip ()
@@ -84,7 +83,7 @@ def retrieve_docs(self, query: str, original_docs: List[str], now_hop: int, topk
8483 response = requests .post (
8584 self .retriever_url ,
8685 json = {"query" : query , "topk" : topk + now_hop },
87- timeout = 60
86+ timeout = 1200
8887 )
8988 data = response .json ()
9089 all_docs = [doc .get ("contents" , "" ) for doc in data .get ("results" , [])]
@@ -97,37 +96,15 @@ def retrieve_docs(self, query: str, original_docs: List[str], now_hop: int, topk
9796 filter_docs = [d for d in unique_docs if "(number)" not in d and "(decade)" not in d ]
9897 return filter_docs [:topk ]
9998
100- def _safe_json_load (self , text : str , stage : str ):
101- """
102- Safely load JSON from LLM output.
103- Return None if parsing fails.
104- """
105- if not text or not text .strip ():
106- self .logger .warning (f"[{ stage } ] Empty LLM output" )
107- return None
108-
109- cleaned = _clean_json_block (text )
110- if not cleaned or not cleaned .strip ():
111- self .logger .warning (f"[{ stage } ] Empty cleaned JSON" )
112- return None
113-
114- try :
115- return json .loads (cleaned )
116- except json .JSONDecodeError as e :
117- self .logger .warning (
118- f"[{ stage } ] JSON decode failed: { e } | content: { cleaned [:200 ]} "
119- )
120- return None
121-
12299 def run (
123100 self ,
124101 storage : DataFlowStorage ,
125102 input_hop : int ,
126103 input_question_key : str = "question" ,
127104 input_answer_key : str = "answer" ,
128105 input_doc_key : str = "doc" ,
129- input_topk : int = 3 ,
130- input_per_doc_qa : int = 1 ,
106+ input_topk : int = 5 ,
107+ input_per_doc_qa : int = 5 ,
131108 ):
132109 self .input_hop = input_hop
133110 self .input_question_key = input_question_key
@@ -153,7 +130,7 @@ def run(
153130 # ---- Phase 1: build atomic prompts for ALL rows/docs and call model in batch ----
154131 atomic_prompts = []
155132 atomic_meta = []
156- for i , current_data in tqdm ( enumerate (rows ), total = len ( rows ), desc = "Generating atomic QA prompts" ):
133+ for i , current_data in enumerate (rows ):
157134 hop_num = input_hop
158135 hop_key = f"hop_{ hop_num } "
159136 now_question = current_data [hop_key ][input_question_key ]
@@ -182,10 +159,8 @@ def run(
182159 atomic_outputs = self .llm_serving .generate_from_input (atomic_prompts )
183160 parsed_atomic = []
184161 for out in atomic_outputs :
185- obj = self ._safe_json_load (out , stage = "atomic_qa" )
186- if obj is None :
187- continue
188- parsed_atomic .append (obj )
162+ cleaned = _clean_json_block (out )
163+ parsed_atomic .append (json .loads (cleaned ))
189164
190165 # ---- Phase 2: build merge prompts for ALL atomic qas and call model in batch ----
191166 merge_prompts = []
@@ -226,12 +201,8 @@ def run(
226201 merge_outputs = self .llm_serving .generate_from_input (merge_prompts )
227202 parsed_merges = []
228203 for out in merge_outputs :
229- obj = self ._safe_json_load (out , stage = "merge_qa" )
230- if obj is None :
231- continue
232- parsed_merges .append (obj )
233-
234- print ("parsed_merges: " , len (parsed_merges ))
204+ cleaned = _clean_json_block (out )
205+ parsed_merges .append (json .loads (cleaned ))
235206
236207 # ---- Phase 3: filter merges and build refine prompts ----
237208 refine_prompts = []
@@ -266,14 +237,9 @@ def run(
266237
267238 refine_outputs = self .llm_serving .generate_from_input (refine_prompts )
268239 parsed_refines = []
269- valid_refine_meta = []
270- for out , meta in zip (refine_outputs , refine_meta ):
271- obj = self ._safe_json_load (out , stage = "refine_answer" )
272- if obj is None :
273- continue
274- parsed_refines .append (obj )
275- valid_refine_meta .append (meta )
276- refine_meta = valid_refine_meta
240+ for out in refine_outputs :
241+ cleaned = _clean_json_block (out )
242+ parsed_refines .append (json .loads (cleaned ))
277243
278244 # ---- Phase 4: build optional prompts for ALL refines and batch call ----
279245 opt_prompts = []
@@ -302,14 +268,9 @@ def run(
302268
303269 opt_outputs = self .llm_serving .generate_from_input (opt_prompts )
304270 parsed_opts = []
305- valid_opt_meta = []
306- for out , meta in zip (opt_outputs , opt_meta ):
307- obj = self ._safe_json_load (out , stage = "optional_answer" )
308- if obj is None :
309- continue
310- parsed_opts .append (obj )
311- valid_opt_meta .append (meta )
312- opt_meta = valid_opt_meta
271+ for out in opt_outputs :
272+ cleaned = _clean_json_block (out )
273+ parsed_opts .append (json .loads (cleaned ))
313274
314275 # ---- Phase 5: assemble new_rows from opt results and corresponding meta ----
315276 new_rows = []
0 commit comments