Source code for coco_pipe.viz.plots

#!/usr/bin/env python3
"""
coco_pipe/viz/plots.py
----------------------
Minimal plotting helpers:

- plot_topomap: 2D topographic map using MNE's plot_topomap for EEG layouts.
- plot_bar:     Generic ranked bar plots with optional error bars.

Both functions are data-agnostic: pass any metric (e.g., accuracy, importance)
for any analysis unit (sensors, features, regions).
"""

from __future__ import annotations

import math
import re
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


[docs] def _coerce_series( data: Union[pd.Series, Mapping[str, float], Sequence[float]], index: Optional[Sequence[str]] = None, ) -> pd.Series: if isinstance(data, pd.Series): return data.dropna() if isinstance(data, Mapping): return pd.Series(data).dropna() values = np.asarray(list(data), dtype=float) if index is None: index = [str(i) for i in range(len(values))] return pd.Series(values, index=index).dropna()
[docs] def _coerce_coords( coords: Union[ pd.DataFrame, Mapping[str, Tuple[float, float]], Sequence[Tuple[float, float]] ], index: Optional[Sequence[str]] = None, ) -> pd.DataFrame: if isinstance(coords, pd.DataFrame): if coords.shape[1] < 2: raise ValueError("coords DataFrame must have at least two columns (x, y)") df = coords.iloc[:, :2].copy() df.columns = ["x", "y"] return df if isinstance(coords, Mapping): items = [(k, v[0], v[1]) for k, v in coords.items()] return pd.DataFrame(items, columns=["name", "x", "y"]).set_index("name") # assume sequence of (x,y) arr = np.asarray(coords, dtype=float) if arr.ndim != 2 or arr.shape[1] < 2: raise ValueError("coords must be Nx2 sequence for (x,y) pairs") idx = list(index) if index is not None else [str(i) for i in range(arr.shape[0])] return pd.DataFrame({"x": arr[:, 0], "y": arr[:, 1]}, index=idx)
[docs] def plot_topomap( values: Union[pd.Series, Mapping[str, float], Sequence[float]], coords: Union[ pd.DataFrame, Mapping[str, Tuple[float, float]], Sequence[Tuple[float, float]] ], *, index: Optional[Sequence[str]] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, cmap: str = "RdBu_r", head_radius: float = 0.5, # kept for API compatibility; not used by MNE levels: int = 64, # kept for API compatibility; not used by MNE fill: bool = True, # kept for API compatibility; not used by MNE sensors: str = "markers", # "markers"->True, "labels"->'labels', "none"->False sensor_size: float = 30.0, # kept for API compatibility outlines: bool = True, # True->'head', False->'none' contours: int = 0, symmetric: bool = True, title: Optional[str] = None, cbar: bool = True, cbar_label: Optional[str] = None, text_size: Optional[float] = None, title_size: Optional[float] = None, title_loc: Optional[str] = None, tick_size: Optional[float] = None, cbar_size: Optional[float] = None, figsize: Tuple[float, float] = (5, 5), ax: Optional[plt.Axes] = None, save: Optional[str] = None, ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a topographic map for arbitrary sensor values using MNE. Parameters ---------- values : Series | Mapping | Sequence Sensor values. If Series or Mapping, keys should match coords names. coords : DataFrame | Mapping | Sequence Sensor coordinates. Accepts: - DataFrame with columns ['x','y'] and index = sensor names - Mapping[name] -> (x,y) - Sequence[(x,y)], plus `index` to name them index : sequence, optional Names for values/coords when passing sequences. sensors : str One of {"markers", "labels", "none"} to show sensor markers/labels. contours : int Number of contour lines to draw on top (0 disables). symmetric : bool, default=True If True, use symmetric color limits around 0 (vlim = [-a, +a]) where a is a rounded-up amplitude from the data range. Set False to use asymmetric nice-rounded limits. Notes ----- - Requires the optional dependency 'mne'. Install with `pip install mne`. - Uses mne.viz.plot_topomap for robust EEG topographies. - vmin/vmax are rounded to "nice" values: we floor vmin and ceil vmax to the nearest power-of-10 step (e.g., 0.87 -> 0.9, 0.012 -> 0.02). - Use `text_size` to scale all text; override with `title_size`, `tick_size`, or `cbar_size` for finer control. - Use `title_loc` to set title alignment: 'left', 'center', or 'right'. """ try: import mne # type: ignore except Exception as e: raise ImportError( "plot_topomap requires 'mne'. Install it via 'pip install mne'." ) from e vals = _coerce_series(values, index=index) cdf = _coerce_coords(coords, index=index) # Align on intersection of names common = vals.index.intersection(cdf.index) if common.empty: raise ValueError("No overlapping sensor names between values and coords") vals = vals.loc[common] cdf = cdf.loc[common] # Prepare arrays for MNE pos = cdf[["x", "y"]].to_numpy(dtype=float) data = vals.to_numpy(dtype=float) # Determine and round vmin/vmax (even if provided) to nice boundaries def _nice_step(x: float) -> float: ax = abs(x) if ax == 0 or not np.isfinite(ax): return 1.0 e = math.floor(math.log10(ax)) return 10.0**e def _nice_floor(x: float) -> float: s = _nice_step(x) return math.floor(x / s) * s def _nice_ceil(x: float) -> float: s = _nice_step(x) return math.ceil(x / s) * s raw_vmin = float(np.nanmin(data)) if vmin is None else float(vmin) raw_vmax = float(np.nanmax(data)) if vmax is None else float(vmax) if raw_vmin > raw_vmax: raw_vmin, raw_vmax = raw_vmax, raw_vmin if symmetric: amp = max(abs(raw_vmin), abs(raw_vmax)) if not np.isfinite(amp) or amp == 0: amp = 1.0 vmax_r = _nice_ceil(amp) vmin_r = -vmax_r else: vmin_r = _nice_floor(raw_vmin) vmax_r = _nice_ceil(raw_vmax) if not np.isfinite(vmin_r) or not np.isfinite(vmax_r) or vmin_r == vmax_r: span = max(abs(raw_vmin), abs(raw_vmax), 1e-12) step = _nice_step(span) vmin_r = _nice_floor(raw_vmin - step) vmax_r = _nice_ceil(raw_vmax + step) # Map sensors option if sensors == "none": sensors_opt: Union[bool, str] = False elif sensors == "labels": sensors_opt = "labels" else: sensors_opt = True # markers outlines_opt: Union[str, dict] = "head" if outlines else "none" # Create axes if needed if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=True) else: fig = ax.figure # Call MNE's plot_topomap im, cn = mne.viz.plot_topomap( data, pos, axes=ax, vlim=(vmin_r, vmax_r), cmap=cmap, sensors=sensors_opt, names=list(common) if sensors == "labels" else None, contours=contours, outlines=outlines_opt, show=False, ) # Resolve text sizes _title_fs = title_size if title_size is not None else text_size _tick_fs = tick_size if tick_size is not None else text_size _cbar_fs = cbar_size if cbar_size is not None else text_size _title_loc = ( None if title_loc is None else ( "center" if str(title_loc).lower() in {"center", "centre"} else str(title_loc).lower() ) ) if title: ax.set_title(title, fontsize=_title_fs, loc=_title_loc) if cbar: cb = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) if _cbar_fs is not None: cb.ax.tick_params(labelsize=_cbar_fs) if cbar_label: cb.set_label(cbar_label, fontsize=_cbar_fs) if _tick_fs is not None: ax.tick_params(axis="both", labelsize=_tick_fs) if save: fig.savefig(save, dpi=150) return fig, ax
[docs] def plot_bar( scores: Union[pd.Series, Mapping[str, float], Sequence[float]], *, errors: Optional[Union[pd.Series, Mapping[str, float], Sequence[float]]] = None, labels: Optional[Sequence[str]] = None, label_map: Optional[Mapping[str, str]] = None, top_n: Optional[int] = None, ascending: bool = False, orientation: str = "vertical", # "vertical" or "horizontal" color: Optional[Union[str, Sequence[str]]] = None, cmap: Optional[str] = None, abs_values: bool = False, nice_axis_limits: bool = False, axis_break_orders: Optional[float] = None, axis_break_pad: float = 0.05, remove_spines: Optional[Union[str, Sequence[str]]] = None, remove_ticks: Optional[Union[str, Sequence[str]]] = None, text_size: Optional[float] = None, title_size: Optional[float] = None, title_loc: Optional[str] = None, axis_label_size: Optional[float] = None, tick_size: Optional[float] = None, grid_axis: Optional[str] = "y", title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, axis_lim: Optional[Tuple[float, float]] = None, figsize: Tuple[float, float] = (7, 4), ax: Optional[plt.Axes] = None, save: Optional[str] = None, ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a ranked bar chart of scores with optional error bars. Parameters ---------- scores : Series | Mapping | Sequence Values to plot. If Sequence, provide `labels`. errors : Series | Mapping | Sequence, optional Error bars (std/sem). Keys or order must match `scores`. label_map : Mapping[str, str], optional Mapping from original item names (index of `scores`) to custom display labels for tick names. Unmapped labels fall back to their original names. top_n : int, optional Show only the top-N items after sorting by value. ascending : bool Sort ascending (False for descending/top highest). orientation : str "vertical" or "horizontal". axis_lim : tuple(float, float), optional Limits for the numeric axis (y for vertical, x for horizontal). If provided, applies via `ax.set_ylim(axis_lim)` (vertical) or `ax.set_xlim(axis_lim)` (horizontal). color : str | list, optional Single color or list per bar. If not provided and `cmap` is set, colors are mapped from values. cmap : str, optional Colormap name to map bar colors by value. abs_values : bool, default=False If True, plot absolute values of `scores` and also sort/select top_n by absolute magnitude. nice_axis_limits : bool, default=False If True and `axis_lim` is not provided, set the numeric axis limits to "nice-rounded" floor(min(values)) and ceil(max(values)). Applies to the y-axis for vertical orientation and x-axis for horizontal orientation. remove_spines : str | sequence, optional Remove axis spines by name. Accepts any of {"left","right","top","bottom"}, or the shortcut "right_top" to remove both right and top. You can also pass a list like ["right", "bottom"]. remove_ticks : str | sequence, optional Hide tick marks (not labels) on the x and/or y axes. Accepts "x", "y", or "both". You can combine as a space/comma-separated string (e.g., "x y" or "x,y") or pass a list like ["x", "y"]. text_size : float, optional Base font size for all text in the plot (title, axis labels, ticks). title_size : float, optional Overrides title font size; falls back to `text_size`. axis_label_size : float, optional Overrides axis label font size; falls back to `text_size`. tick_size : float, optional Overrides tick label font size; falls back to `text_size`. title_loc : str, optional Title alignment: one of {"left", "center", "right"}. Defaults to matplotlib's default (center) if not provided. grid_axis : str | None, default="y" Which axis to draw grid lines on: one of {"x", "y", "both", "auto"}. Use "none" or None to disable the grid. "auto" chooses 'y' for vertical bars and 'x' for horizontal bars. axis_break_orders : float, optional If provided, automatically create a broken numeric axis (y for vertical, x for horizontal) when the largest value is at least 10**axis_break_orders times the second largest. Only applied when values are non-negative and `axis_lim` is not set. Keeps API compatible by returning the primary axis (bottom/left) of the broken pair. axis_break_pad : float, default=0.05 Fractional padding around the split limits for the broken axes. """ s = _coerce_series(scores, index=labels) # Sort and optionally take top-N. When abs_values=True, order by |scores|. if abs_values: order_idx = s.abs().sort_values(ascending=ascending).index s = s.loc[order_idx] else: s = s.sort_values(ascending=ascending) if top_n is not None and top_n > 0: s = s.head(top_n) if errors is not None: e = _coerce_series(errors, index=s.index) e = e.reindex(s.index) else: e = None created_fig = False if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=True) created_fig = True else: fig = ax.figure labels = s.index.tolist() if label_map: disp_labels = [label_map.get(lbl, lbl) for lbl in labels] else: disp_labels = labels # Optionally transform values for plotting s_plot = s.abs() if abs_values else s vals = s_plot.values.astype(float) if cmap is not None and color is None: cmap_obj = plt.get_cmap(cmap) # normalize by min/max vmin, vmax = np.nanmin(vals), np.nanmax(vals) if vmax == vmin: colors = [cmap_obj(0.5)] * len(vals) else: colors = [cmap_obj((v - vmin) / (vmax - vmin)) for v in vals] else: if isinstance(color, (list, tuple, np.ndarray)): colors = list(color) else: colors = color # helper to compute nice-rounded bounds def _nice_step(x: float) -> float: ax = abs(x) if ax == 0 or not np.isfinite(ax): return 1.0 e = math.floor(math.log10(ax)) return 10.0**e def _nice_floor(x: float) -> float: s = _nice_step(x) return math.floor(x / s) * s def _nice_ceil(x: float) -> float: s = _nice_step(x) return math.ceil(x / s) * s # Potential broken-axis branch (applies only when explicit axis limits are not set) # Conditions: axis_break_orders set, at least two values, all non-negative after # optional abs, and axis_lim is None. Uses simple two-panel broken axis with # diagonal marks. do_axis_break = False if ( axis_break_orders is not None and axis_lim is None and len(vals) >= 2 and np.all(np.isfinite(vals)) and np.all(vals >= 0) ): # Consider errors in determining extremes if provided errs = e.values if e is not None else np.zeros_like(vals) vals_with_err = vals + np.maximum(0.0, errs) # Find top-1 and top-2 order = np.argsort(vals_with_err)[::-1] vmax1 = float(vals_with_err[order[0]]) vmax2 = float(vals_with_err[order[1]]) if len(order) > 1 else 0.0 ratio = (vmax1 / max(vmax2, 1e-12)) if vmax1 > 0 else 1.0 threshold = 10.0 ** float(axis_break_orders) if ratio >= threshold and np.isfinite(ratio): do_axis_break = True if do_axis_break: # Compute split limits errs = e.values if e is not None else np.zeros_like(vals) vals_with_err = vals + np.maximum(0.0, errs) order = np.argsort(vals_with_err)[::-1] vmax1 = float(vals_with_err[order[0]]) vmax2 = float(vals_with_err[order[1]]) if len(order) > 1 else 0.0 # Lower panel upper bound based on second max (or fraction of top # if others are 0) lower_upper = ( vmax2 if vmax2 > 0 else (vmax1 / (10.0 ** float(axis_break_orders))) ) lower_upper = max(lower_upper, 0.0) * (1.0 + float(axis_break_pad)) upper_lower = lower_upper # join point upper_upper = vmax1 * (1.0 + float(axis_break_pad)) # Create two axes either stacked (vertical) or side-by-side (horizontal) if orientation == "horizontal": # Side-by-side axes sharing y if created_fig and ax is not None and len(fig.axes) == 1: plt.close(fig) # close single-axes fig before creating new layout # to avoid duplicates fig, (ax_left, ax_right) = plt.subplots( 1, 2, sharey=True, figsize=figsize, constrained_layout=True, gridspec_kw={"width_ratios": [3, 1]}, ) else: # Replace provided ax with two new axes in the same position bbox = ax.get_position() ax.set_visible(False) w_total = bbox.width gap = 0.02 * w_total w_left = w_total * 0.72 w_right = w_total * 0.28 ax_left = fig.add_axes( [bbox.x0, bbox.y0, w_left - gap / 2, bbox.height] ) ax_right = fig.add_axes( [ bbox.x0 + w_left + gap / 2, bbox.y0, w_right - gap / 2, bbox.height, ] ) # share y manually ax_right.get_shared_y_axes().join(ax_right, ax_left) # Draw bars; left shows all, right shows only the outlier (largest) y_pos = np.arange(len(labels)) idx_out = int(order[0]) mask = np.zeros(len(vals), dtype=bool) mask[idx_out] = True # color handling for subset def _subset_colors(mask_bool): if isinstance(colors, (list, tuple, np.ndarray)): return [colors[i] for i, m in enumerate(mask_bool) if m] return colors ax_left.barh( y_pos, vals, xerr=e.values if e is not None else None, color=colors, ) ax_right.barh( y_pos[mask], vals[mask], xerr=(e.values[mask] if e is not None else None), color=_subset_colors(mask), ) # Ticks/labels ax_left.set_yticks(y_pos, labels=disp_labels) ax_right.tick_params(axis="y", labelleft=False) # Limits per panel ax_left.set_xlim((0.0, lower_upper)) ax_right.set_xlim((upper_lower, upper_upper)) # Diagonal break marks (only on primary/left axis) d = 0.015 kwargs = dict(transform=ax_left.transAxes, color="k", clip_on=False) ax_left.plot((1 - d, 1 + d), (-d, +d), **kwargs) ax_left.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # No start lines on the secondary/right axis # Titles/labels and text sizes _title_fs = title_size if title_size is not None else text_size _label_fs = axis_label_size if axis_label_size is not None else text_size _tick_fs = tick_size if tick_size is not None else text_size if title: _title_loc = ( None if title_loc is None else ( "center" if str(title_loc).lower() in {"center", "centre"} else str(title_loc).lower() ) ) ax_left.set_title(title, fontsize=_title_fs, loc=_title_loc) if xlabel: ax_left.set_xlabel(xlabel, fontsize=_label_fs) if ylabel: ax_left.set_ylabel(ylabel, fontsize=_label_fs) if _tick_fs is not None: ax_left.tick_params(axis="both", labelsize=_tick_fs) ax_right.tick_params(axis="both", labelsize=_tick_fs) # Grid configuration axis_opt = None if grid_axis is None else str(grid_axis).lower() if axis_opt not in {None, "none", "off", "false"}: if axis_opt in {"xy", "yx", "all"}: axis_val = "both" elif axis_opt == "auto": axis_val = "x" # numeric axis for horizontal bars elif axis_opt in {"x", "y", "both"}: axis_val = axis_opt else: axis_val = "x" ax_left.grid(True, axis=axis_val, linestyle="--", alpha=0.3) ax_right.grid(True, axis=axis_val, linestyle="--", alpha=0.3) # Apply spines/ticks removal to both axes if requested def _apply_styling_axes(ax_list: List[plt.Axes]): # Spines if remove_spines is not None: if isinstance(remove_spines, str): opts = [t for t in re.split(r"[\s,]+", remove_spines) if t] else: opts = list(remove_spines) spines_to_remove = set() for opt in opts: if opt in ("right_top", "top_right"): spines_to_remove.update(["right", "top"]) elif opt == "all": spines_to_remove.update(["left", "right", "top", "bottom"]) else: spines_to_remove.add(opt) for axx in ax_list: for sp in ("left", "right", "top", "bottom"): if sp in spines_to_remove: axx.spines[sp].set_visible(False) # Ticks if remove_ticks is not None: if isinstance(remove_ticks, str): toks = [t for t in re.split(r"[\s,]+", remove_ticks) if t] else: toks = list(remove_ticks) axes_rm = set() for t in toks: tl = str(t).lower() if tl in ("both", "xy", "yx"): axes_rm.update(["x", "y"]) elif tl in ("x", "y"): axes_rm.add(tl) if "x" in axes_rm: ax_left.tick_params(axis="x", which="both", length=0, width=0) ax_right.tick_params(axis="x", which="both", length=0, width=0) if "y" in axes_rm: ax_left.tick_params(axis="y", which="both", length=0, width=0) ax_right.tick_params(axis="y", which="both", length=0, width=0) _apply_styling_axes([ax_left, ax_right]) if save: fig.savefig(save, dpi=150) return fig, ax_left else: # Vertical bars: stacked axes sharing x if created_fig and ax is not None and len(fig.axes) == 1: plt.close(fig) # close single-axes fig fig, (ax_top, ax_bottom) = plt.subplots( 2, 1, sharex=True, figsize=figsize, constrained_layout=True, gridspec_kw={"height_ratios": [1, 3]}, ) else: bbox = ax.get_position() ax.set_visible(False) h_total = bbox.height gap = 0.02 * h_total h_top = h_total * 0.35 h_bottom = h_total * 0.65 ax_top = fig.add_axes( [bbox.x0, bbox.y0 + h_bottom + gap / 2, bbox.width, h_top - gap / 2] ) ax_bottom = fig.add_axes( [bbox.x0, bbox.y0, bbox.width, h_bottom - gap / 2] ) # share x manually ax_top.get_shared_x_axes().join(ax_top, ax_bottom) x_pos = np.arange(len(labels)) idx_out = int(order[0]) mask = np.zeros(len(vals), dtype=bool) mask[idx_out] = True def _subset_colors(mask_bool): if isinstance(colors, (list, tuple, np.ndarray)): return [colors[i] for i, m in enumerate(mask_bool) if m] return colors # Top shows only outlier; bottom shows all ax_top.bar( x_pos[mask], vals[mask], yerr=(e.values[mask] if e is not None else None), color=_subset_colors(mask), ) ax_bottom.bar( x_pos, vals, yerr=e.values if e is not None else None, color=colors, ) # Tick labels only on bottom ax_bottom.set_xticks(x_pos, labels=disp_labels, rotation=45, ha="right") ax_top.tick_params(axis="x", labelbottom=False) # Limits per panel ax_bottom.set_ylim((0.0, lower_upper)) ax_top.set_ylim((upper_lower, upper_upper)) # Diagonal marks (only on primary/top axis) d = 0.015 kwargs = dict(transform=ax_top.transAxes, color="k", clip_on=False) ax_top.plot((-d, +d), (-d, +d), **kwargs) ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # No start lines on the secondary/bottom axis # Titles/labels and text sizes _title_fs = title_size if title_size is not None else text_size _label_fs = axis_label_size if axis_label_size is not None else text_size _tick_fs = tick_size if tick_size is not None else text_size if title: _title_loc = ( None if title_loc is None else ( "center" if str(title_loc).lower() in {"center", "centre"} else str(title_loc).lower() ) ) ax_top.set_title(title, fontsize=_title_fs, loc=_title_loc) if xlabel: ax_bottom.set_xlabel(xlabel, fontsize=_label_fs) if ylabel: ax_bottom.set_ylabel(ylabel, fontsize=_label_fs) if _tick_fs is not None: ax_top.tick_params(axis="both", labelsize=_tick_fs) ax_bottom.tick_params(axis="both", labelsize=_tick_fs) # Grid configuration axis_opt = None if grid_axis is None else str(grid_axis).lower() if axis_opt not in {None, "none", "off", "false"}: if axis_opt in {"xy", "yx", "all"}: axis_val = "both" elif axis_opt == "auto": axis_val = "y" # numeric axis for vertical bars elif axis_opt in {"x", "y", "both"}: axis_val = axis_opt else: axis_val = "y" ax_top.grid(True, axis=axis_val, linestyle="--", alpha=0.3) ax_bottom.grid(True, axis=axis_val, linestyle="--", alpha=0.3) # Apply spines/ticks removal to both axes if requested def _apply_styling_axes(ax_list: List[plt.Axes]): if remove_spines is not None: if isinstance(remove_spines, str): opts = [t for t in re.split(r"[\s,]+", remove_spines) if t] else: opts = list(remove_spines) spines_to_remove = set() for opt in opts: if opt in ("right_top", "top_right"): spines_to_remove.update(["right", "top"]) elif opt == "all": spines_to_remove.update(["left", "right", "top", "bottom"]) else: spines_to_remove.add(opt) for axx in ax_list: for sp in ("left", "right", "top", "bottom"): if sp in spines_to_remove: axx.spines[sp].set_visible(False) # Ticks removal if remove_ticks is not None: if isinstance(remove_ticks, str): toks = [t for t in re.split(r"[\s,]+", remove_ticks) if t] else: toks = list(remove_ticks) axes_rm = set() for t in toks: tl = str(t).lower() if tl in ("both", "xy", "yx"): axes_rm.update(["x", "y"]) elif tl in ("x", "y"): axes_rm.add(tl) if "x" in axes_rm: ax_top.tick_params(axis="x", which="both", length=0, width=0) ax_bottom.tick_params(axis="x", which="both", length=0, width=0) if "y" in axes_rm: ax_top.tick_params(axis="y", which="both", length=0, width=0) ax_bottom.tick_params(axis="y", which="both", length=0, width=0) _apply_styling_axes([ax_top, ax_bottom]) if save: fig.savefig(save, dpi=150) return fig, ax_bottom # Standard single-axis branch if orientation == "horizontal": y_pos = np.arange(len(labels)) ax.barh(y_pos, vals, xerr=e.values if e is not None else None, color=colors) ax.set_yticks(y_pos, labels=disp_labels) # Resolve sizes _title_fs = title_size if title_size is not None else text_size _label_fs = axis_label_size if axis_label_size is not None else text_size _tick_fs = tick_size if tick_size is not None else text_size if xlabel: ax.set_xlabel(xlabel, fontsize=_label_fs) if ylabel: ax.set_ylabel(ylabel, fontsize=_label_fs) if axis_lim is not None: ax.set_xlim(axis_lim) elif nice_axis_limits: if e is not None: vmin = float(np.nanmin(vals - e.values)) if len(vals) else 0.0 vmax = float(np.nanmax(vals + e.values)) if len(vals) else 1.0 else: vmin = float(np.nanmin(vals)) if len(vals) else 0.0 vmax = float(np.nanmax(vals)) if len(vals) else 1.0 if vmin > vmax: vmin, vmax = vmax, vmin ax.set_xlim((_nice_floor(vmin), _nice_ceil(vmax))) else: x_pos = np.arange(len(labels)) ax.bar(x_pos, vals, yerr=e.values if e is not None else None, color=colors) ax.set_xticks(x_pos, labels=disp_labels, rotation=45, ha="right") # Resolve sizes _title_fs = title_size if title_size is not None else text_size _label_fs = axis_label_size if axis_label_size is not None else text_size _tick_fs = tick_size if tick_size is not None else text_size if xlabel: ax.set_xlabel(xlabel, fontsize=_label_fs) if ylabel: ax.set_ylabel(ylabel, fontsize=_label_fs) if axis_lim is not None: ax.set_ylim(axis_lim) elif nice_axis_limits: if e is not None: vmin = float(np.nanmin(vals - e.values)) if len(vals) else 0.0 vmax = float(np.nanmax(vals + e.values)) if len(vals) else 1.0 else: vmin = float(np.nanmin(vals)) if len(vals) else 0.0 vmax = float(np.nanmax(vals)) if len(vals) else 1.0 if vmin > vmax: vmin, vmax = vmax, vmin ax.set_ylim((_nice_floor(vmin), _nice_ceil(vmax))) if title: _title_fs = title_size if title_size is not None else text_size _title_loc = ( None if title_loc is None else ( "center" if str(title_loc).lower() in {"center", "centre"} else str(title_loc).lower() ) ) ax.set_title(title, fontsize=_title_fs, loc=_title_loc) # Grid configuration axis_opt = None if grid_axis is None else str(grid_axis).lower() if axis_opt not in {None, "none", "off", "false"}: if axis_opt in {"xy", "yx", "all"}: axis_val = "both" elif axis_opt == "auto": axis_val = "x" if orientation == "horizontal" else "y" elif axis_opt in {"x", "y", "both"}: axis_val = axis_opt else: axis_val = "y" # fallback to previous default ax.grid(True, axis=axis_val, linestyle="--", alpha=0.3) # Apply tick label size if requested _tick_fs = tick_size if tick_size is not None else text_size if _tick_fs is not None: ax.tick_params(axis="both", labelsize=_tick_fs) # Optionally remove specific spines if remove_spines is not None: if isinstance(remove_spines, str): # split on commas/whitespace so inputs like "right top" or "right,top" work opts = [t for t in re.split(r"[\s,]+", remove_spines) if t] else: opts = list(remove_spines) spines_to_remove = set() for opt in opts: if opt in ("right_top", "top_right"): spines_to_remove.update(["right", "top"]) elif opt == "all": spines_to_remove.update(["left", "right", "top", "bottom"]) else: spines_to_remove.add(opt) for sp in ("left", "right", "top", "bottom"): if sp in spines_to_remove: ax.spines[sp].set_visible(False) # Optionally remove ticks on x and/or y axes if remove_ticks is not None: if isinstance(remove_ticks, str): toks = [t for t in re.split(r"[\s,]+", remove_ticks) if t] else: toks = list(remove_ticks) axes_rm = set() for t in toks: tl = str(t).lower() if tl in ("both", "xy", "yx"): axes_rm.update(["x", "y"]) elif tl in ("x", "y"): axes_rm.add(tl) if "x" in axes_rm: # Set tick mark length/width to 0 to hide marks but keep labels ax.tick_params(axis="x", which="both", length=0, width=0) if "y" in axes_rm: ax.tick_params(axis="y", which="both", length=0, width=0) if save: fig.savefig(save, dpi=150) return fig, ax
[docs] def plot_scatter2d( x: Union[pd.Series, Sequence[float], np.ndarray], y: Union[pd.Series, Sequence[float], np.ndarray], *, labels: Optional[Union[pd.Series, Sequence[Any], np.ndarray]] = None, label_map: Optional[Mapping[Any, str]] = None, palette: Optional[Sequence[str]] = None, alpha: float = 0.8, s: float = 25.0, text_size: Optional[float] = None, title_size: Optional[float] = None, title_loc: Optional[str] = None, axis_label_size: Optional[float] = None, tick_size: Optional[float] = None, legend_size: Optional[float] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, legend: bool = True, figsize: Tuple[float, float] = (5, 5), ax: Optional[plt.Axes] = None, save: Optional[str] = None, ) -> Tuple[plt.Figure, plt.Axes]: """ Scatter plot with class-based coloring. Parameters ---------- x, y : array-like Coordinates for points. labels : array-like, optional Class labels for coloring. If None, a single color is used. label_map : mapping, optional Mapping from raw labels to display names in legend. palette : sequence of colors, optional Colors to cycle over classes (defaults to matplotlib tab10). legend : bool Whether to show legend for classes. """ x_arr = x.values if isinstance(x, pd.Series) else np.asarray(x) y_arr = y.values if isinstance(y, pd.Series) else np.asarray(y) if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=True) else: fig = ax.figure if labels is None: ax.scatter(x_arr, y_arr, s=s, alpha=alpha) else: lab_arr = labels.values if isinstance(labels, pd.Series) else np.asarray(labels) uniq = pd.unique(lab_arr) # Colors if palette is None: cmap = plt.get_cmap("tab10") colors = [cmap(i % 10) for i in range(len(uniq))] else: colors = list(palette) if len(colors) < len(uniq): # repeat if needed k = int(np.ceil(len(uniq) / len(colors))) colors = (colors * k)[: len(uniq)] for c, u in zip(colors, uniq): mask = lab_arr == u ax.scatter( x_arr[mask], y_arr[mask], s=s, alpha=alpha, label=label_map.get(u, u) if label_map else u, color=c, ) # Resolve sizes _title_fs = title_size if title_size is not None else text_size _label_fs = axis_label_size if axis_label_size is not None else text_size _tick_fs = tick_size if tick_size is not None else text_size _legend_fs = legend_size if legend_size is not None else text_size if title: _title_loc = ( None if title_loc is None else ( "center" if str(title_loc).lower() in {"center", "centre"} else str(title_loc).lower() ) ) ax.set_title(title, fontsize=_title_fs, loc=_title_loc) if xlabel: ax.set_xlabel(xlabel, fontsize=_label_fs) if ylabel: ax.set_ylabel(ylabel, fontsize=_label_fs) if legend and labels is not None: ax.legend(frameon=False, fontsize=_legend_fs) if _tick_fs is not None: ax.tick_params(axis="both", labelsize=_tick_fs) if save: fig.savefig(save, dpi=150) return fig, ax