Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/access/profiling/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from access.profiling.experiment import ProfilingExperiment, ProfilingExperimentStatus, ProfilingLog
from access.profiling.metrics import ProfilingMetric
from access.profiling.plotting_utils import plot_bar_metrics
from access.profiling.scaling import plot_scaling_metrics

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -238,3 +239,55 @@ def plot_scaling_data(
scaling_data.append(component_data)

return plot_scaling_metrics(scaling_data, metric)

def plot_bar_chart(
self,
components: list[str],
regions: list[list[str]],
metric: ProfilingMetric,
region_relabel_map: dict | None = None,
experiments: list[str] | None = None,
show: bool = True,
) -> Figure:
"""Plots a bar chart of a profiling metric over regions, grouped by experiment.

Regions are placed along the x-axis. Within each region group, there is one bar per
experiment, coloured by experiment name.

Args:
components (list[str]): List of component names to include.
regions (list[list[str]]): List of regions to include for each component.
metric (ProfilingMetric): Metric to plot.
region_relabel_map (dict | None): Optional mapping to relabel regions in the plot.
experiments (list[str] | None): Optional list of experiment names to include. If None, all experiments
are included.
show (bool): Whether to show the generated plot. Default: True.

Returns:
Figure: The Matplotlib figure containing the bar chart.

Raises:
ValueError: If no profiling data is found for a specified component in any experiment.
"""
exp_names = experiments if experiments is not None else list(self.data.keys())
relabel = region_relabel_map or {}

# Build a lookup from display label to (component, original_region) and preserve input order.
region_info: list[tuple[str, str, str]] = [] # (component, original_region, display_label)
for component, component_regions in zip(components, regions, strict=True):
for region in component_regions:
region_info.append((component, region, relabel.get(region, region)))
region_labels = [label for _, _, label in region_info]

# Extract metric values per experiment, reading directly from the datasets
bar_data: dict[str, list[float]] = {}
for exp_name in exp_names:
values = []
for component, region, _ in region_info:
ds = self.data[exp_name].get(component)
if ds is None:
raise ValueError(f"No profiling data found for component '{component}' in experiment '{exp_name}'.")
values.append(float(ds[metric].sel(region=region).pint.dequantify().values))
bar_data[exp_name] = values

return plot_bar_metrics(bar_data, region_labels, metric, show=show)
53 changes: 53 additions & 0 deletions src/access/profiling/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0

import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from access.profiling.metrics import ProfilingMetric


def calculate_column_widths(table_data: list[list], first_col_fraction: float = None) -> list:
"""Calculate column widths based on content character length and required width for first column.
Expand Down Expand Up @@ -61,3 +66,51 @@ def calculate_column_widths(table_data: list[list], first_col_fraction: float =
col_widths = [length / total_length for length in max_lengths]

return col_widths


def plot_bar_metrics(
data: dict[str, list[float]],
region_labels: list[str],
metric: ProfilingMetric,
show: bool = True,
) -> Figure:
"""Plots a grouped bar chart of a profiling metric over regions.

Regions are placed along the x-axis. Within each region group, there is one bar per
experiment, coloured by experiment name.

Args:
data (dict[str, list[float]]): Mapping of experiment name to a list of metric values,
one per region (in the same order as ``region_labels``).
region_labels (list[str]): Ordered list of region display labels for the x-axis.
metric (ProfilingMetric): The metric being plotted (used for axis labels and title).
show (bool): Whether to call ``plt.show()``. Default: True.

Returns:
Figure: The Matplotlib figure containing the bar chart.
"""
exp_names = list(data.keys())
n_experiments = len(exp_names)
n_regions = len(region_labels)

fig, ax = plt.subplots(figsize=(max(8, n_experiments * n_regions * 0.8), 6))
bar_width = 0.8 / n_experiments
group_positions = list(range(n_regions))

for i, exp_name in enumerate(exp_names):
offsets = [pos + (i - (n_experiments - 1) / 2) * bar_width for pos in group_positions]
ax.bar(offsets, data[exp_name], width=bar_width, label=exp_name)

ax.set_xticks(group_positions)
ax.set_xticklabels(region_labels)
ax.set_xlabel("Region")
ax.set_ylabel(f"{metric.name} ({metric.units})")
ax.set_title(f"{metric.description}")
ax.legend(title="Experiment")
ax.grid(axis="y", linestyle="--", alpha=0.7)
fig.tight_layout()

if show:
plt.show()

return fig
63 changes: 63 additions & 0 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,66 @@ def test_scaling_data(mock_plot, scaling_data):
assert component_data[count].sel(region="Total").values.tolist() == [1, 1]
assert component_data[tavg].sel(region="Total").values.tolist() == [600365.0, 300182.5]
assert mock_plot.call_args.args[1] == tavg


@mock.patch("access.profiling.manager.plot_bar_metrics")
def test_bar_chart_data(mock_plot, scaling_data):
"""Test the plot_bar_chart method of ProfilingManager.

This test checks that bar chart data is correctly extracted from the datasets and that the
plotting function is called with the right arguments.
"""
paths, ncpus, datasets = scaling_data
manager = MockProfilingManager(paths, ncpus, datasets)

# Test plotting bar chart for non-existing component
with pytest.raises(ValueError):
manager.plot_bar_chart(
components=["non_existing_component"],
regions=[["Region 1"]],
metric=tavg,
)

# Test plotting bar chart with region selection, relabelling, and experiment filtering
manager.plot_bar_chart(
components=["component"],
regions=[["Region 1", "Region 2"]],
metric=tavg,
region_relabel_map={"Region 1": "Total"},
experiments=["1cpu", "4cpu"],
)
assert mock_plot.call_count == 1

# Verify the data dict passed to plot_bar_metrics
bar_data = mock_plot.call_args.args[0]
assert isinstance(bar_data, dict)
assert set(bar_data.keys()) == {"1cpu", "4cpu"}
assert bar_data["1cpu"] == pytest.approx([600365.0, 2.345388])
assert bar_data["4cpu"] == pytest.approx([300182.5, 1.172694])

# Verify region labels
region_labels = mock_plot.call_args.args[1]
assert region_labels == ["Total", "Region 2"]

# Verify metric
assert mock_plot.call_args.args[2] == tavg

# Verify show kwarg is passed through
assert mock_plot.call_args.kwargs["show"] is True


@mock.patch("access.profiling.manager.plot_bar_metrics")
def test_bar_chart_all_experiments(mock_plot, scaling_data):
"""Test plot_bar_chart includes all experiments when none are specified."""
paths, ncpus, datasets = scaling_data
manager = MockProfilingManager(paths, ncpus, datasets)

manager.plot_bar_chart(
components=["component"],
regions=[["Region 1"]],
metric=tavg,
show=False,
)
bar_data = mock_plot.call_args.args[0]
assert set(bar_data.keys()) == {"1cpu", "4cpu", "2cpu"}
assert mock_plot.call_args.kwargs["show"] is False
67 changes: 66 additions & 1 deletion tests/test_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0

from unittest import mock

import pytest
from matplotlib.figure import Figure

from access.profiling.plotting_utils import calculate_column_widths
from access.profiling.metrics import tavg
from access.profiling.plotting_utils import calculate_column_widths, plot_bar_metrics


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -75,3 +79,64 @@ def test_invalid_tables(nonrectangular_table_data, singlerow_table_data, singlec
calculate_column_widths(singlerow_table_data)
with pytest.raises(ValueError):
calculate_column_widths(singlecol_table_data)


def test_plot_bar_metrics_returns_figure():
"""Test that plot_bar_metrics returns a Figure with correct structure."""
data = {
"exp_A": [100.0, 50.0],
"exp_B": [80.0, 40.0],
}
region_labels = ["Region 1", "Region 2"]

fig = plot_bar_metrics(data, region_labels, tavg, show=False)

assert isinstance(fig, Figure)
ax = fig.axes[0]

# Check axis labels and title
assert ax.get_xlabel() == "Region"
assert tavg.name in ax.get_ylabel()
assert str(tavg.units) in ax.get_ylabel()
assert ax.get_title() == tavg.description

# Check x-tick labels match the region labels
tick_labels = [t.get_text() for t in ax.get_xticklabels()]
assert tick_labels == region_labels

# Check legend shows experiment names
legend_labels = [t.get_text() for t in ax.get_legend().get_texts()]
assert legend_labels == ["exp_A", "exp_B"]

# Check correct number of bars: 2 experiments * 2 regions = 4 bars
assert len(ax.patches) == 4

# matplotlib groups bars by series: first all exp_A bars, then all exp_B bars
heights = [p.get_height() for p in ax.patches]
assert heights == pytest.approx([100.0, 50.0, 80.0, 40.0])


def test_plot_bar_metrics_single_experiment():
"""Test plot_bar_metrics with a single experiment."""
data = {"exp_A": [10.0, 20.0, 30.0]}
region_labels = ["R1", "R2", "R3"]

fig = plot_bar_metrics(data, region_labels, tavg, show=False)
ax = fig.axes[0]

assert len(ax.patches) == 3
legend_labels = [t.get_text() for t in ax.get_legend().get_texts()]
assert legend_labels == ["exp_A"]


@mock.patch("access.profiling.plotting_utils.plt.show")
def test_plot_bar_metrics_show(mock_show):
"""Test that plt.show() is called when show=True and not called when show=False."""
data = {"exp_A": [10.0]}
region_labels = ["R1"]

plot_bar_metrics(data, region_labels, tavg, show=True)
assert mock_show.call_count == 1

plot_bar_metrics(data, region_labels, tavg, show=False)
assert mock_show.call_count == 1 # No additional call
Loading