Skip to content

Commit 529a9fe

Browse files
authored
[debug] update consistentchat & condor text operators for prompt_template arguments (#441)
* [debug] update consistentchat & condor text operators for prompt_template arguments * [debug] add DIYPromptABC in type definition
1 parent af3a3b3 commit 529a9fe

File tree

9 files changed

+166
-128
lines changed

9 files changed

+166
-128
lines changed

dataflow/operators/conversations/generate/consistent_chat_generator.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,24 @@
66
from dataflow.utils.storage import DataFlowStorage
77
import pandas as pd
88
from dataflow.core import LLMServingABC
9-
from dataflow.prompts.general_text import ConsistentQueryPrompt, ConsistentResponsePrompt
10-
from dataflow.core.prompt import prompt_restrict
9+
from dataflow.prompts.general_text import ConsistentChatPrompt
10+
from dataflow.core.prompt import DIYPromptABC, prompt_restrict
11+
from typing import Union
1112

1213
@prompt_restrict(
13-
ConsistentQueryPrompt,
14-
ConsistentResponsePrompt
14+
ConsistentChatPrompt
1515
)
1616

1717
@OPERATOR_REGISTRY.register()
1818
class ConsistentChatGenerator(OperatorABC):
19-
def __init__(self, llm_serving: LLMServingABC = None, num_dialogs_per_intent = 20, num_turns_per_dialog = 6, temperature = 0.9):
19+
def __init__(self, llm_serving: LLMServingABC = None, num_dialogs_per_intent = 20, num_turns_per_dialog = 6, temperature = 0.9, prompt_template : Union[ConsistentChatPrompt, DIYPromptABC] = None):
2020
self.logger = get_logger()
2121
self.logger.info(f'Initializing {self.__class__.__name__}...')
2222
self.llm_serving = llm_serving
2323
self.num_dialogs_per_intent = num_dialogs_per_intent # Based on the topic_dict in the existing prompt, it is recommended to set the value to below 1000 (which can generate 9000 conversation data). Otherwise, it is recommended to add more topic_dict in dataflow.prompts.general_text.ConsistentChatPrompt to increase data richness
2424
self.num_turns_per_dialog = num_turns_per_dialog
2525
self.temperature = temperature
26-
self.query_prompt = ConsistentQueryPrompt()
27-
self.response_prompt = ConsistentResponsePrompt()
26+
self.prompt_template = prompt_template
2827
self.logger.info(f'{self.__class__.__name__} initialized.')
2928

3029
@staticmethod
@@ -37,6 +36,7 @@ def get_desc(lang: str = "zh"):
3736
"- num_dialogs_per_intent:每个意图生成的对话数量,默认20\n"
3837
"- num_turns_per_dialog:每个对话的轮次数量,默认6\n"
3938
"- temperature:生成温度,控制输出随机性,默认0.9\n"
39+
"- prompt_template:提示词模板对象,用于定义提示结构\n"
4040
"输出参数:\n"
4141
"- 包含category和conversation字段的DataFrame,其中conversation为多轮对话列表"
4242
)
@@ -48,6 +48,7 @@ def get_desc(lang: str = "zh"):
4848
"- num_dialogs_per_intent: Number of dialogs generated per intent, default 20\n"
4949
"- num_turns_per_dialog: Number of turns per dialog, default 6\n"
5050
"- temperature: Sampling temperature for generation, default 0.9\n"
51+
"- prompt_template: Prompt template object, for defining the prompt structure\n"
5152
"Output Parameters:\n"
5253
"- DataFrame containing 'category' and 'conversation' fields, where conversation is a list of multi-turn dialogues"
5354
)
@@ -57,7 +58,7 @@ def get_desc(lang: str = "zh"):
5758
def run(self, storage: DataFlowStorage):
5859

5960
# Step 1: Generate all queries using LLM
60-
all_query_prompts = self.query_prompt.build_prompt(num_dialogs_per_intent=self.num_dialogs_per_intent)
61+
all_query_prompts = self.prompt_template.build_prompt(mode="query", num_dialogs_per_intent=self.num_dialogs_per_intent)
6162
# Step 2: Generate queries by calling llm_serving once
6263
self.logger.info("Generating queries...")
6364
queries_list = self.llm_serving.generate_from_input(user_inputs=all_query_prompts)
@@ -78,7 +79,7 @@ def run(self, storage: DataFlowStorage):
7879
for queries in valid_queries:
7980
category = queries.get("category")
8081
turns = queries.get("turns")
81-
all_response_prompts.append(self.response_prompt.build_prompt(topic=category, queries=turns))
82+
all_response_prompts.append(self.prompt_template.build_prompt(mode="response", topic=category, queries=turns))
8283
self.logger.info("Generating responses...")
8384
responses_list = self.llm_serving.generate_from_input(user_inputs=all_response_prompts)
8485

dataflow/operators/text_sft/generate/condor_generator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@
77
import pandas as pd
88
from dataflow.core import LLMServingABC
99
from dataflow.prompts.general_text import CondorQuestionPrompt
10-
from dataflow.core.prompt import prompt_restrict
10+
from dataflow.core.prompt import DIYPromptABC, prompt_restrict
11+
from typing import Union
1112

1213
@prompt_restrict(
1314
CondorQuestionPrompt
1415
)
1516

1617
@OPERATOR_REGISTRY.register()
1718
class CondorGenerator(OperatorABC):
18-
def __init__(self, llm_serving: LLMServingABC = None, num_samples=15, use_task_diversity=True):
19+
def __init__(self, llm_serving: LLMServingABC = None, num_samples=15, use_task_diversity=True, prompt_template: Union[CondorQuestionPrompt, DIYPromptABC] = None):
1920
# Based on the existing topics, it is recommended to set num_samples below 5000. Otherwise, it is recommended to add topics in dataflow.prompts.general_text.CondorPrompt on your own to increase data richness
2021
self.logger = get_logger()
2122
self.logger.info(f'Initializing {self.__class__.__name__}...')
2223
self.llm_serving = llm_serving
2324
self.num_questions = num_samples // 3 # 每个prompt生成3个难度的问题
24-
self.prompt = CondorQuestionPrompt()
25+
self.prompt = prompt_template
2526
self.use_task_diversity = use_task_diversity # 是否使用任务场景增强多样性
2627
self.logger.info(f'{self.__class__.__name__} initialized.')
2728

@@ -33,6 +34,7 @@ def get_desc(lang: str = "zh"):
3334
"输入参数:\n"
3435
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
3536
"- num_samples:生成样本总数,建议小于5000,默认值为15\n"
37+
"- prompt_template:提示词模板对象,用于定义提示结构\n"
3638
"输出参数:\n"
3739
"- 包含'difficulty'、'instruction'和'output'字段的DataFrame\n"
3840
"- 返回生成的DataFrame用于后续处理"
@@ -44,6 +46,7 @@ def get_desc(lang: str = "zh"):
4446
"Input Parameters:\n"
4547
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
4648
"- num_samples: Total number of samples to generate, recommended to be less than 5000, default is 15\n\n"
49+
"- prompt_template: Prompt template object, for defining the prompt structure\n"
4750
"Output Parameters:\n"
4851
"- DataFrame containing 'difficulty', 'instruction', and 'output' fields\n"
4952
"- Returns generated DataFrame for subsequent processing"

dataflow/operators/text_sft/refine/condor_refiner.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,21 @@
66
from dataflow.utils.storage import DataFlowStorage
77
import pandas as pd
88
from dataflow.core import LLMServingABC
9-
from dataflow.prompts.general_text import CondorCritiquePrompt, CondorRefinePrompt
10-
from dataflow.core.prompt import prompt_restrict
9+
from dataflow.prompts.general_text import CondorRefinePrompt
10+
from dataflow.core.prompt import prompt_restrict, DIYPromptABC
11+
from typing import Union
1112

1213
@prompt_restrict(
13-
CondorCritiquePrompt,
1414
CondorRefinePrompt
1515
)
1616

1717
@OPERATOR_REGISTRY.register()
1818
class CondorRefiner(OperatorABC):
19-
def __init__(self, llm_serving: LLMServingABC = None):
19+
def __init__(self, llm_serving: LLMServingABC = None, prompt_template: Union[CondorRefinePrompt, DIYPromptABC] = None):
2020
self.logger = get_logger()
2121
self.logger.info(f'Initializing {self.__class__.__name__}...')
2222
self.llm_serving = llm_serving
23-
self.critique_prompt = CondorCritiquePrompt() # 创建 CondorPrompt 类的实例
24-
self.refine_prompt = CondorRefinePrompt()
23+
self.prompt_template = prompt_template
2524
self.logger.info(f'{self.__class__.__name__} initialized.')
2625

2726
@staticmethod
@@ -33,6 +32,7 @@ def get_desc(lang: str = "zh"):
3332
"- llm_serving:LLM服务对象,需实现LLMServingABC接口\n"
3433
"- input_instruction_key:输入指令字段名,默认为'instruction'\n"
3534
"- input_output_key:输入回复字段名,默认为'output'\n"
35+
"- prompt_template:提示词模板对象,用于定义提示结构\n"
3636
"输出参数:\n"
3737
"- 包含优化后回复的DataFrame\n"
3838
"- 返回包含优化后回复字段名的列表,用于后续算子引用"
@@ -44,7 +44,8 @@ def get_desc(lang: str = "zh"):
4444
"Input Parameters:\n"
4545
"- llm_serving: LLM serving object implementing LLMServingABC interface\n"
4646
"- input_instruction_key: Field name for input instructions, default is 'instruction'\n"
47-
"- input_output_key: Field name for input responses, default is 'output'\n\n"
47+
"- input_output_key: Field name for input responses, default is 'output'\n"
48+
"- prompt_template: Prompt template object, for defining the prompt structure\n"
4849
"Output Parameters:\n"
4950
"- DataFrame containing refined responses\n"
5051
"- List containing refined response field name for subsequent operator reference"
@@ -56,13 +57,13 @@ def get_desc(lang: str = "zh"):
5657

5758
def generate_critique(self, question, answer):
5859
# 批量生成 Critique
59-
critique_prompts = [self.critique_prompt.build_prompt(q, a) for q, a in zip(question, answer)]
60+
critique_prompts = [self.prompt_template.build_prompt(mode="critique", question=q, answer=a) for q, a in zip(question, answer)]
6061
critique_responses = self.llm_serving.generate_from_input(critique_prompts)
6162
return critique_responses
6263

6364
def generate_refined_answer(self, question, answer, critique):
6465
# 批量生成修改后的答案
65-
refine_prompts = [self.refine_prompt.build_prompt(q, a, c) for q, a, c in zip(question, answer, critique)]
66+
refine_prompts = [self.prompt_template.build_prompt(mode="refine", question=q, answer=a, critique=c) for q, a, c in zip(question, answer, critique)]
6667
refined_answers = self.llm_serving.generate_from_input(refine_prompts)
6768
refined_answers = [answer.replace('[Improved Answer Start]', '').replace('[Improved Answer End]', '').strip() for answer in refined_answers]
6869
return refined_answers

0 commit comments

Comments
 (0)