Source code for coco_pipe.descriptors.extractors.complexity

"""
Complexity descriptor extraction backend.

This module implements the built-in complexity family for
`coco_pipe.descriptors`. The extractor operates on already segmented NumPy
inputs with shape ``(n_obs, n_channels, n_times)`` and computes one or more
complexity measures per sensor, per observation.

Notes
-----
The complexity family prefers batched backend calls when the selected library
supports them. In the current implementation:

- `spectral_entropy`, `hjorth_mobility`, and `hjorth_complexity` use batched
  `antropy` calls over flattened observation-channel units
- `sample_entropy`, `perm_entropy`, `approx_entropy`, `svd_entropy`,
  `petrosian_fd`, `katz_fd`, `higuchi_fd`, and `lziv_complexity` are still
  evaluated one 1D signal at a time
- `shannon_entropy`, `fuzzy_entropy`, `dispersion_entropy`, and
  `hurst_exponent` use scalar `neurokit2` calls
- `zero_crossings`, `kurtosis`, and `rms` are computed as simple scalar
  channelwise signal descriptors

Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""

from __future__ import annotations

from typing import Any

import numpy as np
from scipy.stats import kurtosis as scipy_kurtosis

from ...utils import import_optional_dependency
from ..configs import ComplexityDescriptorConfig
from .base import BaseDescriptorExtractor, _DescriptorBlock, make_failure_record

_ANTROPY_BATCHED_MEASURES = frozenset(
    {"spectral_entropy", "hjorth_mobility", "hjorth_complexity"}
)
_ANTROPY_SCALAR_MEASURES = frozenset(
    {
        "sample_entropy",
        "perm_entropy",
        "approx_entropy",
        "svd_entropy",
        "petrosian_fd",
        "katz_fd",
        "higuchi_fd",
        "lziv_complexity",
    }
)
_NEUROKIT_SCALAR_MEASURES = frozenset(
    {
        "sample_entropy",
        "perm_entropy",
        "spectral_entropy",
        "shannon_entropy",
        "fuzzy_entropy",
        "dispersion_entropy",
        "hurst_exponent",
    }
)
_CUSTOM_SCALAR_MEASURES = frozenset({"zero_crossings", "kurtosis", "rms"})


[docs] def _normalize_scalar_output(value: Any) -> float: """Normalize backend scalar outputs to one plain float.""" if isinstance(value, tuple): value = value[0] array = np.asarray(value, dtype=float) if array.size != 1: raise ValueError("Complexity backend returned a non-scalar result.") return float(array.reshape(-1)[0])
[docs] class ComplexityDescriptorExtractor(BaseDescriptorExtractor): """ Complexity descriptor extractor. This extractor computes scalar complexity measures for each observation and sensor in a validated descriptor input array. It is intended for signals that are already segmented upstream, such as epochs, windows, or trial blocks. Parameters ---------- config : ComplexityDescriptorConfig Parsed family configuration controlling the selected measures, backend, and any per-measure keyword arguments. Attributes ---------- config : ComplexityDescriptorConfig Stored typed configuration for the complexity family. family_name : str Stable family identifier used in metadata and failure records. Notes ----- The extractor always computes descriptor values per sensor first. Public deterministic sensor-level naming is applied afterward through :meth:`BaseDescriptorExtractor._finalize_descriptor`. When `backend="auto"` is selected, the extractor resolves each measure to the preferred available implementation: - `antropy` for the existing antropy-backed measures - `neurokit2` for measures that are only supported there - built-in NumPy/SciPy implementations for simple scalar signal summaries """ family_name = "complexity" def __init__(self, config: ComplexityDescriptorConfig): super().__init__(config) self.config = config @property def capabilities(self) -> dict[str, Any]: """Return static complexity extractor capability metadata. Returns ------- dict[str, Any] Capability metadata describing sampling-rate requirements and the optional backends used by the complexity family. Notes ----- `spectral_entropy` requires an explicit sampling rate, while the other currently supported measures do not. """ return { **super().capabilities, "requires_sfreq": "spectral_entropy" in self.config.measures, "optional_dependencies": ["antropy", "neurokit2"], }
[docs] def _load_antropy(self): """Import `antropy` lazily when the configured backend needs it. Returns ------- module Imported `antropy` module. Raises ------ ImportError If `antropy` is not installed. """ return import_optional_dependency( lambda: __import__("antropy"), feature="complexity descriptor extraction", dependency="antropy", install_hint="pip install coco-pipe[descriptors]", )
[docs] def _load_neurokit(self): """Import `neurokit2` lazily when the configured backend needs it. Returns ------- module Imported `neurokit2` module. Raises ------ ImportError If `neurokit2` is not installed. """ return import_optional_dependency( lambda: __import__("neurokit2"), feature="neurokit complexity descriptor extraction", dependency="neurokit2", install_hint="pip install coco-pipe[descriptors]", )
[docs] def extract( self, X: np.ndarray, sfreq: float | None, channel_names: list[str] | None, ids: np.ndarray | None, runtime, obs_offset: int = 0, ) -> _DescriptorBlock: """Extract complexity descriptors from segmented multi-channel data. Parameters ---------- X : np.ndarray Input array with shape ``(n_obs, n_channels, n_times)``. Each row already represents one observation segment produced upstream. sfreq : float, optional Sampling frequency in Hertz. Required when `spectral_entropy` is requested. channel_names : list of str, optional Explicit channel labels aligned with axis 1 of ``X``. If omitted, fallback names ``"ch-0"``, ``"ch-1"``, ... are used internally. ids : np.ndarray, optional Observation identifiers aligned with axis 0 of ``X``. runtime : DescriptorRuntimeConfig Runtime execution controls shared across descriptor families. obs_offset : int, default=0 Global observation offset added to any collected failure records when this extractor is called on one observation batch. Returns ------- _DescriptorBlock Complexity-family descriptor block aligned with the input observation axis. Raises ------ ImportError If the configured optional backend is unavailable. ValueError If a requested measure is unsupported by the selected backend, or if runtime error handling is configured to raise on a numerical or backend failure. Notes ----- The extractor uses a mixed execution strategy: - batched `antropy` calls for `spectral_entropy`, `hjorth_mobility`, and `hjorth_complexity` - scalar `antropy` calls for the remaining antropy-backed measures - scalar `neurokit2` calls for `shannon_entropy`, `fuzzy_entropy`, `dispersion_entropy`, and `hurst_exponent` Non-finite outputs are converted to `NaN` and recorded under ``failures`` unless `runtime.on_error == "raise"`, in which case the extractor fails immediately. Example ------- With ``channel_names=["Fz", "Cz"]``, a requested measure such as ``perm_entropy`` yields channel-resolved names like ``complexity_perm_entropy_ch-Fz`` and ``complexity_perm_entropy_ch-Cz``. """ channel_names = channel_names or [f"ch-{idx}" for idx in range(X.shape[1])] descriptor_names: list[str] | None = None failures: list[dict[str, Any]] = [] metric_arrays = { measure: np.full((X.shape[0], X.shape[1]), np.nan, dtype=float) for measure in self.config.measures } measure_kwargs = { measure: dict(self.config.measure_kwargs.get(measure, {})) for measure in self.config.measures } flat_signals = X.reshape(-1, X.shape[-1]) batched_outputs: dict[str, np.ndarray] = {} scalar_dispatch: dict[str, Any] = {} measure_backends: dict[str, str] = {} custom_scalar_dispatch = { "zero_crossings": lambda signal, kwargs, sfreq: float( np.count_nonzero(np.diff(np.signbit(np.asarray(signal, dtype=float)))) ), "kurtosis": lambda signal, kwargs, sfreq: float( scipy_kurtosis( np.asarray(signal, dtype=float), fisher=kwargs.get("fisher", True), bias=kwargs.get("bias", False), ) ), "rms": lambda signal, kwargs, sfreq: float( np.sqrt(np.mean(np.square(np.asarray(signal, dtype=float)))) ), } unsupported: list[str] = [] for measure in self.config.measures: if measure in _CUSTOM_SCALAR_MEASURES: measure_backends[measure] = "custom" continue if self.config.backend == "antropy": if ( measure in _ANTROPY_BATCHED_MEASURES or measure in _ANTROPY_SCALAR_MEASURES ): measure_backends[measure] = "antropy" else: unsupported.append(measure) continue if self.config.backend == "neurokit2": if measure in _NEUROKIT_SCALAR_MEASURES: measure_backends[measure] = "neurokit2" else: unsupported.append(measure) continue if ( measure in _ANTROPY_BATCHED_MEASURES or measure in _ANTROPY_SCALAR_MEASURES ): measure_backends[measure] = "antropy" elif measure in _NEUROKIT_SCALAR_MEASURES: measure_backends[measure] = "neurokit2" else: unsupported.append(measure) if unsupported: raise ValueError( f"Measures {sorted(unsupported)} are not supported by backend " f"'{self.config.backend}'." ) ant = None if "antropy" in measure_backends.values(): ant = self._load_antropy() if measure_backends.get("spectral_entropy") == "antropy": batched_outputs["spectral_entropy"] = np.asarray( ant.spectral_entropy( flat_signals, sf=sfreq, axis=-1, **measure_kwargs["spectral_entropy"], ), dtype=float, ) if ( measure_backends.get("hjorth_mobility") == "antropy" or measure_backends.get("hjorth_complexity") == "antropy" ): mobility, complexity = ant.hjorth_params( flat_signals, axis=-1, ) if measure_backends.get("hjorth_mobility") == "antropy": batched_outputs["hjorth_mobility"] = np.asarray( mobility, dtype=float, ) if measure_backends.get("hjorth_complexity") == "antropy": batched_outputs["hjorth_complexity"] = np.asarray( complexity, dtype=float, ) antropy_scalar_dispatch = { "sample_entropy": lambda signal, kwargs, sfreq: ( _normalize_scalar_output(ant.sample_entropy(signal, **kwargs)) ), "perm_entropy": lambda signal, kwargs, sfreq: _normalize_scalar_output( ant.perm_entropy(signal, **kwargs) ), "approx_entropy": lambda signal, kwargs, sfreq: ( _normalize_scalar_output(ant.app_entropy(signal, **kwargs)) ), "svd_entropy": lambda signal, kwargs, sfreq: _normalize_scalar_output( ant.svd_entropy(signal, **kwargs) ), "petrosian_fd": lambda signal, kwargs, sfreq: _normalize_scalar_output( ant.petrosian_fd(signal, **kwargs) ), "katz_fd": lambda signal, kwargs, sfreq: _normalize_scalar_output( ant.katz_fd(signal, **kwargs) ), "higuchi_fd": lambda signal, kwargs, sfreq: _normalize_scalar_output( ant.higuchi_fd(signal, **kwargs) ), "lziv_complexity": lambda signal, kwargs, sfreq: ( _normalize_scalar_output( ant.lziv_complexity( (signal > np.median(signal)).astype(int), **kwargs, ) ) ), } for measure, func in antropy_scalar_dispatch.items(): if measure_backends.get(measure) == "antropy": scalar_dispatch[measure] = func nk = None if "neurokit2" in measure_backends.values(): nk = self._load_neurokit() neurokit_scalar_dispatch = { "sample_entropy": lambda signal, kwargs, sfreq: ( _normalize_scalar_output(nk.entropy_sample(signal, **kwargs)) ), "perm_entropy": lambda signal, kwargs, sfreq: _normalize_scalar_output( nk.entropy_permutation(signal, **kwargs) ), "spectral_entropy": lambda signal, kwargs, sfreq: ( _normalize_scalar_output( nk.entropy_spectral( signal, sampling_rate=sfreq, **kwargs, ) ) ), "shannon_entropy": lambda signal, kwargs, sfreq: ( _normalize_scalar_output(nk.entropy_shannon(signal, **kwargs)) ), "fuzzy_entropy": lambda signal, kwargs, sfreq: _normalize_scalar_output( nk.entropy_fuzzy(signal, **kwargs) ), "dispersion_entropy": lambda signal, kwargs, sfreq: ( _normalize_scalar_output(nk.entropy_dispersion(signal, **kwargs)) ), "hurst_exponent": lambda signal, kwargs, sfreq: ( _normalize_scalar_output(nk.fractal_hurst(signal, **kwargs)) ), } for measure, func in neurokit_scalar_dispatch.items(): if measure_backends.get(measure) == "neurokit2": scalar_dispatch[measure] = func for measure, func in custom_scalar_dispatch.items(): if measure_backends.get(measure) == "custom": scalar_dispatch[measure] = func for measure, flat_values in batched_outputs.items(): values = np.asarray(flat_values, dtype=float).reshape( X.shape[0], X.shape[1], ) metric_arrays[measure][:] = np.where(np.isfinite(values), values, np.nan) bad_positions = np.argwhere(~np.isfinite(values)) if bad_positions.size == 0: continue message = f"Complexity measure '{measure}' produced a non-finite result." if runtime.on_error == "raise": raise ValueError(message) for obs_rel, unit_idx in bad_positions: failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + int(obs_rel), obs_id=None if ids is None else ids[int(obs_rel)], channel_index=int(unit_idx), channel_name=channel_names[int(unit_idx)], exception_type="NumericalIssue", message=message, ) ) scalar_measures = [ measure for measure in self.config.measures if measure not in batched_outputs ] for obs_rel in range(X.shape[0]): unit_signals = X[obs_rel] obs_id = None if ids is None else ids[obs_rel] for unit_idx, signal in enumerate(unit_signals): for measure in scalar_measures: try: value = scalar_dispatch[measure]( signal, measure_kwargs[measure], sfreq, ) if np.isfinite(value): metric_arrays[measure][obs_rel, unit_idx] = float(value) else: if runtime.on_error == "raise": raise ValueError( "Complexity measure produced a non-finite result." ) failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + obs_rel, obs_id=obs_id, channel_index=unit_idx, channel_name=channel_names[unit_idx], exception_type="NumericalIssue", message=( "Complexity measure " f"'{measure}' produced a non-finite result." ), ) ) except Exception as exc: # pragma: no cover - hit via failure tests if isinstance(exc, ImportError): raise if runtime.on_error == "raise": raise failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + obs_rel, obs_id=obs_id, channel_index=unit_idx, channel_name=channel_names[unit_idx], exception_type=type(exc).__name__, message=str(exc), ) ) chunk_features: list[np.ndarray] = [] chunk_names: list[str] = [] for measure in self.config.measures: feature, names = self._finalize_descriptor( metric_arrays[measure], family_prefix="complexity", metric_name=measure, channel_names=channel_names, ) chunk_features.append(feature) chunk_names.extend(names) descriptor_names = chunk_names return _DescriptorBlock( family=self.family_name, X=np.concatenate(chunk_features, axis=1) if chunk_features else np.empty((X.shape[0], 0)), descriptor_names=descriptor_names or [], meta={ "backend": self.config.backend, "measures": list(self.config.measures), "batched_measures": sorted(batched_outputs), "measure_backends": dict(measure_backends), }, failures=failures, )