diff --git a/.github/workflows/build-ultraplot.yml b/.github/workflows/build-ultraplot.yml index 3a2e9a9eb..4fe9693e9 100644 --- a/.github/workflows/build-ultraplot.yml +++ b/.github/workflows/build-ultraplot.yml @@ -136,9 +136,9 @@ jobs: with: path: ./ultraplot/tests/baseline # The directory to cache # Key is based on OS, Python/Matplotlib versions, and the base commit SHA - key: ${{ runner.os }}-baseline-base-v4-hs${{ env.PYTHONHASHSEED }}-${{ steps.baseline-ref.outputs.base_sha }}-${{ inputs.python-version }}-${{ inputs.matplotlib-version }} + key: ${{ runner.os }}-baseline-base-v5-hs${{ env.PYTHONHASHSEED }}-${{ steps.baseline-ref.outputs.base_sha }}-${{ inputs.python-version }}-${{ inputs.matplotlib-version }} restore-keys: | - ${{ runner.os }}-baseline-base-v4-hs${{ env.PYTHONHASHSEED }}-${{ steps.baseline-ref.outputs.base_sha }}-${{ inputs.python-version }}-${{ inputs.matplotlib-version }}- + ${{ runner.os }}-baseline-base-v5-hs${{ env.PYTHONHASHSEED }}-${{ steps.baseline-ref.outputs.base_sha }}-${{ inputs.python-version }}-${{ inputs.matplotlib-version }}- # Conditional Baseline Generation (Only runs on cache miss) - name: Generate baseline from main diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 4ec1ac1f8..62cda82aa 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -71,30 +71,6 @@ # A-b-c label string ABC_STRING = "abcdefghijklmnopqrstuvwxyz" -# Legend align options -ALIGN_OPTS = { - None: { - "center": "center", - "left": "center left", - "right": "center right", - "top": "upper center", - "bottom": "lower center", - }, - "left": { - "top": "upper right", - "center": "center right", - "bottom": "lower right", - }, - "right": { - "top": "upper left", - "center": "center left", - "bottom": "lower left", - }, - "top": {"left": "lower left", "center": "lower center", "right": "lower right"}, - "bottom": {"left": "upper left", "center": "upper center", "right": "upper right"}, -} - - # Projection docstring _proj_docstring = """ proj, projection : \\ @@ -1486,148 +1462,38 @@ def _add_legend( cols: Optional[Union[int, Tuple[int, int]]] = None, **kwargs, ): - """ - The driver function for adding axes legends. - """ - # Parse input argument units - ncol = _not_none(ncols=ncols, ncol=ncol) - order = _not_none(order, "C") - frameon = _not_none(frame=frame, frameon=frameon, default=rc["legend.frameon"]) - fontsize = _not_none(fontsize, rc["legend.fontsize"]) - titlefontsize = _not_none( - title_fontsize=kwargs.pop("title_fontsize", None), - titlefontsize=titlefontsize, - default=rc["legend.title_fontsize"], - ) - fontsize = _fontsize_to_pt(fontsize) - titlefontsize = _fontsize_to_pt(titlefontsize) - if order not in ("F", "C"): - raise ValueError( - f"Invalid order {order!r}. Please choose from " - "'C' (row-major, default) or 'F' (column-major)." - ) - - # Convert relevant keys to em-widths - for setting in rcsetup.EM_KEYS: # em-width keys - pair = setting.split("legend.", 1) - if len(pair) == 1: - continue - _, key = pair - value = kwargs.pop(key, None) - if isinstance(value, str): - value = units(value, "em", fontsize=fontsize) - if value is not None: - kwargs[key] = value - - # Generate and prepare the legend axes - if loc in ("fill", "left", "right", "top", "bottom"): - lax = self._add_guide_panel( - loc, - align, - width=width, - space=space, - pad=pad, - span=span, - row=row, - col=col, - rows=rows, - cols=cols, - ) - kwargs.setdefault("borderaxespad", 0) - if not frameon: - kwargs.setdefault("borderpad", 0) - try: - kwargs["loc"] = ALIGN_OPTS[lax._panel_side][align] - except KeyError: - raise ValueError(f"Invalid align={align!r} for legend loc={loc!r}.") - else: - lax = self - pad = kwargs.pop("borderaxespad", pad) - kwargs["loc"] = loc # simply pass to legend - kwargs["borderaxespad"] = units(pad, "em", fontsize=fontsize) - - # Handle and text properties that are applied after-the-fact - # NOTE: Set solid_capstyle to 'butt' so line does not extend past error bounds - # shading in legend entry. This change is not noticable in other situations. - kw_frame, kwargs = lax._parse_frame("legend", **kwargs) - kw_text = {} - if fontcolor is not None: - kw_text["color"] = fontcolor - if fontweight is not None: - kw_text["weight"] = fontweight - kw_title = {} - if titlefontcolor is not None: - kw_title["color"] = titlefontcolor - if titlefontweight is not None: - kw_title["weight"] = titlefontweight - kw_handle = _pop_props(kwargs, "line") - kw_handle.setdefault("solid_capstyle", "butt") - kw_handle.update(handle_kw or {}) - - # Parse the legend arguments using axes for auto-handle detection - # TODO: Update this when we no longer use "filled panels" for outer legends - pairs, multi = lax._parse_legend_handles( + return plegend.UltraLegend(self).add( handles, labels, + loc=loc, + align=align, + width=width, + pad=pad, + space=space, + frame=frame, + frameon=frameon, ncol=ncol, - order=order, - center=center, + ncols=ncols, alphabetize=alphabetize, + center=center, + order=order, + label=label, + title=title, + fontsize=fontsize, + fontweight=fontweight, + fontcolor=fontcolor, + titlefontsize=titlefontsize, + titlefontweight=titlefontweight, + titlefontcolor=titlefontcolor, + handle_kw=handle_kw, handler_map=handler_map, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, ) - title = _not_none(label=label, title=title) - kwargs.update( - { - "title": title, - "frameon": frameon, - "fontsize": fontsize, - "handler_map": handler_map, - "title_fontsize": titlefontsize, - } - ) - - # Add the legend and update patch properties - # TODO: Add capacity for categorical labels in a single legend like seaborn - # rather than manual handle overrides with multiple legends. - if multi: - objs = lax._parse_legend_centered(pairs, kw_frame=kw_frame, **kwargs) - else: - kwargs.update({key: kw_frame.pop(key) for key in ("shadow", "fancybox")}) - objs = [lax._parse_legend_aligned(pairs, ncol=ncol, order=order, **kwargs)] - objs[0].legendPatch.update(kw_frame) - for obj in objs: - if hasattr(lax, "legend_") and lax.legend_ is None: - lax.legend_ = obj # make first legend accessible with get_legend() - else: - lax.add_artist(obj) - - # Update legend patch and elements - # WARNING: legendHandles only contains the *first* artist per legend because - # HandlerBase.legend_artist() called in Legend._init_legend_box() only - # returns the first artist. Instead we try to iterate through offset boxes. - for obj in objs: - obj.set_clip_on(False) # needed for tight bounding box calculations - box = getattr(obj, "_legend_handle_box", None) - for obj in guides._iter_children(box): - if isinstance(obj, mtext.Text): - kw = kw_text - else: - kw = { - key: val - for key, val in kw_handle.items() - if hasattr(obj, "set_" + key) - } # noqa: E501 - if hasattr(obj, "set_sizes") and "markersize" in kw_handle: - kw["sizes"] = np.atleast_1d(kw_handle["markersize"]) - obj.update(kw) - - # Register location and return - if isinstance(objs[0], mpatches.FancyBboxPatch): - objs = objs[1:] - obj = objs[0] if len(objs) == 1 else tuple(objs) - self._register_guide("legend", obj, (loc, align)) # possibly replace another - - return obj def _apply_title_above(self): """ diff --git a/ultraplot/legend.py b/ultraplot/legend.py index 9d11ffb9e..043a07430 100644 --- a/ultraplot/legend.py +++ b/ultraplot/legend.py @@ -1,11 +1,94 @@ +from dataclasses import dataclass +from typing import Any, Iterable, Optional, Tuple, Union + +import numpy as np +import matplotlib.patches as mpatches +import matplotlib.text as mtext from matplotlib import legend as mlegend +from .config import rc +from .internals import _not_none, _pop_props, guides, rcsetup +from .utils import _fontsize_to_pt, units + try: from typing import override except ImportError: from typing_extensions import override +ALIGN_OPTS = { + None: { + "center": "center", + "left": "center left", + "right": "center right", + "top": "upper center", + "bottom": "lower center", + }, + "left": { + "center": "center right", + "left": "center right", + "right": "center right", + "top": "upper right", + "bottom": "lower right", + }, + "right": { + "center": "center left", + "left": "center left", + "right": "center left", + "top": "upper left", + "bottom": "lower left", + }, + "top": { + "center": "lower center", + "left": "lower left", + "right": "lower right", + "top": "lower center", + "bottom": "lower center", + }, + "bottom": { + "center": "upper center", + "left": "upper left", + "right": "upper right", + "top": "upper center", + "bottom": "upper center", + }, +} + +LegendKw = dict[str, Any] +LegendHandles = Any +LegendLabels = Any + + +@dataclass(frozen=True) +class _LegendInputs: + handles: LegendHandles + labels: LegendLabels + loc: Any + align: Any + width: Any + pad: Any + space: Any + frameon: bool + ncol: Any + order: str + label: Any + title: Any + fontsize: float + fontweight: Any + fontcolor: Any + titlefontsize: float + titlefontweight: Any + titlefontcolor: Any + handle_kw: Any + handler_map: Any + span: Optional[Union[int, Tuple[int, int]]] + row: Optional[int] + col: Optional[int] + rows: Optional[Union[int, Tuple[int, int]]] + cols: Optional[Union[int, Tuple[int, int]]] + kwargs: dict[str, Any] + + class Legend(mlegend.Legend): # Soft wrapper of matplotlib legend's class. # Currently we only override the syncing of the location. @@ -31,3 +114,432 @@ def set_loc(self, loc=None): value = self.axes._legend_dict.pop(old_loc, None) where, type = old_loc self.axes._legend_dict[(loc, type)] = value + + +def _normalize_em_kwargs(kwargs: dict[str, Any], *, fontsize: float) -> dict[str, Any]: + """ + Convert legend-related em unit kwargs to absolute values in points. + """ + for setting in rcsetup.EM_KEYS: + pair = setting.split("legend.", 1) + if len(pair) == 1: + continue + _, key = pair + value = kwargs.pop(key, None) + if isinstance(value, str): + value = units(value, "em", fontsize=fontsize) + if value is not None: + kwargs[key] = value + return kwargs + + +class UltraLegend: + """ + Centralized legend builder for axes. + """ + + def __init__(self, axes): + self.axes = axes + + @staticmethod + def _align_map() -> dict[Optional[str], dict[str, str]]: + """ + Mapping between panel side + align and matplotlib legend loc strings. + """ + return ALIGN_OPTS + + def _resolve_inputs( + self, + handles=None, + labels=None, + *, + loc=None, + align=None, + width=None, + pad=None, + space=None, + frame=None, + frameon=None, + ncol=None, + ncols=None, + alphabetize=False, + center=None, + order=None, + label=None, + title=None, + fontsize=None, + fontweight=None, + fontcolor=None, + titlefontsize=None, + titlefontweight=None, + titlefontcolor=None, + handle_kw=None, + handler_map=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs: Any, + ): + """ + Normalize inputs, apply rc defaults, and convert units. + """ + ncol = _not_none(ncols=ncols, ncol=ncol) + order = _not_none(order, "C") + frameon = _not_none(frame=frame, frameon=frameon, default=rc["legend.frameon"]) + fontsize = _not_none(fontsize, rc["legend.fontsize"]) + titlefontsize = _not_none( + title_fontsize=kwargs.pop("title_fontsize", None), + titlefontsize=titlefontsize, + default=rc["legend.title_fontsize"], + ) + fontsize = _fontsize_to_pt(fontsize) + titlefontsize = _fontsize_to_pt(titlefontsize) + if order not in ("F", "C"): + raise ValueError( + f"Invalid order {order!r}. Please choose from " + "'C' (row-major, default) or 'F' (column-major)." + ) + + # Convert relevant keys to em-widths + kwargs = _normalize_em_kwargs(kwargs, fontsize=fontsize) + return _LegendInputs( + handles=handles, + labels=labels, + loc=loc, + align=align, + width=width, + pad=pad, + space=space, + frameon=frameon, + ncol=ncol, + order=order, + label=label, + title=title, + fontsize=fontsize, + fontweight=fontweight, + fontcolor=fontcolor, + titlefontsize=titlefontsize, + titlefontweight=titlefontweight, + titlefontcolor=titlefontcolor, + handle_kw=handle_kw, + handler_map=handler_map, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + kwargs=kwargs, + ) + + def _resolve_axes_layout(self, inputs: _LegendInputs): + """ + Determine the legend axes and layout-related kwargs. + """ + ax = self.axes + if inputs.loc in ("fill", "left", "right", "top", "bottom"): + lax = ax._add_guide_panel( + inputs.loc, + inputs.align, + width=inputs.width, + space=inputs.space, + pad=inputs.pad, + span=inputs.span, + row=inputs.row, + col=inputs.col, + rows=inputs.rows, + cols=inputs.cols, + ) + kwargs = dict(inputs.kwargs) + kwargs.setdefault("borderaxespad", 0) + if not inputs.frameon: + kwargs.setdefault("borderpad", 0) + try: + kwargs["loc"] = self._align_map()[lax._panel_side][inputs.align] + except KeyError as exc: + raise ValueError( + f"Invalid align={inputs.align!r} for legend loc={inputs.loc!r}." + ) from exc + else: + lax = ax + kwargs = dict(inputs.kwargs) + pad = kwargs.pop("borderaxespad", inputs.pad) + kwargs["loc"] = inputs.loc # simply pass to legend + kwargs["borderaxespad"] = units(pad, "em", fontsize=inputs.fontsize) + return lax, kwargs + + def _resolve_style_kwargs( + self, + *, + lax, + fontcolor, + fontweight, + handle_kw, + kwargs, + ): + """ + Parse frame settings and build per-element style kwargs. + """ + kw_frame, kwargs = lax._parse_frame("legend", **kwargs) + kw_text = {} + if fontcolor is not None: + kw_text["color"] = fontcolor + if fontweight is not None: + kw_text["weight"] = fontweight + kw_handle = _pop_props(kwargs, "line") + kw_handle.setdefault("solid_capstyle", "butt") + kw_handle.update(handle_kw or {}) + return kw_frame, kw_text, kw_handle, kwargs + + def _build_legends( + self, + *, + lax, + inputs: _LegendInputs, + center, + alphabetize, + kw_frame, + kwargs, + ): + pairs, multi = lax._parse_legend_handles( + inputs.handles, + inputs.labels, + ncol=inputs.ncol, + order=inputs.order, + center=center, + alphabetize=alphabetize, + handler_map=inputs.handler_map, + ) + title = _not_none(label=inputs.label, title=inputs.title) + kwargs.update( + { + "title": title, + "frameon": inputs.frameon, + "fontsize": inputs.fontsize, + "handler_map": inputs.handler_map, + "title_fontsize": inputs.titlefontsize, + } + ) + if multi: + objs = lax._parse_legend_centered(pairs, kw_frame=kw_frame, **kwargs) + else: + kwargs.update({key: kw_frame.pop(key) for key in ("shadow", "fancybox")}) + objs = [ + lax._parse_legend_aligned( + pairs, ncol=inputs.ncol, order=inputs.order, **kwargs + ) + ] + objs[0].legendPatch.update(kw_frame) + for obj in objs: + if hasattr(lax, "legend_") and lax.legend_ is None: + lax.legend_ = obj + else: + lax.add_artist(obj) + return objs + + def _apply_handle_styles(self, objs, *, kw_text, kw_handle): + """ + Apply per-handle styling overrides to legend artists. + """ + for obj in objs: + obj.set_clip_on(False) + box = getattr(obj, "_legend_handle_box", None) + for child in guides._iter_children(box): + if isinstance(child, mtext.Text): + kw = kw_text + else: + kw = { + key: val + for key, val in kw_handle.items() + if hasattr(child, "set_" + key) + } + if hasattr(child, "set_sizes") and "markersize" in kw_handle: + kw["sizes"] = np.atleast_1d(kw_handle["markersize"]) + child.update(kw) + + def _finalize(self, objs, *, loc, align): + """ + Register legend for guide tracking and return the public object. + """ + ax = self.axes + if isinstance(objs[0], mpatches.FancyBboxPatch): + objs = objs[1:] + obj = objs[0] if len(objs) == 1 else tuple(objs) + ax._register_guide("legend", obj, (loc, align)) + return obj + + def add( + self, + handles=None, + labels=None, + *, + loc=None, + align=None, + width=None, + pad=None, + space=None, + frame=None, + frameon=None, + ncol=None, + ncols=None, + alphabetize=False, + center=None, + order=None, + label=None, + title=None, + fontsize=None, + fontweight=None, + fontcolor=None, + titlefontsize=None, + titlefontweight=None, + titlefontcolor=None, + handle_kw=None, + handler_map=None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): + """ + The driver function for adding axes legends. + """ + inputs = self._resolve_inputs( + handles, + labels, + loc=loc, + align=align, + width=width, + pad=pad, + space=space, + frame=frame, + frameon=frameon, + ncol=ncol, + ncols=ncols, + alphabetize=alphabetize, + center=center, + order=order, + label=label, + title=title, + fontsize=fontsize, + fontweight=fontweight, + fontcolor=fontcolor, + titlefontsize=titlefontsize, + titlefontweight=titlefontweight, + titlefontcolor=titlefontcolor, + handle_kw=handle_kw, + handler_map=handler_map, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, + ) + + lax, kwargs = self._resolve_axes_layout(inputs) + + kw_frame, kw_text, kw_handle, kwargs = self._resolve_style_kwargs( + lax=lax, + fontcolor=inputs.fontcolor, + fontweight=inputs.fontweight, + handle_kw=inputs.handle_kw, + kwargs=kwargs, + ) + + objs = self._build_legends( + lax=lax, + inputs=inputs, + center=center, + alphabetize=alphabetize, + kw_frame=kw_frame, + kwargs=kwargs, + ) + + self._apply_handle_styles(objs, kw_text=kw_text, kw_handle=kw_handle) + return self._finalize(objs, loc=inputs.loc, align=inputs.align) + + # Handle and text properties that are applied after-the-fact + # NOTE: Set solid_capstyle to 'butt' so line does not extend past error bounds + # shading in legend entry. This change is not noticable in other situations. + kw_frame, kwargs = lax._parse_frame("legend", **kwargs) + kw_text = {} + if fontcolor is not None: + kw_text["color"] = fontcolor + if fontweight is not None: + kw_text["weight"] = fontweight + kw_title = {} + if titlefontcolor is not None: + kw_title["color"] = titlefontcolor + if titlefontweight is not None: + kw_title["weight"] = titlefontweight + kw_handle = _pop_props(kwargs, "line") + kw_handle.setdefault("solid_capstyle", "butt") + kw_handle.update(handle_kw or {}) + + # Parse the legend arguments using axes for auto-handle detection + # TODO: Update this when we no longer use "filled panels" for outer legends + pairs, multi = lax._parse_legend_handles( + handles, + labels, + ncol=ncol, + order=order, + center=center, + alphabetize=alphabetize, + handler_map=handler_map, + ) + title = _not_none(label=label, title=title) + kwargs.update( + { + "title": title, + "frameon": frameon, + "fontsize": fontsize, + "handler_map": handler_map, + "title_fontsize": titlefontsize, + } + ) + + # Add the legend and update patch properties + # TODO: Add capacity for categorical labels in a single legend like seaborn + # rather than manual handle overrides with multiple legends. + if multi: + objs = lax._parse_legend_centered(pairs, kw_frame=kw_frame, **kwargs) + else: + kwargs.update({key: kw_frame.pop(key) for key in ("shadow", "fancybox")}) + objs = [lax._parse_legend_aligned(pairs, ncol=ncol, order=order, **kwargs)] + objs[0].legendPatch.update(kw_frame) + for obj in objs: + if hasattr(lax, "legend_") and lax.legend_ is None: + lax.legend_ = obj # make first legend accessible with get_legend() + else: + lax.add_artist(obj) + + # Update legend patch and elements + # WARNING: legendHandles only contains the *first* artist per legend because + # HandlerBase.legend_artist() called in Legend._init_legend_box() only + # returns the first artist. Instead we try to iterate through offset boxes. + for obj in objs: + obj.set_clip_on(False) # needed for tight bounding box calculations + box = getattr(obj, "_legend_handle_box", None) + for child in guides._iter_children(box): + if isinstance(child, mtext.Text): + kw = kw_text + else: + kw = { + key: val + for key, val in kw_handle.items() + if hasattr(child, "set_" + key) + } + if hasattr(child, "set_sizes") and "markersize" in kw_handle: + kw["sizes"] = np.atleast_1d(kw_handle["markersize"]) + child.update(kw) + + # Register location and return + if isinstance(objs[0], mpatches.FancyBboxPatch): + objs = objs[1:] + obj = objs[0] if len(objs) == 1 else tuple(objs) + ax._register_guide("legend", obj, (loc, align)) # possibly replace another + + return obj diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 3d7f1596c..03e30926e 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -200,6 +200,42 @@ def test_legend_col_spacing(rng): return fig +def test_legend_align_opts_mapping(): + """ + Basic sanity check for legend alignment mapping. + """ + from ultraplot.legend import ALIGN_OPTS + + assert ALIGN_OPTS[None]["center"] == "center" + assert ALIGN_OPTS["left"]["top"] == "upper right" + assert ALIGN_OPTS["right"]["bottom"] == "lower left" + assert ALIGN_OPTS["top"]["center"] == "lower center" + assert ALIGN_OPTS["bottom"]["right"] == "upper right" + + +def test_legend_builder_smoke(): + """ + Ensure the legend builder path returns a legend object. + """ + import matplotlib.pyplot as plt + + fig, ax = uplt.subplots() + ax.plot([0, 1, 2], label="a") + leg = ax.legend(loc="right", align="center") + assert leg is not None + plt.close(fig) + + +def test_legend_normalize_em_kwargs(): + """ + Ensure em-based legend kwargs are converted to numeric values. + """ + from ultraplot.legend import _normalize_em_kwargs + + out = _normalize_em_kwargs({"labelspacing": "2em"}, fontsize=10) + assert isinstance(out["labelspacing"], (int, float)) + + def test_sync_label_dict(rng): """ Legends are held within _legend_dict for which the key is a tuple of location and alignment.