Skip to content

Commit cd7a0b9

Browse files
fix: feature engineering workflow improvements (#125)
* chore: bump to 0.23.2 * feat: add some observability methods to registry * fix: add fileio tests to ensure backwards comp * feat: add 0.23.2 legacy model loading test * refactor: tools for object registry entities * fix: convert tags that can't be written with safe_dump * feat: add dataset listing tool to manager * fix: simplify schema resolver tool use * fix: schema resolution before proper dataset processing
1 parent c73d2d1 commit cd7a0b9

28 files changed

+590
-123
lines changed

plexe/agents/agents.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
from plexe.internal.models.entities.metric import Metric
2626
from plexe.internal.models.entities.metric import MetricComparator, ComparisonMethod
2727
from plexe.core.interfaces.predictor import Predictor
28-
from plexe.tools.datasets import (
29-
create_input_sample,
30-
)
28+
from plexe.tools.datasets import create_input_sample, get_latest_datasets
3129
from plexe.tools.evaluation import get_review_finalised_model
3230
from plexe.tools.metrics import get_select_target_metric
3331
from plexe.tools.response_formatting import (
@@ -167,6 +165,7 @@ def __init__(
167165
get_select_target_metric(self.tool_model_id),
168166
get_review_finalised_model(self.tool_model_id),
169167
create_input_sample,
168+
get_latest_datasets,
170169
format_final_orchestrator_agent_response,
171170
],
172171
managed_agents=[
@@ -183,7 +182,9 @@ def __init__(
183182
verbosity_level=self.orchestrator_verbosity,
184183
additional_authorized_imports=config.code_generation.authorized_agent_imports,
185184
max_steps=self.max_steps,
186-
prompt_templates=get_prompt_templates("code_agent.yaml", "manager_prompt_templates.yaml"),
185+
prompt_templates=get_prompt_templates(
186+
base_template_name="code_agent.yaml", override_template_name="manager_prompt_templates.yaml"
187+
),
187188
planning_interval=7,
188189
step_callbacks=[self.chain_of_thought_callable],
189190
)

plexe/agents/dataset_analyser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from plexe.config import config, prompt_templates
1414
from plexe.internal.common.utils.agents import get_prompt_templates
15-
from plexe.tools.datasets import register_eda_report, drop_null_columns
15+
from plexe.tools.datasets import register_eda_report, drop_null_columns, get_latest_datasets
1616
from plexe.tools.schemas import get_raw_dataset_schema
1717

1818
logger = logging.getLogger(__name__)
@@ -55,10 +55,10 @@ def __init__(
5555
"and recommendations for ML modeling. Will analyse existing datasets, not create new ones.\n"
5656
"To work effectively, as part of the 'task' prompt the agent STRICTLY requires:\n"
5757
"- the ML task definition (i.e. 'intent')\n"
58-
"- the name of the dataset to use for training"
58+
"- the name of the dataset to be analysed"
5959
),
6060
model=LiteLLMModel(model_id=self.model_id),
61-
tools=[drop_null_columns, register_eda_report, get_raw_dataset_schema],
61+
tools=[drop_null_columns, register_eda_report, get_raw_dataset_schema, get_latest_datasets],
6262
add_base_tools=False,
6363
verbosity_level=self.verbosity,
6464
# planning_interval=3,

plexe/agents/dataset_splitter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313

1414
from plexe.config import config
1515
from plexe.internal.common.utils.agents import get_prompt_templates
16-
from plexe.tools.datasets import get_dataset_preview, get_eda_report
17-
from plexe.tools.datasets import register_split_datasets
16+
from plexe.tools.datasets import get_dataset_preview, register_split_datasets, get_latest_datasets, get_eda_reports
1817

1918
logger = logging.getLogger(__name__)
2019

@@ -59,8 +58,9 @@ def __init__(
5958
model=LiteLLMModel(model_id=model_id),
6059
tools=[
6160
get_dataset_preview,
62-
get_eda_report,
6361
register_split_datasets,
62+
get_latest_datasets,
63+
get_eda_reports,
6464
],
6565
planning_interval=5,
6666
add_base_tools=False,

plexe/agents/feature_engineer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
from plexe.config import config
1414
from plexe.internal.common.utils.agents import get_prompt_templates
15-
from plexe.tools.datasets import get_dataset_preview, get_eda_report
15+
from plexe.tools.datasets import get_dataset_preview, get_eda_reports, get_latest_datasets
1616
from plexe.tools.execution import apply_feature_transformer
1717
from plexe.tools.validation import validate_feature_transformations
18+
from plexe.tools.schemas import get_model_schemas
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -58,9 +59,11 @@ def __init__(
5859
model=LiteLLMModel(model_id=model_id),
5960
tools=[
6061
get_dataset_preview,
61-
get_eda_report,
6262
validate_feature_transformations,
6363
apply_feature_transformer,
64+
get_latest_datasets,
65+
get_eda_reports,
66+
get_model_schemas,
6467
],
6568
add_base_tools=False,
6669
additional_authorized_imports=config.code_generation.authorized_agent_imports

plexe/agents/model_packager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from plexe.tools.context import get_inference_context_tool
1616
from plexe.tools.response_formatting import format_final_mlops_agent_response
1717
from plexe.tools.validation import validate_inference_code
18+
from plexe.tools.code_analysis import get_feature_transformer_code
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -62,6 +63,7 @@ def __init__(
6263
get_inference_context_tool(tool_model_id),
6364
validate_inference_code,
6465
format_final_mlops_agent_response,
66+
get_feature_transformer_code,
6567
],
6668
add_base_tools=False,
6769
verbosity_level=self.verbosity,

plexe/agents/model_planner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from smolagents import ToolCallingAgent, LiteLLMModel
44

55
from plexe.internal.common.utils.agents import get_prompt_templates
6-
from plexe.tools.datasets import get_dataset_preview, get_eda_report
6+
from plexe.tools.datasets import get_dataset_preview, get_latest_datasets, get_eda_reports
7+
from plexe.tools.schemas import get_model_schemas
78

89
logger = logging.getLogger(__name__)
910

@@ -52,7 +53,12 @@ def __init__(
5253
"- the name of the dataset to use for training"
5354
),
5455
model=LiteLLMModel(model_id=model_id),
55-
tools=[get_dataset_preview, get_eda_report],
56+
tools=[
57+
get_dataset_preview,
58+
get_latest_datasets,
59+
get_eda_reports,
60+
get_model_schemas,
61+
],
5662
add_base_tools=False,
5763
verbosity_level=self.verbosity,
5864
prompt_templates=get_prompt_templates("toolcalling_agent.yaml", "mls_prompt_templates.yaml"),

plexe/agents/model_tester.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from plexe.config import config
1414
from plexe.internal.common.utils.agents import get_prompt_templates
1515
from plexe.tools.testing import register_testing_code, register_evaluation_report
16+
from plexe.tools.datasets import get_test_dataset
17+
from plexe.tools.schemas import get_model_schemas
18+
from plexe.tools.code_analysis import get_feature_transformer_code
1619

1720
logger = logging.getLogger(__name__)
1821

@@ -62,6 +65,9 @@ def __init__(
6265
tools=[
6366
register_testing_code,
6467
register_evaluation_report,
68+
get_test_dataset,
69+
get_model_schemas,
70+
get_feature_transformer_code,
6571
],
6672
add_base_tools=False,
6773
verbosity_level=self.verbosity,

plexe/agents/model_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from plexe.internal.common.utils.agents import get_prompt_templates
1313
from plexe.tools.execution import get_executor_tool
1414
from plexe.tools.response_formatting import format_final_mle_agent_response
15-
from plexe.tools.schemas import get_raw_dataset_schema
15+
from plexe.tools.schemas import get_raw_dataset_schema, get_model_schemas
1616
from plexe.tools.training import get_training_code_generation_tool, get_training_code_fixing_tool
1717
from plexe.tools.validation import validate_training_code
18+
from plexe.tools.datasets import get_training_datasets
19+
from plexe.tools.code_analysis import get_feature_transformer_code
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -60,6 +62,9 @@ def __init__(
6062
get_training_code_fixing_tool(tool_model_id),
6163
get_executor_tool(distributed),
6264
format_final_mle_agent_response,
65+
get_training_datasets,
66+
get_model_schemas,
67+
get_feature_transformer_code,
6368
],
6469
add_base_tools=False,
6570
additional_authorized_imports=[

plexe/agents/schema_resolver.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from smolagents import LiteLLMModel, CodeAgent
1313

1414
from plexe.internal.common.utils.agents import get_prompt_templates
15-
from plexe.tools.datasets import get_dataset_preview, get_eda_report
16-
from plexe.tools.schemas import register_final_model_schemas
15+
from plexe.tools.datasets import get_dataset_preview, get_eda_reports, get_latest_datasets
16+
from plexe.tools.schemas import register_final_model_schemas, get_model_schemas
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -53,12 +53,17 @@ def __init__(
5353
"Expert schema resolver that determines appropriate input and output schemas for ML models. "
5454
"To work effectively, as part of the 'task' prompt the agent STRICTLY requires:\n"
5555
"- the ML task definition (i.e. 'intent')\n"
56-
"- the NAME of the dataset to be used for training\n"
57-
"- input schema already defined for this task, if available\n"
58-
"- output schema already defined for this task, if available\n"
56+
"- the name of the feature-engineered dataset that will be used for training"
57+
"Important: the agent requires the feature-engineered dataset to have been created"
5958
),
6059
model=LiteLLMModel(model_id=self.model_id),
61-
tools=[get_dataset_preview, get_eda_report, register_final_model_schemas],
60+
tools=[
61+
get_dataset_preview,
62+
get_model_schemas,
63+
register_final_model_schemas,
64+
get_latest_datasets,
65+
get_eda_reports,
66+
],
6267
add_base_tools=False,
6368
verbosity_level=self.verbosity,
6469
step_callbacks=[chain_of_thought_callable],

plexe/core/object_registry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,11 @@ class ObjectRegistry:
2828
"""
2929

3030
_instance = None
31-
_items: Dict[str, Item] = dict()
3231

3332
def __new__(cls):
3433
if cls._instance is None:
3534
cls._instance = super(ObjectRegistry, cls).__new__(cls)
36-
cls._items = dict()
35+
cls._instance._items = {}
3736
return cls._instance
3837

3938
@staticmethod
@@ -51,10 +50,16 @@ def register(self, t: Type[T], name: str, item: T, overwrite: bool = False, immu
5150
:param immutable: whether the item should be treated as immutable (not modifiable)
5251
"""
5352
uri = self._get_uri(t, name)
53+
was_overwrite = overwrite and uri in self._items
54+
5455
if not overwrite and uri in self._items:
5556
raise ValueError(f"Item '{uri}' already registered, use a different name")
57+
5658
self._items[uri] = Item(item, immutable=immutable)
57-
logger.info(f"Registered item '{uri}'")
59+
60+
# Enhanced logging with context
61+
action = "overwrote" if was_overwrite else "registered"
62+
logger.debug(f"Registry: {action} {uri} (immutable={immutable}, total: {len(self._items)} items)")
5863

5964
def register_multiple(
6065
self, t: Type[T], items: Dict[str, T], overwrite: bool = False, immutable: bool = False
@@ -131,6 +136,17 @@ def list(self) -> List[str]:
131136
"""
132137
return list(self._items.keys())
133138

139+
def list_by_type(self, t: Type[T]) -> List[str]:
140+
"""
141+
List all registered names for a specific type.
142+
143+
:param t: type prefix for the items
144+
:return: List of item names (without the type prefix) for the given type
145+
"""
146+
prefix = str(t)
147+
return [uri.split("://")[1] for uri in self._items.keys() if uri.startswith(prefix)]
148+
149+
# TODO: unclear if this is needed, consider deleting
134150
def get_all_solutions(self) -> List[Dict[str, Any]]:
135151
"""
136152
Get all solutions tracked during model building.

0 commit comments

Comments
 (0)