diff --git a/docs/examples/plot_types/07_sankey.py b/docs/examples/plot_types/07_sankey.py new file mode 100644 index 000000000..c9aee7c57 --- /dev/null +++ b/docs/examples/plot_types/07_sankey.py @@ -0,0 +1,42 @@ +""" +Layered Sankey diagram +====================== + +An example of UltraPlot's layered Sankey renderer for publication-ready +flow diagrams. + +Why UltraPlot here? +------------------- +``sankey`` in layered mode handles node ordering, flow styling, and +label placement without manual geometry. + +Key function: :py:meth:`ultraplot.axes.PlotAxes.sankey`. + +See also +-------- +* :doc:`2D plot types ` +""" + +import ultraplot as uplt + +nodes = ["Budget", "Operations", "R&D", "Marketing", "Support", "Infra"] +flows = [ + ("Budget", "Operations", 5.0, "Ops"), + ("Budget", "R&D", 3.0, "R&D"), + ("Budget", "Marketing", 2.0, "Mkt"), + ("Operations", "Support", 1.5, "Support"), + ("Operations", "Infra", 2.0, "Infra"), +] + +fig, ax = uplt.subplots(refwidth=3.6) +ax.sankey( + nodes=nodes, + flows=flows, + style="budget", + flow_labels=True, + value_format="{:.1f}", + node_label_box=True, + flow_label_pos=0.5, +) +ax.format(title="Budget allocation") +fig.show() diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index b24fd98c9..f661e3d5a 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -10,7 +10,7 @@ import sys from collections.abc import Callable, Iterable from numbers import Integral, Number -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Mapping, Optional, Sequence, Union import matplotlib as mpl import matplotlib.artist as martist @@ -205,6 +205,83 @@ """ docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring + +_sankey_docstring = """ +Draw a Sankey diagram. + +Parameters +---------- +flows : sequence of float or flow tuples + If a numeric sequence, use Matplotlib's Sankey implementation. + Otherwise, expect flow tuples or dicts describing (source, target, value). +nodes : sequence or dict, optional + Node identifiers or dicts with ``id``/``label``/``color`` keys. If omitted, + nodes are inferred from flow sources/targets. +labels : sequence of str, optional + Labels for each flow in Matplotlib's Sankey mode. +orientations : sequence of int, optional + Flow orientations (-1: down, 0: right, 1: up) for Matplotlib's Sankey. +pathlengths : float or sequence of float, optional + Path lengths for each flow in Matplotlib's Sankey. +trunklength : float, optional + Length of the trunk between the input and output flows. +patchlabel : str, optional + Label for the main patch in Matplotlib's Sankey mode. +scale, unit, format, gap, radius, shoulder, offset, head_angle, margin, tolerance : optional + Passed to `matplotlib.sankey.Sankey`. +prior : int, optional + Index of a prior diagram to connect to. +connect : (int, int), optional + Flow indices for the prior and current diagram connection. +rotation : float, optional + Rotation angle in degrees. +node_kw, flow_kw, label_kw : dict-like, optional + Style dictionaries for the layered Sankey renderer. +node_label_kw, flow_label_kw : dict-like, optional + Label style dictionaries for node and flow labels in layered mode. +node_label_box : bool or dict-like, optional + If ``True``, draw a rounded box behind node labels. If dict-like, used as + the ``bbox`` argument for node label styling. +style : {'budget', 'pastel', 'mono'}, optional + Built-in styling presets for layered mode. +node_order : sequence, optional + Explicit node ordering for layered mode. +layer_order : sequence, optional + Explicit layer ordering for layered mode. +group_cycle : sequence, optional + Cycle for flow group colors (defaults to flow cycle). +flow_other : float, optional + Aggregate flows below this threshold into a single ``other_label``. +other_label : str, optional + Label for the aggregated flow target. +value_format : str or callable, optional + Formatter for flow labels when not explicitly provided. +node_label_outside : {'auto', True, False}, optional + Place node labels outside narrow nodes. +node_label_offset : float, optional + Offset for outside node labels (axes-relative units). +flow_sort : bool, optional + Whether to sort flows by target position to reduce crossings. +flow_label_pos : float, optional + Horizontal placement for single flow labels (0 to 1 along the ribbon). + When flow labels overlap, positions are redistributed between 0.25 and 0.75. +node_labels, flow_labels : bool, optional + Whether to draw node or flow labels in layered mode. +align : {'center', 'top', 'bottom'}, optional + Vertical alignment for nodes within each layer in layered mode. +layers : dict-like, optional + Manual layer assignments for nodes in layered mode. +**kwargs + Patch properties passed to `matplotlib.sankey.Sankey.add` in Matplotlib mode. + +Returns +------- +matplotlib.sankey.Sankey or list or SankeyDiagram + The Sankey diagram instance, or a list for multi-diagram usage. For layered + mode, returns a `~ultraplot.axes.plot_types.sankey.SankeyDiagram`. +""" + +docstring._snippet_manager["plot.sankey"] = _sankey_docstring # Auto colorbar and legend docstring _guide_docstring = """ colorbar : bool, int, or str, optional @@ -1849,6 +1926,169 @@ def curved_quiver( stream_container = CurvedQuiverSet(lc, ac) return stream_container + @docstring._snippet_manager + def sankey( + self, + flows: Any, + labels: Sequence[str] | None = None, + orientations: Sequence[int] | None = None, + pathlengths: float | Sequence[float] = 0.25, + trunklength: float = 1.0, + patchlabel: str = "", + *, + nodes: Any = None, + links: Any = None, + node_kw: Mapping[str, Any] | None = None, + flow_kw: Mapping[str, Any] | None = None, + label_kw: Mapping[str, Any] | None = None, + node_label_kw: Mapping[str, Any] | None = None, + flow_label_kw: Mapping[str, Any] | None = None, + node_label_box: bool | Mapping[str, Any] | None = None, + style: str | None = None, + node_order: Sequence[Any] | None = None, + layer_order: Sequence[int] | None = None, + group_cycle: Sequence[Any] | None = None, + flow_other: float | None = None, + other_label: str = "Other", + value_format: str | Callable[[float], str] | None = None, + node_label_outside: bool | str = "auto", + node_label_offset: float = 0.01, + flow_sort: bool = True, + flow_label_pos: float = 0.5, + node_labels: bool = True, + flow_labels: bool = False, + align: str = "center", + layers: Mapping[Any, int] | None = None, + scale: float | None = None, + unit: str | None = None, + format: str | None = None, + gap: float | None = None, + radius: float | None = None, + shoulder: float | None = None, + offset: float | None = None, + head_angle: float | None = None, + margin: float | None = None, + tolerance: float | None = None, + prior: int | None = None, + connect: tuple[int, int] | None = (0, 0), + rotation: float = 0, + **kwargs: Any, + ) -> Any: + """ + %(plot.sankey)s + """ + + def _looks_like_links(values): + if values is None: + return False + if isinstance(values, np.ndarray) and values.ndim == 1: + return False + if isinstance(values, dict): + return True + if isinstance(values, (list, tuple)) and values: + first = values[0] + if isinstance(first, dict): + return True + if isinstance(first, (list, tuple)) and len(first) >= 3: + return True + return False + + use_layered = nodes is not None or links is not None or _looks_like_links(flows) + if use_layered: + from .plot_types.sankey import sankey_diagram + + node_kw = node_kw or {} + flow_kw = flow_kw or {} + label_kw = label_kw or {} + if links is None: + links = flows + + cycle = rc["axes.prop_cycle"].by_key().get("color", []) + if not cycle: + cycle = [self._get_lines.get_next_color()] + + return sankey_diagram( + self, + nodes=nodes, + flows=links, + layers=layers, + flow_cycle=cycle, + group_cycle=group_cycle, + node_order=node_order, + layer_order=layer_order, + style=style, + flow_other=flow_other, + other_label=other_label, + value_format=value_format, + node_kw=node_kw, + flow_kw=flow_kw, + label_kw=label_kw, + node_label_kw=node_label_kw, + flow_label_kw=flow_label_kw, + node_label_box=node_label_box, + node_label_outside=node_label_outside, + node_label_offset=node_label_offset, + flow_sort=flow_sort, + flow_label_pos=flow_label_pos, + node_labels=node_labels, + flow_labels=flow_labels, + align=align, + node_pad=rc["sankey.nodepad"], + node_width=rc["sankey.nodewidth"], + margin=rc["sankey.margin"], + flow_alpha=rc["sankey.flow.alpha"], + flow_curvature=rc["sankey.flow.curvature"], + node_facecolor=rc["sankey.node.facecolor"], + ) + + from matplotlib.sankey import Sankey + + sankey_kw = {} + if scale is not None: + sankey_kw["scale"] = scale + if unit is not None: + sankey_kw["unit"] = unit + if format is not None: + sankey_kw["format"] = format + if gap is not None: + sankey_kw["gap"] = gap + if radius is not None: + sankey_kw["radius"] = radius + if shoulder is not None: + sankey_kw["shoulder"] = shoulder + if offset is not None: + sankey_kw["offset"] = offset + if head_angle is not None: + sankey_kw["head_angle"] = head_angle + if margin is not None: + sankey_kw["margin"] = margin + if tolerance is not None: + sankey_kw["tolerance"] = tolerance + + if "facecolor" not in kwargs and "color" not in kwargs: + kwargs["facecolor"] = self._get_lines.get_next_color() + + sankey = Sankey(ax=self, **sankey_kw) + add_kw = { + "flows": flows, + "trunklength": trunklength, + "patchlabel": patchlabel, + "rotation": rotation, + "pathlengths": pathlengths, + } + if labels is not None: + add_kw["labels"] = labels + if orientations is not None: + add_kw["orientations"] = orientations + if prior is not None: + add_kw["prior"] = prior + if connect is not None: + add_kw["connect"] = connect + + sankey.add(**add_kw, **kwargs) + diagrams = sankey.finish() + return diagrams[0] if len(diagrams) == 1 else diagrams + def _call_native(self, name, *args, **kwargs): """ Call the plotting method and redirect internal calls to native methods. diff --git a/ultraplot/axes/plot_types/sankey.py b/ultraplot/axes/plot_types/sankey.py new file mode 100644 index 000000000..c56efacfc --- /dev/null +++ b/ultraplot/axes/plot_types/sankey.py @@ -0,0 +1,912 @@ +# Helper tools for layered sankey diagrams. +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Mapping, Sequence + +from matplotlib import colors as mcolors +from matplotlib import patches as mpatches +from matplotlib import path as mpath + + +@dataclass +class SankeyDiagram: + nodes: dict[Any, mpatches.Patch] + flows: list[mpatches.PathPatch] + labels: dict[Any, Any] + layout: dict[str, Any] + + +def _tint(color: Any, amount: float) -> tuple[float, float, float]: + """Return a lightened version of a base color.""" + r, g, b = mcolors.to_rgb(color) + return ( + (1 - amount) * r + amount, + (1 - amount) * g + amount, + (1 - amount) * b + amount, + ) + + +def _normalize_nodes( + nodes: Any, flows: Sequence[Mapping[str, Any]] +) -> tuple[dict[Any, dict[str, Any]], list[Any]]: + """Normalize node definitions into a map and stable order list.""" + # Infer node order from the first occurrence in flows. + if nodes is None: + order = [] + seen = set() + for flow in flows: + for key in (flow["source"], flow["target"]): + if key not in seen: + seen.add(key) + order.append(key) + nodes = order + + # Normalize nodes to a dict keyed by node id. + node_map = {} + order = [] + if isinstance(nodes, dict): + nodes = [{"id": key, **value} for key, value in nodes.items()] + for node in nodes: + if isinstance(node, dict): + node_id = node.get("id", node.get("name")) + if node_id is None: + raise ValueError("Node dicts must include an 'id' or 'name'.") + label = node.get("label", str(node_id)) + color = node.get("color", None) + else: + node_id = node + label = str(node_id) + color = None + node_map[node_id] = {"id": node_id, "label": label, "color": color} + order.append(node_id) + return node_map, order + + +def _normalize_flows(flows: Any) -> list[dict[str, Any]]: + """Normalize flow definitions into a list of dicts.""" + if flows is None: + raise ValueError("Flows are required to draw a sankey diagram.") + normalized = [] + for flow in flows: + # Support dict flows or tuple-like flows. + if isinstance(flow, dict): + source = flow["source"] + target = flow["target"] + value = flow["value"] + label = flow.get("label", None) + color = flow.get("color", None) + else: + if len(flow) < 3: + raise ValueError( + "Flow tuples must have at least (source, target, value)." + ) + source, target, value = flow[:3] + label = flow[3] if len(flow) > 3 else None + color = flow[4] if len(flow) > 4 else None + if value is None or value < 0: + raise ValueError("Flow values must be non-negative.") + if value == 0: + continue + # Store a consistent flow record for downstream layout/drawing. + normalized.append( + { + "source": source, + "target": target, + "value": float(value), + "label": label, + "color": color, + "group": flow.get("group", None) if isinstance(flow, dict) else None, + } + ) + if not normalized: + raise ValueError("Flows must include at least one non-zero value.") + return normalized + + +def _assign_layers( + flows: Sequence[Mapping[str, Any]], + nodes: Sequence[Any], + layers: Mapping[Any, int] | None, +) -> dict[Any, int]: + """Assign layer indices for nodes using a DAG topological pass.""" + if layers is not None: + # Honor explicit layer assignments when provided. + layer_map = dict(layers) + missing = [node for node in nodes if node not in layer_map] + if missing: + raise ValueError(f"Missing layer assignments for nodes: {missing}") + return layer_map + + # Build adjacency for a simple topological layer assignment. + successors = {node: set() for node in nodes} + predecessors = {node: set() for node in nodes} + for flow in flows: + source = flow["source"] + target = flow["target"] + successors[source].add(target) + predecessors[target].add(source) + + layer_map = {node: 0 for node in nodes} + indegree = {node: len(preds) for node, preds in predecessors.items()} + queue = [node for node, deg in indegree.items() if deg == 0] + visited = 0 + # Kahn's algorithm to assign layers from sources outward. + while queue: + node = queue.pop(0) + visited += 1 + for succ in successors[node]: + layer_map[succ] = max(layer_map[succ], layer_map[node] + 1) + indegree[succ] -= 1 + if indegree[succ] == 0: + queue.append(succ) + if visited != len(nodes): + raise ValueError("Sankey nodes must form a directed acyclic graph.") + return layer_map + + +def _compute_layout( + nodes: Sequence[Any], + flows: Sequence[Mapping[str, Any]], + *, + node_pad: float, + node_width: float, + align: str, + layers: Mapping[Any, int] | None, + margin: float, + layer_order: Sequence[int] | None = None, +) -> tuple[ + dict[str, Any], + dict[Any, list[dict[str, Any]]], + dict[Any, list[dict[str, Any]]], + dict[Any, float], +]: + """Compute node and flow layout geometry in axes-relative coordinates.""" + # Split flows into incoming/outgoing for node sizing. + flow_in = {node: [] for node in nodes} + flow_out = {node: [] for node in nodes} + for flow in flows: + flow_out[flow["source"]].append(flow) + flow_in[flow["target"]].append(flow) + + node_value = {} + for node in nodes: + incoming = sum(flow["value"] for flow in flow_in[node]) + outgoing = sum(flow["value"] for flow in flow_out[node]) + # Nodes size to the larger of in/out totals. + node_value[node] = max(incoming, outgoing) + + layer_map = _assign_layers(flows, nodes, layers) + max_layer = max(layer_map.values()) if layer_map else 0 + if layer_order is None: + layer_order = sorted(set(layer_map.values())) + # Group nodes by layer in the desired order. + grouped = {layer: [] for layer in layer_order} + for node in nodes: + grouped[layer_map[node]].append(node) + + height_available = 1.0 - 2 * margin + layer_totals = [] + for layer, layer_nodes in grouped.items(): + total = sum(node_value[node] for node in layer_nodes) + total += node_pad * max(len(layer_nodes) - 1, 0) + layer_totals.append(total) + scale = height_available / max(layer_totals) if layer_totals else 1.0 + + # Lay out nodes within each layer using the same scale. + layout = {"nodes": {}, "scale": scale, "layers": layer_map} + for layer in layer_order: + layer_nodes = grouped[layer] + total = sum(node_value[node] for node in layer_nodes) * scale + total += node_pad * max(len(layer_nodes) - 1, 0) + if align == "top": + start = margin + (height_available - total) + elif align == "bottom": + start = margin + else: + start = margin + (height_available - total) / 2 + y = start + for node in layer_nodes: + height = node_value[node] * scale + layout["nodes"][node] = { + "x": margin + + (1.0 - 2 * margin - node_width) * (layer / max(max_layer, 1)), + "y": y, + "width": node_width, + "height": height, + } + y += height + node_pad + return layout, flow_in, flow_out, node_value + + +def _ribbon_path( + x0: float, + y0: float, + x1: float, + y1: float, + thickness: float, + curvature: float, +) -> mpath.Path: + """Build a closed Bezier path for a ribbon segment.""" + dx = x1 - x0 + if dx <= 0: + dx = max(thickness, 0.02) + cx0 = x0 + dx * curvature + cx1 = x1 - dx * curvature + top0 = y0 + thickness / 2 + bot0 = y0 - thickness / 2 + top1 = y1 + thickness / 2 + bot1 = y1 - thickness / 2 + verts = [ + (x0, top0), + (cx0, top0), + (cx1, top1), + (x1, top1), + (x1, bot1), + (cx1, bot1), + (cx0, bot0), + (x0, bot0), + (x0, top0), + ] + codes = [ + mpath.Path.MOVETO, + mpath.Path.CURVE4, + mpath.Path.CURVE4, + mpath.Path.CURVE4, + mpath.Path.LINETO, + mpath.Path.CURVE4, + mpath.Path.CURVE4, + mpath.Path.CURVE4, + mpath.Path.CLOSEPOLY, + ] + return mpath.Path(verts, codes) + + +def _bezier_point(p0: float, p1: float, p2: float, p3: float, t: float) -> float: + """Evaluate a cubic Bezier coordinate at t in [0, 1].""" + u = 1 - t + return (u**3) * p0 + 3 * (u**2) * t * p1 + 3 * u * (t**2) * p2 + (t**3) * p3 + + +def _flow_label_point( + x0: float, + y0: float, + x1: float, + y1: float, + thickness: float, + curvature: float, + frac: float, +) -> tuple[float, float]: + """Return a point along the flow centerline for label placement.""" + dx = x1 - x0 + if dx <= 0: + dx = max(thickness, 0.02) + cx0 = x0 + dx * curvature + cx1 = x1 - dx * curvature + target_x = x0 + (x1 - x0) * frac + if x1 == x0: + t = frac + else: + lo, hi = 0.0, 1.0 + for _ in range(24): + mid = (lo + hi) / 2 + mid_x = _bezier_point(x0, cx0, cx1, x1, mid) + if mid_x < target_x: + lo = mid + else: + hi = mid + t = (lo + hi) / 2 + x = _bezier_point(x0, cx0, cx1, x1, t) + y = _bezier_point(y0, y0, y1, y1, t) + return x, y + + +def _apply_style( + style: str | None, + *, + flow_cycle: Sequence[Any] | None, + node_facecolor: Any, + flow_alpha: float, + flow_curvature: float, + node_label_box: bool | Mapping[str, Any] | None, + node_label_kw: Mapping[str, Any], +) -> dict[str, Any]: + """Apply a named style preset and merge overrides.""" + if style is None: + return { + "flow_cycle": flow_cycle, + "node_facecolor": node_facecolor, + "flow_alpha": flow_alpha, + "flow_curvature": flow_curvature, + "node_label_box": node_label_box, + "node_label_kw": node_label_kw, + } + presets = { + "budget": dict( + node_facecolor="0.8", + flow_alpha=0.85, + flow_curvature=0.55, + node_label_box=True, + node_label_kw=dict(fontsize=9, color="0.2"), + ), + "pastel": dict( + node_facecolor="0.88", + flow_alpha=0.7, + flow_curvature=0.6, + node_label_box=True, + ), + "mono": dict( + node_facecolor="0.7", + flow_alpha=0.5, + flow_curvature=0.45, + node_label_box=False, + flow_cycle=["0.55"], + ), + } + if style not in presets: + raise ValueError(f"Unknown sankey style {style!r}.") + preset = presets[style] + # Merge preset overrides with caller-provided defaults. + return { + "flow_cycle": preset.get("flow_cycle", flow_cycle), + "node_facecolor": preset.get("node_facecolor", node_facecolor), + "flow_alpha": preset.get("flow_alpha", flow_alpha), + "flow_curvature": preset.get("flow_curvature", flow_curvature), + "node_label_box": preset.get("node_label_box", node_label_box), + "node_label_kw": {**preset.get("node_label_kw", {}), **node_label_kw}, + } + + +def _apply_flow_other( + flows: list[dict[str, Any]], flow_other: float | None, other_label: str +) -> list[dict[str, Any]]: + """Aggregate small flows into a single 'Other' target per source.""" + if flow_other is None: + return flows + # Collapse small values per source into an "Other" flow. + other_sums = {} + filtered = [] + for flow in flows: + if flow["value"] < flow_other: + other_sums[flow["source"]] = ( + other_sums.get(flow["source"], 0.0) + flow["value"] + ) + else: + filtered.append(flow) + flows = filtered + for source, other_sum in other_sums.items(): + if other_sum <= 0: + continue + flows.append( + { + "source": source, + "target": other_label, + "value": other_sum, + "label": None, + "color": None, + "group": None, + } + ) + return flows + + +def _ensure_nodes( + nodes: Any, + flows: Sequence[Mapping[str, Any]], + node_order: Sequence[Any] | None, +) -> tuple[dict[Any, dict[str, Any]], list[Any]]: + """Ensure all flow endpoints exist in nodes and validate ordering.""" + node_map, node_order_default = _normalize_nodes(nodes, flows) + # Add any missing flow endpoints to the node list if ordering is implicit. + flow_nodes = {flow["source"] for flow in flows} | {flow["target"] for flow in flows} + missing_nodes = [node for node in flow_nodes if node not in node_map] + if missing_nodes and node_order is not None: + raise ValueError("node_order must include every node exactly once.") + if missing_nodes: + for node in missing_nodes: + node_map[node] = {"id": node, "label": str(node), "color": None} + node_order_default.append(node) + node_order = node_order or node_order_default + if set(node_order) != set(node_map.keys()): + raise ValueError("node_order must include every node exactly once.") + return node_map, node_order + + +def _assign_flow_colors( + flows: Sequence[Mapping[str, Any]], + flow_cycle: Sequence[Any] | None, + group_cycle: Sequence[Any] | None, +) -> dict[Any, Any]: + """Assign colors to flows by group or source.""" + if flow_cycle is None: + flow_cycle = ["0.6"] + if group_cycle is None: + group_cycle = flow_cycle + group_iter = iter(group_cycle) + flow_color_map = {} + # Assign a stable color per group (or per source if no group). + for flow in flows: + if flow["color"] is not None: + continue + group = flow["group"] or flow["source"] + if group not in flow_color_map: + try: + flow_color_map[group] = next(group_iter) + except StopIteration: + group_iter = iter(group_cycle) + flow_color_map[group] = next(group_iter) + flow["color"] = flow_color_map[group] + return flow_color_map + + +def _sort_flows( + flows: Sequence[Mapping[str, Any]], + node_order: Sequence[Any], + layout: Mapping[str, Any], +) -> list[dict[str, Any]]: + """Sort flows by target position to reduce crossings.""" + # Order outgoing links by target center to reduce line crossings. + node_centers = { + node: layout["nodes"][node]["y"] + layout["nodes"][node]["height"] / 2 + for node in node_order + } + ordered = [] + seen = set() + for source in node_order: + outgoing = [flow for flow in flows if flow["source"] == source] + outgoing = sorted(outgoing, key=lambda f: node_centers[f["target"]]) + for flow in outgoing: + ordered.append(flow) + seen.add(id(flow)) + for flow in flows: + if id(flow) not in seen: + ordered.append(flow) + return ordered + + +def _flow_label_text( + flow: Mapping[str, Any], value_format: str | Callable[[float], str] | None +) -> str: + """Resolve the text for a flow label.""" + label_text = flow.get("label", None) + if label_text is not None: + return label_text + if value_format is None: + return f"{flow['value']:.3g}" + if callable(value_format): + return value_format(flow["value"]) + return value_format.format(flow["value"]) + + +def _flow_label_frac(idx: int, count: int, base: float) -> float: + """Return alternating label positions around the midpoint.""" + if count <= 1: + return base + step = 0.25 if count == 2 else 0.2 + offset = (idx // 2 + 1) * step + frac = base - offset if idx % 2 == 0 else base + offset + return min(max(frac, 0.05), 0.95) + + +def _prepare_inputs( + *, + nodes: Any, + flows: Any, + flow_other: float | None, + other_label: str, + node_order: Sequence[Any] | None, + style: str | None, + flow_cycle: Sequence[Any] | None, + node_facecolor: Any, + flow_alpha: float, + flow_curvature: float, + node_label_box: bool | Mapping[str, Any] | None, + node_label_kw: Mapping[str, Any], + group_cycle: Sequence[Any] | None, +) -> tuple[ + list[dict[str, Any]], + dict[Any, dict[str, Any]], + list[Any], + dict[str, Any], + dict[Any, Any], +]: + """Normalize inputs, apply style, and assign colors.""" + # Parse flows and optional "other" aggregation. + flows = _normalize_flows(flows) + flows = _apply_flow_other(flows, flow_other, other_label) + # Ensure nodes include all flow endpoints. + node_map, node_order = _ensure_nodes(nodes, flows, node_order) + # Apply style presets and merge overrides. + style_config = _apply_style( + style, + flow_cycle=flow_cycle, + node_facecolor=node_facecolor, + flow_alpha=flow_alpha, + flow_curvature=flow_curvature, + node_label_box=node_label_box, + node_label_kw=node_label_kw, + ) + # Resolve flow colors after style is applied. + flow_color_map = _assign_flow_colors(flows, style_config["flow_cycle"], group_cycle) + return flows, node_map, node_order, style_config, flow_color_map + + +def _validate_layer_order( + layer_order: Sequence[int] | None, + flows: Sequence[Mapping[str, Any]], + node_order: Sequence[Any], + layers: Mapping[Any, int] | None, +) -> None: + """Validate that layer_order is consistent with computed layers.""" + if layer_order is None: + return + # Compare explicit ordering with the computed layer set. + layer_map = _assign_layers(flows, node_order, layers) + if set(layer_order) != set(layer_map.values()): + raise ValueError("layer_order must include every layer.") + + +def _layer_positions( + layout: Mapping[str, Any], layer_order: Sequence[int] | None +) -> tuple[dict[Any, int], dict[int, int]]: + """Return layer maps and positions for label placement.""" + # Map layer ids to positions for outside-label placement. + layer_map = layout["layers"] + if layer_order is not None: + layer_position = {layer: idx for idx, layer in enumerate(layer_order)} + else: + layer_position = {layer: layer for layer in set(layer_map.values())} + return layer_map, layer_position + + +def _label_box( + node_label_box: bool | Mapping[str, Any] | None, +) -> dict[str, Any] | None: + """Return a bbox dict for node labels, if requested.""" + if not node_label_box: + return None + if node_label_box is True: + # Default rounded box styling. + return dict( + boxstyle="round,pad=0.2,rounding_size=0.1", + facecolor="white", + edgecolor="none", + alpha=0.9, + ) + return dict(node_label_box) + + +def _draw_flows( + ax, + *, + flows: Sequence[Mapping[str, Any]], + node_order: Sequence[Any], + layout: Mapping[str, Any], + flow_color_map: Mapping[Any, Any], + flow_kw: Mapping[str, Any], + label_kw: Mapping[str, Any], + flow_label_kw: Mapping[str, Any], + flow_labels: bool, + value_format: str | Callable[[float], str] | None, + flow_label_pos: float, + flow_alpha: float, + flow_curvature: float, +) -> tuple[list[mpatches.PathPatch], dict[Any, Any]]: + """Draw flow ribbons and optional labels.""" + flow_patches = [] + labels_out = {} + label_items = [] + # Track running offsets per node so flows stack without overlap. + out_offsets = {node: 0.0 for node in node_order} + in_offsets = {node: 0.0 for node in node_order} + link_counts = {} + link_seen = {} + if flow_labels: + # Count links so multiple labels on the same link can be spaced. + for flow in flows: + key = (flow["source"], flow["target"]) + link_counts[key] = link_counts.get(key, 0) + 1 + for flow in flows: + source = flow["source"] + target = flow["target"] + thickness = flow["value"] * layout["scale"] + src = layout["nodes"][source] + tgt = layout["nodes"][target] + x0 = src["x"] + src["width"] + x1 = tgt["x"] + y0 = src["y"] + out_offsets[source] + thickness / 2 + y1 = tgt["y"] + in_offsets[target] + thickness / 2 + out_offsets[source] += thickness + in_offsets[target] += thickness + # Resolve color and build the ribbon patch. + color = flow["color"] or flow_color_map.get(flow["group"] or source, "0.6") + facecolor = _tint(color, 0.35) + path = _ribbon_path(x0, y0, x1, y1, thickness, flow_curvature) + base_flow_kw = {"edgecolor": "none", "linewidth": 0.0} + base_flow_kw.update(flow_kw) + flow_facecolor = base_flow_kw.pop("facecolor", facecolor) + patch = mpatches.PathPatch( + path, + facecolor=flow_facecolor, + alpha=flow_alpha, + **base_flow_kw, + ) + ax.add_patch(patch) + flow_patches.append(patch) + + if flow_labels: + # Place label along the ribbon length. + label_text = _flow_label_text(flow, value_format) + if label_text: + key = (source, target) + count = link_counts.get(key, 1) + idx = link_seen.get(key, 0) + link_seen[key] = idx + 1 + frac = _flow_label_frac(idx, count, flow_label_pos) + label_x, label_y = _flow_label_point( + x0, y0, x1, y1, thickness, flow_curvature, frac + ) + text = ax.text( + label_x, + label_y, + str(label_text), + ha="center", + va="center", + **{**label_kw, **flow_label_kw}, + ) + labels_out[(source, target, idx)] = text + label_items.append( + { + "text": text, + "source": source, + "target": target, + "x0": x0, + "x1": x1, + "y0": y0, + "y1": y1, + "thickness": thickness, + "curvature": flow_curvature, + "frac": frac, + "adjusted": False, + } + ) + + if flow_labels and len(label_items) > 1: + + def _set_label_position(item: dict[str, Any], frac: float) -> None: + label_x, label_y = _flow_label_point( + item["x0"], + item["y0"], + item["x1"], + item["y1"], + item["thickness"], + item["curvature"], + frac, + ) + item["text"].set_position((label_x, label_y)) + item["frac"] = frac + + for i in range(len(label_items)): + for j in range(i + 1, len(label_items)): + a = label_items[i] + b = label_items[j] + if (a["y0"] - b["y0"]) * (a["y1"] - b["y1"]) < 0: + if not a["adjusted"] and not b["adjusted"]: + _set_label_position(a, 0.25) + _set_label_position(b, 0.75) + a["adjusted"] = True + b["adjusted"] = True + elif a["adjusted"] ^ b["adjusted"]: + primary = a if a["adjusted"] else b + secondary = b if a["adjusted"] else a + if abs(primary["frac"] - 0.25) < 1.0e-6: + target = 0.75 + elif abs(primary["frac"] - 0.75) < 1.0e-6: + target = 0.25 + else: + target = 0.25 + _set_label_position(secondary, target) + secondary["adjusted"] = True + return flow_patches, labels_out + + +def _draw_nodes( + ax, + *, + node_order: Sequence[Any], + node_map: Mapping[Any, Mapping[str, Any]], + layout: Mapping[str, Any], + layer_map: Mapping[Any, int], + layer_position: Mapping[int, int], + node_facecolor: Any, + node_kw: Mapping[str, Any], + label_kw: Mapping[str, Any], + node_label_kw: Mapping[str, Any], + node_label_box: bool | Mapping[str, Any] | None, + node_labels: bool, + node_label_outside: bool | str, + node_label_offset: float, +) -> tuple[dict[Any, mpatches.Patch], dict[Any, Any]]: + """Draw node rectangles and optional labels.""" + node_patches = {} + labels_out = {} + for node in node_order: + node_info = layout["nodes"][node] + facecolor = node_map[node]["color"] or node_facecolor + # Draw the node block. + base_node_kw = {"edgecolor": "none", "linewidth": 0.0} + base_node_kw.update(node_kw) + node_face = base_node_kw.pop("facecolor", facecolor) + patch = mpatches.FancyBboxPatch( + (node_info["x"], node_info["y"]), + node_info["width"], + node_info["height"], + boxstyle="round,pad=0.0,rounding_size=0.008", + facecolor=node_face, + **base_node_kw, + ) + ax.add_patch(patch) + node_patches[node] = patch + if node_labels: + # Place labels inside or outside based on width and position. + box_kw = _label_box(node_label_box) + label_x = node_info["x"] + node_info["width"] / 2 + label_y = node_info["y"] + node_info["height"] / 2 + ha = "center" + if node_label_outside: + mode = node_label_outside + if mode == "auto": + mode = node_info["width"] < 0.04 + if mode: + layer = layer_position[layer_map[node]] + if layer == 0: + label_x = node_info["x"] - node_label_offset + ha = "right" + elif layer == max(layer_position.values()): + label_x = ( + node_info["x"] + node_info["width"] + node_label_offset + ) + ha = "left" + labels_out[node] = ax.text( + label_x, + label_y, + node_map[node]["label"], + ha=ha, + va="center", + bbox=box_kw, + **{**label_kw, **node_label_kw}, + ) + return node_patches, labels_out + + +def sankey_diagram( + ax, + *, + nodes=None, + flows=None, + layers=None, + flow_cycle=None, + group_cycle=None, + node_order=None, + layer_order=None, + style=None, + flow_other=None, + other_label="Other", + value_format=None, + node_pad=0.02, + node_width=0.03, + node_kw=None, + flow_kw=None, + label_kw=None, + node_label_kw=None, + flow_label_kw=None, + node_label_box=None, + node_labels=True, + flow_labels=False, + flow_sort=True, + flow_label_pos=0.5, + node_label_outside="auto", + node_label_offset=0.01, + align="center", + margin=0.05, + flow_alpha=0.75, + flow_curvature=0.5, + node_facecolor="0.75", +) -> SankeyDiagram: + """Render a layered Sankey diagram with optional labels.""" + node_kw = node_kw or {} + flow_kw = flow_kw or {} + label_kw = label_kw or {} + node_label_kw = node_label_kw or {} + flow_label_kw = flow_label_kw or {} + + # Normalize inputs, apply presets, and assign colors. + flows, node_map, node_order, style_config, flow_color_map = _prepare_inputs( + nodes=nodes, + flow_cycle=flow_cycle, + flow_other=flow_other, + other_label=other_label, + node_order=node_order, + style=style, + node_label_box=node_label_box, + node_label_kw=node_label_kw, + node_facecolor=node_facecolor, + flow_alpha=flow_alpha, + flow_curvature=flow_curvature, + group_cycle=group_cycle, + flows=flows, + ) + node_facecolor = style_config["node_facecolor"] + flow_alpha = style_config["flow_alpha"] + flow_curvature = style_config["flow_curvature"] + node_label_box = style_config["node_label_box"] + node_label_kw = style_config["node_label_kw"] + + # Validate optional layer ordering before layout. + _validate_layer_order(layer_order, flows, node_order, layers) + + layout, _, _, _ = _compute_layout( + node_order, + flows, + node_pad=node_pad, + node_width=node_width, + align=align, + layers=layers, + margin=margin, + layer_order=layer_order, + ) + + layout["groups"] = flow_color_map + + # Cache layer indices for label placement. + layer_map, layer_position = _layer_positions(layout, layer_order) + + if flow_sort: + # Reorder flows to reduce crossings. + flows = _sort_flows(flows, node_order, layout) + + # Draw flows and nodes, then merge their label handles. + flow_patches, flow_labels_out = _draw_flows( + ax, + flows=flows, + node_order=node_order, + layout=layout, + flow_color_map=flow_color_map, + flow_kw=flow_kw, + label_kw=label_kw, + flow_label_kw=flow_label_kw, + flow_labels=flow_labels, + value_format=value_format, + flow_label_pos=flow_label_pos, + flow_alpha=flow_alpha, + flow_curvature=flow_curvature, + ) + node_patches, node_labels_out = _draw_nodes( + ax, + node_order=node_order, + node_map=node_map, + layout=layout, + layer_map=layer_map, + layer_position=layer_position, + node_facecolor=node_facecolor, + node_kw=node_kw, + label_kw=label_kw, + node_label_kw=node_label_kw, + node_label_box=node_label_box, + node_labels=node_labels, + node_label_outside=node_label_outside, + node_label_offset=node_label_offset, + ) + labels_out = {**flow_labels_out, **node_labels_out} + + # Lock axes to the unit square. + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_axis_off() + + return SankeyDiagram( + nodes=node_patches, + flows=flow_patches, + labels=labels_out, + layout=layout, + ) diff --git a/ultraplot/internals/rcsetup.py b/ultraplot/internals/rcsetup.py index dc8c68463..7177e32a9 100644 --- a/ultraplot/internals/rcsetup.py +++ b/ultraplot/internals/rcsetup.py @@ -932,6 +932,37 @@ def copy(self): _validate_bool, "Whether to draw arrows at the end of curved quiver lines by default.", ), + # Sankey settings + "sankey.nodepad": ( + 0.02, + _validate_float, + "Vertical padding between nodes in layered sankey diagrams.", + ), + "sankey.nodewidth": ( + 0.03, + _validate_float, + "Node width for layered sankey diagrams (axes-relative units).", + ), + "sankey.margin": ( + 0.05, + _validate_float, + "Margin around layered sankey diagrams (axes-relative units).", + ), + "sankey.flow.alpha": ( + 0.75, + _validate_float, + "Flow transparency for layered sankey diagrams.", + ), + "sankey.flow.curvature": ( + 0.5, + _validate_float, + "Flow curvature for layered sankey diagrams.", + ), + "sankey.node.facecolor": ( + "0.75", + _validate_color, + "Default node facecolor for layered sankey diagrams.", + ), # Stylesheet "style": ( None, diff --git a/ultraplot/tests/test_config.py b/ultraplot/tests/test_config.py index 11a308b56..e097e621d 100644 --- a/ultraplot/tests/test_config.py +++ b/ultraplot/tests/test_config.py @@ -1,6 +1,9 @@ -import ultraplot as uplt, pytest import importlib +import pytest + +import ultraplot as uplt + def test_wrong_keyword_reset(): """ @@ -34,9 +37,22 @@ def test_cycle_in_rc_file(tmp_path): assert uplt.rc["cycle"] == "colorblind" +def test_sankey_rc_defaults(): + """ + Sanity check the new sankey defaults in rc. + """ + assert uplt.rc["sankey.nodepad"] == 0.02 + assert uplt.rc["sankey.nodewidth"] == 0.03 + assert uplt.rc["sankey.margin"] == 0.05 + assert uplt.rc["sankey.flow.alpha"] == 0.75 + assert uplt.rc["sankey.flow.curvature"] == 0.5 + assert uplt.rc["sankey.node.facecolor"] == "0.75" + + import io -from unittest.mock import patch, MagicMock from importlib.metadata import PackageNotFoundError +from unittest.mock import MagicMock, patch + from ultraplot.utils import check_for_update diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 1bcb69684..b29fe0f61 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -722,6 +722,361 @@ def test_curved_quiver_color_and_cmap(rng, cmap): return fig +@pytest.mark.mpl_image_compare +def test_sankey_basic(): + """ + Basic sanity check for Sankey diagrams. + """ + fig, ax = uplt.subplots() + diagram = ax.sankey( + flows=[1.0, -0.6, -0.4], + labels=["in", "out_a", "out_b"], + orientations=[0, 1, -1], + trunklength=1.1, + ) + assert getattr(diagram, "patch", None) is not None + assert getattr(diagram, "flows", None) is not None + return fig + + +@pytest.mark.mpl_image_compare +def test_sankey_layered_nodes_flows(): + """ + Check that layered sankey accepts nodes and flows. + """ + fig, ax = uplt.subplots() + nodes = ["Budget", "Ops", "R&D", "Marketing"] + flows = [ + ("Budget", "Ops", 5), + ("Budget", "R&D", 3), + ("Budget", "Marketing", 2), + ] + diagram = ax.sankey(nodes=nodes, flows=flows) + assert len(diagram.nodes) == len(nodes) + assert len(diagram.flows) == len(flows) + return fig + + +@pytest.mark.mpl_image_compare +def test_sankey_layered_labels_and_style(): + """ + Check that style presets and label boxes are accepted. + """ + fig, ax = uplt.subplots() + nodes = ["Budget", "Ops", "R&D", "Marketing"] + flows = [ + ("Budget", "Ops", 5), + ("Budget", "R&D", 3), + ("Budget", "Marketing", 2), + ] + diagram = ax.sankey( + nodes=nodes, + flows=flows, + style="budget", + flow_labels=True, + value_format="{:.1f}", + node_label_box=True, + ) + flow_label_keys = [key for key in diagram.labels if isinstance(key, tuple)] + assert flow_label_keys + return fig + + +def test_sankey_invalid_flows(): + """Validate error handling for malformed flow inputs.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + with pytest.raises(ValueError): + sankey_mod._normalize_flows(None) + with pytest.raises(ValueError): + sankey_mod._normalize_flows([("A", "B", -1)]) + with pytest.raises(ValueError): + sankey_mod._normalize_flows([("A", "B", 0)]) + + +def test_sankey_cycle_layers_error(): + """Cycles in the graph should raise a clear error.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [ + {"source": "A", "target": "B", "value": 1.0}, + {"source": "B", "target": "A", "value": 1.0}, + ] + with pytest.raises(ValueError): + sankey_mod._assign_layers(flows, ["A", "B"], None) + + +def test_sankey_flow_label_frac_alternates(): + """Label fractions should alternate around the midpoint.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + base = 0.5 + assert sankey_mod._flow_label_frac(0, 2, base) == 0.25 + assert sankey_mod._flow_label_frac(1, 2, base) == 0.75 + frac0 = sankey_mod._flow_label_frac(0, 3, base) + frac1 = sankey_mod._flow_label_frac(1, 3, base) + frac2 = sankey_mod._flow_label_frac(2, 3, base) + assert 0.05 <= frac0 <= 0.95 + assert 0.05 <= frac1 <= 0.95 + assert 0.05 <= frac2 <= 0.95 + assert frac0 < base < frac1 + + +def test_sankey_node_labels_outside_auto(): + """Auto outside labels should flip to the left/right on edge layers.""" + fig, ax = uplt.subplots() + diagram = ax.sankey( + nodes=["A", "B", "C"], + flows=[("A", "B", 2.0), ("B", "C", 2.0)], + node_labels=True, + flow_labels=False, + ) + label_a = diagram.labels["A"] + label_c = diagram.labels["C"] + node_a = diagram.nodes["A"] + node_c = diagram.nodes["C"] + ax_a, _ = label_a.get_position() + ax_c, _ = label_c.get_position() + assert ax_a < node_a.get_x() + assert ax_c > node_c.get_x() + node_c.get_width() + uplt.close(fig) + + +def test_sankey_flow_other_creates_other_node(): + """Small flows should be aggregated into an 'Other' node when requested.""" + fig, ax = uplt.subplots() + diagram = ax.sankey( + flows=[("A", "X", 0.2), ("A", "Y", 2.0)], + flow_other=0.5, + other_label="Other", + node_labels=True, + ) + assert "Other" in diagram.nodes + assert "Other" in diagram.labels + uplt.close(fig) + + +def test_sankey_unknown_style_error(): + """Unknown style presets should raise.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + with pytest.raises(ValueError): + sankey_mod._apply_style( + "nope", + flow_cycle=["C0"], + node_facecolor="0.7", + flow_alpha=0.8, + flow_curvature=0.5, + node_label_box=False, + node_label_kw={}, + ) + + +def test_sankey_links_parameter_uses_layered(): + """Links should force layered sankey even with numeric flows input.""" + fig, ax = uplt.subplots() + diagram = ax.sankey( + flows=[1.0, -1.0], + links=[("A", "B", 1.0)], + node_labels=False, + flow_labels=False, + ) + assert "A" in diagram.nodes + assert "B" in diagram.nodes + assert diagram.layout["scale"] > 0 + uplt.close(fig) + + +def test_sankey_tuple_flows_use_layered(): + """Tuple flows without nodes should trigger layered sankey.""" + fig, ax = uplt.subplots() + diagram = ax.sankey(flows=[("A", "B", 1.0)]) + assert "A" in diagram.nodes + assert "B" in diagram.nodes + uplt.close(fig) + + +def test_sankey_dict_flows_use_layered(): + """Dict flows should trigger layered sankey.""" + fig, ax = uplt.subplots() + diagram = ax.sankey(flows=[{"source": "A", "target": "B", "value": 1.0}]) + assert "A" in diagram.nodes + assert "B" in diagram.nodes + assert "nodes" in diagram.layout + uplt.close(fig) + + +def test_sankey_mixed_flow_formats_layered(): + """Mixed dict/tuple flows should still render in layered mode.""" + fig, ax = uplt.subplots() + flows = [ + {"source": "A", "target": "B", "value": 1.0}, + ("B", "C", 2.0), + ] + diagram = ax.sankey(flows=flows) + assert set(diagram.nodes.keys()) == {"A", "B", "C"} + assert len(diagram.flows) == 2 + uplt.close(fig) + + +def test_sankey_numpy_flows_use_matplotlib(): + """1D numeric flows should use Matplotlib Sankey.""" + import numpy as np + + fig, ax = uplt.subplots() + diagram = ax.sankey(flows=np.array([1.0, -1.0])) + assert hasattr(diagram, "patch") + assert not hasattr(diagram, "layout") + uplt.close(fig) + + +def test_sankey_matplotlib_kwargs_passthrough(): + """Matplotlib sankey should pass patch kwargs through.""" + from matplotlib.colors import to_rgba + + fig, ax = uplt.subplots() + diagram = ax.sankey( + flows=[1.0, -1.0], + orientations=[0, 0], + facecolor="red", + edgecolor="blue", + linewidth=1.5, + ) + assert np.allclose(diagram.patch.get_facecolor(), to_rgba("red")) + assert np.allclose(diagram.patch.get_edgecolor(), to_rgba("blue")) + assert diagram.patch.get_linewidth() == 1.5 + uplt.close(fig) + + +def test_sankey_matplotlib_connect_none(): + """Matplotlib sankey should allow connect=None.""" + fig, ax = uplt.subplots() + diagram = ax.sankey( + flows=[1.0, -1.0], + orientations=[0, 0], + connect=None, + ) + assert hasattr(diagram, "patch") + uplt.close(fig) + + +def test_sankey_normalize_nodes_dict_order_and_labels(): + """Node dict inputs should preserve order and resolve labels.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + nodes = {"A": {"label": "Alpha"}, "B": {"label": "Beta"}} + flows = [{"source": "A", "target": "B", "value": 1.0}] + node_map, order = sankey_mod._normalize_nodes(nodes, flows) + assert order == ["A", "B"] + assert node_map["A"]["label"] == "Alpha" + assert node_map["B"]["label"] == "Beta" + + +def test_sankey_layer_order_missing_raises(): + """layer_order must include every layer.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [ + {"source": "A", "target": "B", "value": 1.0}, + {"source": "B", "target": "C", "value": 1.0}, + ] + with pytest.raises(ValueError): + sankey_mod._validate_layer_order([0], flows, ["A", "B", "C"], None) + + +def test_sankey_label_box_dict_copy(): + """Label box dicts should be copied so callers can reuse input.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + box = {"boxstyle": "round", "facecolor": "white"} + resolved = sankey_mod._label_box(box) + assert resolved == box + assert resolved is not box + + +def test_sankey_label_box_default(): + """node_label_box=True should create a default box style.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + resolved = sankey_mod._label_box(True) + assert resolved["boxstyle"].startswith("round") + assert resolved["facecolor"] == "white" + + +def test_sankey_assign_flow_colors_group_cycle(): + """Group cycle should be used for flow colors.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [ + {"source": "A", "target": "B", "value": 1.0, "group": "g1", "color": None}, + {"source": "A", "target": "C", "value": 1.0, "group": "g2", "color": None}, + ] + color_map = sankey_mod._assign_flow_colors( + flows, flow_cycle=None, group_cycle=["C0", "C1"] + ) + assert color_map["g1"] == "C0" + assert color_map["g2"] == "C1" + assert flows[0]["color"] == "C0" + assert flows[1]["color"] == "C1" + + +def test_sankey_assign_flow_colors_preserves_explicit(): + """Explicit flow colors should be preserved.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [ + {"source": "A", "target": "B", "value": 1.0, "group": "g1", "color": "red"} + ] + color_map = sankey_mod._assign_flow_colors(flows, flow_cycle=None, group_cycle=None) + assert flows[0]["color"] == "red" + assert color_map == {} + + +def test_sankey_node_dict_missing_id_raises(): + """Node dicts must include id or name.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [{"source": "A", "target": "B", "value": 1.0}] + with pytest.raises(ValueError): + sankey_mod._normalize_nodes([{"label": "missing"}], flows) + + +def test_sankey_node_order_missing_nodes_raises(): + """node_order must include all flow endpoints.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [{"source": "A", "target": "B", "value": 1.0}] + with pytest.raises(ValueError): + sankey_mod._ensure_nodes(["A"], flows, node_order=["A"]) + + +def test_sankey_flow_other_multiple_sources(): + """flow_other should aggregate per source.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flows = [ + {"source": "A", "target": "X", "value": 0.2, "label": None, "color": None}, + {"source": "A", "target": "Y", "value": 0.1, "label": None, "color": None}, + {"source": "B", "target": "Z", "value": 0.3, "label": None, "color": None}, + {"source": "B", "target": "W", "value": 2.0, "label": None, "color": None}, + ] + result = sankey_mod._apply_flow_other(flows, 0.5, "Other") + others = [flow for flow in result if flow["target"] == "Other"] + assert len(others) == 2 + sums = {flow["source"]: flow["value"] for flow in others} + assert np.isclose(sums["A"], 0.3) + assert np.isclose(sums["B"], 0.3) + + +def test_sankey_flow_label_text_callable(): + """Callable value_format should be used for flow labels.""" + from ultraplot.axes.plot_types import sankey as sankey_mod + + flow = {"value": 1.234, "label": None} + text = sankey_mod._flow_label_text(flow, lambda v: f"{v:.1f}") + assert text == "1.2" + + def test_histogram_norms(): """ Check that all histograms-like plotting functions