Skip to content

Commit 574b032

Browse files
feature: various improvements to dataset handling (#119)
* fix: trainer_source.py not cleaned up * fix: handle dataset splitting for chronological data * fix: switch to codeagent for schema resolver * feat: add data analyser agent * feat: add data analyser agent * fix: put eda report as dict in metadata * feat: update multi-agent-system.md * chore: bump to 0.20.0 * fix: misc improvements to dataset analyser * fix: eda agent using wrong prompt template * chore: remove unused prompt template * chore: remove unused plan generation template * fix: emitter agent colors defined incorrectly * feat: make chain of thought summaries follow t/a/o structure * feat: remove combined data generator in favour of simple * fix: strip split suffix from eda report name * fix: give dataset analyser all required imports * feat: enable mlflow tracing * chore: update vulnerable dependencies * fix: allow scipy.* import for dataset analyser * fix: split_datasets to return dataset sizes * chore: remove smote oversampling * chore: clean up dataset generator config * refactor: clean up datasets.py * refactor: clean up data generation async logic * refactor: clean up data generation async logic * feat: add dataset generation example * chore: fix up base data generator interface * fix: column addition not working plus noisy logging * feat: add dataset augmentation example * chore: bump to 0.21.0 * feat: add eda report to model bundle * fix: re-enable i/o schema logging * fix: add pandas.* to dataset analyser imports * fix: give better instructions to schema resolver * fix: remove silly naming from system prompts * refactor: make schema resolver prompt more concise * fix: tell schema resolver to remove useless fields * feat: add tool for dropping null columns * feat: add heuristic for dropping all kinds of bad columns * fix: remove dataset preview from manager * feat: utils for eda report formatting * fix: handle ints in EDA report keys * fix: tell manager to ignore obviously bugged runs * fix: remove data preview from mle agent * refactor: move mle agent to its own file * fix: handle non-serialisable values when saving * fix: rename tool closures for clarity * fix: convert mle agent to codeagent * fix: convert mle agent to codeagent * fix: set mle max steps to 10 to prevent long failure cycles * fix: focus data analysis on actionable insights * fix: mle agent to use tools * feat: add dataset splitter agent * fix: let mlops engineer import plexe * docs: update multi-agent-system.md to match changes * docs: update README.md
1 parent de96588 commit 574b032

37 files changed

+2574
-2085
lines changed

README.md

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ Generate synthetic data or infer schemas automatically:
129129
```python
130130
# Generate synthetic data
131131
dataset = plexe.DatasetGenerator(
132+
description="Example dataset with features and target",
133+
provider="openai/gpt-4o-mini",
132134
schema={"features": str, "target": int}
133135
)
134136
dataset.generate(500) # Generate 500 samples
@@ -180,16 +182,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines. Join our [Discord](https:
180182
## 6. License
181183
[Apache-2.0 License](LICENSE)
182184

183-
## 7. Product Roadmap
184-
185-
- [X] Fine-tuning and transfer learning for small pre-trained models
186-
- [X] Use Pydantic for schemas and split data generation into a separate module
187-
- [X] Plexe self-hosted platform ⭐ (More details coming soon!)
188-
- [X] Lightweight installation option without heavy deep learning dependencies
189-
- [X] Distributed training with Ray on AWS
190-
- [ ] Support for non-tabular data types in model generation
191-
192-
## 8. Citation
185+
## 7. Citation
193186
If you use Plexe in your research, please cite it as follows:
194187

195188
```bibtex

docs/architecture/multi-agent-system.md

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
- [Overview](#overview)
1111
- [Architecture Diagram](#architecture-diagram)
1212
- [Key Components](#key-components)
13+
- [EDA Agent](#eda-agent)
1314
- [Schema Resolver Agent](#schema-resolver-agent)
15+
- [Dataset Splitter Agent](#dataset-splitter-agent)
1416
- [Manager Agent (Orchestrator)](#manager-agent-orchestrator)
1517
- [ML Research Scientist Agent](#ml-research-scientist-agent)
1618
- [ML Engineer Agent](#ml-engineer-agent)
@@ -45,10 +47,12 @@ graph TD
4547
SchemaResolver --> |"Schemas"| Orchestrator
4648
Model --> |build| Orchestrator["Manager Agent"]
4749
Orchestrator --> |"Plan Task"| MLS["ML Researcher"]
50+
Orchestrator --> |"Split Task"| DS["Dataset Splitter"]
4851
Orchestrator --> |"Implement Task"| MLE["ML Engineer"]
4952
Orchestrator --> |"Inference Task"| MLOPS["ML Operations"]
5053
5154
MLS --> |"Solution Plans"| Orchestrator
55+
DS --> |"Split Datasets"| Orchestrator
5256
MLE --> |"Training Code"| Orchestrator
5357
MLOPS --> |"Inference Code"| Orchestrator
5458
end
@@ -82,6 +86,8 @@ graph TD
8286
SchemaResolver <--> EdaReports
8387
EDA <--> Registry
8488
EDA <--> Tools
89+
DS <--> Registry
90+
DS <--> Tools
8591
8692
Orchestrator --> Result([Trained Model])
8793
Result --> Model
@@ -115,15 +121,15 @@ eda_agent = EdaAgent(
115121
### Schema Resolver Agent
116122

117123
**Class**: `SchemaResolverAgent`
118-
**Type**: `ToolCallingAgent`
124+
**Type**: `CodeAgent`
119125

120126
The Schema Resolver Agent infers input and output schemas from intent and dataset samples:
121127

122128
```python
123129
schema_resolver = SchemaResolverAgent(
124-
model_id=provider_config.tool_provider,
130+
model_id=provider_config.orchestrator_provider,
131+
verbose=verbose,
125132
chain_of_thought_callable=cot_callable,
126-
verbosity_level=1,
127133
)
128134
```
129135

@@ -133,6 +139,27 @@ schema_resolver = SchemaResolverAgent(
133139
- Registering schemas with the Object Registry
134140
- Providing automatic schema resolution when schemas aren't specified
135141

142+
### Dataset Splitter Agent
143+
144+
**Class**: `DatasetSplitterAgent`
145+
**Type**: `CodeAgent`
146+
147+
The Dataset Splitter Agent handles the intelligent partitioning of datasets:
148+
149+
```python
150+
dataset_splitter_agent = DatasetSplitterAgent(
151+
model_id=orchestrator_model_id,
152+
verbose=verbose,
153+
chain_of_thought_callable=chain_of_thought_callable,
154+
)
155+
```
156+
157+
**Responsibilities**:
158+
- Analyzing datasets to determine appropriate splitting strategies
159+
- Handling specialized splitting needs (time-series, imbalanced data)
160+
- Creating train/validation/test splits with proper stratification
161+
- Registering split datasets in the Object Registry for downstream use
162+
136163
### Manager Agent (Orchestrator)
137164

138165
**Class**: `PlexeAgent.manager_agent`
@@ -147,13 +174,10 @@ self.manager_agent = CodeAgent(
147174
tools=[
148175
get_select_target_metric(self.tool_model_id),
149176
get_review_finalised_model(self.tool_model_id),
150-
split_datasets,
151177
create_input_sample,
152-
get_dataset_preview,
153-
get_raw_dataset_schema,
154178
format_final_orchestrator_agent_response,
155179
],
156-
managed_agents=[self.ml_research_agent, self.mle_agent, self.mlops_engineer],
180+
managed_agents=[self.ml_research_agent, self.dataset_splitter_agent, self.mle_agent, self.mlops_engineer],
157181
add_base_tools=False,
158182
verbosity_level=self.orchestrator_verbosity,
159183
additional_authorized_imports=config.code_generation.authorized_agent_imports,
@@ -188,9 +212,10 @@ self.ml_research_agent = ToolCallingAgent(
188212
"- input schema for the model"
189213
"- output schema for the model"
190214
"- the name and comparison method of the metric to optimise"
215+
"- the name of the dataset to use for training"
191216
),
192217
model=LiteLLMModel(model_id=self.ml_researcher_model_id),
193-
tools=[get_dataset_preview],
218+
tools=[get_dataset_preview, get_eda_report],
194219
add_base_tools=False,
195220
verbosity_level=self.specialist_verbosity,
196221
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mls_prompt_templates.yaml"),
@@ -206,38 +231,19 @@ self.ml_research_agent = ToolCallingAgent(
206231

207232
### ML Engineer Agent
208233

209-
**Class**: `PlexeAgent.mle_agent`
210-
**Type**: `ToolCallingAgent`
234+
**Class**: `ModelTrainerAgent`
235+
**Type**: `CodeAgent`
211236

212237
This agent handles the implementation and training of models:
213238

214239
```python
215-
self.mle_agent = ToolCallingAgent(
216-
name="MLEngineer",
217-
description=(
218-
"Expert ML engineer that implements, trains and validates ML models based on provided plans. "
219-
"To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
220-
"- the ML task definition (i.e. 'intent')"
221-
"- input schema for the model"
222-
"- output schema for the model"
223-
"- the name and comparison method of the metric to optimise"
224-
"- the full solution plan that outlines how to solve this problem"
225-
"- the split train/validation dataset names"
226-
"- the working directory to use for model execution"
227-
),
228-
model=LiteLLMModel(model_id=self.ml_engineer_model_id),
229-
tools=[
230-
get_generate_training_code(self.tool_model_id),
231-
validate_training_code,
232-
get_fix_training_code(self.tool_model_id),
233-
get_executor_tool(self.distributed),
234-
format_final_mle_agent_response,
235-
],
236-
add_base_tools=False,
237-
verbosity_level=self.specialist_verbosity,
238-
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mle_prompt_templates.yaml"),
239-
step_callbacks=[self.chain_of_thought_callable],
240-
)
240+
self.mle_agent = ModelTrainerAgent(
241+
ml_engineer_model_id=self.ml_engineer_model_id,
242+
tool_model_id=self.tool_model_id,
243+
distributed=self.distributed,
244+
verbose=verbose,
245+
chain_of_thought_callable=self.chain_of_thought_callable,
246+
).agent
241247
```
242248

243249
**Responsibilities**:
@@ -258,21 +264,21 @@ This agent focuses on productionizing the model through inference code:
258264
self.mlops_engineer = CodeAgent(
259265
name="MLOperationsEngineer",
260266
description=(
261-
"Expert ML operations engineer that writes inference code for ML models to be used in production. "
262-
"To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
267+
"Expert ML operations engineer that analyzes training code and creates high-quality production-ready "
268+
"inference code for ML models. To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
263269
"- input schema for the model"
264270
"- output schema for the model"
265271
"- the 'training code id' of the training code produced by the MLEngineer agent"
266272
),
267273
model=LiteLLMModel(model_id=self.ml_ops_engineer_model_id),
268274
tools=[
269-
get_generate_inference_code(self.tool_model_id),
275+
get_inference_context_tool(self.tool_model_id),
270276
validate_inference_code,
271-
get_fix_inference_code(self.tool_model_id),
272277
format_final_mlops_agent_response,
273278
],
274279
add_base_tools=False,
275280
verbosity_level=self.specialist_verbosity,
281+
additional_authorized_imports=config.code_generation.authorized_agent_imports + ["plexe", "plexe.*"],
276282
prompt_templates=get_prompt_templates("code_agent.yaml", "mlops_prompt_templates.yaml"),
277283
planning_interval=8,
278284
step_callbacks=[self.chain_of_thought_callable],
@@ -326,13 +332,13 @@ The system includes specialized tools that agents can use to perform specific ta
326332
def get_select_target_metric(model_id: str) -> Callable:
327333
"""Factory function that returns a tool for selecting appropriate target metrics."""
328334
@tool
329-
def select_target_metric(task: str, provider: str) -> Dict:
335+
def select_target_metric(task: str) -> Dict:
330336
"""Selects the appropriate target metric to optimise for the given task."""
331337
```
332338

333339
**Code Generation Tools**:
334340
```python
335-
def get_generate_training_code(model_id: str) -> Callable:
341+
def get_training_code_generation_tool(llm_to_use: str) -> Callable:
336342
"""Factory function that returns a tool for generating training code."""
337343
@tool
338344
def generate_training_code(
@@ -342,14 +348,16 @@ def get_generate_training_code(model_id: str) -> Callable:
342348
"""Generates training code based on the solution plan."""
343349
```
344350

345-
**Validation Tools**:
351+
**Dataset Tools**:
346352
```python
347353
@tool
348-
def validate_inference_code(
349-
inference_code: str, model_artifact_names: List[str],
350-
input_schema: Dict[str, str], output_schema: Dict[str, str],
351-
) -> Dict:
352-
"""Validates inference code for syntax, security, and correctness."""
354+
def register_split_datasets(
355+
dataset_names: List[str],
356+
train_datasets: List[pd.DataFrame],
357+
validation_datasets: List[pd.DataFrame],
358+
test_datasets: List[pd.DataFrame],
359+
) -> Dict[str, List[str]]:
360+
"""Register train, validation, and test datasets in the object registry."""
353361
```
354362

355363
**Execution Tools**:
@@ -380,23 +388,28 @@ The multi-agent workflow follows these key steps:
380388
- Schemas are registered in the Object Registry
381389

382390
4. **Orchestration**:
383-
- Manager Agent selects metrics and splits datasets
391+
- Manager Agent selects metrics and coordinates the process
384392
- Manager Agent initializes the solution planning phase
385393

386-
4. **Solution Planning**:
394+
5. **Dataset Splitting**:
395+
- Dataset Splitter Agent analyzes data characteristics
396+
- Creates appropriate train/validation/test splits
397+
- Registers split datasets in the Object Registry
398+
399+
6. **Solution Planning**:
387400
- ML Research Scientist proposes solution approaches
388401
- Manager Agent evaluates and selects approaches
389402

390-
5. **Model Implementation**:
403+
7. **Model Implementation**:
391404
- ML Engineer generates and executes training code
392405
- Model artifacts are registered in the Object Registry
393406
- Process may iterate through multiple approaches
394407

395-
6. **Inference Code Generation**:
408+
8. **Inference Code Generation**:
396409
- ML Operations Engineer generates compatible inference code
397410
- Code is validated with sample inputs
398411

399-
7. **Finalization**:
412+
9. **Finalization**:
400413
- Manager Agent reviews and finalizes the model
401414
- All artifacts and code are collected
402415
- Completed model is returned to the user
@@ -421,7 +434,6 @@ result = self.manager_agent.run(
421434
"working_dir": self.working_dir,
422435
"input_schema": format_schema(self.input_schema),
423436
"output_schema": format_schema(self.output_schema),
424-
"provider": provider_config.tool_provider,
425437
"max_iterations": max_iterations,
426438
"timeout": timeout,
427439
"run_timeout": run_timeout,
@@ -440,7 +452,7 @@ class ProcessExecutor(Executor):
440452
def run(self) -> ExecutionResult:
441453
"""Execute code in a subprocess and return results."""
442454
process = subprocess.Popen(
443-
[sys.executable, str(code_file)],
455+
[sys.executable, str(self.code_file)],
444456
stdout=subprocess.PIPE,
445457
stderr=subprocess.PIPE,
446458
cwd=str(self.working_dir),
@@ -553,6 +565,8 @@ class CustomModelValidator(Validator):
553565
- [Model Class Definition](/plexe/models.py)
554566
- [EdaAgent Definition](/plexe/agents/dataset_analyser.py)
555567
- [SchemaResolverAgent Definition](/plexe/agents/schema_resolver.py)
568+
- [DatasetSplitterAgent Definition](/plexe/agents/dataset_splitter.py)
569+
- [ModelTrainerAgent Definition](/plexe/agents/model_trainer.py)
556570
- [Tool Definitions](/plexe/internal/models/tools/)
557571
- [Dataset Tools](/plexe/internal/models/tools/datasets.py)
558572
- [Executor Implementation](/plexe/internal/models/execution/)

examples/dataset_augmentation.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Example demonstrating dataset augmentation with Plexe:
3+
1. Adding a new column to an existing dataset
4+
2. Adding more rows to an existing dataset
5+
"""
6+
7+
from pydantic import BaseModel, Field
8+
9+
from plexe import DatasetGenerator
10+
11+
12+
class PurchaseSchema(BaseModel):
13+
"""Base schema for purchase data."""
14+
15+
product_name: str = Field(description="Name of the purchased product")
16+
category: str = Field(description="Product category")
17+
price: float = Field(description="Purchase price in USD")
18+
customer_id: str = Field(description="Unique customer identifier")
19+
20+
21+
class AugmentedSchema(PurchaseSchema):
22+
"""Augmented schema with product recommendation field."""
23+
24+
recommendation: str = Field(description="Recommended related product")
25+
26+
27+
def main():
28+
# Step 1: Create base dataset (10 purchase records)
29+
base_dataset = DatasetGenerator(
30+
description="E-commerce purchase data with product and customer information",
31+
provider="openai/gpt-4o",
32+
schema=PurchaseSchema,
33+
)
34+
base_dataset.generate(10)
35+
df_base = base_dataset.data
36+
37+
print("Original dataset (10 records):")
38+
print(df_base.head(3))
39+
print(f"Shape: {df_base.shape}")
40+
41+
# Check if we have data before proceeding
42+
if len(df_base) == 0:
43+
print("Failed to generate base dataset. Exiting.")
44+
return
45+
46+
# Step 2: Add a new column by extending the schema
47+
augmented_dataset = DatasetGenerator(
48+
description="E-commerce purchase data with product recommendations",
49+
provider="openai/gpt-4o",
50+
schema=AugmentedSchema,
51+
data=df_base,
52+
)
53+
augmented_dataset.generate(0) # 0 means just transform existing data
54+
df_column_added = augmented_dataset.data
55+
56+
print("\nDataset with new 'recommendation' column:")
57+
print(df_column_added.head(3))
58+
print(f"Shape: {df_column_added.shape}")
59+
60+
# Step 3: Add more rows to the augmented dataset
61+
augmented_dataset.generate(5) # Add 5 more records
62+
df_rows_added = augmented_dataset.data
63+
64+
print("\nFinal dataset with 5 additional records:")
65+
print(f"Shape: {df_rows_added.shape}")
66+
print(df_rows_added.tail(3))
67+
68+
69+
if __name__ == "__main__":
70+
main()

0 commit comments

Comments
 (0)