66from dataflow .utils .storage import DataFlowStorage
77import pandas as pd
88from dataflow .core import LLMServingABC
9- from dataflow .prompts .general_text import ConsistentQueryPrompt , ConsistentResponsePrompt
10- from dataflow .core .prompt import prompt_restrict
9+ from dataflow .prompts .general_text import ConsistentChatPrompt
10+ from dataflow .core .prompt import DIYPromptABC , prompt_restrict
11+ from typing import Union
1112
1213@prompt_restrict (
13- ConsistentQueryPrompt ,
14- ConsistentResponsePrompt
14+ ConsistentChatPrompt
1515)
1616
1717@OPERATOR_REGISTRY .register ()
1818class ConsistentChatGenerator (OperatorABC ):
19- def __init__ (self , llm_serving : LLMServingABC = None , num_dialogs_per_intent = 20 , num_turns_per_dialog = 6 , temperature = 0.9 ):
19+ def __init__ (self , llm_serving : LLMServingABC = None , num_dialogs_per_intent = 20 , num_turns_per_dialog = 6 , temperature = 0.9 , prompt_template : Union [ ConsistentChatPrompt , DIYPromptABC ] = None ):
2020 self .logger = get_logger ()
2121 self .logger .info (f'Initializing { self .__class__ .__name__ } ...' )
2222 self .llm_serving = llm_serving
2323 self .num_dialogs_per_intent = num_dialogs_per_intent # Based on the topic_dict in the existing prompt, it is recommended to set the value to below 1000 (which can generate 9000 conversation data). Otherwise, it is recommended to add more topic_dict in dataflow.prompts.general_text.ConsistentChatPrompt to increase data richness
2424 self .num_turns_per_dialog = num_turns_per_dialog
2525 self .temperature = temperature
26- self .query_prompt = ConsistentQueryPrompt ()
27- self .response_prompt = ConsistentResponsePrompt ()
26+ self .prompt_template = prompt_template
2827 self .logger .info (f'{ self .__class__ .__name__ } initialized.' )
2928
3029 @staticmethod
@@ -37,6 +36,7 @@ def get_desc(lang: str = "zh"):
3736 "- num_dialogs_per_intent:每个意图生成的对话数量,默认20\n "
3837 "- num_turns_per_dialog:每个对话的轮次数量,默认6\n "
3938 "- temperature:生成温度,控制输出随机性,默认0.9\n "
39+ "- prompt_template:提示词模板对象,用于定义提示结构\n "
4040 "输出参数:\n "
4141 "- 包含category和conversation字段的DataFrame,其中conversation为多轮对话列表"
4242 )
@@ -48,6 +48,7 @@ def get_desc(lang: str = "zh"):
4848 "- num_dialogs_per_intent: Number of dialogs generated per intent, default 20\n "
4949 "- num_turns_per_dialog: Number of turns per dialog, default 6\n "
5050 "- temperature: Sampling temperature for generation, default 0.9\n "
51+ "- prompt_template: Prompt template object, for defining the prompt structure\n "
5152 "Output Parameters:\n "
5253 "- DataFrame containing 'category' and 'conversation' fields, where conversation is a list of multi-turn dialogues"
5354 )
@@ -57,7 +58,7 @@ def get_desc(lang: str = "zh"):
5758 def run (self , storage : DataFlowStorage ):
5859
5960 # Step 1: Generate all queries using LLM
60- all_query_prompts = self .query_prompt .build_prompt (num_dialogs_per_intent = self .num_dialogs_per_intent )
61+ all_query_prompts = self .prompt_template .build_prompt (mode = "query" , num_dialogs_per_intent = self .num_dialogs_per_intent )
6162 # Step 2: Generate queries by calling llm_serving once
6263 self .logger .info ("Generating queries..." )
6364 queries_list = self .llm_serving .generate_from_input (user_inputs = all_query_prompts )
@@ -78,7 +79,7 @@ def run(self, storage: DataFlowStorage):
7879 for queries in valid_queries :
7980 category = queries .get ("category" )
8081 turns = queries .get ("turns" )
81- all_response_prompts .append (self .response_prompt .build_prompt (topic = category , queries = turns ))
82+ all_response_prompts .append (self .prompt_template .build_prompt (mode = "response" , topic = category , queries = turns ))
8283 self .logger .info ("Generating responses..." )
8384 responses_list = self .llm_serving .generate_from_input (user_inputs = all_response_prompts )
8485
0 commit comments