Skip to content

Commit f8e225e

Browse files
committed
feat(plotting): Replace group_cmaps with group_colors using OKLab interpolation
- Implemented group_colors parameter accepting dict of colors - Added OKLab white-to-color gradient generation in _utils.py - Updated DotPlot to use dynamic ListedColormaps - Added colour-science as optional dependency - Updated default legend width to 2.0 to fix layout bugs - Updated reference images for dotplots and affected plots
1 parent 525aaa6 commit f8e225e

File tree

42 files changed

+221
-52
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+221
-52
lines changed

docs/release-notes/3764.feature.md

Lines changed: 1 addition & 1 deletion

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ skmisc = [ "scikit-misc>=0.5.1" ] # highly_variable_genes m
144144
harmony = [ "harmonypy" ] # Harmony dataset integration
145145
scanorama = [ "scanorama" ] # Scanorama dataset integration
146146
scrublet = [ "scikit-image>=0.23" ] # Doublet detection with automatic thresholds
147+
# Plotting
148+
plotting = [ "colour-science" ]
147149
# Acceleration
148150
rapids = [ "cudf>=0.9", "cuml>=0.9", "cugraph>=0.9" ] # GPU accelerated calculation of neighbors
149151
dask = [ "dask[array]>=2024.5.1", "anndata[dask]" ] # Use the Dask parallelization engine

src/scanpy/plotting/_dotplot.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@
66
from matplotlib import colormaps
77

88
from .. import logging as logg
9-
from .._compat import old_positionals
9+
from .._compat import old_positionals, warn
1010
from .._settings import settings
1111
from .._utils import _doc_params, _empty
1212
from ._baseplot_class import BasePlot, doc_common_groupby_plot_args
1313
from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm
14-
from ._utils import _dk, check_colornorm, fix_kwds, make_grid_spec, savefig_or_show
14+
from ._utils import (
15+
_create_white_to_color_gradient,
16+
_dk,
17+
check_colornorm,
18+
fix_kwds,
19+
make_grid_spec,
20+
savefig_or_show,
21+
)
1522

1623
if TYPE_CHECKING:
1724
from collections.abc import Mapping, Sequence
@@ -156,7 +163,7 @@ def __init__( # noqa: PLR0913
156163
vmax: float | None = None,
157164
vcenter: float | None = None,
158165
norm: Normalize | None = None,
159-
group_cmaps: Mapping[str, str] | None = None,
166+
group_colors: Mapping[str, ColorLike] | None = None,
160167
**kwds,
161168
) -> None:
162169
BasePlot.__init__(
@@ -209,23 +216,30 @@ def __init__( # noqa: PLR0913
209216
self.standard_scale = standard_scale
210217
self.expression_cutoff = expression_cutoff
211218
self.mean_only_expressed = mean_only_expressed
212-
self.group_cmaps = group_cmaps
219+
self.group_colors = group_colors
220+
self.group_cmaps = None
213221

214222
self.dot_color_df, self.dot_size_df = self._prepare_dot_data(
215223
dot_color_df, dot_size_df
216224
)
217-
218-
# If group_cmaps is used, validate that all plotted groups have a defined colormap.
219-
if self.group_cmaps is not None:
220-
plotted_groups = set(self.dot_color_df.index)
221-
defined_groups = set(self.group_cmaps.keys())
222-
missing_groups = plotted_groups - defined_groups
225+
if self.group_colors is not None:
226+
self.group_cmaps = {}
227+
plotted_groups = self.dot_color_df.index
228+
missing_groups = []
229+
for group in plotted_groups:
230+
if group in self.group_colors:
231+
self.group_cmaps[group] = _create_white_to_color_gradient(
232+
self.group_colors[group]
233+
)
234+
else:
235+
self.group_cmaps[group] = self.cmap
236+
missing_groups.append(group)
223237
if missing_groups:
224-
msg = (
225-
"The following groups are in the plot data but are missing from the `group_cmaps` dictionary. "
226-
f"Please define a colormap for them: {sorted(missing_groups)}"
238+
warn(
239+
f"The following groups will use the default colormap as no "
240+
f"specific colors were assigned: {missing_groups}",
241+
UserWarning,
227242
)
228-
raise ValueError(msg)
229243

230244
def _prepare_dot_data(self, dot_color_df, dot_size_df):
231245
"""Calculate the dataframes for dot size and color."""
@@ -638,7 +652,10 @@ def _plot_stacked_colorbars(self, fig, colorbar_area_spec, normalize):
638652
ax = fig.add_subplot(
639653
colorbar_gs[i, 0]
640654
) # Place the colorbar Axes in the first, wider column
641-
cmap = colormaps.get_cmap(self.group_cmaps[group_name])
655+
cmap = self.group_cmaps[group_name]
656+
# Handle fallback case where cmap might be a string
657+
if isinstance(cmap, str):
658+
cmap = colormaps.get_cmap(cmap)
642659
mappable = ScalarMappable(norm=legend_norm, cmap=cmap)
643660

644661
cb = matplotlib.colorbar.Colorbar(
@@ -894,8 +911,10 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915
894911

895912
# Here we loop through each group and plot it with its own cmap
896913
for group_idx, group_name in enumerate(groups_iter):
897-
group_cmap_name = group_cmaps[group_name]
898-
group_cmap = colormaps.get_cmap(group_cmap_name)
914+
group_cmap = group_cmaps[group_name]
915+
# Handle fallback case where group_cmap might be a string
916+
if isinstance(group_cmap, str):
917+
group_cmap = colormaps.get_cmap(group_cmap)
899918

900919
# Slice the flattened data arrays correctly depending on orientation
901920
if not are_axes_swapped:
@@ -1017,7 +1036,7 @@ def dotplot( # noqa: PLR0913
10171036
norm: Normalize | None = None,
10181037
# Style parameters
10191038
cmap: Colormap | str | None = DotPlot.DEFAULT_COLORMAP,
1020-
group_cmaps: Mapping[str, str] | None = None,
1039+
group_colors: Mapping[str, ColorLike] | None = None,
10211040
dot_max: float | None = DotPlot.DEFAULT_DOT_MAX,
10221041
dot_min: float | None = DotPlot.DEFAULT_DOT_MIN,
10231042
smallest_dot: float = DotPlot.DEFAULT_SMALLEST_DOT,
@@ -1056,11 +1075,14 @@ def dotplot( # noqa: PLR0913
10561075
mean_only_expressed
10571076
If True, gene expression is averaged only over the cells
10581077
expressing the given genes.
1059-
group_cmaps
1060-
A mapping of group names to colormap names, e.g.
1061-
`{{'T-cell': 'Blues', 'B-cell': 'Reds'}}`. This allows for specifying a
1062-
different colormap for each group. If used, all groups in the plot
1063-
must have a colormap defined in this mapping.
1078+
group_colors
1079+
A mapping of group names to colors.
1080+
e.g. `{{'T-cell': 'blue', 'B-cell': '#aa40fc'}}`.
1081+
Colors can be specified as any valid matplotlib color.
1082+
If `group_colors` is used, a colormap is generated from white
1083+
to the given color for each group.
1084+
If a group is not present in the dictionary, the value of `cmap`
1085+
is used.
10641086
dot_max
10651087
If ``None``, the maximum dot size is set to the maximum fraction value found
10661088
(e.g. 0.6). If given, the value should be a number between 0 and 1.
@@ -1129,6 +1151,14 @@ def dotplot( # noqa: PLR0913
11291151
# instead of `cmap`
11301152
cmap = kwds.pop("color_map", cmap)
11311153

1154+
# Warn if both cmap and group_colors are specified
1155+
if group_colors is not None and cmap != DotPlot.DEFAULT_COLORMAP:
1156+
warn(
1157+
"Both `cmap` and `group_colors` are specified. "
1158+
"`group_colors` takes precedence for the specified groups.",
1159+
UserWarning,
1160+
)
1161+
11321162
dp = DotPlot(
11331163
adata,
11341164
var_names,
@@ -1148,7 +1178,7 @@ def dotplot( # noqa: PLR0913
11481178
var_group_rotation=var_group_rotation,
11491179
layer=layer,
11501180
dot_color_df=dot_color_df,
1151-
group_cmaps=group_cmaps,
1181+
group_colors=group_colors,
11521182
ax=ax,
11531183
vmin=vmin,
11541184
vmax=vmax,

src/scanpy/plotting/_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,3 +1105,56 @@ def _deprecated_scale(
11051105
def _dk(dendrogram: bool | str | None) -> str | None: # noqa: FBT001
11061106
"""Convert the `dendrogram` parameter to a `dendrogram_key` parameter."""
11071107
return None if isinstance(dendrogram, bool) else dendrogram
1108+
1109+
1110+
def _create_white_to_color_gradient(color: ColorLike, n_steps: int = 256):
1111+
"""Generate a perceptually uniform colormap from white to a target color.
1112+
1113+
This function uses the OKLab color space for interpolation to ensure that
1114+
the brightness of the generated colormap changes uniformly.
1115+
1116+
Parameters
1117+
----------
1118+
color
1119+
The target color for the gradient. Can be any valid matplotlib color.
1120+
n_steps
1121+
The number of steps in the colormap.
1122+
1123+
Returns
1124+
-------
1125+
A `matplotlib.colors.ListedColormap` object.
1126+
"""
1127+
try:
1128+
import colour
1129+
except ImportError:
1130+
msg = (
1131+
"Please install the `colour-science` package to use `group_colors`: "
1132+
"`pip install colour-science` or `pip install scanpy[plotting]`"
1133+
)
1134+
raise ImportError(msg) from None
1135+
from matplotlib.colors import ListedColormap, to_hex
1136+
1137+
# Convert the input color to a hex string
1138+
hex_color = to_hex(color, keep_alpha=False)
1139+
1140+
# Define the color space for interpolation
1141+
space = "OKLab"
1142+
1143+
# Convert start (white) and end (target color) to the OKLab color space
1144+
target_oklab = colour.convert(hex_color, "Hexadecimal", space)
1145+
white_oklab = colour.convert("#ffffff", "Hexadecimal", space)
1146+
1147+
# Create the gradient through linear interpolation in OKLab
1148+
gradient = colour.algebra.lerp(
1149+
np.linspace(0, 1, n_steps)[..., np.newaxis],
1150+
white_oklab,
1151+
target_oklab,
1152+
)
1153+
1154+
# Convert the gradient back to sRGB for display
1155+
rgb_gradient = colour.convert(gradient, space, "sRGB")
1156+
1157+
# Clip values to be within the valid [0, 1] range for RGB
1158+
clipped_rgb = np.clip(rgb_gradient, 0, 1)
1159+
1160+
return ListedColormap(clipped_rgb)

src/testing/scanpy/_pytest/marks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _generate_next_value_(
3030

3131
mod: str
3232

33+
colour = "colour-science"
3334
dask = auto()
3435
dask_ml = auto()
3536
fa2 = auto()
-3.5 KB
-2.96 KB

tests/_images/dotplot/expected.png

-100 Bytes
-39 Bytes
-66 Bytes

0 commit comments

Comments
 (0)