Source code for ridgeplot._color.interpolation

"""Color interpolation utilities."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeAlias

from typing_extensions import Literal, Protocol

from ridgeplot._color.utils import apply_alpha, round_color, to_rgb, unpack_rgb
from ridgeplot._types import CollectionL2, ColorScale
from ridgeplot._utils import get_xy_extrema, normalise_min_max

if TYPE_CHECKING:
    from collections.abc import Generator

    from ridgeplot._types import Densities, Numeric


# ==============================================================
# --- Interpolation utilities
# ==============================================================


[docs] def interpolate_color(colorscale: ColorScale, p: float) -> str: """Get a color from a colorscale at a given interpolation point ``p``. This function always returns a color in the RGB format, even if the input colorscale contains colors in other formats. """ if not (0 <= p <= 1): raise ValueError( f"The interpolation point 'p' should be a float value between 0 and 1, not {p}." ) scale = [s for s, _ in colorscale] colors = [to_rgb(c) for _, c in colorscale] if p in scale: return colors[scale.index(p)] ceil = min(filter(lambda s: s > p, scale)) floor = max(filter(lambda s: s < p, scale)) color_floor = unpack_rgb(colors[scale.index(floor)]) color_ceil = unpack_rgb(colors[scale.index(ceil)]) p_norm = normalise_min_max(p, min_=floor, max_=ceil) rgb = to_rgb( ( color_floor[0] + (p_norm * (color_ceil[0] - color_floor[0])), color_floor[1] + (p_norm * (color_ceil[1] - color_floor[1])), color_floor[2] + (p_norm * (color_ceil[2] - color_floor[2])), ) ) alpha_floor = color_floor[3] if len(color_floor) == 4 else 1 alpha_ceil = color_ceil[3] if len(color_ceil) == 4 else 1 alpha = alpha_floor + (p_norm * (alpha_ceil - alpha_floor)) if alpha < 1: rgb = apply_alpha(rgb, alpha) # To address floating point errors, we round all color channels to a # reasonable precision, which should result in the exact some result # being rendered by any browsers and most Plotly output formats. return round_color(rgb, 5)
[docs] def slice_colorscale( colorscale: ColorScale, p_lower: float, p_upper: float, ) -> ColorScale: """Slice a continuous colorscale between two intermediate points. Parameters ---------- colorscale The continuous colorscale to slice. p_lower The lower bound of the slicing interval. Must be >= 0 and < p_upper. p_upper The upper bound of the slicing interval. Must be <= 1 and > p_lower. Returns ------- ColorScale The sliced colorscale. Raises ------ ValueError If ``p_lower`` is >= ``p_upper``, or if either ``p_lower`` or ``p_upper`` are outside the range [0, 1]. """ if p_lower >= p_upper: raise ValueError("p_lower should be less than p_upper.") if p_lower < 0 or p_upper > 1: raise ValueError("p_lower should be >= 0 and p_upper should be <= 1.") if p_lower == 0 and p_upper == 1: return colorscale return ( (0.0, interpolate_color(colorscale, p=p_lower)), *[ (normalise_min_max(v, min_=p_lower, max_=p_upper), c) for v, c in colorscale if p_lower < v < p_upper ], (1.0, interpolate_color(colorscale, p=p_upper)), )
# ============================================================== # --- Solid color modes # ============================================================== ColorscaleInterpolants: TypeAlias = CollectionL2[float] """A :data:`ColorscaleInterpolants` contains the interpolants for a :data:`ColorScale`. Example ------- >>> interpolants: ColorscaleInterpolants = [ ... [0.2, 0.5, 1], ... [0.3, 0.7], ... ] """
[docs] @dataclass class InterpolationContext: """Context information needed by the interpolation functions.""" densities: Densities n_rows: int n_traces: int x_min: Numeric x_max: Numeric
[docs] @classmethod def from_densities(cls, densities: Densities) -> InterpolationContext: x_min, x_max, _, _ = get_xy_extrema(densities=densities) return cls( densities=densities, n_rows=len(densities), n_traces=sum(len(row) for row in densities), x_min=x_min, x_max=x_max, )
[docs] class InterpolationFunc(Protocol):
[docs] def __call__(self, ctx: InterpolationContext) -> ColorscaleInterpolants: ...
[docs] def _mul(a: tuple[Numeric, ...], b: tuple[Numeric, ...]) -> tuple[Numeric, ...]: """Multiply two tuples element-wise.""" return tuple(a_i * b_i for a_i, b_i in zip(a, b, strict=True))
[docs] def _interpolate_row_index(ctx: InterpolationContext) -> ColorscaleInterpolants: if ctx.n_rows == 1: return [[0.0] * ctx.n_traces] return [ [((ctx.n_rows - 1) - ith_row) / (ctx.n_rows - 1)] * len(row) for ith_row, row in enumerate(ctx.densities) ]
[docs] def _interpolate_trace_index(ctx: InterpolationContext) -> ColorscaleInterpolants: if ctx.n_traces == 1: return [[0.0]] ps = [] ith_trace = 0 for row in ctx.densities: ps_row = [] for _ in row: ps_row.append(((ctx.n_traces - 1) - ith_trace) / (ctx.n_traces - 1)) ith_trace += 1 ps.append(ps_row) return ps
[docs] def _interpolate_trace_index_row_wise(ctx: InterpolationContext) -> ColorscaleInterpolants: return [ [ ((len(row) - 1) - ith_row_trace) / (len(row) - 1) if len(row) > 1 else 0.0 for ith_row_trace in range(len(row)) ] for row in ctx.densities ]
[docs] def _interpolate_mean_minmax(ctx: InterpolationContext) -> ColorscaleInterpolants: ps = [] for row in ctx.densities: ps_row = [] for trace in row: x, y = zip(*trace) ps_row.append( normalise_min_max(sum(_mul(x, y)) / sum(y), min_=ctx.x_min, max_=ctx.x_max) ) ps.append(ps_row) return ps
[docs] def _interpolate_mean_means(ctx: InterpolationContext) -> ColorscaleInterpolants: means = [] for row in ctx.densities: means_row = [] for trace in row: x, y = zip(*trace) means_row.append(sum(_mul(x, y)) / sum(y)) means.append(means_row) min_mean = min(min(row) for row in means) max_mean = max(max(row) for row in means) return [ [normalise_min_max(mean, min_=min_mean, max_=max_mean) for mean in row] for row in means ]
SolidColormode: TypeAlias = Literal[ "row-index", "trace-index", "trace-index-row-wise", "mean-minmax", "mean-means", ] """See :paramref:`ridgeplot.ridgeplot.colormode` for more information.""" SOLID_COLORMODE_MAPS: dict[SolidColormode, InterpolationFunc] = { "row-index": _interpolate_row_index, "trace-index": _interpolate_trace_index, "trace-index-row-wise": _interpolate_trace_index_row_wise, "mean-minmax": _interpolate_mean_minmax, "mean-means": _interpolate_mean_means, }
[docs] def compute_solid_colors( colorscale: ColorScale, colormode: SolidColormode, opacity: float | None, interpolation_ctx: InterpolationContext, ) -> Generator[Generator[str]]: """Compute the solid colors for all traces in the plot.""" def get_fill_color(p: float) -> str: fill_color = interpolate_color(colorscale, p=p) if opacity is not None: # Sometimes the interpolation logic can drop the alpha channel fill_color = apply_alpha(fill_color, alpha=float(opacity)) return fill_color interpolate_func = SOLID_COLORMODE_MAPS[colormode] interpolants = interpolate_func(ctx=interpolation_ctx) return ((get_fill_color(p) for p in row) for row in interpolants)