Source code for ridgeplot._ridgeplot

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, cast

from ridgeplot._figure_factory import create_ridgeplot
from ridgeplot._missing import MISSING
from ridgeplot._types import (
    Densities,
    Samples,
    is_shallow_densities,
    is_shallow_samples,
    nest_shallow_collection,
)
from ridgeplot._utils import normalise_densities

if TYPE_CHECKING:
    from collections.abc import Collection

    import plotly.graph_objects as go
    from typing_extensions import Literal

    from ridgeplot._color.interpolation import SolidColormode
    from ridgeplot._kde import (
        KDEBandwidth,
        KDEPoints,
    )
    from ridgeplot._missing import MissingType
    from ridgeplot._types import (
        Color,
        ColorScale,
        LabelsArray,
        NormalisationOption,
        SampleWeights,
        SampleWeightsArray,
        ShallowDensities,
        ShallowLabelsArray,
        ShallowSamples,
        ShallowSampleWeightsArray,
        ShallowTraceTypesArray,
        TraceType,
        TraceTypesArray,
    )


[docs] def _coerce_to_densities( samples: Samples | ShallowSamples | None, densities: Densities | ShallowDensities | None, # KDE parameters kernel: str, bandwidth: KDEBandwidth, kde_points: KDEPoints, # Histogram parameters nbins: int | None, # Common parameters for density estimation sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights, ) -> Densities: # Importing statsmodels, scipy, and numpy can be slow, # so we're hiding the kde import here to only incur # this cost if the user actually needs this it... from ridgeplot._hist import bin_samples from ridgeplot._kde import estimate_densities # Input validation has_samples = samples is not None has_densities = densities is not None if has_samples and has_densities: raise ValueError("You may not specify both `samples` and `densities` arguments!") if not has_samples and not has_densities: raise ValueError("You must specify either `samples` or `densities`") # Exit early if densities are already provided if has_densities: if is_shallow_densities(densities): densities = nest_shallow_collection(densities) return densities # Transform samples into densities via KDE or histogram binning if is_shallow_samples(samples): samples = nest_shallow_collection(samples) samples = cast("Samples", samples) if nbins is not None: densities = bin_samples( samples=samples, nbins=nbins, sample_weights=sample_weights, ) else: densities = estimate_densities( samples=samples, points=kde_points, kernel=kernel, bandwidth=bandwidth, sample_weights=sample_weights, ) return densities
[docs] def ridgeplot( samples: Samples | ShallowSamples | None = None, densities: Densities | ShallowDensities | None = None, trace_type: TraceTypesArray | ShallowTraceTypesArray | TraceType | None = None, labels: LabelsArray | ShallowLabelsArray | None = None, row_labels: Collection[str] | None | Literal[False] = None, legendgroup: bool = False, # TODO: document and rename to smth better! # KDE parameters kernel: str = "gau", bandwidth: KDEBandwidth = "normal_reference", kde_points: KDEPoints = 500, # Histogram parameters nbins: int | None = None, # Common parameters for density estimation sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights = None, norm: NormalisationOption | None = None, # Coloring and styling parameters colorscale: ColorScale | Collection[Color] | str | None = None, colormode: Literal["fillgradient"] | SolidColormode = "fillgradient", color_discrete_map: dict[str, str] | None = None, opacity: float | None = None, line_color: Color | Literal["fill-color"] = "black", line_width: float | None = None, spacing: float = 0.5, xpad: float = 0.05, # Deprecated parameters coloralpha: float | None | MissingType = MISSING, linewidth: float | MissingType = MISSING, show_yticklabels: bool | MissingType = MISSING, ) -> go.Figure: r"""Return an interactive ridgeline (Plotly) |~go.Figure|. .. note:: You must specify either :paramref:`.samples` or :paramref:`.densities` to this function, but not both. When specifying :paramref:`.samples`, the function will estimate the densities using either Kernel Density Estimation (KDE) or histogram binning. When specifying :paramref:`.densities`, the function will skip the density estimation step and use the provided densities directly. See the parameter descriptions below for more details. .. _bandwidths.py: https://www.statsmodels.org/stable/_modules/statsmodels/nonparametric/bandwidths.html .. _Plotly's built-in color-scales: https://plotly.com/python/builtin-colorscales/ .. _ragged: https://en.wikipedia.org/wiki/Jagged_array Parameters ---------- samples : Samples or ShallowSamples If ``samples`` data is specified, either Kernel Density Estimation (KDE) or histogram binning will be performed to estimate the underlying densities. See :paramref:`.kernel`, :paramref:`.bandwidth`, and :paramref:`.kde_points` for more details on the different KDE parameters. See :paramref:`.nbins` for more details on histogram binning. The :paramref:`.sample_weights` parameter can be used for both KDE and histogram binning. The ``samples`` argument should be an array of shape :math:`(R, T_r, S_t)`. Note that we support irregular (`ragged`_) arrays, where: - :math:`R` is the number of rows in the plot - :math:`T_r` is the number of traces per row, where each row :math:`r \in R` can have a different number of traces. - :math:`S_t` is the number of samples per trace, where each trace :math:`t \in T_r` can also have a different number of samples. The density estimation step will be performed over the sample values (:math:`S_t`) for all traces. The resulting array will be a (4D) :paramref:`.densities` array of shape :math:`(R, T_r, P_t, 2)` (see :paramref:`.densities` below for more details). densities : Densities or ShallowDensities If a ``densities`` array is specified, the density estimation step will be skipped and all associated arguments ignored. Each density array should have shape :math:`(R, T_r, P_t, 2)` (4D). Just like the :paramref:`.samples` argument, we also support irregular (`ragged`_) ``densities`` arrays, where: - :math:`R` is the number of rows in the plot - :math:`T_r` is the number of traces per row, where each row :math:`r \in R` can have a different number of traces. - :math:`P_t` is the number of points per trace, where each trace :math:`t \in T_r` can also have a different number of points. - :math:`2` is the number of coordinates per point (x and y) See :paramref:`.samples` above for more details. trace_type : TraceTypesArray or ShallowTraceTypesArray or TraceType or None The type of trace to display. Choices are ``'area'`` or ``'bar'``. If a single value is passed, it will be used for all traces. If a list of values is passed, it should have the same shape as the samples array. If not specified (default), the traces will be displayed as area plots (``trace_type='area'``) unless histogram binning is used, in which case the traces will be displayed as bar plots (``trace_type='bar'``). .. versionadded:: 0.3.0 labels : LabelsArray or ShallowLabelsArray or None A collection of string labels for each trace. If not specified (default), the labels will be automatically generated as ``"Trace {i}"``, where ``i`` is the trace's index. If instead a collection of labels is specified, it should have the same shape as the samples array. row_labels : Collection[str] or None or False A collection of string labels for each row in the ridgeline plot. If specified, the length of this collection should match the number of rows in the plot (i.e., the :math:`R` dimension in the :paramref:`.samples` or :paramref:`.densities` parameter). If not specified (default), the row labels displayed on the y-axis will be automatically generated based on the :paramref:`.labels` argument. If set to ``False``, the row labels won't be displayed at all. .. versionadded:: 0.4.0 Added support for custom row labels, and replaced the deprecated :paramref:`.show_yticklabels` parameter. kernel : str The Kernel to be used during Kernel Density Estimation. The default is a Gaussian Kernel (``"gau"``). Choices are: - ``"biw"`` for biweight - ``"cos"`` for cosine - ``"epa"`` for Epanechnikov - ``"gau"`` for Gaussian. - ``"tri"`` for triangular - ``"triw"`` for triweight - ``"uni"`` for uniform bandwidth : KDEBandwidth The bandwidth to use during Kernel Density Estimation. The default is ``"normal_reference"``. Choices are: - ``"scott"`` - 1.059 * A * nobs ** (-1/5.), where A is ``min(std(x),IQR/1.34)`` - ``"silverman"`` - .9 * A * nobs ** (-1/5.), where A is ``min(std(x),IQR/1.34)`` - ``"normal_reference"`` - C * A * nobs ** (-1/5.), where C is calculated from the kernel. Equivalent (up to 2 dp) to the ``"scott"`` bandwidth for gaussian kernels. See `bandwidths.py`_. - If a float is given, its value is used as the bandwidth. - If a callable is given, it's return value is used. The callable should take exactly two arguments, i.e., ``fn(x, kern)``, and return a float, where: - ``x``: the clipped input data - ``kern``: the kernel instance used kde_points : KDEPoints This parameter controls the points at which KDE is computed. If an ``int`` value is passed (default=500), the densities will be evaluated at ``kde_points`` evenly spaced points between the min and max of each set of samples. Optionally, you can also pass a custom 1D numerical array, which will be used for all traces. nbins : int or None The number of bins to use when applying histogram binning. If not specified (default), KDE will be used instead of histogram binning. .. versionadded:: 0.3.0 sample_weights : SampleWeightsArray or ShallowSampleWeightsArray or SampleWeights or None An (optional) array of KDE weights corresponding to each sample. The weights should have the same shape as the samples array. If not specified (default), all samples will be weighted equally. norm : NormalisationOption or None The normalisation option to use when normalising the densities. If not specified (default), no normalisation will be applied and the densities will be used *as is*. The following normalisation options are available: - ``"probability"`` - normalise the densities by dividing each trace by its sum. - ``"percent"`` - same as ``"probability"``, but the normalised values are multiplied by 100. .. versionadded:: 0.2.0 colorscale : ColorScale or Collection[Color] or str or None A continuous color scale used to color the different traces in the ridgeline plot. It can be represented by a string name (e.g., ``"viridis"``), a :data:`~ridgeplot._types.ColorScale` object, or a list of valid :data:`~ridgeplot._colors.Color` objects. If a string name is provided, it must be one of the built-in color scales (see :func:`~ridgeplot.list_all_colorscale_names()` and `Plotly's built-in color-scales`_). If a list of colors is provided, it must be a list of valid CSS colors (e.g., ``["rgb(255, 0, 0)", "blue", "hsl(120, 100%, 50%)"]``). The list will ultimately be converted into a :data:`~ridgeplot._types.ColorScale` object, assuming the colors provided are evenly spaced. If not specified (default), the color scale will be inferred from current Plotly template. colormode : "fillgradient" or SolidColormode This parameter controls the logic used for the coloring of each ridgeline trace. The ``"fillgradient"`` mode (default) will fill each trace with a gradient using the specified :paramref:`.colorscale`. The gradient normalisation is done using the minimum and maximum x-values over all densities. All other modes provide different methods for calculating interpolation values from the specified :paramref:`.colorscale` (i.e., a float value between 0 and 1) for each trace. The interpolated color will be used to color each trace with a solid color. The available modes are: - ``"row-index"`` - uses the row's index. This is useful when the desired effect is to have the same color for all traces on the same row. e.g., if a ridgeplot has 3 rows of traces, then the color scale interpolation values will be ``[[0, ...], [0.5, ...], [1, ...]]``, respectively. - ``"trace-index"`` - uses the trace's index. e.g., if a ridgeplot has a total of 3 traces (across all rows), then the color scale interpolation values will be 0, 0.5, and 1, respectively, and regardless of each trace's row. - ``"trace-index-row-wise"`` - uses the row-wise trace index. This is similar to the ``"trace-index"`` mode, but the trace index is reset for each row. e.g., if a ridgeplot has a row with only one trace and another with two traces, then the color scale interpolation values will be ``[[0], [0, 1]]``, respectively. - ``"mean-minmax"`` - uses the min-max normalised (weighted) mean of each density to calculate the interpolation values. The normalisation min and max values are the *absolute* minimum and maximum x-values over all densities. This mode is useful when the desired effect is to have the color of each trace reflect the mean of the distribution, while also taking into account the distributions' spread. - ``"mean-means"`` - similar to the ``"mean-minmax"`` mode, but where the normalisation min and max values are the minimum and maximum *mean* x-values over all densities. This mode is useful when the desired effect is to have the color of each trace reflect the mean of the distribution, but without taking into account the entire variability of the distributions. .. versionchanged:: 0.2.0 The default value changed from ``"mean-minmax"`` to ``"fillgradient"``. color_discrete_map: dict or None A mapping from trace labels to specific colors. This parameter is useful when you want to have full manual control over the colors assigned to each trace. If specified, the assigned colors are determined by looking up the trace's label as a key in this dictionary. All labels must be present as keys in the dictionary. Note that this parameter overrides any value specified for :paramref:`.colorscale` and :paramref:`.colormode`. In this case, the color assigned to each trace will be a solid color, as specified in this dictionary. If not specified (default), the colors will be determined using the :paramref:`.colorscale` and :paramref:`.colormode` parameters. .. versionadded:: 0.5.0 opacity : float or None If None (default), this parameter will be ignored and the transparency values of the specified color-scale will remain untouched. Otherwise, if a float value is passed, it will be used to overwrite the opacity/transparency of the color-scale's colors. .. versionadded:: 0.2.0 Replaces the deprecated :paramref:`.coloralpha` parameter. line_color : Color or "fill-color" The color of the traces' lines. Any valid CSS color is allowed (default: ``"black"``). If the value is set to "fill-color", the line color will be the same as the fill color of the traces (see :paramref:`.colormode`). If ``colormode='fillgradient'``, the line color will be the mean color of the fill gradient (i.e., equivalent to the fill color when ``colormode='mean-minmax'``). .. versionadded:: 0.2.0 line_width : float or None The traces' line width (in px). If not specified (default), area plots will have a line width of 1.5 px, and bar plots will have a line width of 0.5 px. .. versionadded:: 0.2.0 Replaces the deprecated :paramref:`.linewidth` parameter. .. versionchanged:: 0.2.0 The default value changed from 1 to 1.5 spacing : float The vertical spacing between density traces, which is defined in units of the highest distribution (i.e., the maximum y-value). xpad : float Specifies the extra padding to use on the x-axis. It is defined in units of the range between the minimum and maximum x-values from all distributions. coloralpha : float .. deprecated:: 0.2.0 Use :paramref:`.opacity` instead. linewidth : float .. deprecated:: 0.2.0 Use :paramref:`.line_width` instead. show_yticklabels : bool .. deprecated:: 0.4.0 Use :paramref:`.row_labels` instead. Returns ------- :class:`plotly.graph_objects.Figure` A Plotly :class:`~plotly.graph_objects.Figure` with a ridgeline plot. You can further customize this figure to your liking (e.g. using the :meth:`~plotly.graph_objects.Figure.update_layout()` method). Raises ------ :exc:`ValueError` If both :paramref:`.samples` and :paramref:`.densities` are specified, or if neither of them is specified. i.e., you may only specify one of them. """ if trace_type is None: trace_type = "area" if nbins is None else "bar" densities = _coerce_to_densities( samples=samples, densities=densities, kernel=kernel, bandwidth=bandwidth, kde_points=kde_points, nbins=nbins, sample_weights=sample_weights, ) del samples, kernel, bandwidth, kde_points, nbins, sample_weights if norm: densities = normalise_densities(densities, norm=norm) if coloralpha is not MISSING: if opacity is not None: raise ValueError( "You may not specify both the 'coloralpha' and 'opacity' arguments! " "HINT: Use the new 'opacity' argument instead of the deprecated 'coloralpha'." ) warnings.warn( "The 'coloralpha' argument has been deprecated in favor of 'opacity'. " "Support for the deprecated argument will be removed in a future version.", DeprecationWarning, stacklevel=2, ) opacity = coloralpha if linewidth is not MISSING: if line_width is not None: raise ValueError( "You may not specify both the 'linewidth' and 'line_width' arguments! " "HINT: Use the new 'line_width' argument instead of the deprecated 'linewidth'." ) warnings.warn( "The 'linewidth' argument has been deprecated in favor of 'line_width'. " "Support for the deprecated argument will be removed in a future version.", DeprecationWarning, stacklevel=2, ) line_width = linewidth if show_yticklabels is not MISSING: if row_labels is not None: raise ValueError( "You may not specify both the 'show_yticklabels' and 'row_labels' arguments! " "HINT: Use the new 'row_labels' argument instead of the deprecated " "'show_yticklabels'." ) warnings.warn( "The 'show_yticklabels' argument has been deprecated in favor of 'row_labels'. " "Support for the deprecated argument will be removed in a future version.", DeprecationWarning, stacklevel=2, ) row_labels = False if not show_yticklabels else None if colorscale == "default": warnings.warn( "colorscale='default' is deprecated and support for it will be removed in a future " "version. Please use colorscale=px.colors.DEFAULT_PLOTLY_COLORS for the same effect. " "To list all supported colorscale names, please refer to Plotly's " "px.colors.named_colorscales(), or visit: " "https://plotly.com/python/builtin-colorscales/#named-builtin-continuous-color-scales", DeprecationWarning, stacklevel=2, ) del coloralpha, linewidth fig = create_ridgeplot( densities=densities, trace_labels=labels, trace_types=trace_type, row_labels=row_labels, legendgroup=legendgroup, colorscale=colorscale, colormode=colormode, color_discrete_map=color_discrete_map, opacity=opacity, line_color=line_color, line_width=line_width, spacing=spacing, xpad=xpad, ) return fig