Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions docs/architecture/multi-agent-system.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ graph TD
User([User]) --> |"Intent & Datasets"| Model["Model Class"]

subgraph "Multi-Agent System"
Model --> |"Data Registration"| EDA["EDA Agent"]
EDA --> |"Analysis & Reports"| SchemaResolver
Model --> |"Schema Resolution"| SchemaResolver["Schema Resolver"]
SchemaResolver --> |"Schemas"| Orchestrator
Model --> |build| Orchestrator["Manager Agent"]
Expand All @@ -53,6 +55,7 @@ graph TD

subgraph Registry["Object Registry"]
Datasets[(Datasets)]
EdaReports[(EDA Reports)]
Artifacts[(Model Artifacts)]
Code[(Code Snippets)]
Schemas[(I/O Schemas)]
Expand All @@ -70,10 +73,15 @@ graph TD
Orchestrator <--> Registry
Orchestrator <--> Tools
MLS <--> Tools
MLS <--> EdaReports
MLE <--> Tools
MLE <--> EdaReports
MLOPS <--> Tools
SchemaResolver <--> Registry
SchemaResolver <--> Tools
SchemaResolver <--> EdaReports
EDA <--> Registry
EDA <--> Tools

Orchestrator --> Result([Trained Model])
Result --> Model
Expand All @@ -82,6 +90,28 @@ graph TD

## Key Components

### EDA Agent

**Class**: `EdaAgent`
**Type**: `CodeAgent`

The EDA Agent performs exploratory data analysis on datasets early in the workflow:

```python
eda_agent = EdaAgent(
model_id=provider_config.research_provider,
verbose=verbose,
chain_of_thought_callable=cot_callable,
)
```

**Responsibilities**:
- Analyzing datasets to understand structure, distributions, and relationships
- Identifying data quality issues, outliers, and missing values
- Generating key insights about the data
- Providing recommendations for preprocessing and modeling
- Registering EDA reports in the Object Registry for use by downstream agents

### Schema Resolver Agent

**Class**: `SchemaResolverAgent`
Expand Down Expand Up @@ -339,11 +369,17 @@ The multi-agent workflow follows these key steps:
- User creates a `Model` instance with intent and datasets
- User calls `model.build()` to start the process

2. **Schema Resolution**:
2. **Exploratory Data Analysis**:
- EdaAgent analyzes datasets to understand structure and characteristics
- Generates insights about data patterns, quality issues, and modeling considerations
- EDA reports are registered in the Object Registry for use by other agents

3. **Schema Resolution**:
- If schemas aren't provided, SchemaResolverAgent infers them
- The agent can leverage EDA findings to determine appropriate schemas
- Schemas are registered in the Object Registry

3. **Orchestration**:
4. **Orchestration**:
- Manager Agent selects metrics and splits datasets
- Manager Agent initializes the solution planning phase

Expand Down Expand Up @@ -372,7 +408,7 @@ The multi-agent workflow follows these key steps:
The system uses a hierarchical communication pattern:

```
User → Model → Schema Resolver → Manager Agent → Specialist Agents → Manager Agent → Model → User
User → Model → EDA Agent → Schema Resolver → Manager Agent → Specialist Agents → Manager Agent → Model → User
```

Each agent communicates through structured task descriptions and responses:
Expand Down Expand Up @@ -515,7 +551,9 @@ class CustomModelValidator(Validator):

- [PlexeAgent Class Definition](/plexe/internal/agents.py)
- [Model Class Definition](/plexe/models.py)
- [EdaAgent Definition](/plexe/agents/dataset_analyser.py)
- [SchemaResolverAgent Definition](/plexe/agents/schema_resolver.py)
- [Tool Definitions](/plexe/internal/models/tools/)
- [Dataset Tools](/plexe/internal/models/tools/datasets.py)
- [Executor Implementation](/plexe/internal/models/execution/)
- [Object Registry](/plexe/internal/common/registries/objects.py)
98 changes: 98 additions & 0 deletions plexe/agents/dataset_analyser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Exploratory Data Analysis (EDA) Agent for data analysis and insights in ML models.

This module defines an EdaAgent that analyzes datasets to generate comprehensive
exploratory data analysis reports before model building begins.
"""

import logging
from typing import List, Callable

from smolagents import LiteLLMModel, CodeAgent

from plexe.config import prompt_templates
from plexe.internal.common.utils.agents import get_prompt_templates
from plexe.internal.models.tools.datasets import register_eda_report
from plexe.internal.models.tools.schemas import get_raw_dataset_schema

logger = logging.getLogger(__name__)


class EdaAgent:
"""
Agent for performing exploratory data analysis on datasets.

This agent analyzes the available datasets to produce a comprehensive EDA report
containing data overview, feature analysis, relationships, data quality issues,
key insights, and recommendations for modeling.
"""

def __init__(
self,
model_id: str = "openai/gpt-4o",
verbose: bool = False,
chain_of_thought_callable: Callable = None,
):
"""
Initialize the EDA agent.

Args:
model_id: Model ID for the LLM to use for data analysis
verbose: Whether to display detailed agent logs
chain_of_thought_callable: Optional callable for chain of thought logging
"""
self.model_id = model_id
self.verbose = verbose

# Set verbosity level
self.verbosity = 1 if verbose else 0

# Create the EDA agent with the necessary tools
self.agent = CodeAgent(
name="DatasetAnalyser",
description=(
"Expert data analyst that performs exploratory data analysis on datasets "
"to generate insights and recommendations for ML modeling."
),
model=LiteLLMModel(model_id=self.model_id),
tools=[register_eda_report, get_raw_dataset_schema],
add_base_tools=False,
verbosity_level=self.verbosity,
# planning_interval=3,
max_steps=30,
step_callbacks=[chain_of_thought_callable],
additional_authorized_imports=["pandas", "numpy", "plexe"],
prompt_templates=get_prompt_templates("code_agent.yaml", "eda_prompt_templates.yaml"),
)

def run(
self,
intent: str,
dataset_names: List[str],
) -> bool:
"""
Run the EDA agent to analyze datasets and create EDA reports.

Args:
intent: Natural language description of the model's purpose
dataset_names: List of dataset registry names available for analysis

Returns:
Dictionary containing:
- eda_report_names: List of registered EDA report names in the Object Registry
- dataset_names: List of datasets that were analyzed
- summary: Brief summary of key findings
"""
# Use the template system to create the prompt
datasets_str = ", ".join(dataset_names)

# Generate the prompt using the template system
task_description = prompt_templates.eda_agent_prompt(
intent=intent,
datasets=datasets_str,
)

# Run the agent to get analysis
self.agent.run(task_description)

return True
10 changes: 5 additions & 5 deletions plexe/agents/schema_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import logging
from typing import Dict, List, Any, Callable

from smolagents import ToolCallingAgent, LiteLLMModel
from smolagents import LiteLLMModel, CodeAgent

from plexe.config import prompt_templates
from plexe.internal.common.registries.objects import ObjectRegistry
from plexe.internal.models.tools.datasets import get_dataset_preview
from plexe.internal.models.tools.schemas import get_raw_dataset_schema, register_final_model_schemas
from plexe.internal.models.tools.datasets import get_dataset_preview, get_eda_report
from plexe.internal.models.tools.schemas import register_final_model_schemas

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,14 +49,14 @@ def __init__(
self.verbosity = 1 if verbose else 0

# Create the schema resolver agent with the necessary tools
self.agent = ToolCallingAgent(
self.agent = CodeAgent(
name="SchemaResolver",
description=(
"Expert schema resolver that determines the appropriate input and output "
"schemas for ML models based on intent and available datasets."
),
model=LiteLLMModel(model_id=self.model_id),
tools=[get_dataset_preview, get_raw_dataset_schema, register_final_model_schemas],
tools=[get_dataset_preview, get_eda_report, register_final_model_schemas],
add_base_tools=False,
verbosity_level=self.verbosity,
step_callbacks=[chain_of_thought_callable],
Expand Down
16 changes: 7 additions & 9 deletions plexe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,6 @@ def planning_system(self) -> str:
def planning_select_metric(self, problem_statement) -> str:
return self._render("planning/select_metric.jinja", problem_statement=problem_statement)

def planning_generate(self, problem_statement, metric_to_optimise) -> str:
return self._render(
"planning/generate.jinja",
problem_statement=problem_statement,
metric_to_optimise=metric_to_optimise,
allowed_packages=config.code_generation.allowed_packages,
deep_learning_available=config.code_generation.deep_learning_available,
)

def schema_base(self) -> str:
return self._render("schemas/base.jinja")

Expand All @@ -225,6 +216,13 @@ def schema_resolver_prompt(
has_output_schema=has_output_schema,
)

def eda_agent_prompt(self, intent, datasets) -> str:
return self._render(
"agent/agent_data_analyser_prompt.jinja",
intent=intent,
datasets=datasets,
)

def training_system(self) -> str:
return self._render("training/system_prompt.jinja")

Expand Down
10 changes: 8 additions & 2 deletions plexe/internal/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
)
from plexe.internal.models.tools.evaluation import get_review_finalised_model
from plexe.internal.models.tools.metrics import get_select_target_metric
from plexe.internal.models.tools.datasets import split_datasets, create_input_sample, get_dataset_preview
from plexe.internal.models.tools.datasets import (
split_datasets,
create_input_sample,
get_dataset_preview,
get_eda_report,
)
from plexe.internal.models.tools.schemas import get_raw_dataset_schema
from plexe.internal.models.tools.execution import get_executor_tool
from plexe.internal.models.tools.response_formatting import (
Expand Down Expand Up @@ -106,9 +111,10 @@ def __init__(
"- input schema for the model"
"- output schema for the model"
"- the name and comparison method of the metric to optimise"
"- the name of the dataset to use for training"
),
model=LiteLLMModel(model_id=self.ml_researcher_model_id),
tools=[get_dataset_preview],
tools=[get_dataset_preview, get_eda_report],
add_base_tools=False,
verbosity_level=self.specialist_verbosity,
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mls_prompt_templates.yaml"),
Expand Down
47 changes: 41 additions & 6 deletions plexe/internal/common/datasets/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,58 @@ def split(
test_ratio: float = 0.15,
stratify_column: Optional[str] = None,
random_state: Optional[int] = None,
is_time_series: bool = False,
time_index_column: Optional[str] = None,
) -> Tuple["TabularDataset", "TabularDataset", "TabularDataset"]:
"""
Split dataset into train, validation and test sets.

:param train_ratio: Proportion of data to use for training
:param val_ratio: Proportion of data to use for validation
:param test_ratio: Proportion of data to use for testing
:param stratify_column: Column to use for stratified splitting
:param random_state: Random seed for reproducibility
:param stratify_column: Column to use for stratified splitting (not used for time series)
:param random_state: Random seed for reproducibility (not used for time series)
:param is_time_series: Whether the data is chronological time series data
:param time_index_column: Column name that represents the time index, required if is_time_series=True
:returns: A tuple of (train_dataset, val_dataset, test_dataset)
:raises ValueError: If ratios don't sum to approximately 1.0
:raises ValueError: If ratios don't sum to approximately 1.0 or if time_index_column is missing for time series
"""
from sklearn.model_selection import train_test_split

if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-10:
raise ValueError("Split ratios must sum to 1.0")

# Handle time series data
if is_time_series:
if not time_index_column:
raise ValueError("time_index_column must be provided when is_time_series=True")

if time_index_column not in self._data.columns:
raise ValueError(f"time_index_column '{time_index_column}' not found in dataset columns")

# Sort by time index
sorted_data = self._data.sort_values(by=time_index_column).reset_index(drop=True)

# Calculate split indices
n_samples = len(sorted_data)
train_end = int(n_samples * train_ratio)
val_end = train_end + int(n_samples * val_ratio)

# Split the data sequentially
train_data = sorted_data.iloc[:train_end]
val_data = sorted_data.iloc[train_end:val_end]
test_data = sorted_data.iloc[val_end:]

# Handle edge cases for empty splits
empty_df = pd.DataFrame(columns=self._data.columns)
if val_ratio < 1e-10:
val_data = empty_df
if test_ratio < 1e-10:
test_data = empty_df

return TabularDataset(train_data), TabularDataset(val_data), TabularDataset(test_data)

# Regular random splitting for non-time series data
from sklearn.model_selection import train_test_split

# Handle all-data-to-train edge case
if val_ratio < 1e-10 and test_ratio < 1e-10:
return (
Expand Down Expand Up @@ -101,7 +136,7 @@ def split(
stratify=temp_data[stratify_column] if stratify_column else None,
random_state=random_state,
)
return (TabularDataset(train_data), TabularDataset(val_data), TabularDataset(test_data))
return TabularDataset(train_data), TabularDataset(val_data), TabularDataset(test_data)

def sample(
self, n: int = None, frac: float = None, replace: bool = False, random_state: int = None
Expand Down
8 changes: 5 additions & 3 deletions plexe/internal/common/utils/chain_of_thought/emitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def _get_agent_color(agent_name: str) -> str:
"""Get the color for an agent based on its role."""
agent_colors = {
"System": "bright_blue",
"ML Research Scientist": "green",
"ML Engineer": "yellow",
"ML Ops Engineer": "magenta",
"MLResearchScientist": "green",
"MLEngineer": "yellow",
"MLOperationsEngineer": "magenta",
"Orchestrator": "cyan",
"DatasetAnalyser": "red",
"SchemaResolver": "orange",
# Default color
"default": "blue",
}
Expand Down
7 changes: 7 additions & 0 deletions plexe/internal/models/callbacks/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,15 @@ def on_iteration_end(self, info: BuildStateInfo) -> None:
with open(code_path, "w") as f:
f.write(info.node.training_code)
mlflow.log_artifact(str(code_path))
# Clean up the temporary file after logging
code_path.unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Could not log trainer source: {e}")
# Attempt to clean up the file even if logging failed
try:
Path("trainer_source.py").unlink(missing_ok=True)
except Exception:
pass

# Log node performance if available
if info.node.performance:
Expand Down
Loading