Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"xarray",
"pint",
"pint-xarray",
"matplotlib",
]

[build-system]
Expand Down
2 changes: 1 addition & 1 deletion src/access/profiling/fms_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def read(self, stream: str) -> dict:
profiling_section = match.group(1)
for line in profiling_region_p.finditer(profiling_section):
stats["region"].append(line.group("region"))
for label, metric in zip(labels, self.metrics):
for label, metric in zip(labels, self.metrics, strict=True):
stats[metric].append(_convert_from_string(line.group(label)))

# Convert time fraction to percentage
Expand Down
3 changes: 2 additions & 1 deletion src/access/profiling/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def parse_data_series(self, logs: list[str], varname: str, vars: Iterable) -> xr
Dataset: Series profiling data.
"""
datasets = []
for var, log in zip(vars, logs):
for var, log in zip(vars, logs, strict=True):
data = self.read(log)
datasets.append(
xr.Dataset(
Expand All @@ -86,6 +86,7 @@ def parse_data_series(self, logs: list[str], varname: str, vars: Iterable) -> xr
xr.DataArray([data[metric]], dims=[varname, "region"]).pint.quantify(metric.units)
for metric in self.metrics
],
strict=True,
)
),
coords={varname: [var], "region": data["region"]},
Expand Down
63 changes: 63 additions & 0 deletions src/access/profiling/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0


def calculate_column_widths(table_data: list[list], first_col_fraction: float = None) -> list:
"""Calculate column widths based on content character length and required width for first column.

Args:
table_data (list[list]): Table data including headers. e.g.
[[ "ncpus", "col1", "col2", "col3"],
["region1", 0.1, 0.2, 0.3],
["region2", 1. , 2. , 3. ]]
first_col_fraction (float): If provided, controls the fraction of the table width
assigned to the first column. Default None.
If set to 0.0 or None, all columns have the same width.
Must be between 0.0 (inclusive) and 1.0 (exclusive).

Returns:
list : Column width fractions, adding up to 1.

Raises:
ValueError: If table_data has fewer than 2 rows or 2 columns.
ValueError: If table_data shape is not rectangular, i.e., all rows do not have the same number of columns.
"""
if not table_data:
return []

# Check that table has a header row and row-label column, and no missing elements.
if len(table_data) > 1:
for row in table_data[1:]:
if len(row) != len(table_data[0]):
raise ValueError("Table rows must have the same number of elements")
else:
raise ValueError("Table must have at least 2 rows (first row is table header)")
if len(table_data[0]) < 2:
raise ValueError("Table must have at least 2 columns (first column is row label)")

if first_col_fraction is not None and not (0 <= first_col_fraction < 1):
raise ValueError("first_col_fration must be between 0 and 1 (exclusive)")

n_cols = len(table_data[0])

# Calculate max content length for each column based on no. of chars
max_lengths = []
for col in range(n_cols):
col_lengths = [len(str(row[col])) for row in table_data]
max_lengths.append(max(col_lengths))

if first_col_fraction and first_col_fraction > 0:
# Set data columns to proportional widths based on content and first_col_fraction
data_cols_total = sum(max_lengths[1:])
base_width = (1 - first_col_fraction) / data_cols_total

col_widths = [first_col_fraction]
for length in max_lengths[1:]:
col_widths.append(length * base_width)

else:
# Equal column width
total_length = sum(max_lengths)
col_widths = [length / total_length for length in max_lengths]

return col_widths
98 changes: 98 additions & 0 deletions src/access/profiling/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

"""Functions to calculate metrics related to parallel scaling of applications."""

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import xarray as xr
from matplotlib.figure import Figure

from access.profiling.metrics import ProfilingMetric
from access.profiling.plotting_utils import calculate_column_widths


def parallel_speedup(stats: xr.Dataset, metric: ProfilingMetric) -> xr.DataArray:
Expand Down Expand Up @@ -42,3 +46,97 @@ def parallel_efficiency(stats: xr.Dataset, metric: ProfilingMetric) -> xr.DataAr
eff = eff.pint.to("percent")
eff.name = "parallel efficiency"
return eff


def plot_scaling_metrics(
stats: list[xr.Dataset],
regions: list[list[str]],
metric: ProfilingMetric,
xcoordinate: str = "ncpus",
region_relabel_map: dict = None,
first_col_fraction: float = 0.4,
show: bool = True,
) -> Figure:
"""Plots parallel speedup and efficiency from a list of datasets

Args:
stats (list[xr.Dataset]): The raw times to plot.
regions (list[list[str]]): The list of regions to plot.
regions[0][:] corresponds to regions to plot in stats[0].
metric (ProfilinMetric): The metric to plot for each stat.
xcoordinate (str): The x-axis variable e.g. "ncpus".
region_relabel_map (dict): Mapping of labels to use for each region instead of the region name.
If an element of "regions" is a key in this map, the region
will be replaced by the corresponding value in the plot.
Default: None.
first_col_fraction (float): The fraction of table width to assign to the row labels. Default 0.4.
show (bool): Whether to show the generated plot. Default: True.

Returns:
Figure: The Matplotlib figure on which the scaling plots and table are plotted on.

Raises:
ValueError: If region_labels is non-empty
"""

# set default relabel map
if region_relabel_map is None:
region_relabel_map = {}

# setup plots
fig = plt.figure(figsize=(15, 6))
# using gridspec so table can be added
gs = gridspec.GridSpec(2, 2, height_ratios=[3, 1], hspace=0.3)
ax1, ax2 = fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])
ax_tbl = fig.add_subplot(gs[1, :])

# add table of raw timings
tbl = [[xcoordinate] + list(stats[0][xcoordinate].values)] # first row
for stat, region in zip(stats, regions, strict=True):
# calculate efficiency and speedup
efficiency = parallel_efficiency(stat, metric)
speedup = parallel_speedup(stat, metric)

# plots speedup and efficiency on their respective axes.
max_eff = 100
for r in region:
label = region_relabel_map.get(r)
label = label if label else r
speedup.loc[r, :].plot.line(x=xcoordinate, ax=ax1, marker="o", label=label)
efficiency.loc[r, :].plot.line(x=xcoordinate, ax=ax2, marker="o", label=label)
# find max efficiency for setting efficiency axis
max_eff = max(max_eff, efficiency.loc[r, :].max())

tbl.append([label] + [f"{val:.2f}" for val in stat[metric].loc[:, r].pint.dequantify().values])

# ideal speedup/scaling
minx = stat[xcoordinate].values.min()
nx = len(stat[xcoordinate].values)
ideal_speedups = [i / minx for i in stat[xcoordinate].values]
ax1.plot(stat[xcoordinate].values, ideal_speedups, "k:", label="ideal")
ax2.plot(stat[xcoordinate].values, [100] * nx, "k:", label="ideal")

# formatting
ax1.legend()
ax1.grid()
ax2.grid()
ax2.set_ylim((0, 1.1 * max_eff))
ax1.set_title("Parallel Speedup")
ax2.set_title("Parallel Efficiency")
ax_tbl.axis("off")
tbl_chart = ax_tbl.table(
tbl,
bbox=(0.05, 0, 0.9, 1),
cellLoc="center",
colWidths=calculate_column_widths(tbl, first_col_fraction),
)
ax_tbl.set_title(f"Timings ({stat[metric].pint.units})")
for i in range(len(tbl[0])):
tbl_chart[(0, i)].set_text_props(weight="bold")
for i in range(len(tbl)):
tbl_chart[(i, 0)].set_text_props(weight="bold")

if show:
plt.show()

return fig
77 changes: 77 additions & 0 deletions tests/test_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0

import pytest

from access.profiling.plotting_utils import calculate_column_widths


@pytest.fixture(scope="module")
def table_data():
"""Fixture returning sample table data."""
return [
["Cell00", "Cell01", "Cell 02"],
["Cell10", "Cell11", "Cell 12"],
["Cell20", "Cell21", "Cell 22"],
["Cell30", "Cell31", "Cell32"],
]


@pytest.fixture(scope="module")
def nonrectangular_table_data():
"""Fixture returning sample invalid table data where table isn't rectangular."""
return [
["Cell00", "Cell01", "Cell02"],
["Cell10", "Cell11"],
]


@pytest.fixture(scope="module")
def singlerow_table_data():
"""Fixture returning sample invalid table data where table only has 1 row."""
return [
["Cell00", "Cell01", "Cell02"],
]


@pytest.fixture(scope="module")
def singlecol_table_data():
"""Fixture returning sample invalid table data where table only has 1 column."""
return [
["Cell00"],
["Cell10"],
]


def test_calculate_column_widths_flexible(table_data):
"""Test the calculate_column_widths function."""
# Test with empty table
assert calculate_column_widths([]) == []

# Test with multiple rows, multiple columns
col_widths = calculate_column_widths(table_data, 0.4)
assert abs(sum(col_widths) - 1.0) < 1e-6 # Sum to 1.0
assert col_widths == pytest.approx([0.4, 0.2, 0.4]) # Proportional to content length with first column flexible


def test_calculate_column_widths_fixed(table_data):
# Test with first_col_flexible=False
col_widths = calculate_column_widths(table_data)
assert abs(sum(col_widths) - 1.0) < 1e-6 # Sum to 1.0
assert col_widths == [0.25, 0.25, 0.5] # Proportional to content length


def test_wrong_column_widths(table_data):
with pytest.raises(ValueError):
calculate_column_widths(table_data, first_col_fraction=1.0)
with pytest.raises(ValueError):
calculate_column_widths(table_data, first_col_fraction=-1.0)


def test_invalid_tables(nonrectangular_table_data, singlerow_table_data, singlecol_table_data):
with pytest.raises(ValueError):
calculate_column_widths(nonrectangular_table_data)
with pytest.raises(ValueError):
calculate_column_widths(singlerow_table_data)
with pytest.raises(ValueError):
calculate_column_widths(singlecol_table_data)
20 changes: 19 additions & 1 deletion tests/test_scaling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details.
# SPDX-License-Identifier: Apache-2.0

from unittest import mock

import pint
import pytest
import xarray as xr

from access.profiling.metrics import count, tavg
from access.profiling.scaling import parallel_efficiency, parallel_speedup
from access.profiling.scaling import parallel_efficiency, parallel_speedup, plot_scaling_metrics


@pytest.fixture()
Expand Down Expand Up @@ -79,3 +84,16 @@ def test_incorrect_units(simple_scaling_data):
"""Test calculation with incorrect units."""
with pytest.raises(ValueError):
parallel_speedup(simple_scaling_data, count)


@mock.patch("matplotlib.pyplot.show", autospec=True)
def test_plot_scaling_metrics(mock_plt, simple_scaling_data):
"""Test plotting scaling metrics. Currently only checks that the function runs without errors."""

plot_scaling_metrics(
stats=[simple_scaling_data],
regions=[["Region 1", "Region 2"]],
metric=tavg,
xcoordinate="ncpus",
)
mock_plt.assert_called_once()
Loading