Skip to content

Commit 271b827

Browse files
feature: flexible model schema inference (#116)
* feat: add experiments/ to gitignore * feat: add agentic schema resolution * fix: update CLAUDE.md because claude is annoying * fix: schema resolver called even if not needed * fix: improve prompt template for schema resolver * fix: move schema generation logic to agent itself * fix: duplicated schema registration * fix: duplicated registration logic * fix: improve schema resolver prompt * fix: "already registered" failure for schema resolver * fix: resolver not used if schemas provided * chore: bump to 0.19.0 * docs: update how it works doc * fix: remove unused schema resolution templates --------- Co-authored-by: marcellodebernardi <[email protected]>
1 parent 70696bf commit 271b827

File tree

13 files changed

+475
-133
lines changed

13 files changed

+475
-133
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,8 @@ plexe-full-codebase.txt
194194
examples/datasets/
195195
examples/datasets/*
196196

197+
# Experiments
198+
experiments/
199+
notebooks/
200+
197201
**/.claude/settings.local.json

CLAUDE.md

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,37 @@
11
# CLAUDE.md: Plexe Coding Reference
22

3-
## Project Structure
4-
- `plexe/`: Main package directory
5-
- `plexe/models.py`: Implemented the top-level `Model` class
6-
- `plexe/fileio.py`: Saving and loading models
7-
- `plexe/config.py`: Configuration for model building, including LLM prompts
8-
- `plexe/internal/common/`: Package containing common utilities and functions
9-
- `plexe/internal/models/`: Package containing model building and training logic
10-
- `plexe/internal/datasets`: Package containing synthetic data generation logic
11-
- `plexe/internal/schemas/`: Package containing schema validation and inference logic
12-
- `plexe/internal/models/generators.py`: Main implementation of the model building and training logic
3+
## Project Overview
4+
Plexe is a framework for building ML models using natural language. It employs a multi-agent architecture where
5+
specialized AI agents collaborate to analyze requirements, generate solutions, and build functional ML models.
6+
7+
The core architecture is as follows: agents go in `plexe/agents/*`, tools in `plexe/tools/*`, prompt templates in
8+
`plexe/templates/prompts/*`, and the main model code in `plexe/models.py`. This structure must be followed.
9+
10+
## Key Components
11+
- `plexe/models.py`: Core `Model` class with build/predict functionality
12+
- `plexe/agents/schema_resolver.py`: Agent inferring input/output schemas
13+
- `plexe/internal/agents.py`: Multi-agent system implementation (`PlexeAgent` class)
14+
- `plexe/internal/models/tools/`: Tools for code generation, execution, validation
15+
- `plexe/config.py`: Configuration management and prompt templates
16+
- `plexe/internal/common/registries/objects.py`: Shared object registry for agents
17+
- `plexe/datasets.py`: Dataset handling and synthetic data generation
18+
- `docs/architecture/multi-agent-system.md`: Architectural documentation
19+
- `plexe/templates/prompts/`: Prompt templates for agents and LLM calls
1320

1421
## Build/Run Commands
1522
- Install deps: `poetry install`
1623
- Format code: `poetry run black .`
1724
- Lint code: `poetry run ruff check . --fix`
18-
- Run all tests: `poetry run pytest tests/`
19-
- Run single test: `poetry run pytest tests/path/to/test_file.py::test_function_name`
20-
- Run unit tests: `poetry run pytest tests/unit/`
21-
- Run integration tests: `poetry run pytest tests/integration/`
25+
- Run tests: `poetry run pytest tests/`
2226
- Run with coverage: `poetry run pytest --cov=plexe tests/`
2327

2428
## Code Style
25-
- **Paradigm**: object-oriented structure, functional implementations where appropriate
26-
- **Functions**: 50 lines max (not including docstrings)
29+
- **Functions**: Max 50 lines (excluding docstrings)
2730
- **Formatting**: Black with 120 char line length
2831
- **Linting**: Ruff with E203/E501/E402 ignored
29-
- **Typing**: Use type hints and Pydantic models
30-
- **Naming**: snake_case (functions/vars), CamelCase (classes)
31-
- **Imports**: Group stdlib, third-party, then local imports; NO IMPORTS INSIDE FUNCTIONS, always import at the top of the file
32-
- **__init__.py**: No code in __init__.py files except in plexe/__init__.py for convenience
33-
- **Docstrings**: Required for public modules/classes/functions; Sphinx style without type hints
32+
- **Typing**: Type hints and Pydantic models required
33+
- **Imports**: ALWAYS at top level in order: stdlib, third-party, local; NEVER inside functions
34+
- **__init__.py**: No implementation code except in `plexe/__init__.py`
35+
- **Docstrings**: Required for public APIs; Sphinx style
3436
- **Testing**: Write pytest tests for all new functionality
35-
36-
## Commit Messages
37-
- Format: `<type>: <subject>`
38-
- Types: feat, fix, docs, style, refactor, test, chore
39-
- Example: `feat: add support for deepseek`
37+
- **Elegance**: Write the simplest solution possible; avoid over-engineering; prefer deleting code over adding code

docs/architecture/multi-agent-system.md

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [Overview](#overview)
1111
- [Architecture Diagram](#architecture-diagram)
1212
- [Key Components](#key-components)
13+
- [Schema Resolver Agent](#schema-resolver-agent)
1314
- [Manager Agent (Orchestrator)](#manager-agent-orchestrator)
1415
- [ML Research Scientist Agent](#ml-research-scientist-agent)
1516
- [ML Engineer Agent](#ml-engineer-agent)
@@ -38,6 +39,8 @@ graph TD
3839
User([User]) --> |"Intent & Datasets"| Model["Model Class"]
3940
4041
subgraph "Multi-Agent System"
42+
Model --> |"Schema Resolution"| SchemaResolver["Schema Resolver"]
43+
SchemaResolver --> |"Schemas"| Orchestrator
4144
Model --> |build| Orchestrator["Manager Agent"]
4245
Orchestrator --> |"Plan Task"| MLS["ML Researcher"]
4346
Orchestrator --> |"Implement Task"| MLE["ML Engineer"]
@@ -52,6 +55,7 @@ graph TD
5255
Datasets[(Datasets)]
5356
Artifacts[(Model Artifacts)]
5457
Code[(Code Snippets)]
58+
Schemas[(I/O Schemas)]
5559
end
5660
5761
subgraph Tools["Tool System"]
@@ -68,6 +72,8 @@ graph TD
6872
MLS <--> Tools
6973
MLE <--> Tools
7074
MLOPS <--> Tools
75+
SchemaResolver <--> Registry
76+
SchemaResolver <--> Tools
7177
7278
Orchestrator --> Result([Trained Model])
7379
Result --> Model
@@ -76,9 +82,30 @@ graph TD
7682

7783
## Key Components
7884

85+
### Schema Resolver Agent
86+
87+
**Class**: `SchemaResolverAgent`
88+
**Type**: `ToolCallingAgent`
89+
90+
The Schema Resolver Agent infers input and output schemas from intent and dataset samples:
91+
92+
```python
93+
schema_resolver = SchemaResolverAgent(
94+
model_id=provider_config.tool_provider,
95+
chain_of_thought_callable=cot_callable,
96+
verbosity_level=1,
97+
)
98+
```
99+
100+
**Responsibilities**:
101+
- Analyzing the problem description and sample data
102+
- Inferring appropriate input and output schemas
103+
- Registering schemas with the Object Registry
104+
- Providing automatic schema resolution when schemas aren't specified
105+
79106
### Manager Agent (Orchestrator)
80107

81-
**Class**: `PlexeAgent` attribute `manager_agent`
108+
**Class**: `PlexeAgent.manager_agent`
82109
**Type**: `CodeAgent`
83110

84111
The Manager Agent serves as the central coordinator for the entire ML development process:
@@ -87,8 +114,15 @@ The Manager Agent serves as the central coordinator for the entire ML developmen
87114
self.manager_agent = CodeAgent(
88115
name="Orchestrator",
89116
model=LiteLLMModel(model_id=self.orchestrator_model_id),
90-
tools=[select_target_metric, review_finalised_model, split_datasets,
91-
create_input_sample, format_final_orchestrator_agent_response],
117+
tools=[
118+
get_select_target_metric(self.tool_model_id),
119+
get_review_finalised_model(self.tool_model_id),
120+
split_datasets,
121+
create_input_sample,
122+
get_dataset_preview,
123+
get_raw_dataset_schema,
124+
format_final_orchestrator_agent_response,
125+
],
92126
managed_agents=[self.ml_research_agent, self.mle_agent, self.mlops_engineer],
93127
add_base_tools=False,
94128
verbosity_level=self.orchestrator_verbosity,
@@ -124,14 +158,13 @@ self.ml_research_agent = ToolCallingAgent(
124158
"- input schema for the model"
125159
"- output schema for the model"
126160
"- the name and comparison method of the metric to optimise"
127-
"- the identifier of the LLM that should be used for plan generation"
128161
),
129162
model=LiteLLMModel(model_id=self.ml_researcher_model_id),
130-
tools=[],
163+
tools=[get_dataset_preview],
131164
add_base_tools=False,
132165
verbosity_level=self.specialist_verbosity,
133166
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mls_prompt_templates.yaml"),
134-
step_callbacks=[self.chain_of_thought_callable]
167+
step_callbacks=[self.chain_of_thought_callable],
135168
)
136169
```
137170

@@ -161,14 +194,13 @@ self.mle_agent = ToolCallingAgent(
161194
"- the full solution plan that outlines how to solve this problem"
162195
"- the split train/validation dataset names"
163196
"- the working directory to use for model execution"
164-
"- the identifier of the LLM that should be used for code generation"
165197
),
166198
model=LiteLLMModel(model_id=self.ml_engineer_model_id),
167199
tools=[
168-
generate_training_code,
200+
get_generate_training_code(self.tool_model_id),
169201
validate_training_code,
170-
fix_training_code,
171-
get_executor_tool(distributed),
202+
get_fix_training_code(self.tool_model_id),
203+
get_executor_tool(self.distributed),
172204
format_final_mle_agent_response,
173205
],
174206
add_base_tools=False,
@@ -188,32 +220,30 @@ self.mle_agent = ToolCallingAgent(
188220
### ML Operations Engineer Agent
189221

190222
**Class**: `PlexeAgent.mlops_engineer`
191-
**Type**: `ToolCallingAgent`
223+
**Type**: `CodeAgent`
192224

193225
This agent focuses on productionizing the model through inference code:
194226

195227
```python
196-
self.mlops_engineer = ToolCallingAgent(
228+
self.mlops_engineer = CodeAgent(
197229
name="MLOperationsEngineer",
198230
description=(
199231
"Expert ML operations engineer that writes inference code for ML models to be used in production. "
200232
"To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
201233
"- input schema for the model"
202234
"- output schema for the model"
203235
"- the 'training code id' of the training code produced by the MLEngineer agent"
204-
"- the identifier of the LLM that should be used for code generation"
205236
),
206237
model=LiteLLMModel(model_id=self.ml_ops_engineer_model_id),
207238
tools=[
208-
split_datasets,
209-
generate_inference_code,
239+
get_generate_inference_code(self.tool_model_id),
210240
validate_inference_code,
211-
fix_inference_code,
241+
get_fix_inference_code(self.tool_model_id),
212242
format_final_mlops_agent_response,
213243
],
214244
add_base_tools=False,
215245
verbosity_level=self.specialist_verbosity,
216-
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mlops_prompt_templates.yaml"),
246+
prompt_templates=get_prompt_templates("code_agent.yaml", "mlops_prompt_templates.yaml"),
217247
planning_interval=8,
218248
step_callbacks=[self.chain_of_thought_callable],
219249
)
@@ -254,27 +284,32 @@ class ObjectRegistry:
254284
**Key Features**:
255285
- Type-safe storage and retrieval
256286
- Shared access across agents
257-
- Registration of multiple item types (datasets, artifacts, code)
287+
- Registration of multiple item types (datasets, artifacts, code, schemas)
288+
- Batch operations with register_multiple and get_multiple
258289

259290
### Tool System
260291

261-
The system includes specialized tools that agents can use to perform specific tasks:
292+
The system includes specialized tools that agents can use to perform specific tasks, implemented using factory patterns:
262293

263294
**Metric Selection Tool**:
264295
```python
265-
@tool
266-
def select_target_metric(task: str, provider: str) -> Dict:
267-
"""Selects the appropriate target metric to optimise for the given task."""
296+
def get_select_target_metric(model_id: str) -> Callable:
297+
"""Factory function that returns a tool for selecting appropriate target metrics."""
298+
@tool
299+
def select_target_metric(task: str, provider: str) -> Dict:
300+
"""Selects the appropriate target metric to optimise for the given task."""
268301
```
269302

270303
**Code Generation Tools**:
271304
```python
272-
@tool
273-
def generate_training_code(
274-
task: str, solution_plan: str, train_datasets: List[str],
275-
validation_datasets: List[str], llm_to_use: str
276-
) -> str:
277-
"""Generates training code based on the solution plan."""
305+
def get_generate_training_code(model_id: str) -> Callable:
306+
"""Factory function that returns a tool for generating training code."""
307+
@tool
308+
def generate_training_code(
309+
task: str, solution_plan: str, train_datasets: List[str],
310+
validation_datasets: List[str]
311+
) -> str:
312+
"""Generates training code based on the solution plan."""
278313
```
279314

280315
**Validation Tools**:
@@ -289,12 +324,11 @@ def validate_inference_code(
289324

290325
**Execution Tools**:
291326
```python
292-
@tool
293-
def execute_training_code(
294-
node_id: str, code: str, working_dir: str, dataset_names: List[str],
295-
timeout: int, metric_to_optimise_name: str, metric_to_optimise_comparison_method: str,
296-
) -> Dict:
297-
"""Executes training code in an isolated environment."""
327+
def get_executor_tool(distributed: bool) -> Callable:
328+
"""Factory function that returns the appropriate executor tool."""
329+
if distributed:
330+
return execute_training_code_distributed
331+
return execute_training_code
298332
```
299333

300334
## Workflow
@@ -305,24 +339,28 @@ The multi-agent workflow follows these key steps:
305339
- User creates a `Model` instance with intent and datasets
306340
- User calls `model.build()` to start the process
307341

308-
2. **Orchestration**:
342+
2. **Schema Resolution**:
343+
- If schemas aren't provided, SchemaResolverAgent infers them
344+
- Schemas are registered in the Object Registry
345+
346+
3. **Orchestration**:
309347
- Manager Agent selects metrics and splits datasets
310348
- Manager Agent initializes the solution planning phase
311349

312-
3. **Solution Planning**:
350+
4. **Solution Planning**:
313351
- ML Research Scientist proposes solution approaches
314352
- Manager Agent evaluates and selects approaches
315353

316-
4. **Model Implementation**:
354+
5. **Model Implementation**:
317355
- ML Engineer generates and executes training code
318356
- Model artifacts are registered in the Object Registry
319357
- Process may iterate through multiple approaches
320358

321-
5. **Inference Code Generation**:
359+
6. **Inference Code Generation**:
322360
- ML Operations Engineer generates compatible inference code
323361
- Code is validated with sample inputs
324362

325-
6. **Finalization**:
363+
7. **Finalization**:
326364
- Manager Agent reviews and finalizes the model
327365
- All artifacts and code are collected
328366
- Completed model is returned to the user
@@ -334,7 +372,7 @@ The multi-agent workflow follows these key steps:
334372
The system uses a hierarchical communication pattern:
335373

336374
```
337-
User → Model → Manager Agent → Specialist Agents → Manager Agent → Model → User
375+
User → Model → Schema Resolver → Manager Agent → Specialist Agents → Manager Agent → Model → User
338376
```
339377

340378
Each agent communicates through structured task descriptions and responses:
@@ -447,14 +485,17 @@ class CustomPlexeAgent(PlexeAgent):
447485

448486
### Implementing Custom Tools
449487

450-
You can add new tools by using the `@tool` decorator:
488+
You can add new tools using the factory pattern with the `@tool` decorator:
451489

452490
```python
453-
@tool
454-
def custom_tool(param1: str, param2: int) -> Dict:
455-
"""Description of what this tool does."""
456-
# Tool implementation
457-
return {"result": "Output of the tool"}
491+
def get_custom_tool(model_id: str) -> Callable:
492+
"""Factory function that returns a custom tool."""
493+
@tool
494+
def custom_tool(param1: str, param2: int) -> Dict:
495+
"""Description of what this tool does."""
496+
# Tool implementation
497+
return {"result": "Output of the tool"}
498+
return custom_tool
458499
```
459500

460501
### Supporting New Model Types
@@ -474,5 +515,7 @@ class CustomModelValidator(Validator):
474515

475516
- [PlexeAgent Class Definition](/plexe/internal/agents.py)
476517
- [Model Class Definition](/plexe/models.py)
518+
- [SchemaResolverAgent Definition](/plexe/agents/schema_resolver.py)
477519
- [Tool Definitions](/plexe/internal/models/tools/)
478-
- [Executor Implementation](/plexe/internal/models/execution/)
520+
- [Executor Implementation](/plexe/internal/models/execution/)
521+
- [Object Registry](/plexe/internal/common/registries/objects.py)

plexe/agents/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""
2+
Agents for the Plexe ML platform.
3+
4+
This package contains agent implementations for various tasks in the Plexe platform.
5+
"""

0 commit comments

Comments
 (0)