Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c5236f0
feat: initial commit for the SIMBAUQSamplingStrategy
Apr 2, 2026
ea51043
chore: added a separate filed to mot.meta for the similarity matrix
Apr 2, 2026
5c23a58
chore: added a second aggregation by classification CE algorithm
Apr 2, 2026
d7f3b6a
refactor: revised and moved the SIMBAUQSamplingStrategy in docs/examples
Apr 3, 2026
908258c
Update test/stdlib/sampling/test_simbauq.py
radum2275 Apr 7, 2026
8b8c336
Update docs/examples/simbauq/simbauq_example.py
radum2275 Apr 7, 2026
865e85f
Update .gitignore
radum2275 Apr 7, 2026
a6b356a
Update docs/examples/simbauq/README.md
radum2275 Apr 7, 2026
cbae30c
Update docs/examples/simbauq/README.md
radum2275 Apr 7, 2026
a3c51a8
Update mellea/stdlib/sampling/simbauq.py
radum2275 Apr 7, 2026
e9b05f1
Update mellea/stdlib/sampling/simbauq.py
radum2275 Apr 7, 2026
372046a
Update mellea/stdlib/sampling/simbauq.py
radum2275 Apr 7, 2026
af55899
refactor: refactored the simbauq sampling strategy
Apr 8, 2026
da1440d
fix: added the ollama backend in simbauq example
Apr 8, 2026
11b180f
chore: set aggregation by mean in simbauq example
Apr 9, 2026
6c6c099
chore: fixed a typo in the simbauq README.md file
Apr 9, 2026
78fe6c7
chore: added scikit-learn as required dependency for simbauq strategy
Apr 9, 2026
65a1268
Update test/stdlib/sampling/test_simbauq.py
radum2275 Apr 10, 2026
41728a5
Update test/stdlib/sampling/test_simbauq.py
radum2275 Apr 10, 2026
f90a466
Update mellea/stdlib/sampling/simbauq.py
radum2275 Apr 10, 2026
1cd588c
Update mellea/stdlib/sampling/simbauq.py
radum2275 Apr 10, 2026
c8bd228
Update mellea/stdlib/sampling/simbauq.py
radum2275 Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions docs/examples/simbauq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SIMBA-UQ Sampling Strategy

Confidence-aware sample selection using the SIMBA-UQ framework
(Bhattacharjya et al., 2025). Generates multiple samples across a range of
temperatures and selects the one with the highest estimated confidence.

**Paper:** [SIMBA UQ: Similarity-Based Aggregation for Uncertainty Quantification in Large Language Models](https://arxiv.org/abs/2510.13836)

## Files

### simbauq_example.py

Complete example demonstrating both confidence estimation methods with
ollama and granite-4.0-micro.

## Architecture

```
User Query
|
v
Generate N samples (across temperatures)
|
v
Compute pairwise similarity matrix (N x N)
|
+---> [Aggregation] Aggregate similarities per sample -> confidence
|
+---> [Classifier] Extract features per sample -> RF predicts P(correct)
|
v
Select sample with highest confidence
|
v
Result (with confidence metadata in mot.meta["simba_uq"])
```

## Confidence Methods

### 1. Aggregation (data-free)

No training data required. For each sample, computes its similarity to every
other sample, then aggregates those values into a confidence score. Samples
that are more similar to the majority get higher confidence.

```python
from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy

strategy = SIMBAUQSamplingStrategy(
temperatures=[0.3, 0.5, 0.7, 1.0],
n_per_temp=3,
similarity_metric="rouge",
confidence_method="aggregation",
aggregation="mean",
)

result = m.instruct("Your query here", strategy=strategy, return_sampling_results=True)
```

### 2. Classifier (trained)

Uses a random forest classifier trained on labeled examples. The classifier
learns to predict P(correct) from pairwise similarity features. Provide
either training data or a pre-trained sklearn classifier.

**With training data:**

```python
strategy = SIMBAUQSamplingStrategy(
temperatures=[0.3, 0.5, 0.7, 1.0],
n_per_temp=3,
similarity_metric="rouge",
confidence_method="classifier",
training_samples=[
["correct answer 1", "correct answer 2", ..., "wrong answer"], # group 1
["correct answer 1", "correct answer 2", ..., "wrong answer"], # group 2
],
training_labels=[
[1, 1, ..., 0], # labels for group 1
[1, 1, ..., 0], # labels for group 2
],
)
```

Each training group must have exactly `len(temperatures) * n_per_temp` samples
so the feature vectors match at inference time.

**With pre-trained classifier:**

```python
strategy = SIMBAUQSamplingStrategy(
temperatures=[0.3, 0.5, 0.7, 1.0],
n_per_temp=3,
confidence_method="classifier",
classifier=my_pretrained_sklearn_clf,
)
```

## Constructor Parameters

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `temperatures` | `list[float]` | `[0.3, 0.5, 0.7, 1.0]` | Temperature values to sample at |
| `n_per_temp` | `int` | `4` | Number of samples per temperature |
| `similarity_metric` | `"rouge"`, `"jaccard"`, `"sbert"` | `"rouge"` | Pairwise similarity metric |
| `confidence_method` | `"aggregation"`, `"classifier"` | `"aggregation"` | Confidence estimation method |
| `aggregation` | `"mean"`, `"geometric_mean"`, `"harmonic_mean"`, `"median"`, `"max"`, `"min"` | `"mean"` | Aggregation function (for `aggregation` method) |
| `classifier` | sklearn classifier | `None` | Pre-trained classifier with `predict_proba` |
| `training_samples` | `list[list[str]]` | `None` | Training data for classifier |
| `training_labels` | `list[list[int]]` | `None` | Binary correctness labels (0/1) |
| `clf_max_depth` | `int` | `4` | Max tree depth for random forest |
| `rouge_type` | `str` | `"rougeL"` | Rouge variant |
| `sbert_model` | `str` | `"all-MiniLM-L6-v2"` | Sentence-BERT model name |
| `requirements` | `list[Requirement]` | `None` | Requirements to validate the selected sample |

## Similarity Metrics

- **rouge** (default): RougeL F-measure. Good general-purpose text similarity.
No extra dependencies beyond `rouge-score` (already in mellea).
- **jaccard**: Word-level set overlap (intersection / union). Fast, no
external dependencies, works well for short structured answers.
- **sbert**: Cosine similarity of Sentence-BERT embeddings. Best semantic
similarity but requires `sentence-transformers` (`pip install
mellea[granite_retriever]`).

## Inspecting Results

The selected sample's `ModelOutputThunk` stores confidence metadata:

```python
result = m.instruct(..., strategy=strategy, return_sampling_results=True)

# Best sample
best_mot = result.result
meta = best_mot._meta["simba_uq"]

meta["confidence"] # float: confidence of the selected sample
meta["all_confidences"] # list[float]: confidence for every sample
meta["similarity_matrix"] # list[list[float]]: N x N pairwise similarity matrix
meta["temperatures_used"] # list[float]: temperature used for each sample
meta["confidence_method"] # "aggregation" or "classifier"
meta["similarity_metric"] # "rouge", "jaccard", or "sbert"
meta["aggregation"] # aggregation function name

# All generated samples
for i, mot in enumerate(result.sample_generations):
print(f"Sample {i}: {mot.value}")
```

## Related Files

- `mellea/stdlib/sampling/simbauq.py` -- Strategy implementation
- `test/stdlib/sampling/test_simbauq.py` -- Unit and integration tests
202 changes: 202 additions & 0 deletions docs/examples/simbauq/simbauq_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# pytest: ollama, llm, qualitative

"""SIMBA-UQ Sampling Strategy Example.

This example demonstrates the SIMBAUQSamplingStrategy using both confidence
estimation methods:

1. **Aggregation** (data-free) - Computes pairwise similarity between all
generated samples and aggregates them into per-sample confidence scores.
The sample with the highest confidence is selected.

2. **Classifier** (trained) - Uses a random forest classifier trained on
labeled examples to predict P(correct) for each sample based on its
pairwise similarity features.

Both methods generate multiple samples across different temperature values,
compute a similarity matrix, and select the most confident response.

The example uses OllamaModelBackend with granite4:micro. To run:

ollama serve
uv run python docs/examples/simbauq/simbauq_example.py
"""

import numpy as np

from mellea import MelleaSession
from mellea.backends import ModelOption
from mellea.backends.ollama import OllamaModelBackend
from mellea.core import SamplingResult
from mellea.stdlib.context import ChatContext
from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy


def make_session() -> MelleaSession:
"""Create a MelleaSession with OllamaModelBackend."""
backend = OllamaModelBackend(model_options={ModelOption.MAX_NEW_TOKENS: 100})
return MelleaSession(backend, ctx=ChatContext())


def print_results(result: SamplingResult) -> None:
"""Print detailed results from a SIMBA-UQ sampling run."""
meta = result.result._meta["simba_uq"]
confidences = meta["all_confidences"]
temperatures = meta["temperatures_used"]
sim_matrix = np.array(meta["similarity_matrix"])

# --- Best response ---
print("=" * 70)
print("BEST RESPONSE")
print("=" * 70)
print(f" Index: {result.result_index}")
print(f" Confidence: {meta['confidence']:.4f}")
print(f" Method: {meta['confidence_method']}")
print(f" Metric: {meta['similarity_metric']}")
print(f" Aggregation: {meta['aggregation']}")
print(f" Text:\n {result.result!s}")
print()

# --- All samples ---
print("=" * 70)
print("ALL SAMPLES")
print("=" * 70)
print(f"{'Idx':>4} {'Temp':>5} {'Conf':>8} {'Text'}")
print("-" * 70)
for i, mot in enumerate(result.sample_generations):
text = str(mot).replace("\n", " ")
truncated = (text[:100] + "...") if len(text) > 100 else text
marker = " <-- best" if i == result.result_index else ""
print(
f"{i:>4} {temperatures[i]:>5.2f} {confidences[i]:>8.4f} "
f"{truncated}{marker}"
)
print()

# --- Similarity matrix ---
n = sim_matrix.shape[0]
print("=" * 70)
print("SIMILARITY MATRIX")
print("=" * 70)
header = " " + "".join(f" [{i:>2}] " for i in range(n))
print(header)
for i in range(n):
row = f"[{i:>2}] " + "".join(f" {sim_matrix[i, j]:.3f} " for j in range(n))
print(row)
print()


def run_aggregation_example() -> None:
"""Run SIMBA-UQ with data-free similarity aggregation."""
print("\n>>> AGGREGATION CONFIDENCE METHOD <<<\n")

m = make_session()

strategy = SIMBAUQSamplingStrategy(
temperatures=[0.3, 0.5, 0.7, 1.0],
n_per_temp=3,
similarity_metric="rouge",
confidence_method="aggregation",
aggregation="mean",
)

result: SamplingResult = m.instruct(
"Which magazine was started first Arthur's Magazine or First for Women?",
strategy=strategy,
return_sampling_results=True,
)

print(f"Total samples generated: {len(result.sample_generations)}")
print_results(result)

del m


def run_classifier_example() -> None:
"""Run SIMBA-UQ with a trained random forest classifier."""
print("\n>>> CLASSIFIER CONFIDENCE METHOD <<<\n")

m = make_session()

# Synthetic training data: 3 groups of 12 samples (4 temps * 3 per temp).
# Each group has mostly "correct" similar answers and a few outliers.
training_samples = [
[
"Paris is the capital of France.",
"The capital of France is Paris.",
"France's capital city is Paris.",
"Paris, the capital of France.",
"The capital city of France is Paris.",
"France has Paris as its capital.",
"Paris serves as France's capital.",
"In France, Paris is the capital.",
"The French capital is Paris.",
"Bananas are a yellow fruit.",
"Dogs are loyal pets.",
"The ocean is very deep.",
],
[
"Water boils at 100 degrees Celsius.",
"At 100C water reaches boiling point.",
"The boiling point of water is 100 degrees.",
"Water boils when heated to 100C.",
"100 degrees Celsius is water's boiling point.",
"Boiling occurs at 100C for water.",
"Water starts boiling at one hundred degrees.",
"At 100 degrees water boils.",
"The temperature for boiling water is 100C.",
"Cats like to sleep a lot.",
"Mountains can be very high.",
"Stars shine in the night sky.",
],
[
"Python is a programming language.",
"Python is a popular programming language.",
"The Python programming language is widely used.",
"Python is used for programming.",
"Programming in Python is common.",
"Python is a well-known language for coding.",
"Many developers use Python.",
"Python is a general-purpose language.",
"The language Python is popular.",
"Pizza originated in Italy.",
"Rain falls from clouds.",
"Books contain many pages.",
],
]
training_labels = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
]

strategy = SIMBAUQSamplingStrategy(
temperatures=[0.3, 0.5, 0.7, 1.0],
n_per_temp=3,
similarity_metric="rouge",
confidence_method="classifier",
training_samples=training_samples,
training_labels=training_labels,
)

result: SamplingResult = m.instruct(
"Which magazine was started first Arthur's Magazine or First for Women?",
strategy=strategy,
return_sampling_results=True,
)

print(f"Total samples generated: {len(result.sample_generations)}")
print_results(result)

del m


def main():
"""Run both SIMBA-UQ confidence estimation examples."""
run_aggregation_example()
print("\n" + "=" * 70 + "\n")
run_classifier_example()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions mellea/stdlib/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
RejectionSamplingStrategy,
RepairTemplateStrategy,
)
from .simbauq import SIMBAUQSamplingStrategy
from .sofai import SOFAISamplingStrategy

__all__ = [
"BaseSamplingStrategy",
"MultiTurnStrategy",
"RejectionSamplingStrategy",
"RepairTemplateStrategy",
"SIMBAUQSamplingStrategy",
"SamplingResult",
"SamplingStrategy",
]
Loading
Loading