Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gplately/mapping/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# This submodule contains code to plot maps.
# The folder is named "mapping" to avoid name conflicts with the "plot" submodule.
# The folder name is inspired by GMT(The Generic Mapping Tools).
# The PlotEngine abstract base class is defined in plot_engine.py.
# There are different PlotEngine subclasses, CartopyPlotEngine and PygmtPlotEngine, for different plotting libraries,
# such as Cartopy and PyGMT.
41 changes: 41 additions & 0 deletions gplately/mapping/cartopy_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import cartopy.crs as ccrs
from geopandas.geodataframe import GeoDataFrame

from ..grids import Raster

from ..tools import EARTH_RADIUS
from ..utils.plot_utils import plot_subduction_teeth
from .plot_engine import PlotEngine
Expand Down Expand Up @@ -133,3 +135,42 @@ def plot_subduction_zones(
color=color,
**kwargs,
)

def plot_grid(
self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs
):
"""Plot a grid onto a map using Cartopy

Parameters
----------
ax_or_fig : cartopy.mpl.geoaxes.GeoAxes
Cartopy GeoAxes instance
grid : 2D array-like
The grid data to be plotted
projection : cartopy.crs.Projection
The projection to use for the grid
extent : tuple
The extent of the grid in the form (min_lon, max_lon, min_lat, max_lat)
**kwargs :
Keyword arguments for plotting the grid. See Matplotlib's ``imshow()`` keyword arguments
`here <https://matplotlib.org/3.5.1/api/_as_gen/matplotlib.axes.Axes.imshow.html>`__.

"""
# Override matplotlib default origin ('upper')
origin = kwargs.pop("origin", "lower")

if isinstance(grid, Raster):
# extract extent and origin
extent = grid.extent
origin = grid.origin
data = grid.data
else:
data = grid

return ax_or_fig.imshow(
data,
extent=extent,
transform=projection,
origin=origin,
**kwargs,
)
7 changes: 7 additions & 0 deletions gplately/mapping/plot_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ def plot_subduction_zones(
):
"""Plot subduction zones with "teeth"(abstract method)"""
pass # This is an abstract method, no implementation here.

@abstractmethod
def plot_grid(
self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs
):
"""Plot a grid (abstract method). See :meth:`CartopyPlotEngine.plot_grid()` and :meth:`PygmtPlotEngine.plot_grid()` for details."""
pass # This is an abstract method, no implementation here.
59 changes: 59 additions & 0 deletions gplately/mapping/pygmt_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
import logging
from pathlib import Path

logger = logging.getLogger("gplately")
try:
Expand Down Expand Up @@ -135,3 +136,61 @@ def plot_subduction_zones(
fill=color,
style="f0.2/0.08+r+t",
)

def plot_grid(
self,
ax_or_fig,
grid,
projection=None,
extent=(-180, 180, -90, 90),
cmap="gmt/geo",
nan_transparent=False,
**kwargs,
):
"""Use PyGMT to plot a grid onto a map.

Parameters
----------
ax_or_fig : pygmt.Figure()
A PyGMT Figure object.
grid : Raster
A gplately Raster object or 2D array-like grid data.
projection : str
Not used currently.
extent : str or tuple
(xmin, xmax, ymin, ymax). See details at
https://www.pygmt.org/dev/tutorials/basics/regions.html
cmap : str
A built-in GMT colormaps name or a CPT file path.
nan_transparent : bool
If True, NaN values in the grid will be plotted as transparent.
**kwargs :
Additional keyword arguments.
"""
from ..grids import Raster
import xarray as xr

# we need to convert the grid data to xarray.DataArray for pygmt.grdimage().
if isinstance(grid, Raster):
data = xr.DataArray(
data=grid.data,
dims=["lat", "lon"],
coords=dict(
lon=(["lon"], grid.lons),
lat=(["lat"], grid.lats),
),
)
else:
data = xr.DataArray(grid)

# check exisence if cmap is a CPT file
if cmap.endswith(".cpt"):
if not Path(cmap).exists():
raise FileNotFoundError(f"The CPT file '{cmap}' does not exist.")

ax_or_fig.grdimage(
grid=data,
cmap=cmap,
region=extent,
nan_transparent=nan_transparent,
)
67 changes: 34 additions & 33 deletions gplately/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (C) 2024-2025 The University of Sydney, Australia
# Copyright (C) 2024-2026 The University of Sydney, Australia
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License, version 2, as published by
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def plot_subduction_teeth(
)

def plot_plate_polygon_by_id(self, ax, plate_id, color="black", **kwargs):
"""Plot a plate polygon with the given``plate_id`` on a map.
"""Plot a plate polygon with the given ``plate_id`` on a map.

Parameters
----------
Expand Down Expand Up @@ -1221,53 +1221,54 @@ def plot_plate_id(self, *args, **kwargs):
return self.plot_plate_polygon_by_id(*args, **kwargs)

def plot_grid(self, ax, grid, extent=(-180, 180, -90, 90), **kwargs):
"""Plot a `MaskedArray`_ raster or grid onto a map.

.. note::

Plotting grid with pygmt has not been implemented yet!
"""Plot a grid onto a map. The grid can be a NumPy `MaskedArray`_ object, a GPlately `Raster` object
or a time-dependent raster name.

Parameters
----------
ax :
Cartopy ax.
Cartopy ax or pygmt figure object.

grid : MaskedArray or Raster
grid : NumPy `MaskedArray`_, GPlately `Raster` or a time-dependent raster name.
A `MaskedArray`_ with elements that define a grid. The number of rows in the raster
corresponds to the number of latitudinal coordinates, while the number of raster
columns corresponds to the number of longitudinal coordinates.
Alternatively, a GPlately `Raster` object can be provided.
If a raster name is provided, the raster will be looked up from the time-dependent rasters registered in the Plate Model Manager.
The :class:`gplately.PlateReconstruction` object must be created with a valid :class:`PlateModel` object.

extent : tuple, default=(-180, 180, -90, 90)
A tuple of 4 (min_lon, max_lon, min_lat, max_lat) representing the extent of gird.

**kwargs :
Keyword arguments for plotting the grid.
See Matplotlib's ``imshow()`` keyword arguments
`here <https://matplotlib.org/3.5.1/api/_as_gen/matplotlib.axes.Axes.imshow.html>`__.

.. note::

The parameters of this function are different for different plot engines. See :meth:`CartopyPlotEngine.plot_grid`
and :meth:`PyGMTPlotEngine.plot_grid` for details.

"""
if not isinstance(self._plot_engine, CartopyPlotEngine):
raise NotImplementedError(
f"Plotting grid has not been implemented for {self._plot_engine.__class__} yet."
)
# Override matplotlib default origin ('upper')
origin = kwargs.pop("origin", "lower")

if isinstance(grid, Raster):
# extract extent and origin
extent = grid.extent
origin = grid.origin
data = grid.data
else:
data = grid
# TODO: the parameters of this function need to be unified for different plot engines.

return ax.imshow(
data,
extent=extent,
transform=self.base_projection,
origin=origin,
**kwargs,
)
if isinstance(grid, str): # grid is a raster name
if not self.plate_reconstruction.plate_model:
raise Exception(
"The 'plate_reconstruction' does not have a valid 'plate_model' object. "
"Cannot look up the raster by name. Make sure to create the 'plate_reconstruction' with a valid 'plate_model' object."
)

grid_data = Raster(
data=self.plate_reconstruction.plate_model.get_raster(grid, self.time),
plate_reconstruction=self.plate_reconstruction,
extent=(-180, 180, -90, 90),
)
return self._plot_engine.plot_grid(
ax, grid_data, extent=extent, projection=self.base_projection, **kwargs
)
else: # grid is a MaskedArray or Raster object
return self._plot_engine.plot_grid(
ax, grid, extent=extent, projection=self.base_projection, **kwargs
)

def plot_grid_from_netCDF(self, ax, filename, **kwargs):
"""Read raster data from a netCDF file, convert the data into a `MaskedArray`_ object and plot it on a map.
Expand Down
51 changes: 35 additions & 16 deletions tests-dir/unittest/test_pygmt_plot.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,69 @@
#!/usr/bin/env python3

# This test script generates a sample plot using the PygmtPlotEngine.

import os

import pygmt

os.environ["DISABLE_GPLATELY_DEV_WARNING"] = "true"

from gplately.auxiliary import get_gplot, get_pygmt_basemap_figure
from gplately.mapping.pygmt_plot import PygmtPlotEngine

if __name__ == "__main__":
model_name = "muller2019"
reconstruction_name = 55

gplot = get_gplot(
"merdith2021", "plate-model-repo", time=55, plot_engine=PygmtPlotEngine()
model_name,
"plate-model-repo",
time=reconstruction_name,
plot_engine=PygmtPlotEngine(),
)
fig = get_pygmt_basemap_figure(projection="N180/10c", region="d")

gplot.plot_grid(
fig, "AgeGrids", cmap="create-age-grids-video/agegrid.cpt", nan_transparent=True
)

# fig.coast(shorelines=True)

gplot.plot_coastlines(
fig,
edgecolor="none",
facecolor="gray",
linewidth=0.1,
central_meridian=180,
gmtlabel="Coastlines",
)
gplot.plot_topological_plate_boundaries(
fig,
edgecolor="black",
linewidth=0.25,
central_meridian=180,
gmtlabel="plate boundaries",
)
gplot.plot_coastlines(
fig, edgecolor="none", facecolor="gray", linewidth=0.1, central_meridian=180
gmtlabel="Plate Boundaries",
)
gplot.plot_ridges(fig, pen="0.5p,red", gmtlabel="ridges")
gplot.plot_transforms(fig, pen="0.5p,red", gmtlabel="transforms")
gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="subduction zones")

try:
gplot.plot_grid(fig, None)
except NotImplementedError as e:
print(e)
gplot.plot_ridges(fig, pen="0.5p,black", gmtlabel="Ridges")
gplot.plot_transforms(fig, pen="0.5p,green", gmtlabel="Transforms")
gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="Subduction Zones")

try:
gplot.plot_plate_motion_vectors(fig)
except NotImplementedError as e:
print(e)

fig.text(
text="55Ma (Merdith2021)",
text=f"{reconstruction_name}Ma",
position="TC",
no_clip=True,
font="12p,Helvetica,black",
offset="j0/-0.5c",
)
fig.legend(position="jBL+o-2.7/0", box="+gwhite+p0.5p")
with pygmt.config(FONT_ANNOT_PRIMARY=4):
fig.legend(position="jBL+o-1.0/0", box="+gwhite+p0.25p")

# fig.show(width=1200)
fig.savefig("./output/test-pygmt-plot.pdf")
output_file = "./output/test-pygmt-plot.pdf"
fig.savefig(output_file)
print(f"The figure has been saved to: {output_file}.")