diff --git a/src/access/profiling/manager.py b/src/access/profiling/manager.py index d0ad053..5af74db 100644 --- a/src/access/profiling/manager.py +++ b/src/access/profiling/manager.py @@ -11,6 +11,7 @@ from access.profiling.experiment import ProfilingExperiment, ProfilingExperimentStatus, ProfilingLog from access.profiling.metrics import ProfilingMetric +from access.profiling.plotting_utils import plot_bar_metrics from access.profiling.scaling import plot_scaling_metrics logger = logging.getLogger(__name__) @@ -238,3 +239,55 @@ def plot_scaling_data( scaling_data.append(component_data) return plot_scaling_metrics(scaling_data, metric) + + def plot_bar_chart( + self, + components: list[str], + regions: list[list[str]], + metric: ProfilingMetric, + region_relabel_map: dict | None = None, + experiments: list[str] | None = None, + show: bool = True, + ) -> Figure: + """Plots a bar chart of a profiling metric over regions, grouped by experiment. + + Regions are placed along the x-axis. Within each region group, there is one bar per + experiment, coloured by experiment name. + + Args: + components (list[str]): List of component names to include. + regions (list[list[str]]): List of regions to include for each component. + metric (ProfilingMetric): Metric to plot. + region_relabel_map (dict | None): Optional mapping to relabel regions in the plot. + experiments (list[str] | None): Optional list of experiment names to include. If None, all experiments + are included. + show (bool): Whether to show the generated plot. Default: True. + + Returns: + Figure: The Matplotlib figure containing the bar chart. + + Raises: + ValueError: If no profiling data is found for a specified component in any experiment. + """ + exp_names = experiments if experiments is not None else list(self.data.keys()) + relabel = region_relabel_map or {} + + # Build a lookup from display label to (component, original_region) and preserve input order. + region_info: list[tuple[str, str, str]] = [] # (component, original_region, display_label) + for component, component_regions in zip(components, regions, strict=True): + for region in component_regions: + region_info.append((component, region, relabel.get(region, region))) + region_labels = [label for _, _, label in region_info] + + # Extract metric values per experiment, reading directly from the datasets + bar_data: dict[str, list[float]] = {} + for exp_name in exp_names: + values = [] + for component, region, _ in region_info: + ds = self.data[exp_name].get(component) + if ds is None: + raise ValueError(f"No profiling data found for component '{component}' in experiment '{exp_name}'.") + values.append(float(ds[metric].sel(region=region).pint.dequantify().values)) + bar_data[exp_name] = values + + return plot_bar_metrics(bar_data, region_labels, metric, show=show) diff --git a/src/access/profiling/plotting_utils.py b/src/access/profiling/plotting_utils.py index 227c415..b39113c 100644 --- a/src/access/profiling/plotting_utils.py +++ b/src/access/profiling/plotting_utils.py @@ -1,6 +1,11 @@ # Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details. # SPDX-License-Identifier: Apache-2.0 +import matplotlib.pyplot as plt +from matplotlib.figure import Figure + +from access.profiling.metrics import ProfilingMetric + def calculate_column_widths(table_data: list[list], first_col_fraction: float = None) -> list: """Calculate column widths based on content character length and required width for first column. @@ -61,3 +66,51 @@ def calculate_column_widths(table_data: list[list], first_col_fraction: float = col_widths = [length / total_length for length in max_lengths] return col_widths + + +def plot_bar_metrics( + data: dict[str, list[float]], + region_labels: list[str], + metric: ProfilingMetric, + show: bool = True, +) -> Figure: + """Plots a grouped bar chart of a profiling metric over regions. + + Regions are placed along the x-axis. Within each region group, there is one bar per + experiment, coloured by experiment name. + + Args: + data (dict[str, list[float]]): Mapping of experiment name to a list of metric values, + one per region (in the same order as ``region_labels``). + region_labels (list[str]): Ordered list of region display labels for the x-axis. + metric (ProfilingMetric): The metric being plotted (used for axis labels and title). + show (bool): Whether to call ``plt.show()``. Default: True. + + Returns: + Figure: The Matplotlib figure containing the bar chart. + """ + exp_names = list(data.keys()) + n_experiments = len(exp_names) + n_regions = len(region_labels) + + fig, ax = plt.subplots(figsize=(max(8, n_experiments * n_regions * 0.8), 6)) + bar_width = 0.8 / n_experiments + group_positions = list(range(n_regions)) + + for i, exp_name in enumerate(exp_names): + offsets = [pos + (i - (n_experiments - 1) / 2) * bar_width for pos in group_positions] + ax.bar(offsets, data[exp_name], width=bar_width, label=exp_name) + + ax.set_xticks(group_positions) + ax.set_xticklabels(region_labels) + ax.set_xlabel("Region") + ax.set_ylabel(f"{metric.name} ({metric.units})") + ax.set_title(f"{metric.description}") + ax.legend(title="Experiment") + ax.grid(axis="y", linestyle="--", alpha=0.7) + fig.tight_layout() + + if show: + plt.show() + + return fig diff --git a/tests/test_manager.py b/tests/test_manager.py index 61d957a..c86ba4f 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -307,3 +307,66 @@ def test_scaling_data(mock_plot, scaling_data): assert component_data[count].sel(region="Total").values.tolist() == [1, 1] assert component_data[tavg].sel(region="Total").values.tolist() == [600365.0, 300182.5] assert mock_plot.call_args.args[1] == tavg + + +@mock.patch("access.profiling.manager.plot_bar_metrics") +def test_bar_chart_data(mock_plot, scaling_data): + """Test the plot_bar_chart method of ProfilingManager. + + This test checks that bar chart data is correctly extracted from the datasets and that the + plotting function is called with the right arguments. + """ + paths, ncpus, datasets = scaling_data + manager = MockProfilingManager(paths, ncpus, datasets) + + # Test plotting bar chart for non-existing component + with pytest.raises(ValueError): + manager.plot_bar_chart( + components=["non_existing_component"], + regions=[["Region 1"]], + metric=tavg, + ) + + # Test plotting bar chart with region selection, relabelling, and experiment filtering + manager.plot_bar_chart( + components=["component"], + regions=[["Region 1", "Region 2"]], + metric=tavg, + region_relabel_map={"Region 1": "Total"}, + experiments=["1cpu", "4cpu"], + ) + assert mock_plot.call_count == 1 + + # Verify the data dict passed to plot_bar_metrics + bar_data = mock_plot.call_args.args[0] + assert isinstance(bar_data, dict) + assert set(bar_data.keys()) == {"1cpu", "4cpu"} + assert bar_data["1cpu"] == pytest.approx([600365.0, 2.345388]) + assert bar_data["4cpu"] == pytest.approx([300182.5, 1.172694]) + + # Verify region labels + region_labels = mock_plot.call_args.args[1] + assert region_labels == ["Total", "Region 2"] + + # Verify metric + assert mock_plot.call_args.args[2] == tavg + + # Verify show kwarg is passed through + assert mock_plot.call_args.kwargs["show"] is True + + +@mock.patch("access.profiling.manager.plot_bar_metrics") +def test_bar_chart_all_experiments(mock_plot, scaling_data): + """Test plot_bar_chart includes all experiments when none are specified.""" + paths, ncpus, datasets = scaling_data + manager = MockProfilingManager(paths, ncpus, datasets) + + manager.plot_bar_chart( + components=["component"], + regions=[["Region 1"]], + metric=tavg, + show=False, + ) + bar_data = mock_plot.call_args.args[0] + assert set(bar_data.keys()) == {"1cpu", "4cpu", "2cpu"} + assert mock_plot.call_args.kwargs["show"] is False diff --git a/tests/test_plotting_utils.py b/tests/test_plotting_utils.py index 1193463..e07f976 100644 --- a/tests/test_plotting_utils.py +++ b/tests/test_plotting_utils.py @@ -1,9 +1,13 @@ # Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details. # SPDX-License-Identifier: Apache-2.0 +from unittest import mock + import pytest +from matplotlib.figure import Figure -from access.profiling.plotting_utils import calculate_column_widths +from access.profiling.metrics import tavg +from access.profiling.plotting_utils import calculate_column_widths, plot_bar_metrics @pytest.fixture(scope="module") @@ -75,3 +79,64 @@ def test_invalid_tables(nonrectangular_table_data, singlerow_table_data, singlec calculate_column_widths(singlerow_table_data) with pytest.raises(ValueError): calculate_column_widths(singlecol_table_data) + + +def test_plot_bar_metrics_returns_figure(): + """Test that plot_bar_metrics returns a Figure with correct structure.""" + data = { + "exp_A": [100.0, 50.0], + "exp_B": [80.0, 40.0], + } + region_labels = ["Region 1", "Region 2"] + + fig = plot_bar_metrics(data, region_labels, tavg, show=False) + + assert isinstance(fig, Figure) + ax = fig.axes[0] + + # Check axis labels and title + assert ax.get_xlabel() == "Region" + assert tavg.name in ax.get_ylabel() + assert str(tavg.units) in ax.get_ylabel() + assert ax.get_title() == tavg.description + + # Check x-tick labels match the region labels + tick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert tick_labels == region_labels + + # Check legend shows experiment names + legend_labels = [t.get_text() for t in ax.get_legend().get_texts()] + assert legend_labels == ["exp_A", "exp_B"] + + # Check correct number of bars: 2 experiments * 2 regions = 4 bars + assert len(ax.patches) == 4 + + # matplotlib groups bars by series: first all exp_A bars, then all exp_B bars + heights = [p.get_height() for p in ax.patches] + assert heights == pytest.approx([100.0, 50.0, 80.0, 40.0]) + + +def test_plot_bar_metrics_single_experiment(): + """Test plot_bar_metrics with a single experiment.""" + data = {"exp_A": [10.0, 20.0, 30.0]} + region_labels = ["R1", "R2", "R3"] + + fig = plot_bar_metrics(data, region_labels, tavg, show=False) + ax = fig.axes[0] + + assert len(ax.patches) == 3 + legend_labels = [t.get_text() for t in ax.get_legend().get_texts()] + assert legend_labels == ["exp_A"] + + +@mock.patch("access.profiling.plotting_utils.plt.show") +def test_plot_bar_metrics_show(mock_show): + """Test that plt.show() is called when show=True and not called when show=False.""" + data = {"exp_A": [10.0]} + region_labels = ["R1"] + + plot_bar_metrics(data, region_labels, tavg, show=True) + assert mock_show.call_count == 1 + + plot_bar_metrics(data, region_labels, tavg, show=False) + assert mock_show.call_count == 1 # No additional call