Skip to content

Commit fad9f27

Browse files
feature: nicer chat ui (#126)
* chore: bump to 0.23.2 * feat: add some observability methods to registry * fix: add fileio tests to ensure backwards comp * feat: add 0.23.2 legacy model loading test * refactor: tools for object registry entities * fix: convert tags that can't be written with safe_dump * feat: add dataset listing tool to manager * fix: simplify schema resolver tool use * fix: schema resolution before proper dataset processing * feat: add ui for local testing of chat * fix: encourage schema resolver to be more 'minimal' * fix: rename get_raw_dataset_schema for clarity * fix: memory flushed between messages * chore: bump to 0.23.3 * fix: transformed dataset not marked immutable * fix: schema resolver adds unnecessary fields * chore: poetry lock
1 parent cd7a0b9 commit fad9f27

File tree

13 files changed

+1115
-583
lines changed

13 files changed

+1115
-583
lines changed

plexe/agents/agents.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
This module defines a multi-agent ML engineering system for building machine learning models.
33
"""
44

5+
import json
56
import logging
67
import types
78
from dataclasses import dataclass, field
89
from typing import List, Dict, Optional, Callable
910

10-
from smolagents import CodeAgent, LiteLLMModel
11+
from smolagents import CodeAgent, LiteLLMModel, AgentText
1112

1213
from plexe.agents.dataset_analyser import EdaAgent
1314
from plexe.agents.dataset_splitter import DatasetSplitterAgent
@@ -206,6 +207,9 @@ def run(self, task, additional_args: dict) -> ModelGenerationResult:
206207
if logger.isEnabledFor(logging.DEBUG):
207208
logger.debug("Agent result: %s", result)
208209

210+
if isinstance(result, AgentText):
211+
result = json.loads(str(result))
212+
209213
# Extract data from the agent result
210214
training_code_id = result.get("training_code_id", "")
211215
inference_code_id = result.get("inference_code_id", "")

plexe/agents/dataset_analyser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from plexe.config import config, prompt_templates
1414
from plexe.internal.common.utils.agents import get_prompt_templates
1515
from plexe.tools.datasets import register_eda_report, drop_null_columns, get_latest_datasets
16-
from plexe.tools.schemas import get_raw_dataset_schema
16+
from plexe.tools.schemas import get_dataset_schema
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -58,7 +58,7 @@ def __init__(
5858
"- the name of the dataset to be analysed"
5959
),
6060
model=LiteLLMModel(model_id=self.model_id),
61-
tools=[drop_null_columns, register_eda_report, get_raw_dataset_schema, get_latest_datasets],
61+
tools=[drop_null_columns, register_eda_report, get_dataset_schema, get_latest_datasets],
6262
add_base_tools=False,
6363
verbosity_level=self.verbosity,
6464
# planning_interval=3,

plexe/agents/model_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from plexe.internal.common.utils.agents import get_prompt_templates
1313
from plexe.tools.execution import get_executor_tool
1414
from plexe.tools.response_formatting import format_final_mle_agent_response
15-
from plexe.tools.schemas import get_raw_dataset_schema, get_model_schemas
15+
from plexe.tools.schemas import get_dataset_schema, get_model_schemas
1616
from plexe.tools.training import get_training_code_generation_tool, get_training_code_fixing_tool
1717
from plexe.tools.validation import validate_training_code
1818
from plexe.tools.datasets import get_training_datasets
@@ -58,7 +58,7 @@ def __init__(
5858
tools=[
5959
get_training_code_generation_tool(tool_model_id),
6060
validate_training_code,
61-
get_raw_dataset_schema,
61+
get_dataset_schema,
6262
get_training_code_fixing_tool(tool_model_id),
6363
get_executor_tool(distributed),
6464
format_final_mle_agent_response,

plexe/main.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,46 @@
22
Application entry point for using the plexe package as a conversational agent.
33
"""
44

5-
from smolagents import GradioUI
6-
from plexe.agents.conversational import ConversationalAgent
5+
import threading
6+
import time
7+
import webbrowser
8+
import logging
9+
10+
import uvicorn
11+
12+
logging.basicConfig(level=logging.INFO)
13+
logger = logging.getLogger(__name__)
714

815

916
def main():
10-
ui = GradioUI(ConversationalAgent().agent)
11-
ui.launch()
17+
"""Launch the Plexe assistant with a web UI."""
18+
host = "127.0.0.1"
19+
port = 8000
20+
21+
# Configure uvicorn to run in a thread
22+
config = uvicorn.Config("plexe.server:app", host=host, port=port, log_level="info", reload=False)
23+
server = uvicorn.Server(config)
24+
25+
# Start server in a background thread
26+
thread = threading.Thread(target=server.run, daemon=True)
27+
thread.start()
28+
29+
# Give the server a moment to start
30+
time.sleep(4)
31+
32+
# Open the browser
33+
url = f"http://{host}:{port}"
34+
logger.info(f"Opening browser at {url}")
35+
webbrowser.open(url)
36+
37+
# Keep the main thread alive
38+
try:
39+
logger.info("Plexe Assistant is running. Press Ctrl+C to stop.")
40+
while True:
41+
time.sleep(1)
42+
except KeyboardInterrupt:
43+
logger.info("\nShutting down Plexe Assistant...")
44+
server.should_exit = True
1245

1346

1447
if __name__ == "__main__":

plexe/server.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
FastAPI server for the Plexe conversational agent.
3+
4+
This module provides a lightweight WebSocket API for the conversational agent
5+
and serves the assistant-ui frontend for local execution.
6+
"""
7+
8+
import json
9+
import logging
10+
import uuid
11+
from pathlib import Path
12+
13+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
14+
from fastapi.staticfiles import StaticFiles
15+
from fastapi.responses import FileResponse
16+
17+
from plexe.agents.conversational import ConversationalAgent
18+
19+
logger = logging.getLogger(__name__)
20+
21+
app = FastAPI(title="Plexe Assistant", version="1.0.0")
22+
23+
# Serve static files from the ui directory
24+
ui_dir = Path(__file__).parent / "ui"
25+
if ui_dir.exists():
26+
app.mount("/static", StaticFiles(directory=str(ui_dir)), name="static")
27+
28+
29+
@app.get("/")
30+
async def root():
31+
"""Serve the main HTML page."""
32+
index_path = ui_dir / "index.html"
33+
if index_path.exists():
34+
return FileResponse(str(index_path))
35+
return {"error": "Frontend not found. Please ensure plexe/ui/index.html exists."}
36+
37+
38+
@app.websocket("/ws")
39+
async def websocket_endpoint(websocket: WebSocket):
40+
"""WebSocket endpoint for real-time chat communication."""
41+
await websocket.accept()
42+
session_id = str(uuid.uuid4())
43+
logger.info(f"New WebSocket connection: {session_id}")
44+
45+
# Create a new agent instance for this session
46+
agent = ConversationalAgent()
47+
48+
try:
49+
while True:
50+
# Receive message from client
51+
data = await websocket.receive_text()
52+
53+
try:
54+
message_data = json.loads(data)
55+
user_message = message_data.get("content", "")
56+
57+
# Process the message with the agent
58+
logger.debug(f"Processing message: {user_message[:100]}...")
59+
response = agent.agent.run(user_message, reset=False)
60+
61+
# Send response back to client
62+
await websocket.send_json({"role": "assistant", "content": response, "id": str(uuid.uuid4())})
63+
64+
except json.JSONDecodeError:
65+
# Handle plain text messages for compatibility
66+
response = agent.agent.run(data, reset=False)
67+
await websocket.send_json({"role": "assistant", "content": response, "id": str(uuid.uuid4())})
68+
69+
except Exception as e:
70+
logger.error(f"Error processing message: {e}")
71+
await websocket.send_json(
72+
{
73+
"role": "assistant",
74+
"content": f"I encountered an error: {str(e)}. Please try again.",
75+
"id": str(uuid.uuid4()),
76+
"error": True,
77+
}
78+
)
79+
80+
except WebSocketDisconnect:
81+
logger.info(f"WebSocket disconnected: {session_id}")
82+
except Exception as e:
83+
logger.error(f"WebSocket error for session {session_id}: {e}")
84+
await websocket.close()
85+
86+
87+
@app.get("/health")
88+
async def health_check():
89+
"""Health check endpoint."""
90+
return {"status": "healthy", "service": "plexe-assistant"}

plexe/templates/prompts/agent/eda_prompt_templates.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ managed_agent:
1414
## Available Tools:
1515
- `get_latest_datasets`: Returns all available datasets with their roles (raw, transformed, train, val, test)
1616
- `drop_null_columns`: Clean datasets by removing problematic columns
17-
- `get_raw_dataset_schema`: Get column names and types for a dataset
17+
- `get_dataset_schema`: Get column names and types for a dataset
1818
- `register_eda_report`: Store your analysis findings
1919
2020
To access datasets, USE EXACTLY THIS PATTERN:

plexe/templates/prompts/agent/mle_prompt_templates.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ managed_agent:
1010
## Available Tools:
1111
- `get_training_datasets`: Get training and validation dataset names automatically
1212
- `get_model_schemas`: Get input/output schemas directly
13-
- `get_raw_dataset_schema`: Understand dataset structure
13+
- `get_dataset_schema`: Understand dataset structure
1414
- `get_feature_transformer_code`: Retrieve feature transformation code (if exists) if you need to review it
1515
- `generate_training_code`: Generate ML training code
1616
- `validate_training_code`: Validate generated code

plexe/templates/prompts/agent/schema_resolver_prompt_templates.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ managed_agent:
3131
5. Call register_final_model_schemas with your determined schemas and reasoning
3232
3333
## Key requirements:
34-
1. IMPORTANT: keep schemas aligned with dataset structure unless the task clearly requires otherwise
34+
1. IMPORTANT: keep schemas conceptually aligned with dataset structure
3535
2. Use only Python types: "int", "float", "str", "bool"
3636
3. DO NOT add new input or output fields unless absolutely necessary for the task
3737
4. DO NOT add features that can be straightforwardly derived from existing data
3838
5. Schemas should include only necessary fields for the model's purpose
3939
6. You can REMOVE fields that are unnecessary, irrelevant, redundant, or contain bad data; this is highly encouraged
40-
6. Include reasoning for any deviations from the dataset structure
40+
7. Include reasoning for any deviations from the dataset structure
41+
8. Ensure the schemas are 'minimal' and 'sufficient': for example, if A is a categorical variable and X, Y, Z are
42+
boolean indicators of the possible values of A, you can remove X, Y, Z from the input schema and keep only A.
4143
4244
When calling register_final_model_schemas, use this format:
4345
- input_schema: dictionary mapping field names to types

plexe/tools/execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def apply_feature_transformer(dataset_name: str) -> Dict:
283283
# Register transformed dataset
284284
transformed_name = f"{dataset_name}_transformed"
285285
transformed_ds = DatasetAdapter.coerce(transformed_df)
286-
object_registry.register(TabularConvertible, transformed_name, transformed_ds, overwrite=True)
286+
object_registry.register(TabularConvertible, transformed_name, transformed_ds, overwrite=True, immutable=True)
287287

288288
logger.debug(f"✅ Applied feature transformer to {dataset_name}{transformed_name}")
289289

plexe/tools/schemas.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ def register_final_model_schemas(
7474

7575

7676
@tool
77-
def get_raw_dataset_schema(dataset_name: str) -> Dict[str, Any]:
77+
def get_dataset_schema(dataset_name: str) -> Dict[str, Any]:
7878
"""
79-
Extract the schema (column names and types) from a raw dataset.
79+
Extract the schema (column names and types) from a dataset. This is useful for understanding the structure
80+
of the dataset and how it can be used in model training.
8081
8182
Args:
8283
dataset_name: Name of the dataset in the registry

0 commit comments

Comments
 (0)