Skip to content

Commit d5bef9d

Browse files
feature: conversational interface (#121)
* refactor: separate model from model builder agent * fix: build state info without model reference * feat: add chat agent * fix: not using yaml safe dump
1 parent 918e256 commit d5bef9d

File tree

11 files changed

+819
-281
lines changed

11 files changed

+819
-281
lines changed

docs/architecture/multi-agent-system.md

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ This approach offers several advantages:
4040

4141
```mermaid
4242
graph TD
43-
User([User]) --> |"Intent & Datasets"| Model["Model Class"]
43+
User([User]) --> |"Intent & Datasets"| ModelBuilder["ModelBuilder"]
44+
User --> |"Intent & Datasets"| Model["Model Class (deprecated)"]
4445
4546
subgraph "Multi-Agent System"
46-
Model --> |build| Orchestrator["Manager Agent"]
47+
ModelBuilder --> |build| Orchestrator["Manager Agent"]
48+
Model --> |build (deprecated)| ModelBuilder
4749
Orchestrator --> |"Schema Task"| SchemaResolver["Schema Resolver"]
4850
Orchestrator --> |"EDA Task"| EDA["EDA Agent"]
4951
Orchestrator --> |"Feature Task"| FE["Feature Engineer"]
@@ -191,7 +193,7 @@ self.dataset_splitter_agent = DatasetSplitterAgent(
191193

192194
### Manager Agent (Orchestrator)
193195

194-
**Class**: `PlexeAgent.manager_agent`
196+
**Class**: `CodeAgent`
195197
**Type**: `CodeAgent`
196198

197199
The Manager Agent serves as the central coordinator for the entire ML development process:
@@ -339,12 +341,11 @@ class ObjectRegistry:
339341
"""
340342

341343
_instance = None
342-
_items: Dict[str, Item] = dict()
343344

344345
def __new__(cls):
345346
if cls._instance is None:
346347
cls._instance = super(ObjectRegistry, cls).__new__(cls)
347-
cls._items = dict()
348+
cls._instance._items = {}
348349
return cls._instance
349350
```
350351

@@ -407,12 +408,12 @@ def get_executor_tool(distributed: bool) -> Callable:
407408
The multi-agent workflow follows these key steps:
408409

409410
1. **Initialization**:
410-
- User creates a `Model` instance with intent and datasets
411-
- User calls `model.build()` to start the process
411+
- User creates a `ModelBuilder` instance or `Model` instance with intent and datasets
412+
- User calls `ModelBuilder.build()` or `model.build()` (deprecated) to start the process
412413

413414
2. **Orchestration**:
414-
- Manager Agent initializes and coordinates the entire process
415-
- Manager Agent tasks specialist agents based on the workflow requirements
415+
- `ModelBuilder` (preferred) or `Model.build()` (deprecated) initializes the process
416+
- Manager Agent coordinates the entire process and tasks specialist agents based on workflow requirements
416417

417418
3. **Schema Resolution**:
418419
- If schemas aren't provided, SchemaResolverAgent infers them
@@ -607,19 +608,20 @@ class CustomModelValidator(Validator):
607608

608609
## References
609610

610-
- [PlexeAgent Class Definition](/plexe/agents/agents.py)
611-
- [Model Class Definition](/plexe/models.py)
612-
- [EdaAgent Definition](/plexe/agents/dataset_analyser.py)
613-
- [SchemaResolverAgent Definition](/plexe/agents/schema_resolver.py)
614-
- [FeatureEngineeringAgent Definition](/plexe/agents/feature_engineer.py)
615-
- [DatasetSplitterAgent Definition](/plexe/agents/dataset_splitter.py)
616-
- [ModelTrainerAgent Definition](/plexe/agents/model_trainer.py)
617-
- [ModelPackagerAgent Definition](/plexe/agents/model_packager.py)
618-
- [ModelPlannerAgent Definition](/plexe/agents/model_planner.py)
619-
- [ModelTesterAgent Definition](/plexe/agents/model_tester.py)
620-
- [Tool Definitions](/plexe/tools/)
621-
- [Dataset Tools](/plexe/tools/datasets.py)
622-
- [Validation Tools](/plexe/tools/validation.py)
623-
- [Testing Tools](/plexe/tools/testing.py)
624-
- [Executor Implementation](/plexe/internal/models/execution/)
625-
- [Object Registry](/plexe/core/object_registry.py)
611+
- [PlexeAgent Class Definition](plexe/agents/agents.py)
612+
- [Model Class Definition](plexe/models.py)
613+
- [ModelBuilder Class Definition](plexe/model_builder.py)
614+
- [EdaAgent Definition](plexe/agents/dataset_analyser.py)
615+
- [SchemaResolverAgent Definition](plexe/agents/schema_resolver.py)
616+
- [FeatureEngineeringAgent Definition](plexe/agents/feature_engineer.py)
617+
- [DatasetSplitterAgent Definition](plexe/agents/dataset_splitter.py)
618+
- [ModelTrainerAgent Definition](plexe/agents/model_trainer.py)
619+
- [ModelPackagerAgent Definition](plexe/agents/model_packager.py)
620+
- [ModelPlannerAgent Definition](plexe/agents/model_planner.py)
621+
- [ModelTesterAgent Definition](plexe/agents/model_tester.py)
622+
- [Tool Definitions](plexe/tools/)
623+
- [Dataset Tools](plexe/tools/datasets.py)
624+
- [Validation Tools](plexe/tools/validation.py)
625+
- [Testing Tools](plexe/tools/testing.py)
626+
- [Executor Implementation](plexe/internal/models/execution/)
627+
- [Object Registry](plexe/core/object_registry.py)

plexe/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .models import Model as Model
2+
from .model_builder import ModelBuilder as ModelBuilder
23
from .datasets import DatasetGenerator as DatasetGenerator
34
from .fileio import (
45
load_model as load_model,

plexe/agents/conversational.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Conversational Agent for guiding users through ML model definition and initiation.
3+
4+
This module defines a ConversationalAgent that helps users define their ML requirements
5+
through natural conversation, validates their inputs, and initiates model building
6+
when all necessary information has been gathered.
7+
"""
8+
9+
import logging
10+
11+
from smolagents import ToolCallingAgent, LiteLLMModel
12+
13+
from plexe.internal.common.utils.agents import get_prompt_templates
14+
from plexe.tools.datasets import get_dataset_preview
15+
from plexe.tools.conversation import validate_dataset_files, initiate_model_build
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class ConversationalAgent:
21+
"""
22+
Agent for conversational model definition and build initiation.
23+
24+
This agent guides users through defining their ML requirements via natural
25+
conversation, helps clarify the problem, validates dataset availability,
26+
and initiates the model building process when all requirements are met.
27+
"""
28+
29+
def __init__(
30+
self,
31+
model_id: str = "anthropic/claude-sonnet-4-20250514",
32+
verbose: bool = False,
33+
):
34+
"""
35+
Initialize the conversational agent.
36+
37+
Args:
38+
model_id: Model ID for the LLM to use for conversation
39+
verbose: Whether to display detailed agent logs
40+
"""
41+
self.model_id = model_id
42+
self.verbose = verbose
43+
44+
# Set verbosity level
45+
self.verbosity = 1 if verbose else 0
46+
47+
# Create the conversational agent with necessary tools
48+
self.agent = ToolCallingAgent(
49+
name="ModelDefinitionAssistant",
50+
description=(
51+
"Expert ML consultant that helps users define their machine learning requirements "
52+
"through conversational guidance. Specializes in clarifying problem definitions, "
53+
"understanding data requirements, and initiating model builds when ready. "
54+
"Maintains a friendly, helpful conversation while ensuring all technical "
55+
"requirements are properly defined before proceeding with model creation."
56+
),
57+
model=LiteLLMModel(model_id=self.model_id),
58+
tools=[
59+
get_dataset_preview,
60+
validate_dataset_files,
61+
initiate_model_build,
62+
],
63+
add_base_tools=False,
64+
verbosity_level=self.verbosity,
65+
prompt_templates=get_prompt_templates(
66+
base_template_name="toolcalling_agent.yaml",
67+
override_template_name="conversational_prompt_templates.yaml",
68+
),
69+
)

plexe/callbacks.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,25 @@ class BuildStateInfo:
6262
node: Optional[Node] = None
6363
"""The solution node being evaluated in the current iteration."""
6464

65-
# Reference to the model being built (for callbacks that need direct model access)
66-
model: Any = None
67-
"""Reference to the model being built."""
65+
# Model information fields (replacing direct model reference)
66+
model_identifier: Optional[str] = None
67+
"""Model unique identifier."""
68+
69+
model_state: Optional[str] = None
70+
"""Current model state (BUILDING/READY/ERROR)."""
71+
72+
# Final model artifacts (only available at build end)
73+
final_metric: Optional[Any] = None
74+
"""Final performance metric."""
75+
76+
final_artifacts: Optional[list] = None
77+
"""Model artifacts list."""
78+
79+
trainer_source: Optional[str] = None
80+
"""Training source code."""
81+
82+
predictor_source: Optional[str] = None
83+
"""Predictor source code."""
6884

6985

7086
class Callback(ABC):

plexe/core/storage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _save_model_to_tar(model: Any, path: str | Path) -> str:
8181
for key, value in metadata.items():
8282
if key in ["metrics", "metadata"]:
8383
info = tarfile.TarInfo(f"metadata/{key}.yaml")
84-
content = yaml.dump(value, default_flow_style=False).encode("utf-8")
84+
content = yaml.safe_dump(value, default_flow_style=False).encode("utf-8")
8585
else:
8686
info = tarfile.TarInfo(f"metadata/{key}.txt")
8787
content = str(value).encode("utf-8")
@@ -92,7 +92,7 @@ def _save_model_to_tar(model: Any, path: str | Path) -> str:
9292
for name, schema in [("input_schema", model.input_schema), ("output_schema", model.output_schema)]:
9393
schema_dict = {name: field.annotation.__name__ for name, field in schema.model_fields.items()}
9494
info = tarfile.TarInfo(f"schemas/{name}.yaml")
95-
content = yaml.dump(schema_dict, default_flow_style=False).encode("utf-8")
95+
content = yaml.safe_dump(schema_dict, default_flow_style=False).encode("utf-8")
9696
info.size = len(content)
9797
tar.addfile(info, io.BytesIO(content))
9898

@@ -134,7 +134,7 @@ def _save_model_to_tar(model: Any, path: str | Path) -> str:
134134
# Save evaluation report if available
135135
if hasattr(model, "evaluation_report") and model.evaluation_report:
136136
info = tarfile.TarInfo("metadata/evaluation_report.yaml")
137-
content = yaml.dump(model.evaluation_report, default_flow_style=False).encode("utf-8")
137+
content = yaml.safe_dump(model.evaluation_report, default_flow_style=False).encode("utf-8")
138138
info.size = len(content)
139139
tar.addfile(info, io.BytesIO(content))
140140

@@ -340,7 +340,7 @@ def _save_checkpoint_to_tar(model: Any, iteration: int, path: Optional[str | Pat
340340
for key, value in metadata.items():
341341
if key in ["metadata"]:
342342
info = tarfile.TarInfo(f"metadata/{key}.yaml")
343-
content = yaml.dump(value, default_flow_style=False).encode("utf-8")
343+
content = yaml.safe_dump(value, default_flow_style=False).encode("utf-8")
344344
else:
345345
info = tarfile.TarInfo(f"metadata/{key}.txt")
346346
content = str(value).encode("utf-8")
@@ -351,7 +351,7 @@ def _save_checkpoint_to_tar(model: Any, iteration: int, path: Optional[str | Pat
351351
for name, schema in [("input_schema", model.input_schema), ("output_schema", model.output_schema)]:
352352
schema_dict = {name: field.annotation.__name__ for name, field in schema.model_fields.items()}
353353
info = tarfile.TarInfo(f"schemas/{name}.yaml")
354-
content = yaml.dump(schema_dict, default_flow_style=False).encode("utf-8")
354+
content = yaml.safe_dump(schema_dict, default_flow_style=False).encode("utf-8")
355355
info.size = len(content)
356356
tar.addfile(info, io.BytesIO(content))
357357

plexe/internal/models/callbacks/mlflow.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,8 @@ def _extract_model_context(self, info: BuildStateInfo) -> Dict[str, Any]:
187187
context["max_iterations"] = info.max_iterations
188188

189189
# Add model ID if available
190-
model_id = self._safe_get(info.model, ["identifier"])
191-
if model_id:
192-
context["model_id"] = model_id
190+
if info.model_identifier:
191+
context["model_id"] = info.model_identifier
193192

194193
# Add basic schema and dataset info
195194
if info.input_schema:
@@ -234,7 +233,7 @@ def on_build_start(self, info: BuildStateInfo) -> None:
234233
self.experiment_id = self._get_or_create_experiment()
235234

236235
# Get model info and timestamp
237-
model_id = self._safe_get(info.model, ["identifier"], "unknown")[0:12] + "..."
236+
model_id = (info.model_identifier or "unknown")[0:12] + "..."
238237
timestamp = self._timestamp()
239238

240239
# End any active run before starting parent
@@ -385,35 +384,27 @@ def on_build_end(self, info: BuildStateInfo) -> None:
385384
self._safe_log_artifact(content=report_markdown, filename=f"eda_report_{dataset_name}.md")
386385

387386
# Log model information
388-
model = info.model
389-
if model:
390-
# Log best model metric
391-
metric = self._safe_get(model, ["metric"])
392-
if metric and hasattr(metric, "name") and hasattr(metric, "value"):
393-
mlflow.log_metric(f"best_{metric.name}", float(metric.value))
394-
395-
# Log model artifacts and status
396-
mlflow.set_tag("best_iteration", str(info.iteration))
397-
398-
# Log artifact names
399-
artifacts = self._safe_get(model, ["artifacts"], [])
400-
if artifacts:
401-
artifact_names = [a.name for a in artifacts]
402-
mlflow.set_tag("model_artifacts", ", ".join(artifact_names))
403-
404-
# Log model state
405-
state = self._safe_get(model, ["state"])
406-
if state:
407-
mlflow.set_tag("final_model_state", str(state))
408-
409-
# Log final model code
410-
trainer_source = self._safe_get(model, ["trainer_source"])
411-
if trainer_source:
412-
self._safe_log_artifact(content=trainer_source, filename="final_trainer.py")
413-
414-
predictor_source = self._safe_get(model, ["predictor_source"])
415-
if predictor_source:
416-
self._safe_log_artifact(content=predictor_source, filename="final_predictor.py")
387+
if info.final_metric and hasattr(info.final_metric, "name") and hasattr(info.final_metric, "value"):
388+
mlflow.log_metric(f"best_{info.final_metric.name}", float(info.final_metric.value))
389+
390+
# Log model artifacts and status
391+
mlflow.set_tag("best_iteration", str(info.iteration))
392+
393+
# Log artifact names
394+
if info.final_artifacts:
395+
artifact_names = [a.name for a in info.final_artifacts]
396+
mlflow.set_tag("model_artifacts", ", ".join(artifact_names))
397+
398+
# Log model state
399+
if info.model_state:
400+
mlflow.set_tag("final_model_state", str(info.model_state))
401+
402+
# Log final model code
403+
if info.trainer_source:
404+
self._safe_log_artifact(content=info.trainer_source, filename="final_trainer.py")
405+
406+
if info.predictor_source:
407+
self._safe_log_artifact(content=info.predictor_source, filename="final_predictor.py")
417408

418409
# End the parent run
419410
mlflow.end_run()

plexe/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
Application entry point for using the plexe package as a conversational agent.
33
"""
44

5-
# TODO: launch chat UI from here
5+
from smolagents import GradioUI
6+
from plexe.agents.conversational import ConversationalAgent
67

78

89
def main():
9-
pass
10+
ui = GradioUI(ConversationalAgent().agent)
11+
ui.launch()
1012

1113

1214
if __name__ == "__main__":

0 commit comments

Comments
 (0)