Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5ea3b73
fix: trainer_source.py not cleaned up
marcellodebernardi May 14, 2025
4b2a668
fix: handle dataset splitting for chronological data
marcellodebernardi May 14, 2025
819aa39
fix: switch to codeagent for schema resolver
marcellodebernardi May 15, 2025
0e8979a
feat: add data analyser agent
marcellodebernardi May 15, 2025
1113e69
feat: add data analyser agent
marcellodebernardi May 15, 2025
0d8df15
fix: put eda report as dict in metadata
marcellodebernardi May 15, 2025
df517a3
feat: update multi-agent-system.md
marcellodebernardi May 15, 2025
05b8684
chore: bump to 0.20.0
marcellodebernardi May 15, 2025
684cda3
fix: misc improvements to dataset analyser
marcellodebernardi May 15, 2025
f2023ba
fix: eda agent using wrong prompt template
marcellodebernardi May 17, 2025
b20dbbe
chore: remove unused prompt template
marcellodebernardi May 17, 2025
cc22569
chore: remove unused plan generation template
marcellodebernardi May 17, 2025
b6eef55
fix: emitter agent colors defined incorrectly
marcellodebernardi May 17, 2025
b94c721
feat: make chain of thought summaries follow t/a/o structure
marcellodebernardi May 17, 2025
35ba360
feat: remove combined data generator in favour of simple
marcellodebernardi May 18, 2025
fa2fa90
Merge branch 'refs/heads/main' into fix/dataset-generator-cleanup
marcellodebernardi May 18, 2025
22a1a7d
fix: strip split suffix from eda report name
marcellodebernardi May 18, 2025
ad068c5
fix: give dataset analyser all required imports
marcellodebernardi May 18, 2025
42b1dda
feat: enable mlflow tracing
marcellodebernardi May 18, 2025
6c5ff99
chore: update vulnerable dependencies
marcellodebernardi May 18, 2025
d385c2a
fix: allow scipy.* import for dataset analyser
marcellodebernardi May 18, 2025
0f936b5
fix: split_datasets to return dataset sizes
marcellodebernardi May 18, 2025
eb80f54
chore: remove smote oversampling
marcellodebernardi May 18, 2025
4889b85
chore: clean up dataset generator config
marcellodebernardi May 18, 2025
cf8ecca
refactor: clean up datasets.py
marcellodebernardi May 18, 2025
5b786ff
refactor: clean up data generation async logic
marcellodebernardi May 18, 2025
c011aa1
refactor: clean up data generation async logic
marcellodebernardi May 18, 2025
3f0af16
feat: add dataset generation example
marcellodebernardi May 18, 2025
8559bce
chore: fix up base data generator interface
marcellodebernardi May 18, 2025
b07786e
fix: column addition not working plus noisy logging
marcellodebernardi May 18, 2025
831989c
feat: add dataset augmentation example
marcellodebernardi May 18, 2025
2c75417
chore: bump to 0.21.0
marcellodebernardi May 18, 2025
908e20a
feat: add eda report to model bundle
marcellodebernardi May 18, 2025
f42dcad
fix: re-enable i/o schema logging
marcellodebernardi May 18, 2025
a020be7
fix: add pandas.* to dataset analyser imports
marcellodebernardi May 18, 2025
09b376a
fix: give better instructions to schema resolver
marcellodebernardi May 18, 2025
8bfb1a1
fix: remove silly naming from system prompts
marcellodebernardi May 18, 2025
045fefc
refactor: make schema resolver prompt more concise
marcellodebernardi May 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/dataset_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Example demonstrating dataset augmentation with Plexe:
1. Adding a new column to an existing dataset
2. Adding more rows to an existing dataset
"""

from pydantic import BaseModel, Field

from plexe import DatasetGenerator


class PurchaseSchema(BaseModel):
"""Base schema for purchase data."""

product_name: str = Field(description="Name of the purchased product")
category: str = Field(description="Product category")
price: float = Field(description="Purchase price in USD")
customer_id: str = Field(description="Unique customer identifier")


class AugmentedSchema(PurchaseSchema):
"""Augmented schema with product recommendation field."""

recommendation: str = Field(description="Recommended related product")


def main():
# Step 1: Create base dataset (10 purchase records)
base_dataset = DatasetGenerator(
description="E-commerce purchase data with product and customer information",
provider="openai/gpt-4o",
schema=PurchaseSchema,
)
base_dataset.generate(10)
df_base = base_dataset.data

print("Original dataset (10 records):")
print(df_base.head(3))
print(f"Shape: {df_base.shape}")

# Check if we have data before proceeding
if len(df_base) == 0:
print("Failed to generate base dataset. Exiting.")
return

# Step 2: Add a new column by extending the schema
augmented_dataset = DatasetGenerator(
description="E-commerce purchase data with product recommendations",
provider="openai/gpt-4o",
schema=AugmentedSchema,
data=df_base,
)
augmented_dataset.generate(0) # 0 means just transform existing data
df_column_added = augmented_dataset.data

print("\nDataset with new 'recommendation' column:")
print(df_column_added.head(3))
print(f"Shape: {df_column_added.shape}")

# Step 3: Add more rows to the augmented dataset
augmented_dataset.generate(5) # Add 5 more records
df_rows_added = augmented_dataset.data

print("\nFinal dataset with 5 additional records:")
print(f"Shape: {df_rows_added.shape}")
print(df_rows_added.tail(3))


if __name__ == "__main__":
main()
64 changes: 64 additions & 0 deletions examples/dataset_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Example script demonstrating synthetic data generation with Plexe.

This script creates a synthetic restaurant review dataset that could be used for
sentiment analysis or restaurant recommendation systems.
"""

from typing import Literal

from pydantic import BaseModel, Field

from plexe import DatasetGenerator


class RestaurantReviewSchema(BaseModel):
"""Schema definition for restaurant reviews dataset."""

restaurant_name: str = Field(description="Name of the restaurant")
cuisine_type: str = Field(description="Type of cuisine (Italian, Chinese, Mexican, etc.)")
price_range: Literal["$", "$$", "$$$", "$$$$"] = Field(
description="Price category from $ (cheap) to $$$$ (very expensive)"
)
location: str = Field(description="City or neighborhood where the restaurant is located")
rating: float = Field(description="Overall customer rating from 1.0 to 5.0")
service_rating: int = Field(description="Rating for service quality from 1 to 5")
food_rating: int = Field(description="Rating for food quality from 1 to 5")


def main():
# Create dataset generator
print("Creating synthetic restaurant reviews dataset...")
dataset = DatasetGenerator(
description=(
"Restaurant reviews dataset for sentiment analysis and recommendation systems. "
"Each record represents a customer review of a restaurant, including a rating."
),
provider="openai/gpt-4o", # Use your preferred provider
schema=RestaurantReviewSchema,
)

# Generate 20 synthetic records
print("Generating 20 synthetic reviews...")
dataset.generate(20)

# Convert to pandas DataFrame for analysis
df = dataset.data

# Display statistics and sample data
print(f"\nGenerated {len(df)} restaurant reviews")
print("\nSample reviews:")
print(df.head(5))

# Only try to display samples if we have data
if len(df) > 0:
for i, row in df.sample(min(5, len(df))).iterrows():
print(f"\n{'-' * 70}")
print(f"{row['restaurant_name']} - {row['cuisine_type']} ({row['price_range']}) - {row['location']}")
print(
f"Overall: {row['rating']:.1f}/5.0 | Service: {row['service_rating']}/5 | Food: {row['food_rating']}/5"
)


if __name__ == "__main__":
main()
11 changes: 10 additions & 1 deletion plexe/agents/dataset_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,16 @@ def __init__(
# planning_interval=3,
max_steps=30,
step_callbacks=[chain_of_thought_callable],
additional_authorized_imports=["pandas", "numpy", "plexe"],
additional_authorized_imports=[
"pandas",
"pandas.*",
"numpy",
"numpy.*",
"plexe",
"plexe.*",
"scipy",
"scipy.*",
],
prompt_templates=get_prompt_templates("code_agent.yaml", "eda_prompt_templates.yaml"),
)

Expand Down
131 changes: 92 additions & 39 deletions plexe/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
Users can either pass raw datasets directly to models or leverage this class for dataset management and augmentation.
"""

from typing import Iterator, Type, Dict
from typing import Iterator, Type, Dict, Optional
import logging
import pandas as pd
from pydantic import BaseModel

logger = logging.getLogger(__name__)

from plexe.internal.common.datasets.interface import TabularConvertible
from plexe.internal.common.provider import Provider
from plexe.internal.common.datasets.adapter import DatasetAdapter
Expand All @@ -31,14 +34,15 @@ class DatasetGenerator:
- Wrap real datasets (pandas etc.).
- Generate synthetic data from scratch.
- Augment existing datasets with synthetic samples.
- Add new columns to existing datasets using an extended schema.

Example:
>>> synthetic_dataset = DatasetGenerator(
>>> description="Synthetic reviews",
>>> provider="openai/gpt-4",
>>> provider="openai/gpt-4o",
>>> schema=MovieReviewSchema,
>>> num_samples=100
>>> )
>>> synthetic_dataset.generate(100) # Generate 100 samples
>>> model.build(datasets={"train": synthetic_dataset})
"""

Expand All @@ -50,6 +54,8 @@ def __init__(
data: pd.DataFrame = None,
) -> None:
"""
Initialize a new DatasetGenerator.

:param description: A human-readable description of the dataset
:param provider: LLM provider used for synthetic data generation
:param schema: The schema the data should match, if any
Expand All @@ -60,79 +66,126 @@ def __init__(
self.provider = Provider(provider)

# Internal attributes for data management
self._data: pd.DataFrame = data
self._data: Optional[pd.DataFrame] = None
self._index = 0
self.schema = None

# TODO: simplify this logic and use DatasetAdapter to support more dataset types
if schema is not None and data is not None:
# Process schema and data inputs
if schema is not None:
# Convert schema to Pydantic BaseModel if it's a dictionary
self.schema = map_to_basemodel("data", schema)
self._validate_schema(data)
data_wrapper = DatasetAdapter.coerce(data)
if isinstance(data_wrapper, TabularConvertible):
self._data = data_wrapper.to_pandas()
else:
raise ValueError("Dataset must be convertible to pandas DataFrame.")
elif data is not None:

if data is not None:
# Convert and validate input data
data_wrapper = DatasetAdapter.coerce(data)
if isinstance(data_wrapper, TabularConvertible):
self._data = data_wrapper.to_pandas()
else:
raise ValueError("Dataset must be convertible to pandas DataFrame.")

schemas = SchemaResolver(self.provider, self.description).resolve({"data": self._data})
self.schema = merge_models("data", list(schemas))
elif schema is not None:
self.schema = map_to_basemodel("data", schema)
# If schema is provided, validate data against schema
# but only validate existing columns, not new ones being added
if schema is not None:
self._validate_schema(self._data, allow_new_columns=True)
# If no schema provided, infer it from data
else:
schemas = SchemaResolver(self.provider, self.description).resolve({"data": self._data})
self.schema = merge_models("data", list(schemas))

# Initialize data generator
self.data_generator = DataGenerator(self.provider, self.description, self.schema)

def generate(self, num_samples: int):
"""Generates synthetic data if a provider is available."""
self._data = pd.concat([self._data, self.data_generator.generate(num_samples, self._data)], ignore_index=True)
"""
Generate synthetic data samples or augment existing data.

If num_samples is 0 and existing data is provided with a new schema,
this will transform the existing data to match the new schema (adding columns).

:param num_samples: Number of new samples to generate
"""
generated_data = self.data_generator.generate(num_samples, self._data)

if self._data is None:
self._data = generated_data
elif num_samples == 0:
# When num_samples is 0, we're just adding columns to existing data
# SimpleLLMDataGenerator.generate already handles this correctly by returning
# the existing data with new columns added, so we just replace _data directly
self._data = generated_data
else:
# When adding new rows, concatenate them with existing data
self._data = pd.concat([self._data, generated_data], ignore_index=True)

def _validate_schema(self, data: pd.DataFrame):
"""Ensures data matches the schema."""
def _validate_schema(self, data: pd.DataFrame, allow_new_columns: bool = False):
"""
Ensure data matches the schema by checking column presence.

:param data: DataFrame to validate against the schema
:param allow_new_columns: If True, allow schema to have columns that don't exist in data yet
:raises ValueError: If required columns from schema are missing and not allowed
"""
for key in self.schema.model_fields.keys():
if key not in data.columns:
raise ValueError(f"Dataset does not match schema, missing column in dataset: {key}")
if not allow_new_columns:
raise ValueError(f"Dataset does not match schema, missing column in dataset: {key}")
else:
# When augmenting with new columns, we'll skip validation for those columns
logger.debug(f"Allowing new column that will be added through augmentation: {key}")

@property
def data(self) -> pd.DataFrame:
"""Returns the dataset."""
"""
Get the dataset as a pandas DataFrame.

:return: The dataset as a DataFrame
:raises ValueError: If no data has been set or generated
"""
if self._data is None:
raise ValueError("No data has been set or generated.")
raise ValueError("No data has been set or generated. Call generate() first.")
return self._data

def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
if isinstance(self._data, pd.DataFrame):
"""
Get the number of samples in the dataset.

:return: Number of rows in the dataset, or 0 if no data
"""
if self._data is not None:
return len(self._data)
return 0

def __iter__(self) -> Iterator:
"""Returns an iterator over the dataset."""
"""
Get an iterator over the dataset rows.

:return: Self as iterator
"""
self._index = 0
return self

def __next__(self):
"""Returns the next item in the dataset."""
"""
Get the next item when iterating over the dataset.

:return: Dictionary representing the next row
:raises StopIteration: When all rows have been processed
"""
if self._data is None or self._index >= len(self):
raise StopIteration

if isinstance(self._data, pd.DataFrame):
row = self._data.iloc[self._index].to_dict()
else:
raise TypeError("Unsupported data type in dataset.")

row = self._data.iloc[self._index].to_dict()
self._index += 1
return row

def __getitem__(self, index: int):
"""Returns the dataset item at a given index."""
"""
Get a dataset item by index.

:param index: Row index to retrieve
:return: Dictionary representing the row at the given index
:raises IndexError: If dataset is empty
"""
if self._data is None:
raise IndexError("Dataset is empty.")

if isinstance(self._data, pd.DataFrame):
return self._data.iloc[index].to_dict()
else:
raise TypeError("Unsupported data type in dataset.")
return self._data.iloc[index].to_dict()
20 changes: 20 additions & 0 deletions plexe/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def save_model(model: Model, path: str | Path) -> str:
info.size = len(content)
tar.addfile(info, io.BytesIO(content))

# Save EDA markdown reports if available
if "eda_markdown_reports" in model.metadata and model.metadata["eda_markdown_reports"]:
for dataset_name, report_markdown in model.metadata["eda_markdown_reports"].items():
info = tarfile.TarInfo(f"metadata/eda_report_{dataset_name}.md")
content = report_markdown.encode("utf-8")
info.size = len(content)
tar.addfile(info, io.BytesIO(content))

except Exception as e:
logger.error(f"Error saving model: {e}")
if Path(path).exists():
Expand Down Expand Up @@ -169,6 +177,14 @@ def load_model(path: str | Path) -> Model:
if "metadata/constraints.pkl" in [m.name for m in tar.getmembers()]:
constraints = pickle.loads(tar.extractfile("metadata/constraints.pkl").read())

# Load EDA markdown reports if available
eda_markdown_reports = {}
for member in tar.getmembers():
if member.name.startswith("metadata/eda_report_") and member.name.endswith(".md"):
dataset_name = member.name.replace("metadata/eda_report_", "").replace(".md", "")
report_content = tar.extractfile(member).read().decode("utf-8")
eda_markdown_reports[dataset_name] = report_content

# Get handles for all model artifacts
artifact_handles = []
for member in tar.getmembers():
Expand Down Expand Up @@ -206,6 +222,10 @@ def type_from_name(type_name: str) -> type:
model.metadata = metadata
model.identifier = identifier
model.trainer_source = trainer_source

# Add to the metadata if reports were found
if eda_markdown_reports:
model.metadata["eda_markdown_reports"] = eda_markdown_reports
model.predictor_source = predictor_source

if predictor_source:
Expand Down
Loading
Loading