Skip to content

Commit 36265ca

Browse files
committed
Revert "modify code and add flash_rag_serving"
This reverts commit 734fe06.
1 parent 5301df5 commit 36265ca

File tree

4 files changed

+95
-266
lines changed

4 files changed

+95
-266
lines changed

dataflow/operators/agentic_rag/eval/agenticrag_multihop_sample_evaluator.py

Lines changed: 13 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,6 @@ def __init__(
4747
def get_desc(lang: str = "zh"):
4848
return "MultiHopRAG 验证算子:对 multi_hop_data 中每个候选进行多步验证并返回合格的数据。" if lang == "zh" else "Verifier for MultiHop RAG."
4949

50-
def _safe_json_load(self, text: str, stage: str):
51-
"""
52-
Safely load JSON from LLM output.
53-
Return None if parsing fails.
54-
"""
55-
if not text or not text.strip():
56-
self.logger.warning(f"[{stage}] Empty LLM output")
57-
return None
58-
59-
cleaned = _clean_json_block(text)
60-
if not cleaned or not cleaned.strip():
61-
self.logger.warning(f"[{stage}] Empty cleaned JSON")
62-
return None
63-
64-
try:
65-
return json.loads(cleaned)
66-
except json.JSONDecodeError as e:
67-
self.logger.warning(
68-
f"[{stage}] JSON decode failed: {e} | content: {cleaned[:200]}"
69-
)
70-
return None
71-
7250
def run(self, storage: DataFlowStorage):
7351
df = storage.read("dataframe")
7452

@@ -132,15 +110,9 @@ def run(self, storage: DataFlowStorage):
132110

133111
check_outputs = self.llm_serving.generate_from_input(check_prompts) if check_prompts else []
134112
parsed_checks = []
135-
valid_check_meta = []
136-
137-
for out, meta in zip(check_outputs, check_meta):
138-
check_obj = self._safe_json_load(out, stage="phase1_check")
139-
if check_obj is None:
140-
continue
141-
parsed_checks.append(check_obj)
142-
valid_check_meta.append(meta)
143-
check_meta = valid_check_meta
113+
for out in check_outputs:
114+
cleaned = _clean_json_block(out)
115+
parsed_checks.append(json.loads(cleaned))
144116

145117
passed_after_check = []
146118
for idx, check_result in enumerate(parsed_checks):
@@ -168,6 +140,7 @@ def run(self, storage: DataFlowStorage):
168140
# ---- Phase 2: reasoning prompts (one per passed row) ----
169141
reasoning_prompts = []
170142
reasoning_meta = []
143+
print("passed_after_check: ", len(passed_after_check))
171144
for item in passed_after_check:
172145
qa_type = item["qa_type"]
173146
final_question = item["final_question"]
@@ -195,16 +168,9 @@ def run(self, storage: DataFlowStorage):
195168

196169
judge_outputs = self.llm_serving.generate_from_input(judge_prompts) if judge_prompts else []
197170
parsed_judges = []
198-
valid_judge_meta = []
199-
200-
for out, meta in zip(judge_outputs, judge_meta):
201-
judge_obj = self._safe_json_load(out, stage="phase3_reasoning_judge")
202-
if judge_obj is None:
203-
continue
204-
parsed_judges.append(judge_obj)
205-
valid_judge_meta.append(meta)
206-
207-
judge_meta = valid_judge_meta
171+
for out in judge_outputs:
172+
cleaned = _clean_json_block(out)
173+
parsed_judges.append(json.loads(cleaned))
208174

209175
passed_after_reasoning = []
210176
for idx, judge_res in enumerate(parsed_judges):
@@ -261,16 +227,9 @@ def run(self, storage: DataFlowStorage):
261227

262228
single_judge_outputs = self.llm_serving.generate_from_input(single_judge_prompts) if single_judge_prompts else []
263229
parsed_single_judges = []
264-
valid_single_judge_meta = []
265-
266-
for out, meta in zip(single_judge_outputs, single_judge_meta):
267-
judge_obj = self._safe_json_load(out, stage="phase5_singlehop_judge")
268-
if judge_obj is None:
269-
continue
270-
parsed_single_judges.append(judge_obj)
271-
valid_single_judge_meta.append(meta)
272-
273-
single_judge_meta = valid_single_judge_meta
230+
for out in single_judge_outputs:
231+
cleaned = _clean_json_block(out)
232+
parsed_single_judges.append(json.loads(cleaned))
274233

275234
row_fail_map = {}
276235
for idx, judge_res in enumerate(parsed_single_judges):
@@ -345,16 +304,9 @@ def run(self, storage: DataFlowStorage):
345304

346305
final_judge_outputs = self.llm_serving.generate_from_input(final_judge_prompts) if final_judge_prompts else []
347306
parsed_final_judges = []
348-
valid_final_judge_meta = []
349-
350-
for out, meta in zip(final_judge_outputs, final_judge_meta):
351-
judge_obj = self._safe_json_load(out, stage="phase7_final_judge")
352-
if judge_obj is None:
353-
continue
354-
parsed_final_judges.append(judge_obj)
355-
valid_final_judge_meta.append(meta)
356-
357-
final_judge_meta = valid_final_judge_meta
307+
for out in final_judge_outputs:
308+
cleaned = _clean_json_block(out)
309+
parsed_final_judges.append(json.loads(cleaned))
358310

359311
verified_rows = []
360312
for idx, judge_res in enumerate(parsed_final_judges):

dataflow/operators/agentic_rag/generate/agenticrag_multihop_qa_generator.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import List
2121
import requests
2222
import time
23-
from tqdm import tqdm
2423

2524
def _clean_json_block(item: str) -> str:
2625
return item.strip().removeprefix("```json").removeprefix("```").removesuffix("```").strip()
@@ -84,7 +83,7 @@ def retrieve_docs(self, query: str, original_docs: List[str], now_hop: int, topk
8483
response = requests.post(
8584
self.retriever_url,
8685
json={"query": query, "topk": topk + now_hop},
87-
timeout=60
86+
timeout=1200
8887
)
8988
data = response.json()
9089
all_docs = [doc.get("contents", "") for doc in data.get("results", [])]
@@ -97,37 +96,15 @@ def retrieve_docs(self, query: str, original_docs: List[str], now_hop: int, topk
9796
filter_docs = [d for d in unique_docs if "(number)" not in d and "(decade)" not in d]
9897
return filter_docs[:topk]
9998

100-
def _safe_json_load(self, text: str, stage: str):
101-
"""
102-
Safely load JSON from LLM output.
103-
Return None if parsing fails.
104-
"""
105-
if not text or not text.strip():
106-
self.logger.warning(f"[{stage}] Empty LLM output")
107-
return None
108-
109-
cleaned = _clean_json_block(text)
110-
if not cleaned or not cleaned.strip():
111-
self.logger.warning(f"[{stage}] Empty cleaned JSON")
112-
return None
113-
114-
try:
115-
return json.loads(cleaned)
116-
except json.JSONDecodeError as e:
117-
self.logger.warning(
118-
f"[{stage}] JSON decode failed: {e} | content: {cleaned[:200]}"
119-
)
120-
return None
121-
12299
def run(
123100
self,
124101
storage: DataFlowStorage,
125102
input_hop: int,
126103
input_question_key: str = "question",
127104
input_answer_key: str = "answer",
128105
input_doc_key: str = "doc",
129-
input_topk: int = 3,
130-
input_per_doc_qa: int = 1,
106+
input_topk: int = 5,
107+
input_per_doc_qa: int = 5,
131108
):
132109
self.input_hop = input_hop
133110
self.input_question_key = input_question_key
@@ -153,7 +130,7 @@ def run(
153130
# ---- Phase 1: build atomic prompts for ALL rows/docs and call model in batch ----
154131
atomic_prompts = []
155132
atomic_meta = []
156-
for i, current_data in tqdm(enumerate(rows), total=len(rows), desc="Generating atomic QA prompts"):
133+
for i, current_data in enumerate(rows):
157134
hop_num = input_hop
158135
hop_key = f"hop_{hop_num}"
159136
now_question = current_data[hop_key][input_question_key]
@@ -182,10 +159,8 @@ def run(
182159
atomic_outputs = self.llm_serving.generate_from_input(atomic_prompts)
183160
parsed_atomic = []
184161
for out in atomic_outputs:
185-
obj = self._safe_json_load(out, stage="atomic_qa")
186-
if obj is None:
187-
continue
188-
parsed_atomic.append(obj)
162+
cleaned = _clean_json_block(out)
163+
parsed_atomic.append(json.loads(cleaned))
189164

190165
# ---- Phase 2: build merge prompts for ALL atomic qas and call model in batch ----
191166
merge_prompts = []
@@ -226,12 +201,8 @@ def run(
226201
merge_outputs = self.llm_serving.generate_from_input(merge_prompts)
227202
parsed_merges = []
228203
for out in merge_outputs:
229-
obj = self._safe_json_load(out, stage="merge_qa")
230-
if obj is None:
231-
continue
232-
parsed_merges.append(obj)
233-
234-
print("parsed_merges: ", len(parsed_merges))
204+
cleaned = _clean_json_block(out)
205+
parsed_merges.append(json.loads(cleaned))
235206

236207
# ---- Phase 3: filter merges and build refine prompts ----
237208
refine_prompts = []
@@ -266,14 +237,9 @@ def run(
266237

267238
refine_outputs = self.llm_serving.generate_from_input(refine_prompts)
268239
parsed_refines = []
269-
valid_refine_meta = []
270-
for out, meta in zip(refine_outputs, refine_meta):
271-
obj = self._safe_json_load(out, stage="refine_answer")
272-
if obj is None:
273-
continue
274-
parsed_refines.append(obj)
275-
valid_refine_meta.append(meta)
276-
refine_meta = valid_refine_meta
240+
for out in refine_outputs:
241+
cleaned = _clean_json_block(out)
242+
parsed_refines.append(json.loads(cleaned))
277243

278244
# ---- Phase 4: build optional prompts for ALL refines and batch call ----
279245
opt_prompts = []
@@ -302,14 +268,9 @@ def run(
302268

303269
opt_outputs = self.llm_serving.generate_from_input(opt_prompts)
304270
parsed_opts = []
305-
valid_opt_meta = []
306-
for out, meta in zip(opt_outputs, opt_meta):
307-
obj = self._safe_json_load(out, stage="optional_answer")
308-
if obj is None:
309-
continue
310-
parsed_opts.append(obj)
311-
valid_opt_meta.append(meta)
312-
opt_meta = valid_opt_meta
271+
for out in opt_outputs:
272+
cleaned = _clean_json_block(out)
273+
parsed_opts.append(json.loads(cleaned))
313274

314275
# ---- Phase 5: assemble new_rows from opt results and corresponding meta ----
315276
new_rows = []

dataflow/serving/flash_rag_serving.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)