From 14fdca3b8a434319226be600440d2ac3fc462a33 Mon Sep 17 00:00:00 2001 From: Micael Oliveira Date: Fri, 3 Oct 2025 12:41:04 +1000 Subject: [PATCH] Add function to plot scaling data. --- pyproject.toml | 1 + src/access/profiling/fms_parser.py | 2 +- src/access/profiling/parser.py | 3 +- src/access/profiling/plotting_utils.py | 63 +++++++++++++++++ src/access/profiling/scaling.py | 98 ++++++++++++++++++++++++++ tests/test_plotting_utils.py | 77 ++++++++++++++++++++ tests/test_scaling.py | 20 +++++- 7 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 src/access/profiling/plotting_utils.py create mode 100644 tests/test_plotting_utils.py diff --git a/pyproject.toml b/pyproject.toml index 8b4b675..fe70314 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "xarray", "pint", "pint-xarray", + "matplotlib", ] [build-system] diff --git a/src/access/profiling/fms_parser.py b/src/access/profiling/fms_parser.py index fa1db5b..6cb68a2 100644 --- a/src/access/profiling/fms_parser.py +++ b/src/access/profiling/fms_parser.py @@ -67,7 +67,7 @@ def read(self, stream: str) -> dict: profiling_section = match.group(1) for line in profiling_region_p.finditer(profiling_section): stats["region"].append(line.group("region")) - for label, metric in zip(labels, self.metrics): + for label, metric in zip(labels, self.metrics, strict=True): stats[metric].append(_convert_from_string(line.group(label))) # Convert time fraction to percentage diff --git a/src/access/profiling/parser.py b/src/access/profiling/parser.py index a8177ee..687cd72 100644 --- a/src/access/profiling/parser.py +++ b/src/access/profiling/parser.py @@ -75,7 +75,7 @@ def parse_data_series(self, logs: list[str], varname: str, vars: Iterable) -> xr Dataset: Series profiling data. """ datasets = [] - for var, log in zip(vars, logs): + for var, log in zip(vars, logs, strict=True): data = self.read(log) datasets.append( xr.Dataset( @@ -86,6 +86,7 @@ def parse_data_series(self, logs: list[str], varname: str, vars: Iterable) -> xr xr.DataArray([data[metric]], dims=[varname, "region"]).pint.quantify(metric.units) for metric in self.metrics ], + strict=True, ) ), coords={varname: [var], "region": data["region"]}, diff --git a/src/access/profiling/plotting_utils.py b/src/access/profiling/plotting_utils.py new file mode 100644 index 0000000..7344bb2 --- /dev/null +++ b/src/access/profiling/plotting_utils.py @@ -0,0 +1,63 @@ +# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details. +# SPDX-License-Identifier: Apache-2.0 + + +def calculate_column_widths(table_data: list[list], first_col_fraction: float = None) -> list: + """Calculate column widths based on content character length and required width for first column. + + Args: + table_data (list[list]): Table data including headers. e.g. + [[ "ncpus", "col1", "col2", "col3"], + ["region1", 0.1, 0.2, 0.3], + ["region2", 1. , 2. , 3. ]] + first_col_fraction (float): If provided, controls the fraction of the table width + assigned to the first column. Default None. + If set to 0.0 or None, all columns have the same width. + Must be between 0.0 (inclusive) and 1.0 (exclusive). + + Returns: + list : Column width fractions, adding up to 1. + + Raises: + ValueError: If table_data has fewer than 2 rows or 2 columns. + ValueError: If table_data shape is not rectangular, i.e., all rows do not have the same number of columns. + """ + if not table_data: + return [] + + # Check that table has a header row and row-label column, and no missing elements. + if len(table_data) > 1: + for row in table_data[1:]: + if len(row) != len(table_data[0]): + raise ValueError("Table rows must have the same number of elements") + else: + raise ValueError("Table must have at least 2 rows (first row is table header)") + if len(table_data[0]) < 2: + raise ValueError("Table must have at least 2 columns (first column is row label)") + + if first_col_fraction is not None and not (0 <= first_col_fraction < 1): + raise ValueError("first_col_fration must be between 0 and 1 (exclusive)") + + n_cols = len(table_data[0]) + + # Calculate max content length for each column based on no. of chars + max_lengths = [] + for col in range(n_cols): + col_lengths = [len(str(row[col])) for row in table_data] + max_lengths.append(max(col_lengths)) + + if first_col_fraction and first_col_fraction > 0: + # Set data columns to proportional widths based on content and first_col_fraction + data_cols_total = sum(max_lengths[1:]) + base_width = (1 - first_col_fraction) / data_cols_total + + col_widths = [first_col_fraction] + for length in max_lengths[1:]: + col_widths.append(length * base_width) + + else: + # Equal column width + total_length = sum(max_lengths) + col_widths = [length / total_length for length in max_lengths] + + return col_widths diff --git a/src/access/profiling/scaling.py b/src/access/profiling/scaling.py index 7dafeab..af24d9e 100644 --- a/src/access/profiling/scaling.py +++ b/src/access/profiling/scaling.py @@ -3,9 +3,13 @@ """Functions to calculate metrics related to parallel scaling of applications.""" +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt import xarray as xr +from matplotlib.figure import Figure from access.profiling.metrics import ProfilingMetric +from access.profiling.plotting_utils import calculate_column_widths def parallel_speedup(stats: xr.Dataset, metric: ProfilingMetric) -> xr.DataArray: @@ -42,3 +46,97 @@ def parallel_efficiency(stats: xr.Dataset, metric: ProfilingMetric) -> xr.DataAr eff = eff.pint.to("percent") eff.name = "parallel efficiency" return eff + + +def plot_scaling_metrics( + stats: list[xr.Dataset], + regions: list[list[str]], + metric: ProfilingMetric, + xcoordinate: str = "ncpus", + region_relabel_map: dict = None, + first_col_fraction: float = 0.4, + show: bool = True, +) -> Figure: + """Plots parallel speedup and efficiency from a list of datasets + + Args: + stats (list[xr.Dataset]): The raw times to plot. + regions (list[list[str]]): The list of regions to plot. + regions[0][:] corresponds to regions to plot in stats[0]. + metric (ProfilinMetric): The metric to plot for each stat. + xcoordinate (str): The x-axis variable e.g. "ncpus". + region_relabel_map (dict): Mapping of labels to use for each region instead of the region name. + If an element of "regions" is a key in this map, the region + will be replaced by the corresponding value in the plot. + Default: None. + first_col_fraction (float): The fraction of table width to assign to the row labels. Default 0.4. + show (bool): Whether to show the generated plot. Default: True. + + Returns: + Figure: The Matplotlib figure on which the scaling plots and table are plotted on. + + Raises: + ValueError: If region_labels is non-empty + """ + + # set default relabel map + if region_relabel_map is None: + region_relabel_map = {} + + # setup plots + fig = plt.figure(figsize=(15, 6)) + # using gridspec so table can be added + gs = gridspec.GridSpec(2, 2, height_ratios=[3, 1], hspace=0.3) + ax1, ax2 = fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1]) + ax_tbl = fig.add_subplot(gs[1, :]) + + # add table of raw timings + tbl = [[xcoordinate] + list(stats[0][xcoordinate].values)] # first row + for stat, region in zip(stats, regions, strict=True): + # calculate efficiency and speedup + efficiency = parallel_efficiency(stat, metric) + speedup = parallel_speedup(stat, metric) + + # plots speedup and efficiency on their respective axes. + max_eff = 100 + for r in region: + label = region_relabel_map.get(r) + label = label if label else r + speedup.loc[r, :].plot.line(x=xcoordinate, ax=ax1, marker="o", label=label) + efficiency.loc[r, :].plot.line(x=xcoordinate, ax=ax2, marker="o", label=label) + # find max efficiency for setting efficiency axis + max_eff = max(max_eff, efficiency.loc[r, :].max()) + + tbl.append([label] + [f"{val:.2f}" for val in stat[metric].loc[:, r].pint.dequantify().values]) + + # ideal speedup/scaling + minx = stat[xcoordinate].values.min() + nx = len(stat[xcoordinate].values) + ideal_speedups = [i / minx for i in stat[xcoordinate].values] + ax1.plot(stat[xcoordinate].values, ideal_speedups, "k:", label="ideal") + ax2.plot(stat[xcoordinate].values, [100] * nx, "k:", label="ideal") + + # formatting + ax1.legend() + ax1.grid() + ax2.grid() + ax2.set_ylim((0, 1.1 * max_eff)) + ax1.set_title("Parallel Speedup") + ax2.set_title("Parallel Efficiency") + ax_tbl.axis("off") + tbl_chart = ax_tbl.table( + tbl, + bbox=(0.05, 0, 0.9, 1), + cellLoc="center", + colWidths=calculate_column_widths(tbl, first_col_fraction), + ) + ax_tbl.set_title(f"Timings ({stat[metric].pint.units})") + for i in range(len(tbl[0])): + tbl_chart[(0, i)].set_text_props(weight="bold") + for i in range(len(tbl)): + tbl_chart[(i, 0)].set_text_props(weight="bold") + + if show: + plt.show() + + return fig diff --git a/tests/test_plotting_utils.py b/tests/test_plotting_utils.py new file mode 100644 index 0000000..1193463 --- /dev/null +++ b/tests/test_plotting_utils.py @@ -0,0 +1,77 @@ +# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from access.profiling.plotting_utils import calculate_column_widths + + +@pytest.fixture(scope="module") +def table_data(): + """Fixture returning sample table data.""" + return [ + ["Cell00", "Cell01", "Cell 02"], + ["Cell10", "Cell11", "Cell 12"], + ["Cell20", "Cell21", "Cell 22"], + ["Cell30", "Cell31", "Cell32"], + ] + + +@pytest.fixture(scope="module") +def nonrectangular_table_data(): + """Fixture returning sample invalid table data where table isn't rectangular.""" + return [ + ["Cell00", "Cell01", "Cell02"], + ["Cell10", "Cell11"], + ] + + +@pytest.fixture(scope="module") +def singlerow_table_data(): + """Fixture returning sample invalid table data where table only has 1 row.""" + return [ + ["Cell00", "Cell01", "Cell02"], + ] + + +@pytest.fixture(scope="module") +def singlecol_table_data(): + """Fixture returning sample invalid table data where table only has 1 column.""" + return [ + ["Cell00"], + ["Cell10"], + ] + + +def test_calculate_column_widths_flexible(table_data): + """Test the calculate_column_widths function.""" + # Test with empty table + assert calculate_column_widths([]) == [] + + # Test with multiple rows, multiple columns + col_widths = calculate_column_widths(table_data, 0.4) + assert abs(sum(col_widths) - 1.0) < 1e-6 # Sum to 1.0 + assert col_widths == pytest.approx([0.4, 0.2, 0.4]) # Proportional to content length with first column flexible + + +def test_calculate_column_widths_fixed(table_data): + # Test with first_col_flexible=False + col_widths = calculate_column_widths(table_data) + assert abs(sum(col_widths) - 1.0) < 1e-6 # Sum to 1.0 + assert col_widths == [0.25, 0.25, 0.5] # Proportional to content length + + +def test_wrong_column_widths(table_data): + with pytest.raises(ValueError): + calculate_column_widths(table_data, first_col_fraction=1.0) + with pytest.raises(ValueError): + calculate_column_widths(table_data, first_col_fraction=-1.0) + + +def test_invalid_tables(nonrectangular_table_data, singlerow_table_data, singlecol_table_data): + with pytest.raises(ValueError): + calculate_column_widths(nonrectangular_table_data) + with pytest.raises(ValueError): + calculate_column_widths(singlerow_table_data) + with pytest.raises(ValueError): + calculate_column_widths(singlecol_table_data) diff --git a/tests/test_scaling.py b/tests/test_scaling.py index c551ed0..e6642d5 100644 --- a/tests/test_scaling.py +++ b/tests/test_scaling.py @@ -1,9 +1,14 @@ +# Copyright 2025 ACCESS-NRI and contributors. See the top-level COPYRIGHT file for details. +# SPDX-License-Identifier: Apache-2.0 + +from unittest import mock + import pint import pytest import xarray as xr from access.profiling.metrics import count, tavg -from access.profiling.scaling import parallel_efficiency, parallel_speedup +from access.profiling.scaling import parallel_efficiency, parallel_speedup, plot_scaling_metrics @pytest.fixture() @@ -79,3 +84,16 @@ def test_incorrect_units(simple_scaling_data): """Test calculation with incorrect units.""" with pytest.raises(ValueError): parallel_speedup(simple_scaling_data, count) + + +@mock.patch("matplotlib.pyplot.show", autospec=True) +def test_plot_scaling_metrics(mock_plt, simple_scaling_data): + """Test plotting scaling metrics. Currently only checks that the function runs without errors.""" + + plot_scaling_metrics( + stats=[simple_scaling_data], + regions=[["Region 1", "Region 2"]], + metric=tavg, + xcoordinate="ncpus", + ) + mock_plt.assert_called_once()