Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [0.4] - 2026-01-05

### Fixed

- `build_reliability_diagram`: Error when scores are distributed uniformly.

## [0.3] - 2025-12-26

### Fixed
Expand All @@ -17,7 +23,8 @@
- `calibration_error` function
- `build_reliability_diagram` function

[Unreleased]: https://github.com/yourusername/yourproject/compare/0.3...HEAD
[0.3]: https://github.com/yourusername/yourproject/compare/0.2...0.3
[0.2]: https://github.com/yourusername/yourproject/compare/0.1...0.2
[0.1]: https://github.com/yourusername/yourproject/releases/tag/0.1
[Unreleased]: https://github.com/khiopslab/khalib/compare/0.4...HEAD
[0.4]: https://github.com/khiopslab/khalib/compare/0.3...0.4
[0.3]: https://github.com/khiopslab/khalib/compare/0.2...0.3
[0.2]: https://github.com/khiopslab/khalib/compare/0.1...0.2
[0.1]: https://github.com/khiopslab/khalib/releases/tag/0.1
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "khalib"
version = "0.3"
version = "0.4"
description = "Classifier Calibration with Khiops"
authors = [{ name = "Felipe Olmos", email = "[email protected]" }]
requires-python = ">=3.11"
Expand Down
76 changes: 48 additions & 28 deletions src/khalib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ def from_data(
do_data_preparation_only=True,
)
results = kh.read_analysis_results_file(f"{work_dir}/report.khj")
# import shutil

# shutil.copy(f"{work_dir}/report.khj", "tmp.khj")

# Initialize the histogram
if y is not None:
Expand Down Expand Up @@ -287,10 +284,13 @@ def from_data(
else:
if use_finest:
histogram_index = -1
else:
elif 100 in score_stats.modl_histograms.information_rates:
histogram_index = (
score_stats.modl_histograms.information_rates.index(100)
)
else:
assert len(score_stats.modl_histograms.information_rates) == 1
histogram_index = 0
breakpoints = score_stats.modl_histograms.histograms[
histogram_index
].bounds
Expand Down Expand Up @@ -954,31 +954,11 @@ def build_reliability_diagram(
fig.subplots_adjust(hspace=0)
fig.suptitle("Reliability Diagram")

# Build a unsupervised histogram to to detect the dirac masses case
# Build a unsupervised histogram
uhist_y_scores = Histogram.from_data(y_scores, use_finest=True)
dirac_indexes = []
if uhist_y_scores.freqs[1] == 0 and (
uhist_y_scores.bins[0][1] - uhist_y_scores.bins[0][0] < dirac_threshold
):
dirac_indexes.append(True)
else:
dirac_indexes.append(False)
for i in range(1, uhist_y_scores.n_bins - 1):
cur_left, cur_right = uhist_y_scores.bins[i]
if (
uhist_y_scores.freqs[i - 1] == 0
and uhist_y_scores.freqs[i + 1] == 0
and (cur_right - cur_left < dirac_threshold)
):
dirac_indexes.append(True)
else:
dirac_indexes.append(False)
if uhist_y_scores.freqs[-2] == 0 and (
uhist_y_scores.bins[-1][1] - uhist_y_scores.bins[-1][0] < dirac_threshold
):
dirac_indexes.append(True)
else:
dirac_indexes.append(False)

# Compute the dirac mass indexes
dirac_indexes = compute_dirac_indexes(uhist_y_scores, dirac_threshold)

# Compute the supervised score histogram
hist_y_scores = Histogram.from_data(y_scores, y)
Expand Down Expand Up @@ -1091,3 +1071,43 @@ def build_reliability_diagram(
)

return fig, axs


def compute_dirac_indexes(uhist, dirac_threshold):
"""Computes the dirac mass indexes of a histogram

We declare a dirac mass bin if:

- it is surrounded by empty bins.
- its length is less than ``dirac_threshold``

"""
dirac_indexes = []
if (
len(uhist.freqs) > 1
and uhist.freqs[1] == 0
and (uhist.bins[0][1] - uhist.bins[0][0] < dirac_threshold)
):
dirac_indexes.append(True)
else:
dirac_indexes.append(False)
for i in range(1, uhist.n_bins - 1):
cur_left, cur_right = uhist.bins[i]
if (
uhist.freqs[i - 1] == 0
and uhist.freqs[i + 1] == 0
and (cur_right - cur_left < dirac_threshold)
):
dirac_indexes.append(True)
else:
dirac_indexes.append(False)
if (
len(uhist.freqs) > 2
and uhist.freqs[-2] == 0
and (uhist.bins[-1][1] - uhist.bins[-1][0] < dirac_threshold)
):
dirac_indexes.append(True)
else:
dirac_indexes.append(False)

return dirac_indexes
29 changes: 28 additions & 1 deletion src/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from sklearn.svm import SVC
from sklearn.utils.validation import check_is_fitted

from khalib import Histogram, KhalibClassifier, calibration_error
from khalib import Histogram, KhalibClassifier, calibrate_binary, calibration_error
from khalib.main import compute_dirac_indexes


@pytest.fixture(name="data_root_dir")
Expand Down Expand Up @@ -443,3 +444,29 @@ def test_calibrator(
y_scores_calib_test, y_test, multi_class_method="top-label"
)
assert ece == pytest.approx(expected_ece, rel=1e-2)


class TestDiracHeuristic:
def test_uniform(self):
# The uniform distribution should have only one bin and no diracs
rng = np.random.default_rng(seed=1234567)
y_scores = rng.uniform(size=2000)
uhist = Histogram.from_data(y_scores, use_finest=True)
assert compute_dirac_indexes(uhist, 1e-06) == [False, False]

def test_dirac(self):
y_scores = np.array([0.2] * 250 + [0.5] * 250 + [0.9] * 500)
y = np.array(
[0] * 200 + [1] * 50 + [0] * 125 + [1] * 125 + [0] * 50 + [1] * 450
)
hist = Histogram.from_data(y_scores, y=y)
y_scores_calib = calibrate_binary(y_scores, hist, only_positive=True)
uhist = Histogram.from_data(y_scores_calib, use_finest=True)
assert compute_dirac_indexes(uhist, 1e-06) == [True, False, True, False, True]

def test_beta(self):
# A sufficiently large continuous distribution sample should have no diracs
rng = np.random.default_rng(seed=1234567)
y_scores = rng.beta(a=0.5, b=0.5, size=2000)
uhist = Histogram.from_data(y_scores, use_finest=True)
assert not all(compute_dirac_indexes(uhist, 1e-06))
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.