diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bef1b2..f2fd960 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [0.4] - 2026-01-05 + +### Fixed + +- `build_reliability_diagram`: Error when scores are distributed uniformly. + ## [0.3] - 2025-12-26 ### Fixed @@ -17,7 +23,8 @@ - `calibration_error` function - `build_reliability_diagram` function -[Unreleased]: https://github.com/yourusername/yourproject/compare/0.3...HEAD -[0.3]: https://github.com/yourusername/yourproject/compare/0.2...0.3 -[0.2]: https://github.com/yourusername/yourproject/compare/0.1...0.2 -[0.1]: https://github.com/yourusername/yourproject/releases/tag/0.1 +[Unreleased]: https://github.com/khiopslab/khalib/compare/0.4...HEAD +[0.4]: https://github.com/khiopslab/khalib/compare/0.3...0.4 +[0.3]: https://github.com/khiopslab/khalib/compare/0.2...0.3 +[0.2]: https://github.com/khiopslab/khalib/compare/0.1...0.2 +[0.1]: https://github.com/khiopslab/khalib/releases/tag/0.1 diff --git a/pyproject.toml b/pyproject.toml index eb15340..0aa56c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "khalib" -version = "0.3" +version = "0.4" description = "Classifier Calibration with Khiops" authors = [{ name = "Felipe Olmos", email = "luisfelipe.olmosmarchant@orange.com" }] requires-python = ">=3.11" diff --git a/src/khalib/main.py b/src/khalib/main.py index 44ed919..17ba434 100644 --- a/src/khalib/main.py +++ b/src/khalib/main.py @@ -233,9 +233,6 @@ def from_data( do_data_preparation_only=True, ) results = kh.read_analysis_results_file(f"{work_dir}/report.khj") - # import shutil - - # shutil.copy(f"{work_dir}/report.khj", "tmp.khj") # Initialize the histogram if y is not None: @@ -287,10 +284,13 @@ def from_data( else: if use_finest: histogram_index = -1 - else: + elif 100 in score_stats.modl_histograms.information_rates: histogram_index = ( score_stats.modl_histograms.information_rates.index(100) ) + else: + assert len(score_stats.modl_histograms.information_rates) == 1 + histogram_index = 0 breakpoints = score_stats.modl_histograms.histograms[ histogram_index ].bounds @@ -954,31 +954,11 @@ def build_reliability_diagram( fig.subplots_adjust(hspace=0) fig.suptitle("Reliability Diagram") - # Build a unsupervised histogram to to detect the dirac masses case + # Build a unsupervised histogram uhist_y_scores = Histogram.from_data(y_scores, use_finest=True) - dirac_indexes = [] - if uhist_y_scores.freqs[1] == 0 and ( - uhist_y_scores.bins[0][1] - uhist_y_scores.bins[0][0] < dirac_threshold - ): - dirac_indexes.append(True) - else: - dirac_indexes.append(False) - for i in range(1, uhist_y_scores.n_bins - 1): - cur_left, cur_right = uhist_y_scores.bins[i] - if ( - uhist_y_scores.freqs[i - 1] == 0 - and uhist_y_scores.freqs[i + 1] == 0 - and (cur_right - cur_left < dirac_threshold) - ): - dirac_indexes.append(True) - else: - dirac_indexes.append(False) - if uhist_y_scores.freqs[-2] == 0 and ( - uhist_y_scores.bins[-1][1] - uhist_y_scores.bins[-1][0] < dirac_threshold - ): - dirac_indexes.append(True) - else: - dirac_indexes.append(False) + + # Compute the dirac mass indexes + dirac_indexes = compute_dirac_indexes(uhist_y_scores, dirac_threshold) # Compute the supervised score histogram hist_y_scores = Histogram.from_data(y_scores, y) @@ -1091,3 +1071,43 @@ def build_reliability_diagram( ) return fig, axs + + +def compute_dirac_indexes(uhist, dirac_threshold): + """Computes the dirac mass indexes of a histogram + + We declare a dirac mass bin if: + + - it is surrounded by empty bins. + - its length is less than ``dirac_threshold`` + + """ + dirac_indexes = [] + if ( + len(uhist.freqs) > 1 + and uhist.freqs[1] == 0 + and (uhist.bins[0][1] - uhist.bins[0][0] < dirac_threshold) + ): + dirac_indexes.append(True) + else: + dirac_indexes.append(False) + for i in range(1, uhist.n_bins - 1): + cur_left, cur_right = uhist.bins[i] + if ( + uhist.freqs[i - 1] == 0 + and uhist.freqs[i + 1] == 0 + and (cur_right - cur_left < dirac_threshold) + ): + dirac_indexes.append(True) + else: + dirac_indexes.append(False) + if ( + len(uhist.freqs) > 2 + and uhist.freqs[-2] == 0 + and (uhist.bins[-1][1] - uhist.bins[-1][0] < dirac_threshold) + ): + dirac_indexes.append(True) + else: + dirac_indexes.append(False) + + return dirac_indexes diff --git a/src/tests/test_all.py b/src/tests/test_all.py index 2dbf2a4..e782b49 100644 --- a/src/tests/test_all.py +++ b/src/tests/test_all.py @@ -11,7 +11,8 @@ from sklearn.svm import SVC from sklearn.utils.validation import check_is_fitted -from khalib import Histogram, KhalibClassifier, calibration_error +from khalib import Histogram, KhalibClassifier, calibrate_binary, calibration_error +from khalib.main import compute_dirac_indexes @pytest.fixture(name="data_root_dir") @@ -443,3 +444,29 @@ def test_calibrator( y_scores_calib_test, y_test, multi_class_method="top-label" ) assert ece == pytest.approx(expected_ece, rel=1e-2) + + +class TestDiracHeuristic: + def test_uniform(self): + # The uniform distribution should have only one bin and no diracs + rng = np.random.default_rng(seed=1234567) + y_scores = rng.uniform(size=2000) + uhist = Histogram.from_data(y_scores, use_finest=True) + assert compute_dirac_indexes(uhist, 1e-06) == [False, False] + + def test_dirac(self): + y_scores = np.array([0.2] * 250 + [0.5] * 250 + [0.9] * 500) + y = np.array( + [0] * 200 + [1] * 50 + [0] * 125 + [1] * 125 + [0] * 50 + [1] * 450 + ) + hist = Histogram.from_data(y_scores, y=y) + y_scores_calib = calibrate_binary(y_scores, hist, only_positive=True) + uhist = Histogram.from_data(y_scores_calib, use_finest=True) + assert compute_dirac_indexes(uhist, 1e-06) == [True, False, True, False, True] + + def test_beta(self): + # A sufficiently large continuous distribution sample should have no diracs + rng = np.random.default_rng(seed=1234567) + y_scores = rng.beta(a=0.5, b=0.5, size=2000) + uhist = Histogram.from_data(y_scores, use_finest=True) + assert not all(compute_dirac_indexes(uhist, 1e-06)) diff --git a/uv.lock b/uv.lock index f2d3110..5234e7f 100644 --- a/uv.lock +++ b/uv.lock @@ -1080,7 +1080,7 @@ wheels = [ [[package]] name = "khalib" -version = "0.2" +version = "0.4" source = { editable = "." } dependencies = [ { name = "khiops" },