Skip to content

Commit b4d6dfc

Browse files
fix: improve predictor construction robustness (#112)
* fix: simplify santander example * refactor: remove 'llm_to_use' from tool signature * chore: yell at claude about imports * feat: initial implementation of more agentic predictor production * feat: extract input sample as dict * feat: simplify inference generation tools * chore: bump to 0.18.3 * fix: remove unused agent inputs * fix: include prompt templates in dumpcode.py * feat: move predictor generation from tools to agent * fix: register schemas * fix: remove unused inference prompts * fix: allow plexe imports for mlops engineer * feat: extract artifacts in inference context * feat: add house prices example * fix: artifact list extraction defined incorrectly * fix: incorrect sampling in examples * fix: add io and plexe to allowed imports * fix: setting llm for extraction incorrectly * fix: get schemas from registry at inference validation * fix: artifact extraction can fail silently * fix: extra space in house prices example * fix: only one integration test per module --------- Co-authored-by: Vaibhav Dubey <[email protected]>
1 parent 256e8b5 commit b4d6dfc

34 files changed

+592
-974
lines changed

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
- **Linting**: Ruff with E203/E501/E402 ignored
2929
- **Typing**: Use type hints and Pydantic models
3030
- **Naming**: snake_case (functions/vars), CamelCase (classes)
31-
- **Imports**: Group stdlib, third-party, then local imports; NO LOCAL IMPORTS, always import at the top of the file
31+
- **Imports**: Group stdlib, third-party, then local imports; NO IMPORTS INSIDE FUNCTIONS, always import at the top of the file
3232
- **__init__.py**: No code in __init__.py files except in plexe/__init__.py for convenience
3333
- **Docstrings**: Required for public modules/classes/functions; Sphinx style without type hints
3434
- **Testing**: Write pytest tests for all new functionality

examples/house_prices.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
This script demonstrates how to run the plexe ML engineering agent to build a predictive model. The example
3+
uses the Kaggle 'House Prices - Advanced Regression Techniques' competition's training dataset.
4+
5+
The dataset is owned and hosted by Kaggle, and is available for download at
6+
https://www.kaggle.com/competitions/house-prices-advanced-regression-techniques/data under the MIT license
7+
(https://www.mit.edu/~amini/LICENSE.md). This dataset is not part of the plexe package or in any way
8+
affiliated to it, and Plexe AI claims no rights over it. The dataset is used here for demonstration purposes
9+
only. Please refer to the Kaggle competition page for more details on the dataset and its usage.
10+
11+
Citation:
12+
Anna Montoya and DataCanary. House Prices - Advanced Regression Techniques.
13+
https://kaggle.com/competitions/house-prices-advanced-regression-techniques, 2016. Kaggle.
14+
"""
15+
16+
# NOTE: you must download the dataset from Kaggle for this example to work
17+
18+
from datetime import datetime
19+
import pandas as pd
20+
21+
import plexe
22+
from plexe.internal.common.provider import ProviderConfig
23+
24+
25+
# Step 1: Define the model
26+
# Note: for conciseness we leave the input schema empty and let plexe infer it
27+
model = plexe.Model(
28+
intent=(
29+
"With 79 explanatory variables describing aspects of residential homes in Ames, Iowa, predict "
30+
"the final price of each home. Use only linear regression and decision tree models, no ensembling. "
31+
"The models must be extremely simple and quickly trainable on extremely constrained hardware."
32+
),
33+
output_schema={
34+
"SalePrice": float,
35+
},
36+
)
37+
38+
# Step 2: Build the model using the training dataset
39+
# 2A [OPTIONAL]: Define MLFlow callback for tracking
40+
mlflow_callback = plexe.callbacks.MLFlowCallback(
41+
tracking_uri="http://127.0.0.1:8080",
42+
experiment_name=f"house-prices-{datetime.now().strftime('%Y%m%d-%H%M%S') }",
43+
)
44+
# 2B: Build the model with the dataset
45+
# NOTE: In order to run this example, you will need to download the dataset from Kaggle
46+
model.build(
47+
datasets=[pd.read_csv("examples/datasets/house-prices-train.csv")],
48+
provider=ProviderConfig(
49+
default_provider="openai/gpt-4o",
50+
orchestrator_provider="anthropic/claude-3-7-sonnet-20250219",
51+
research_provider="openai/gpt-4o",
52+
engineer_provider="anthropic/claude-3-7-sonnet-20250219",
53+
ops_provider="anthropic/claude-3-7-sonnet-20250219",
54+
tool_provider="openai/gpt-4o",
55+
),
56+
max_iterations=2,
57+
timeout=1800, # 30 minute timeout
58+
run_timeout=180,
59+
verbose=False,
60+
callbacks=[mlflow_callback],
61+
chain_of_thought=True, # Enable chain of thought output
62+
)
63+
64+
# Step 3: Save the model
65+
plexe.save_model(model, "house-prices.tar.gz")
66+
67+
# Step 4: Run a prediction on the built model
68+
test_df = pd.read_csv("examples/datasets/house-prices-test.csv").sample(10)
69+
predictions = pd.DataFrame.from_records([model.predict(x) for x in test_df.to_dict(orient="records")])
70+
71+
# Step 5: print a sample of predictions
72+
print(predictions)
73+
74+
# Step 6: Print model description
75+
description = model.describe()
76+
print(description.as_text())

examples/santander_customer_transactions.py renamed to examples/santander_transactions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
model = plexe.Model(
2727
intent=(
2828
"Identify which customers will make a specific transaction in the future, irrespective of the amount "
29-
"of money transacted. For each Id, make a binary prediction of the 'target' variable."
29+
"of money transacted. For each Id, make a binary prediction of the 'target' variable. Use only linear "
30+
"regression and decision tree models, no ensembling. The models must be extremely simple and quickly "
31+
"trainable on extremely constrained hardware."
3032
),
3133
output_schema={
3234
"target": int,
@@ -51,7 +53,7 @@
5153
ops_provider="anthropic/claude-3-7-sonnet-20250219",
5254
tool_provider="openai/gpt-4o",
5355
),
54-
max_iterations=8,
56+
max_iterations=5,
5557
timeout=1800, # 30 minute timeout
5658
run_timeout=180,
5759
verbose=False,
@@ -63,11 +65,11 @@
6365
plexe.save_model(model, "santander_transactions_model.tar.gz")
6466

6567
# Step 4: Run a prediction on the built model
66-
test_df = pd.read_csv("examples/datasets/santander-transactions-test-mini.csv")
68+
test_df = pd.read_csv("examples/datasets/santander-transactions-test-mini.csv").sample(10)
6769
predictions = pd.DataFrame.from_records([model.predict(x) for x in test_df.to_dict(orient="records")])
6870

6971
# Step 5: print a sample of predictions
70-
print(predictions.sample(10))
72+
print(predictions)
7173

7274
# Step 6: Print model description
7375
description = model.describe()

examples/spaceship_titanic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
model = plexe.Model(
2525
intent=(
2626
"From features describing a Spaceship Titanic passenger's information, determine whether they were "
27-
"transported or not."
27+
"transported or not. Use only linear regression and decision tree models, no ensembling. The models "
28+
"must be extremely simple and quickly trainable on extremely constrained hardware."
2829
),
2930
input_schema={
3031
"PassengerId": str,
@@ -64,7 +65,7 @@
6465
ops_provider="anthropic/claude-3-7-sonnet-20250219",
6566
tool_provider="openai/gpt-4o",
6667
),
67-
max_iterations=4,
68+
max_iterations=1,
6869
timeout=300, # 5 minute timeout
6970
run_timeout=150,
7071
verbose=False,
@@ -76,11 +77,11 @@
7677
plexe.save_model(model, "spaceship_titanic_model.tar.gz")
7778

7879
# Step 4: Run a prediction on the built model
79-
test_df = pd.read_csv("examples/datasets/spaceship-titanic-test.csv")
80+
test_df = pd.read_csv("examples/datasets/spaceship-titanic-test.csv").sample(10)
8081
predictions = pd.DataFrame.from_records([model.predict(x) for x in test_df.to_dict(orient="records")])
8182

8283
# Step 5: print a sample of predictions
83-
print(predictions.sample(10))
84+
print(predictions)
8485

8586
# Step 6: Print model description
8687
description = model.describe()

plexe/config.py

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class _CodeGenerationConfig:
9393
"typing",
9494
"dataclasses",
9595
"json",
96+
"io",
9697
"time",
9798
"datetime",
9899
"os",
@@ -109,6 +110,7 @@ class _CodeGenerationConfig:
109110
"logging",
110111
"importlib",
111112
"types",
113+
"plexe",
112114
]
113115
)
114116

@@ -252,79 +254,6 @@ def training_review(self, problem_statement, plan, training_code, problems, allo
252254
allowed_packages=allowed_packages,
253255
)
254256

255-
def inference_system(self) -> str:
256-
return self._render("inference/system_prompt.jinja")
257-
258-
def inference_load(self, predictor_template, training_code) -> str:
259-
return self._render(
260-
"inference/load.jinja",
261-
predictor_template=predictor_template,
262-
training_code=training_code,
263-
)
264-
265-
def inference_preprocess(self, inference_code, input_schema, training_code) -> str:
266-
return self._render(
267-
"inference/preprocess.jinja",
268-
inference_code=inference_code,
269-
input_schema=input_schema,
270-
training_code=training_code,
271-
)
272-
273-
def inference_postprocess(self, inference_code, output_schema, training_code) -> str:
274-
return self._render(
275-
"inference/postprocess.jinja",
276-
inference_code=inference_code,
277-
output_schema=output_schema,
278-
training_code=training_code,
279-
)
280-
281-
def inference_predict(self, output_schema, input_schema, training_code, inference_code) -> str:
282-
return self._render(
283-
"inference/predict.jinja",
284-
output_schema=output_schema,
285-
input_schema=input_schema,
286-
training_code=training_code,
287-
inference_code=inference_code,
288-
)
289-
290-
def inference_combine(self, inference_code, predictor_interface_source) -> str:
291-
return self._render(
292-
"inference/combine.jinja",
293-
inference_code=inference_code,
294-
predictor_interface_source=predictor_interface_source,
295-
)
296-
297-
def inference_fix(self, predictor_interface_source, predictor_template, inference_code, review, problems) -> str:
298-
return self._render(
299-
"inference/fix.jinja",
300-
predictor_interface_source=predictor_interface_source,
301-
predictor_template=predictor_template,
302-
inference_code=inference_code,
303-
review=review,
304-
problems=problems,
305-
)
306-
307-
def inference_review(
308-
self,
309-
predictor_interface_source,
310-
predictor_template,
311-
inference_code,
312-
input_schema,
313-
output_schema,
314-
training_code,
315-
problems,
316-
) -> str:
317-
return self._render(
318-
"inference/review.jinja",
319-
predictor_interface_source=predictor_interface_source,
320-
predictor_template=predictor_template,
321-
inference_code=inference_code,
322-
input_schema=input_schema,
323-
output_schema=output_schema,
324-
training_code=training_code,
325-
problems=problems,
326-
)
327-
328257
def review_system(self) -> str:
329258
return self._render("review/system_prompt.jinja")
330259

plexe/internal/agents.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@
1717
from plexe.internal.models.entities.metric import Metric
1818
from plexe.internal.models.entities.metric import MetricComparator, ComparisonMethod
1919
from plexe.internal.models.interfaces.predictor import Predictor
20-
from plexe.internal.models.tools.code_generation import (
21-
generate_inference_code,
22-
fix_inference_code,
23-
generate_training_code,
24-
fix_training_code,
20+
from plexe.internal.models.tools.training import (
21+
get_generate_training_code,
22+
get_fix_training_code,
2523
)
24+
from plexe.internal.models.tools.evaluation import get_review_finalised_model
25+
from plexe.internal.models.tools.metrics import get_select_target_metric
2626
from plexe.internal.models.tools.datasets import split_datasets, create_input_sample
27-
from plexe.internal.models.tools.evaluation import review_finalised_model
2827
from plexe.internal.models.tools.execution import get_executor_tool
29-
from plexe.internal.models.tools.metrics import select_target_metric
3028
from plexe.internal.models.tools.response_formatting import (
3129
format_final_orchestrator_agent_response,
3230
format_final_mle_agent_response,
3331
format_final_mlops_agent_response,
3432
)
35-
from plexe.internal.models.tools.validation import validate_inference_code, validate_training_code
33+
from plexe.internal.models.tools.context import get_inference_context_tool
34+
from plexe.internal.models.tools.validation import validate_training_code, validate_inference_code
3635

3736
logger = logging.getLogger(__name__)
3837

@@ -62,6 +61,7 @@ def __init__(
6261
ml_researcher_model_id: str = "openai/gpt-4o",
6362
ml_engineer_model_id: str = "anthropic/claude-3-7-sonnet-20250219",
6463
ml_ops_engineer_model_id: str = "anthropic/claude-3-7-sonnet-20250219",
64+
tool_model_id: str = "openai/gpt-4o",
6565
verbose: bool = False,
6666
max_steps: int = 30,
6767
distributed: bool = False,
@@ -75,6 +75,7 @@ def __init__(
7575
ml_researcher_model_id: Model ID for the ML researcher agent
7676
ml_engineer_model_id: Model ID for the ML engineer agent
7777
ml_ops_engineer_model_id: Model ID for the ML ops engineer agent
78+
tool_model_id: Model ID for the model used inside tool calls
7879
verbose: Whether to display detailed agent logs
7980
max_steps: Maximum number of steps for the orchestrator agent
8081
distributed: Whether to run the agents in a distributed environment
@@ -84,6 +85,7 @@ def __init__(
8485
self.ml_researcher_model_id = ml_researcher_model_id
8586
self.ml_engineer_model_id = ml_engineer_model_id
8687
self.ml_ops_engineer_model_id = ml_ops_engineer_model_id
88+
self.tool_model_id = tool_model_id
8789
self.verbose = verbose
8890
self.max_steps = max_steps
8991
self.distributed = distributed
@@ -103,7 +105,6 @@ def __init__(
103105
"- input schema for the model"
104106
"- output schema for the model"
105107
"- the name and comparison method of the metric to optimise"
106-
"- the identifier of the LLM that should be used for plan generation"
107108
),
108109
model=LiteLLMModel(model_id=self.ml_researcher_model_id),
109110
tools=[],
@@ -126,13 +127,12 @@ def __init__(
126127
"- the full solution plan that outlines how to solve this problem"
127128
"- the split train/validation dataset names"
128129
"- the working directory to use for model execution"
129-
"- the identifier of the LLM that should be used for code generation"
130130
),
131131
model=LiteLLMModel(model_id=self.ml_engineer_model_id),
132132
tools=[
133-
generate_training_code,
133+
get_generate_training_code(self.tool_model_id),
134134
validate_training_code,
135-
fix_training_code,
135+
get_fix_training_code(self.tool_model_id),
136136
get_executor_tool(distributed),
137137
format_final_mle_agent_response,
138138
],
@@ -143,27 +143,25 @@ def __init__(
143143
)
144144

145145
# Create predictor builder agent - creates inference code
146-
self.mlops_engineer = ToolCallingAgent(
146+
self.mlops_engineer = CodeAgent(
147147
name="MLOperationsEngineer",
148148
description=(
149-
"Expert ML operations engineer that writes inference code for ML models to be used in production. "
150-
"To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
149+
"Expert ML operations engineer that analyzes training code and creates high-quality production-ready "
150+
"inference code for ML models. To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
151151
"- input schema for the model"
152152
"- output schema for the model"
153153
"- the 'training code id' of the training code produced by the MLEngineer agent"
154-
"- the identifier of the LLM that should be used for code generation"
155154
),
156155
model=LiteLLMModel(model_id=self.ml_ops_engineer_model_id),
157156
tools=[
158-
split_datasets,
159-
generate_inference_code,
157+
get_inference_context_tool(self.tool_model_id),
160158
validate_inference_code,
161-
fix_inference_code,
162159
format_final_mlops_agent_response,
163160
],
164161
add_base_tools=False,
165162
verbosity_level=self.specialist_verbosity,
166-
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mlops_prompt_templates.yaml"),
163+
additional_authorized_imports=config.code_generation.authorized_agent_imports,
164+
prompt_templates=get_prompt_templates("code_agent.yaml", "mlops_prompt_templates.yaml"),
167165
planning_interval=8,
168166
step_callbacks=[self.chain_of_thought_callable],
169167
)
@@ -173,8 +171,8 @@ def __init__(
173171
name="Orchestrator",
174172
model=LiteLLMModel(model_id=self.orchestrator_model_id),
175173
tools=[
176-
select_target_metric,
177-
review_finalised_model,
174+
get_select_target_metric(self.tool_model_id),
175+
get_review_finalised_model(self.tool_model_id),
178176
split_datasets,
179177
create_input_sample,
180178
format_final_orchestrator_agent_response,

0 commit comments

Comments
 (0)