Skip to content

Commit 4cb81f5

Browse files
fix: registration of best performing code (#127)
* fix: manager sometimes passes wrong code id to mlops agent * feat: add tool for retrieving model performances * chore: bump to 0.23.4
1 parent fad9f27 commit 4cb81f5

File tree

10 files changed

+74
-18
lines changed

10 files changed

+74
-18
lines changed

plexe/agents/agents.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
from plexe.internal.models.entities.metric import MetricComparator, ComparisonMethod
2828
from plexe.core.interfaces.predictor import Predictor
2929
from plexe.tools.datasets import create_input_sample, get_latest_datasets
30-
from plexe.tools.evaluation import get_review_finalised_model
30+
from plexe.tools.evaluation import get_review_finalised_model, get_model_performances
3131
from plexe.tools.metrics import get_select_target_metric
3232
from plexe.tools.response_formatting import (
3333
format_final_orchestrator_agent_response,
3434
)
35+
from plexe.tools.training import register_best_training_code
3536

3637
logger = logging.getLogger(__name__)
3738

@@ -167,6 +168,8 @@ def __init__(
167168
get_review_finalised_model(self.tool_model_id),
168169
create_input_sample,
169170
get_latest_datasets,
171+
get_model_performances,
172+
register_best_training_code,
170173
format_final_orchestrator_agent_response,
171174
],
172175
managed_agents=[

plexe/agents/model_packager.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,8 @@ def __init__(
5353
name="MLOperationsEngineer",
5454
description=(
5555
"Expert ML operations engineer that analyzes training code and creates high-quality production-ready "
56-
"inference code for ML models. To work effectively, as part of the 'task' prompt the agent STRICTLY requires:"
57-
"- input schema for the model"
58-
"- output schema for the model"
59-
"- the 'training code id' of the training code produced by the MLEngineer agent"
56+
"inference code for ML models. This agent STRICTLY requires the training code of the best model to have "
57+
"been registered in the object registry."
6058
),
6159
model=LiteLLMModel(model_id=model_id),
6260
tools=[

plexe/internal/models/entities/code.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ class Code:
1010
"""Represents a code object."""
1111

1212
code: str = field()
13+
performance: float = field(default=None)

plexe/templates/prompts/agent/agent_manager_prompt.jinja

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ Ensure the output maximizes model performance while adhering to all constraints.
5959
exactly 0 is almost certainly a bugged model or a sign of overfitting, so this should be ignored.
6060
- 'MLEngineer' should only be asked to work on implementing ONE plan at a time.
6161
- 'MLOperationsEngineer' only needs to work on the final, best performing model.
62+
- Use the 'get_model_performances' tool to remind yourself of the performance of the models
63+
produced by 'MLEngineer' before deciding which one is the best.
64+
- Use the 'register_best_training_code' tool to make the training code of the best performing model available for
65+
subsequent instructions.
6266
- 'MLEngineer' and 'MLOperationsEngineer' return IDs that identify the code they produce. Use these IDs to refer to the
6367
code they produce in any subsequent instructions.
6468
- 'ModelTester' should only be called once the model has been completed and is ready for testing.

plexe/templates/prompts/agent/mlops_prompt_templates.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ managed_agent:
1515
1616
## Process
1717
1. First, gather all necessary context:
18-
- Use `get_inference_context` tool with the training_code_id
18+
- Use `get_inference_context` tool to get the training code, schemas, and other relevant information.
1919
- Use `get_feature_transformer_code` to check for feature transformations, if required
2020
2121
2. Analyze the context to understand:
@@ -43,10 +43,6 @@ managed_agent:
4343
4444
6. Once validation succeeds, use the `format_final_mlops_agent_response` tool with the inference_code_id.
4545
46-
## Information Required
47-
To complete this task, you need:
48-
- The 'training_code_id' from the MLEngineer agent (must be provided in your task)
49-
5046
## Available Tools
5147
- get_inference_context: Retrieve training code, schemas, interface definitions, and other context
5248
- validate_inference_code: Validate your generated inference code

plexe/tools/context.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,22 @@ def get_inference_context_tool(llm_to_use: str) -> Callable:
2121
"""Returns a tool function to get inference context with the model ID pre-filled."""
2222

2323
@tool
24-
def get_inference_context(training_code_id: str) -> Dict[str, Any]:
24+
def get_inference_context() -> Dict[str, Any]:
2525
"""
2626
Provides comprehensive context needed for generating inference code. Use this tool to retrieve
2727
a summary of the training code, schemas, expected inputs for the purpose of planning the inference
2828
code.
2929
30-
Args:
31-
training_code_id: The ID of the code that was used to train the model
32-
3330
Returns:
3431
A dictionary containing all context needed for inference code generation
3532
"""
3633
object_registry = ObjectRegistry()
3734

3835
# Retrieve training code
3936
try:
40-
training_code = object_registry.get(Code, training_code_id).code
37+
training_code = object_registry.get(Code, "best_performing_training_code").code
4138
except Exception as e:
42-
raise ValueError(f"Training code with ID {training_code_id} not found: {str(e)}")
39+
raise ValueError(f"Training code with ID 'best_performing_training_code' not found: {str(e)}")
4340

4441
# Retrieve schemas
4542
try:

plexe/tools/evaluation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,30 @@ def review_finalised_model(
6161
)
6262

6363
return review_finalised_model
64+
65+
66+
@tool
67+
def get_model_performances() -> Dict[str, float]:
68+
"""
69+
Returns the performance of all successfully trained models so far. The performances are returned as a dictionary
70+
mapping the 'model training ID' to the performance score. Use this function to remind yourself of the performance
71+
of all models, so that you can do things such as select the best performing model for deployment.
72+
73+
Returns:
74+
A dictionary mapping model IDs to their performance scores with structure:
75+
{
76+
"model_training_id_1": performance_score_1,
77+
"model_training_id_2": performance_score_2,
78+
}
79+
"""
80+
from plexe.core.object_registry import ObjectRegistry
81+
82+
object_registry = ObjectRegistry()
83+
performances = {}
84+
85+
for code_id in object_registry.list_by_type(Code):
86+
code = object_registry.get(Code, code_id)
87+
if code.performance is not None:
88+
performances[code_id] = code.performance
89+
90+
return performances

plexe/tools/execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def execute_training_code(
154154
artifact_paths = node.model_artifacts if node.model_artifacts else []
155155
artifacts = [Artifact.from_path(p) for p in artifact_paths]
156156
object_registry.register_multiple(Artifact, {a.name: a for a in artifacts})
157-
object_registry.register(Code, execution_id, Code(node.training_code))
157+
object_registry.register(Code, execution_id, Code(node.training_code, node.performance.value))
158158

159159
# Return results
160160
return {

plexe/tools/training.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,42 @@
88

99
from smolagents import tool
1010

11+
from plexe.core.object_registry import ObjectRegistry
1112
from plexe.internal.common.provider import Provider
13+
from plexe.internal.models.entities.code import Code
1214
from plexe.internal.models.generation.training import TrainingCodeGenerator
1315

1416
logger = logging.getLogger(__name__)
1517

1618

19+
@tool
20+
def register_best_training_code(best_training_code_id: str) -> str:
21+
"""
22+
Register the identifier returned by the MLEngineer for the solution with the best performance in the object
23+
registry. This step is required in order for the training code to be available for future use.
24+
25+
Args:
26+
best_training_code_id: 'training_code_id' of the best performing model
27+
28+
Returns:
29+
Success message confirming registration
30+
"""
31+
object_registry = ObjectRegistry()
32+
33+
try:
34+
# Register the testing code with a fixed ID
35+
code_id = "best_performing_training_code"
36+
code = object_registry.get(Code, best_training_code_id).code
37+
object_registry.register(Code, code_id, Code(code), overwrite=True, immutable=True)
38+
39+
logger.debug(f"✅ Registered model training code with ID '{code_id}'")
40+
return f"Successfully registered model training code with ID '{code_id}' for the best performing model."
41+
42+
except Exception as e:
43+
logger.warning(f"⚠️ Error registering training code: {str(e)}")
44+
raise RuntimeError(f"Failed to register training code: {str(e)}")
45+
46+
1747
def get_training_code_generation_tool(llm_to_use: str) -> Callable:
1848
"""Returns a tool function to generate training code with the model ID pre-filled."""
1949

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "plexe"
3-
version = "0.23.3"
3+
version = "0.23.4"
44
description = "An agentic framework for building ML models from natural language"
55
authors = [
66
"marcellodebernardi <[email protected]>",

0 commit comments

Comments
 (0)