Source code for ridgeplot._kde

"""Kernel density estimation (KDE) utilities."""

from __future__ import annotations

from collections.abc import Callable, Collection
from functools import partial
from typing import TYPE_CHECKING, TypeAlias, cast

import numpy as np
import numpy.typing as npt
import statsmodels.api as sm
from statsmodels.sandbox.nonparametric.kernels import CustomKernel as StatsmodelsKernel
from typing_extensions import Any, TypeIs

from ridgeplot._types import (
    CollectionL1,
    DensityTrace,
    Numeric,
    SampleWeights,
    SampleWeightsArray,
    ShallowSampleWeightsArray,
    is_flat_numeric_collection,
    nest_shallow_collection,
)
from ridgeplot._utils import normalise_row_attrs

if TYPE_CHECKING:
    from ridgeplot._types import Densities, Samples, SamplesTrace


KDEPoints: TypeAlias = int | CollectionL1[Numeric]
"""The :paramref:`ridgeplot.ridgeplot.kde_points` parameter."""

KDEBandwidth: TypeAlias = str | float | Callable[[CollectionL1[Numeric], StatsmodelsKernel], float]
"""The :paramref:`ridgeplot.ridgeplot.bandwidth` parameter."""


[docs] def _is_sample_weights(obj: Any) -> TypeIs[SampleWeights]: """Type guard for :data:`SampleWeights`. Examples -------- >>> _is_sample_weights("definitely not") False >>> _is_sample_weights([1, 2, 3.14]) True >>> _is_sample_weights([1, 2, "3"]) False >>> _is_sample_weights(None) True """ return obj is None or is_flat_numeric_collection(obj)
[docs] def _is_shallow_sample_weights(obj: Any) -> TypeIs[ShallowSampleWeightsArray]: """Type guard for :data:`ShallowSampleWeightsArray`. Examples -------- >>> _is_shallow_sample_weights("definitely not") False >>> _is_shallow_sample_weights([1, 2, 3]) False >>> _is_shallow_sample_weights([[1, 2, 3], [4, 5, 6]]) True >>> _is_shallow_sample_weights([[1, 2, "3"], [4, 5, None]]) False >>> _is_shallow_sample_weights([[1, 2, 3], None]) True """ return isinstance(obj, Collection) and all(map(_is_sample_weights, obj))
[docs] def normalize_sample_weights( sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights, samples: Samples, ) -> SampleWeightsArray: """Normalize the sample weights to the correct shape. Examples -------- >>> samples = [[[1, 2], [3, 4]], [[5, 6]]] >>> normalize_sample_weights(None, samples) [[None, None], [None]] >>> normalize_sample_weights([8, 9], samples) [[[8, 9], [8, 9]], [[8, 9]]] >>> weights = [[[0, 1], None], [[2, 3]]] >>> normalize_sample_weights(weights, samples) == weights True >>> normalize_sample_weights([None, [0, 1]], samples) [[None, None], [[0, 1]]] """ if _is_sample_weights(sample_weights): return [[sample_weights] * len(row) for row in samples] if _is_shallow_sample_weights(sample_weights): sample_weights = nest_shallow_collection(sample_weights) sample_weights = normalise_row_attrs(attrs=sample_weights, l2_target=samples) return sample_weights
[docs] def estimate_density_trace( trace_samples: SamplesTrace, points: KDEPoints, kernel: str, bandwidth: KDEBandwidth, weights: SampleWeights = None, ) -> DensityTrace: """Estimates a density trace from a set of samples. For a given set of sample values, computes the kernel densities (KDE) at the given points. """ trace_samples = np.asarray(trace_samples, dtype=float) if not np.isfinite(trace_samples).all(): raise ValueError("The samples array should not contain any infs or NaNs.") if isinstance(points, int): # By default, we'll use a 'hard' KDE span. That is, we'll # evaluate the densities and N equally spaced points # over the range [min(samples), max(samples)] density_x = np.linspace( start=min(trace_samples), stop=max(trace_samples), num=points, ) else: # Unless a specific range is specified... density_x = np.asarray(points) if density_x.ndim > 1: raise ValueError( f"The 'points' at which KDE is computed should be represented by a " f"one-dimensional array, got an array of shape {density_x.shape} instead." ) if weights is not None: weights = np.asarray(weights, dtype=float) if len(weights) != len(trace_samples): raise ValueError("The weights array should have the same length as the samples array.") if not np.isfinite(weights).all(): raise ValueError("The weights array should not contain any infs or NaNs.") # ref: https://github.com/tpvasconcelos/ridgeplot/issues/116 dens = sm.nonparametric.KDEUnivariate(trace_samples) # I'm hard-coding `fft=kernel == "gau" and weights is not None` # to avoid exposing yet another KDE parameter in ridgeplot() # If we ever find any issues with this heuristic, I would # prefer just leaving `fft=False` here and *not* expose # this parameter to the user. If the user wants more # control over the KDE estimation, they can always # implement their own logic and pass `densities` # directly to the ridgeplot() figure factory. dens.fit( kernel=kernel, fft=kernel == "gau" and weights is None, bw=bandwidth, # pyright: ignore[reportArgumentType] weights=weights, ) density_y = dens.evaluate(density_x) density_y = _validate_densities(x=density_x, y=density_y, kernel=kernel) return list(zip(density_x, density_y))
[docs] def _validate_densities( x: npt.NDArray[np.floating[Any]], y: Any, kernel: str, ) -> npt.NDArray[np.floating[Any]]: # I haven't investigated the root of this issue yet # but statsmodels' KDEUnivariate implementation # can return a float('NaN') if something goes # wrong internally. As to avoid confusion # further down the pipeline, I decided # to check whether the correct object # (and shape) are being returned. msg = ( f"statsmodels failed to evaluate densities using the {kernel!r} kernel. " "Try setting kernel='gau' (the default kernel)." if kernel != "gau" else "" ) if not isinstance(y, np.ndarray): # Fail early if the return type is incorrect # Otherwise, the remaining checks will fail raise RuntimeError(msg) # noqa: TRY004 y = cast("npt.NDArray[np.floating[Any]]", y) wrong_shape = y.shape != x.shape not_finite = ~np.isfinite(y).all() if wrong_shape or not_finite: raise RuntimeError(msg) return y
[docs] def estimate_densities( samples: Samples, points: KDEPoints, kernel: str, bandwidth: KDEBandwidth, sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights = None, ) -> Densities: """Perform KDE for a set of samples.""" normalised_weights = normalize_sample_weights(sample_weights=sample_weights, samples=samples) kde = partial(estimate_density_trace, points=points, kernel=kernel, bandwidth=bandwidth) return [ [ kde(samples_trace, weights=weights) for samples_trace, weights in zip(samples_row, weights_row, strict=True) ] for samples_row, weights_row in zip(samples, normalised_weights, strict=True) ]