1010- [ Overview] ( #overview )
1111- [ Architecture Diagram] ( #architecture-diagram )
1212- [ Key Components] ( #key-components )
13+ - [ EDA Agent] ( #eda-agent )
1314 - [ Schema Resolver Agent] ( #schema-resolver-agent )
15+ - [ Dataset Splitter Agent] ( #dataset-splitter-agent )
1416 - [ Manager Agent (Orchestrator)] ( #manager-agent-orchestrator )
1517 - [ ML Research Scientist Agent] ( #ml-research-scientist-agent )
1618 - [ ML Engineer Agent] ( #ml-engineer-agent )
@@ -45,10 +47,12 @@ graph TD
4547 SchemaResolver --> |"Schemas"| Orchestrator
4648 Model --> |build| Orchestrator["Manager Agent"]
4749 Orchestrator --> |"Plan Task"| MLS["ML Researcher"]
50+ Orchestrator --> |"Split Task"| DS["Dataset Splitter"]
4851 Orchestrator --> |"Implement Task"| MLE["ML Engineer"]
4952 Orchestrator --> |"Inference Task"| MLOPS["ML Operations"]
5053
5154 MLS --> |"Solution Plans"| Orchestrator
55+ DS --> |"Split Datasets"| Orchestrator
5256 MLE --> |"Training Code"| Orchestrator
5357 MLOPS --> |"Inference Code"| Orchestrator
5458 end
@@ -82,6 +86,8 @@ graph TD
8286 SchemaResolver <--> EdaReports
8387 EDA <--> Registry
8488 EDA <--> Tools
89+ DS <--> Registry
90+ DS <--> Tools
8591
8692 Orchestrator --> Result([Trained Model])
8793 Result --> Model
@@ -115,15 +121,15 @@ eda_agent = EdaAgent(
115121### Schema Resolver Agent
116122
117123** Class** : ` SchemaResolverAgent `
118- ** Type** : ` ToolCallingAgent `
124+ ** Type** : ` CodeAgent `
119125
120126The Schema Resolver Agent infers input and output schemas from intent and dataset samples:
121127
122128``` python
123129schema_resolver = SchemaResolverAgent(
124- model_id = provider_config.tool_provider,
130+ model_id = provider_config.orchestrator_provider,
131+ verbose = verbose,
125132 chain_of_thought_callable = cot_callable,
126- verbosity_level = 1 ,
127133)
128134```
129135
@@ -133,6 +139,27 @@ schema_resolver = SchemaResolverAgent(
133139- Registering schemas with the Object Registry
134140- Providing automatic schema resolution when schemas aren't specified
135141
142+ ### Dataset Splitter Agent
143+
144+ ** Class** : ` DatasetSplitterAgent `
145+ ** Type** : ` CodeAgent `
146+
147+ The Dataset Splitter Agent handles the intelligent partitioning of datasets:
148+
149+ ``` python
150+ dataset_splitter_agent = DatasetSplitterAgent(
151+ model_id = orchestrator_model_id,
152+ verbose = verbose,
153+ chain_of_thought_callable = chain_of_thought_callable,
154+ )
155+ ```
156+
157+ ** Responsibilities** :
158+ - Analyzing datasets to determine appropriate splitting strategies
159+ - Handling specialized splitting needs (time-series, imbalanced data)
160+ - Creating train/validation/test splits with proper stratification
161+ - Registering split datasets in the Object Registry for downstream use
162+
136163### Manager Agent (Orchestrator)
137164
138165** Class** : ` PlexeAgent.manager_agent `
@@ -147,13 +174,10 @@ self.manager_agent = CodeAgent(
147174 tools = [
148175 get_select_target_metric(self .tool_model_id),
149176 get_review_finalised_model(self .tool_model_id),
150- split_datasets,
151177 create_input_sample,
152- get_dataset_preview,
153- get_raw_dataset_schema,
154178 format_final_orchestrator_agent_response,
155179 ],
156- managed_agents = [self .ml_research_agent, self .mle_agent, self .mlops_engineer],
180+ managed_agents = [self .ml_research_agent, self .dataset_splitter_agent, self . mle_agent, self .mlops_engineer],
157181 add_base_tools = False ,
158182 verbosity_level = self .orchestrator_verbosity,
159183 additional_authorized_imports = config.code_generation.authorized_agent_imports,
@@ -188,9 +212,10 @@ self.ml_research_agent = ToolCallingAgent(
188212 " - input schema for the model"
189213 " - output schema for the model"
190214 " - the name and comparison method of the metric to optimise"
215+ " - the name of the dataset to use for training"
191216 ),
192217 model = LiteLLMModel(model_id = self .ml_researcher_model_id),
193- tools = [get_dataset_preview],
218+ tools = [get_dataset_preview, get_eda_report ],
194219 add_base_tools = False ,
195220 verbosity_level = self .specialist_verbosity,
196221 prompt_templates = get_prompt_templates(" toolcalling_agent.yaml" , " mls_prompt_templates.yaml" ),
@@ -206,38 +231,19 @@ self.ml_research_agent = ToolCallingAgent(
206231
207232### ML Engineer Agent
208233
209- ** Class** : ` PlexeAgent.mle_agent `
210- ** Type** : ` ToolCallingAgent `
234+ ** Class** : ` ModelTrainerAgent `
235+ ** Type** : ` CodeAgent `
211236
212237This agent handles the implementation and training of models:
213238
214239``` python
215- self .mle_agent = ToolCallingAgent(
216- name = " MLEngineer" ,
217- description = (
218- " Expert ML engineer that implements, trains and validates ML models based on provided plans. "
219- " To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
220- " - the ML task definition (i.e. 'intent')"
221- " - input schema for the model"
222- " - output schema for the model"
223- " - the name and comparison method of the metric to optimise"
224- " - the full solution plan that outlines how to solve this problem"
225- " - the split train/validation dataset names"
226- " - the working directory to use for model execution"
227- ),
228- model = LiteLLMModel(model_id = self .ml_engineer_model_id),
229- tools = [
230- get_generate_training_code(self .tool_model_id),
231- validate_training_code,
232- get_fix_training_code(self .tool_model_id),
233- get_executor_tool(self .distributed),
234- format_final_mle_agent_response,
235- ],
236- add_base_tools = False ,
237- verbosity_level = self .specialist_verbosity,
238- prompt_templates = get_prompt_templates(" toolcalling_agent.yaml" , " mle_prompt_templates.yaml" ),
239- step_callbacks = [self .chain_of_thought_callable],
240- )
240+ self .mle_agent = ModelTrainerAgent(
241+ ml_engineer_model_id = self .ml_engineer_model_id,
242+ tool_model_id = self .tool_model_id,
243+ distributed = self .distributed,
244+ verbose = verbose,
245+ chain_of_thought_callable = self .chain_of_thought_callable,
246+ ).agent
241247```
242248
243249** Responsibilities** :
@@ -258,21 +264,21 @@ This agent focuses on productionizing the model through inference code:
258264self .mlops_engineer = CodeAgent(
259265 name = " MLOperationsEngineer" ,
260266 description = (
261- " Expert ML operations engineer that writes inference code for ML models to be used in production. "
262- " To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
267+ " Expert ML operations engineer that analyzes training code and creates high-quality production-ready "
268+ " inference code for ML models. To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
263269 " - input schema for the model"
264270 " - output schema for the model"
265271 " - the 'training code id' of the training code produced by the MLEngineer agent"
266272 ),
267273 model = LiteLLMModel(model_id = self .ml_ops_engineer_model_id),
268274 tools = [
269- get_generate_inference_code (self .tool_model_id),
275+ get_inference_context_tool (self .tool_model_id),
270276 validate_inference_code,
271- get_fix_inference_code(self .tool_model_id),
272277 format_final_mlops_agent_response,
273278 ],
274279 add_base_tools = False ,
275280 verbosity_level = self .specialist_verbosity,
281+ additional_authorized_imports = config.code_generation.authorized_agent_imports + [" plexe" , " plexe.*" ],
276282 prompt_templates = get_prompt_templates(" code_agent.yaml" , " mlops_prompt_templates.yaml" ),
277283 planning_interval = 8 ,
278284 step_callbacks = [self .chain_of_thought_callable],
@@ -326,13 +332,13 @@ The system includes specialized tools that agents can use to perform specific ta
326332def get_select_target_metric (model_id : str ) -> Callable:
327333 """ Factory function that returns a tool for selecting appropriate target metrics."""
328334 @tool
329- def select_target_metric (task : str , provider : str ) -> Dict:
335+ def select_target_metric (task : str ) -> Dict:
330336 """ Selects the appropriate target metric to optimise for the given task."""
331337```
332338
333339** Code Generation Tools** :
334340``` python
335- def get_generate_training_code ( model_id : str ) -> Callable:
341+ def get_training_code_generation_tool ( llm_to_use : str ) -> Callable:
336342 """ Factory function that returns a tool for generating training code."""
337343 @tool
338344 def generate_training_code (
@@ -342,14 +348,16 @@ def get_generate_training_code(model_id: str) -> Callable:
342348 """ Generates training code based on the solution plan."""
343349```
344350
345- ** Validation Tools** :
351+ ** Dataset Tools** :
346352``` python
347353@tool
348- def validate_inference_code (
349- inference_code : str , model_artifact_names : List[str ],
350- input_schema : Dict[str , str ], output_schema : Dict[str , str ],
351- ) -> Dict:
352- """ Validates inference code for syntax, security, and correctness."""
354+ def register_split_datasets (
355+ dataset_names : List[str ],
356+ train_datasets : List[pd.DataFrame],
357+ validation_datasets : List[pd.DataFrame],
358+ test_datasets : List[pd.DataFrame],
359+ ) -> Dict[str , List[str ]]:
360+ """ Register train, validation, and test datasets in the object registry."""
353361```
354362
355363** Execution Tools** :
@@ -380,23 +388,28 @@ The multi-agent workflow follows these key steps:
380388 - Schemas are registered in the Object Registry
381389
3823904 . ** Orchestration** :
383- - Manager Agent selects metrics and splits datasets
391+ - Manager Agent selects metrics and coordinates the process
384392 - Manager Agent initializes the solution planning phase
385393
386- 4 . ** Solution Planning** :
394+ 5 . ** Dataset Splitting** :
395+ - Dataset Splitter Agent analyzes data characteristics
396+ - Creates appropriate train/validation/test splits
397+ - Registers split datasets in the Object Registry
398+
399+ 6 . ** Solution Planning** :
387400 - ML Research Scientist proposes solution approaches
388401 - Manager Agent evaluates and selects approaches
389402
390- 5 . ** Model Implementation** :
403+ 7 . ** Model Implementation** :
391404 - ML Engineer generates and executes training code
392405 - Model artifacts are registered in the Object Registry
393406 - Process may iterate through multiple approaches
394407
395- 6 . ** Inference Code Generation** :
408+ 8 . ** Inference Code Generation** :
396409 - ML Operations Engineer generates compatible inference code
397410 - Code is validated with sample inputs
398411
399- 7 . ** Finalization** :
412+ 9 . ** Finalization** :
400413 - Manager Agent reviews and finalizes the model
401414 - All artifacts and code are collected
402415 - Completed model is returned to the user
@@ -421,7 +434,6 @@ result = self.manager_agent.run(
421434 " working_dir" : self .working_dir,
422435 " input_schema" : format_schema(self .input_schema),
423436 " output_schema" : format_schema(self .output_schema),
424- " provider" : provider_config.tool_provider,
425437 " max_iterations" : max_iterations,
426438 " timeout" : timeout,
427439 " run_timeout" : run_timeout,
@@ -440,7 +452,7 @@ class ProcessExecutor(Executor):
440452 def run (self ) -> ExecutionResult:
441453 """ Execute code in a subprocess and return results."""
442454 process = subprocess.Popen(
443- [sys.executable, str (code_file)],
455+ [sys.executable, str (self . code_file)],
444456 stdout = subprocess.PIPE ,
445457 stderr = subprocess.PIPE ,
446458 cwd = str (self .working_dir),
@@ -553,6 +565,8 @@ class CustomModelValidator(Validator):
553565- [ Model Class Definition] ( /plexe/models.py )
554566- [ EdaAgent Definition] ( /plexe/agents/dataset_analyser.py )
555567- [ SchemaResolverAgent Definition] ( /plexe/agents/schema_resolver.py )
568+ - [ DatasetSplitterAgent Definition] ( /plexe/agents/dataset_splitter.py )
569+ - [ ModelTrainerAgent Definition] ( /plexe/agents/model_trainer.py )
556570- [ Tool Definitions] ( /plexe/internal/models/tools/ )
557571- [ Dataset Tools] ( /plexe/internal/models/tools/datasets.py )
558572- [ Executor Implementation] ( /plexe/internal/models/execution/ )
0 commit comments