From 9b2abd463ace0974dfb4c9c284e793e4be9e1981 Mon Sep 17 00:00:00 2001 From: michaelchin Date: Thu, 8 Jan 2026 14:57:04 +1100 Subject: [PATCH 1/3] initial draft to plot grid with pygmt --- gplately/mapping/cartopy_plot.py | 38 +++++++++++++++++ gplately/mapping/plot_engine.py | 7 ++++ gplately/mapping/pygmt_plot.py | 60 +++++++++++++++++++++++++++ gplately/plot.py | 22 +--------- tests-dir/unittest/test_pygmt_plot.py | 16 ++++--- 5 files changed, 117 insertions(+), 26 deletions(-) diff --git a/gplately/mapping/cartopy_plot.py b/gplately/mapping/cartopy_plot.py index 42d3bf53..e5b586e3 100644 --- a/gplately/mapping/cartopy_plot.py +++ b/gplately/mapping/cartopy_plot.py @@ -20,6 +20,8 @@ import cartopy.crs as ccrs from geopandas.geodataframe import GeoDataFrame +from ..grids import Raster + from ..tools import EARTH_RADIUS from ..utils.plot_utils import plot_subduction_teeth from .plot_engine import PlotEngine @@ -133,3 +135,39 @@ def plot_subduction_zones( color=color, **kwargs, ) + + def plot_grid( + self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs + ): + """Plot a grid onto a map using Cartopy + + Parameters + ---------- + ax_or_fig : cartopy.mpl.geoaxes.GeoAxes + Cartopy GeoAxes instance + grid : 2D array-like + The grid data to be plotted + projection : cartopy.crs.Projection + The projection to use for the grid + extent : tuple + The extent of the grid in the form (min_lon, max_lon, min_lat, max_lat) + + """ + # Override matplotlib default origin ('upper') + origin = kwargs.pop("origin", "lower") + + if isinstance(grid, Raster): + # extract extent and origin + extent = grid.extent + origin = grid.origin + data = grid.data + else: + data = grid + + return ax_or_fig.imshow( + data, + extent=extent, + transform=projection, + origin=origin, + **kwargs, + ) diff --git a/gplately/mapping/plot_engine.py b/gplately/mapping/plot_engine.py index 1606e422..d5962a91 100644 --- a/gplately/mapping/plot_engine.py +++ b/gplately/mapping/plot_engine.py @@ -45,3 +45,10 @@ def plot_subduction_zones( ): """Plot subduction zones with "teeth"(abstract method)""" pass # This is an abstract method, no implementation here. + + @abstractmethod + def plot_grid( + self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs + ): + """Plot a grid (abstract method)""" + pass # This is an abstract method, no implementation here. diff --git a/gplately/mapping/pygmt_plot.py b/gplately/mapping/pygmt_plot.py index 7dbb6ea1..78c7480c 100644 --- a/gplately/mapping/pygmt_plot.py +++ b/gplately/mapping/pygmt_plot.py @@ -135,3 +135,63 @@ def plot_subduction_zones( fill=color, style="f0.2/0.08+r+t", ) + + def plot_grid( + self, + ax_or_fig, + grid, + projection=None, + extent=(-180, 180, -90, 90), + nan_transparent=False, + **kwargs, + ): + """Use PyGMT to plot a grid onto a map. + + Parameters + ---------- + ax_or_fig : pygmt.Figure() + pygmt Figure object + grid : Raster + gplately Raster object or 2D array-like grid data + projection : str + GMT projection string, e.g., "M6i" for Mercator projection with 6-inch width. + extent : tuple + (min_lon, max_lon, min_lat, max_lat) + cmap : str + Colormap name + shading : str + Shading method, e.g., "a" for artificial illumination. + """ + from ..grids import Raster + import xarray as xr + + if isinstance(grid, Raster): + # extract extent and origin + extent = grid.extent + origin = grid.origin + data = xr.DataArray( + data=grid.data, + dims=["lat", "lon"], + coords=dict( + lon=(["lon"], grid.lons), + lat=(["lat"], grid.lats), + ), + ) + else: + data = xr.DataArray(grid) + + region = [extent[0], extent[1], extent[2], extent[3]] + + ax_or_fig.grdimage( + grid=data, + cmap="gmt/geo", + nan_transparent=nan_transparent, + ) + """ + region=region, + projection=projection, + # cmap=cmap, cmap="YlGnBu", + # shading=shading, + frame=False, + **kwargs, + )""" diff --git a/gplately/plot.py b/gplately/plot.py index c0fe79a7..f61df07b 100644 --- a/gplately/plot.py +++ b/gplately/plot.py @@ -1246,27 +1246,9 @@ def plot_grid(self, ax, grid, extent=(-180, 180, -90, 90), **kwargs): `here `__. """ - if not isinstance(self._plot_engine, CartopyPlotEngine): - raise NotImplementedError( - f"Plotting grid has not been implemented for {self._plot_engine.__class__} yet." - ) - # Override matplotlib default origin ('upper') - origin = kwargs.pop("origin", "lower") - - if isinstance(grid, Raster): - # extract extent and origin - extent = grid.extent - origin = grid.origin - data = grid.data - else: - data = grid - return ax.imshow( - data, - extent=extent, - transform=self.base_projection, - origin=origin, - **kwargs, + return self._plot_engine.plot_grid( + ax, grid, extent=extent, projection=self.base_projection, **kwargs ) def plot_grid_from_netCDF(self, ax, filename, **kwargs): diff --git a/tests-dir/unittest/test_pygmt_plot.py b/tests-dir/unittest/test_pygmt_plot.py index d97a7774..68c1b537 100755 --- a/tests-dir/unittest/test_pygmt_plot.py +++ b/tests-dir/unittest/test_pygmt_plot.py @@ -3,14 +3,23 @@ os.environ["DISABLE_GPLATELY_DEV_WARNING"] = "true" +from gplately import Raster from gplately.auxiliary import get_gplot, get_pygmt_basemap_figure from gplately.mapping.pygmt_plot import PygmtPlotEngine if __name__ == "__main__": gplot = get_gplot( - "merdith2021", "plate-model-repo", time=55, plot_engine=PygmtPlotEngine() + "muller2019", "plate-model-repo", time=55, plot_engine=PygmtPlotEngine() ) fig = get_pygmt_basemap_figure(projection="N180/10c", region="d") + + age_grid_raster = Raster( + data=gplot.plate_reconstruction.plate_model.get_raster("AgeGrids", 55), + plate_reconstruction=gplot.plate_reconstruction, + extent=(-180, 180, -90, 90), + ) + gplot.plot_grid(fig, age_grid_raster, nan_transparent=True) + # fig.coast(shorelines=True) gplot.plot_topological_plate_boundaries( @@ -27,11 +36,6 @@ gplot.plot_transforms(fig, pen="0.5p,red", gmtlabel="transforms") gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="subduction zones") - try: - gplot.plot_grid(fig, None) - except NotImplementedError as e: - print(e) - try: gplot.plot_plate_motion_vectors(fig) except NotImplementedError as e: From 3a75de10a46e29dba129f08ec3a530a41c5aec2b Mon Sep 17 00:00:00 2001 From: michaelchin Date: Thu, 8 Jan 2026 15:59:32 +1100 Subject: [PATCH 2/3] work to plot time dependent raster by name --- gplately/mapping/cartopy_plot.py | 3 ++ gplately/plot.py | 49 ++++++++++++++++++--------- tests-dir/unittest/test_pygmt_plot.py | 22 ++++++------ 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/gplately/mapping/cartopy_plot.py b/gplately/mapping/cartopy_plot.py index e5b586e3..dcccfdea 100644 --- a/gplately/mapping/cartopy_plot.py +++ b/gplately/mapping/cartopy_plot.py @@ -151,6 +151,9 @@ def plot_grid( The projection to use for the grid extent : tuple The extent of the grid in the form (min_lon, max_lon, min_lat, max_lat) + **kwargs : + Keyword arguments for plotting the grid. See Matplotlib's ``imshow()`` keyword arguments + `here `__. """ # Override matplotlib default origin ('upper') diff --git a/gplately/plot.py b/gplately/plot.py index f61df07b..971ffe01 100644 --- a/gplately/plot.py +++ b/gplately/plot.py @@ -1,5 +1,5 @@ # -# Copyright (C) 2024-2025 The University of Sydney, Australia +# Copyright (C) 2024-2026 The University of Sydney, Australia # # This program is free software; you can redistribute it and/or modify it under # the terms of the GNU General Public License, version 2, as published by @@ -1175,7 +1175,7 @@ def plot_subduction_teeth( ) def plot_plate_polygon_by_id(self, ax, plate_id, color="black", **kwargs): - """Plot a plate polygon with the given``plate_id`` on a map. + """Plot a plate polygon with the given ``plate_id`` on a map. Parameters ---------- @@ -1221,35 +1221,52 @@ def plot_plate_id(self, *args, **kwargs): return self.plot_plate_polygon_by_id(*args, **kwargs) def plot_grid(self, ax, grid, extent=(-180, 180, -90, 90), **kwargs): - """Plot a `MaskedArray`_ raster or grid onto a map. - - .. note:: - - Plotting grid with pygmt has not been implemented yet! + """Plot a grid onto a map. The grid can be a NumPy `MaskedArray`_ object, a GPlately `Raster` object + or a time-dependent raster name. Parameters ---------- ax : - Cartopy ax. + Cartopy ax or pygmt figure object. - grid : MaskedArray or Raster + grid : NumPy `MaskedArray`_, GPlately `Raster` or a time-dependent raster name. A `MaskedArray`_ with elements that define a grid. The number of rows in the raster corresponds to the number of latitudinal coordinates, while the number of raster columns corresponds to the number of longitudinal coordinates. + Alternatively, a GPlately `Raster` object can be provided. + If a raster name is provided, the raster will be looked up from the time-dependent rasters registered in the Plate Model Manager. + The :class:`gplately.PlateReconstruction` object must be created with a valid :class:`PlateModel` object. extent : tuple, default=(-180, 180, -90, 90) A tuple of 4 (min_lon, max_lon, min_lat, max_lat) representing the extent of gird. - **kwargs : - Keyword arguments for plotting the grid. - See Matplotlib's ``imshow()`` keyword arguments - `here `__. + + .. note:: + + The parameters of this function are different for different plot engines. See `CartopyPlotEngine.plot_grid` + and `PyGMTPlotEngine.plot_grid` for details. """ - return self._plot_engine.plot_grid( - ax, grid, extent=extent, projection=self.base_projection, **kwargs - ) + if isinstance(grid, str): # grid is a raster name + if not self.plate_reconstruction.plate_model: + raise Exception( + "The 'plate_reconstruction' does not have a valid 'plate_model'. " + "Cannot look up the raster by name. Make sure to create the 'plate_reconstruction' with a valid 'plate_model'." + ) + + grid_data = Raster( + data=self.plate_reconstruction.plate_model.get_raster(grid, self.time), + plate_reconstruction=self.plate_reconstruction, + extent=(-180, 180, -90, 90), + ) + return self._plot_engine.plot_grid( + ax, grid_data, extent=extent, projection=self.base_projection, **kwargs + ) + else: # grid is a MaskedArray or Raster object + return self._plot_engine.plot_grid( + ax, grid, extent=extent, projection=self.base_projection, **kwargs + ) def plot_grid_from_netCDF(self, ax, filename, **kwargs): """Read raster data from a netCDF file, convert the data into a `MaskedArray`_ object and plot it on a map. diff --git a/tests-dir/unittest/test_pygmt_plot.py b/tests-dir/unittest/test_pygmt_plot.py index 68c1b537..1ccffe09 100755 --- a/tests-dir/unittest/test_pygmt_plot.py +++ b/tests-dir/unittest/test_pygmt_plot.py @@ -3,22 +3,22 @@ os.environ["DISABLE_GPLATELY_DEV_WARNING"] = "true" -from gplately import Raster from gplately.auxiliary import get_gplot, get_pygmt_basemap_figure from gplately.mapping.pygmt_plot import PygmtPlotEngine if __name__ == "__main__": + model_name = "muller2019" + reconstruction_name = 55 + gplot = get_gplot( - "muller2019", "plate-model-repo", time=55, plot_engine=PygmtPlotEngine() + model_name, + "plate-model-repo", + time=reconstruction_name, + plot_engine=PygmtPlotEngine(), ) fig = get_pygmt_basemap_figure(projection="N180/10c", region="d") - age_grid_raster = Raster( - data=gplot.plate_reconstruction.plate_model.get_raster("AgeGrids", 55), - plate_reconstruction=gplot.plate_reconstruction, - extent=(-180, 180, -90, 90), - ) - gplot.plot_grid(fig, age_grid_raster, nan_transparent=True) + gplot.plot_grid(fig, "AgeGrids", nan_transparent=True) # fig.coast(shorelines=True) @@ -42,7 +42,7 @@ print(e) fig.text( - text="55Ma (Merdith2021)", + text=f"{reconstruction_name}Ma ({model_name})", position="TC", no_clip=True, font="12p,Helvetica,black", @@ -51,4 +51,6 @@ fig.legend(position="jBL+o-2.7/0", box="+gwhite+p0.5p") # fig.show(width=1200) - fig.savefig("./output/test-pygmt-plot.pdf") + output_file = "./output/test-pygmt-plot.pdf" + fig.savefig(output_file) + print(f"The figure has been saved to: {output_file}.") From d9e311fafd6bdae605a0ae7c07ec10b3956b50a5 Mon Sep 17 00:00:00 2001 From: michaelchin Date: Thu, 15 Jan 2026 17:04:16 +1100 Subject: [PATCH 3/3] finish up plotting grid with pygmt --- gplately/mapping/__init__.py | 5 ++++ gplately/mapping/plot_engine.py | 2 +- gplately/mapping/pygmt_plot.py | 41 +++++++++++++-------------- gplately/plot.py | 10 ++++--- tests-dir/unittest/test_pygmt_plot.py | 33 ++++++++++++++------- 5 files changed, 55 insertions(+), 36 deletions(-) diff --git a/gplately/mapping/__init__.py b/gplately/mapping/__init__.py index 2bf22597..cf0a93bb 100644 --- a/gplately/mapping/__init__.py +++ b/gplately/mapping/__init__.py @@ -1 +1,6 @@ # This submodule contains code to plot maps. +# The folder is named "mapping" to avoid name conflicts with the "plot" submodule. +# The folder name is inspired by GMT(The Generic Mapping Tools). +# The PlotEngine abstract base class is defined in plot_engine.py. +# There are different PlotEngine subclasses, CartopyPlotEngine and PygmtPlotEngine, for different plotting libraries, +# such as Cartopy and PyGMT. diff --git a/gplately/mapping/plot_engine.py b/gplately/mapping/plot_engine.py index d5962a91..e21964b5 100644 --- a/gplately/mapping/plot_engine.py +++ b/gplately/mapping/plot_engine.py @@ -50,5 +50,5 @@ def plot_subduction_zones( def plot_grid( self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs ): - """Plot a grid (abstract method)""" + """Plot a grid (abstract method). See :meth:`CartopyPlotEngine.plot_grid()` and :meth:`PygmtPlotEngine.plot_grid()` for details.""" pass # This is an abstract method, no implementation here. diff --git a/gplately/mapping/pygmt_plot.py b/gplately/mapping/pygmt_plot.py index 78c7480c..5cdbc613 100644 --- a/gplately/mapping/pygmt_plot.py +++ b/gplately/mapping/pygmt_plot.py @@ -15,6 +15,7 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # import logging +from pathlib import Path logger = logging.getLogger("gplately") try: @@ -142,6 +143,7 @@ def plot_grid( grid, projection=None, extent=(-180, 180, -90, 90), + cmap="gmt/geo", nan_transparent=False, **kwargs, ): @@ -150,25 +152,26 @@ def plot_grid( Parameters ---------- ax_or_fig : pygmt.Figure() - pygmt Figure object + A PyGMT Figure object. grid : Raster - gplately Raster object or 2D array-like grid data + A gplately Raster object or 2D array-like grid data. projection : str - GMT projection string, e.g., "M6i" for Mercator projection with 6-inch width. - extent : tuple - (min_lon, max_lon, min_lat, max_lat) + Not used currently. + extent : str or tuple + (xmin, xmax, ymin, ymax). See details at + https://www.pygmt.org/dev/tutorials/basics/regions.html cmap : str - Colormap name - shading : str - Shading method, e.g., "a" for artificial illumination. + A built-in GMT colormaps name or a CPT file path. + nan_transparent : bool + If True, NaN values in the grid will be plotted as transparent. + **kwargs : + Additional keyword arguments. """ from ..grids import Raster import xarray as xr + # we need to convert the grid data to xarray.DataArray for pygmt.grdimage(). if isinstance(grid, Raster): - # extract extent and origin - extent = grid.extent - origin = grid.origin data = xr.DataArray( data=grid.data, dims=["lat", "lon"], @@ -180,18 +183,14 @@ def plot_grid( else: data = xr.DataArray(grid) - region = [extent[0], extent[1], extent[2], extent[3]] + # check exisence if cmap is a CPT file + if cmap.endswith(".cpt"): + if not Path(cmap).exists(): + raise FileNotFoundError(f"The CPT file '{cmap}' does not exist.") ax_or_fig.grdimage( grid=data, - cmap="gmt/geo", + cmap=cmap, + region=extent, nan_transparent=nan_transparent, ) - """ - region=region, - projection=projection, - # cmap=cmap, cmap="YlGnBu", - # shading=shading, - frame=False, - **kwargs, - )""" diff --git a/gplately/plot.py b/gplately/plot.py index 971ffe01..619c59d8 100644 --- a/gplately/plot.py +++ b/gplately/plot.py @@ -1243,16 +1243,18 @@ def plot_grid(self, ax, grid, extent=(-180, 180, -90, 90), **kwargs): .. note:: - The parameters of this function are different for different plot engines. See `CartopyPlotEngine.plot_grid` - and `PyGMTPlotEngine.plot_grid` for details. + The parameters of this function are different for different plot engines. See :meth:`CartopyPlotEngine.plot_grid` + and :meth:`PyGMTPlotEngine.plot_grid` for details. """ + # TODO: the parameters of this function need to be unified for different plot engines. + if isinstance(grid, str): # grid is a raster name if not self.plate_reconstruction.plate_model: raise Exception( - "The 'plate_reconstruction' does not have a valid 'plate_model'. " - "Cannot look up the raster by name. Make sure to create the 'plate_reconstruction' with a valid 'plate_model'." + "The 'plate_reconstruction' does not have a valid 'plate_model' object. " + "Cannot look up the raster by name. Make sure to create the 'plate_reconstruction' with a valid 'plate_model' object." ) grid_data = Raster( diff --git a/tests-dir/unittest/test_pygmt_plot.py b/tests-dir/unittest/test_pygmt_plot.py index 1ccffe09..ace85df2 100755 --- a/tests-dir/unittest/test_pygmt_plot.py +++ b/tests-dir/unittest/test_pygmt_plot.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 + +# This test script generates a sample plot using the PygmtPlotEngine. + import os +import pygmt + os.environ["DISABLE_GPLATELY_DEV_WARNING"] = "true" from gplately.auxiliary import get_gplot, get_pygmt_basemap_figure @@ -18,23 +23,30 @@ ) fig = get_pygmt_basemap_figure(projection="N180/10c", region="d") - gplot.plot_grid(fig, "AgeGrids", nan_transparent=True) + gplot.plot_grid( + fig, "AgeGrids", cmap="create-age-grids-video/agegrid.cpt", nan_transparent=True + ) # fig.coast(shorelines=True) + gplot.plot_coastlines( + fig, + edgecolor="none", + facecolor="gray", + linewidth=0.1, + central_meridian=180, + gmtlabel="Coastlines", + ) gplot.plot_topological_plate_boundaries( fig, edgecolor="black", linewidth=0.25, central_meridian=180, - gmtlabel="plate boundaries", - ) - gplot.plot_coastlines( - fig, edgecolor="none", facecolor="gray", linewidth=0.1, central_meridian=180 + gmtlabel="Plate Boundaries", ) - gplot.plot_ridges(fig, pen="0.5p,red", gmtlabel="ridges") - gplot.plot_transforms(fig, pen="0.5p,red", gmtlabel="transforms") - gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="subduction zones") + gplot.plot_ridges(fig, pen="0.5p,black", gmtlabel="Ridges") + gplot.plot_transforms(fig, pen="0.5p,green", gmtlabel="Transforms") + gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="Subduction Zones") try: gplot.plot_plate_motion_vectors(fig) @@ -42,13 +54,14 @@ print(e) fig.text( - text=f"{reconstruction_name}Ma ({model_name})", + text=f"{reconstruction_name}Ma", position="TC", no_clip=True, font="12p,Helvetica,black", offset="j0/-0.5c", ) - fig.legend(position="jBL+o-2.7/0", box="+gwhite+p0.5p") + with pygmt.config(FONT_ANNOT_PRIMARY=4): + fig.legend(position="jBL+o-1.0/0", box="+gwhite+p0.25p") # fig.show(width=1200) output_file = "./output/test-pygmt-plot.pdf"