Source code for coco_pipe.descriptors.extractors.spectral

"""
Band summary descriptor extraction backend.

This module implements the built-in spectral band family for
`coco_pipe.descriptors`. The extractor operates on already segmented NumPy
inputs with shape ``(n_obs, n_channels, n_times)`` and computes PSD-derived
band summaries per sensor, per observation.

Notes
-----
The spectral family is a PSD consumer. When used through
`DescriptorPipeline.extract()`, it can share one batch-scoped PSD computation
with other compatible PSD consumers such as the parametric family. The actual
descriptor outputs are then derived from that shared `psds, freqs` pair.

Within one extracted PSD batch, the family computes band integrals once and
reuses them for all requested outputs:

- absolute power
- log absolute power
- relative power
- band ratios
- corrected absolute power
- corrected log absolute power
- corrected relative power
- corrected band ratios

Corrected outputs are derived from periodic-only PSDs produced by a shared
parametric fit batch. They are therefore only available through the shared
planner path or an explicit ``fit_batch`` passed to :meth:`extract_psd`.

Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""

from __future__ import annotations

from typing import Any

import numpy as np

from ..configs import BandDescriptorConfig
from ._parametric_fit import _ParametricFitBatch
from ._psd import compute_psd
from .base import BasePSDDescriptorExtractor, _DescriptorBlock, make_failure_record


[docs] class BandDescriptorExtractor(BasePSDDescriptorExtractor): """ Spectral band descriptor extractor. This extractor computes PSD-derived band summaries for each observation and sensor in a validated descriptor input array. It is intended for signals that are already segmented upstream, such as epochs, windows, or trial blocks. Parameters ---------- config : BandDescriptorConfig Parsed family configuration controlling the PSD method, frequency window, band definitions, and requested spectral outputs. Attributes ---------- config : BandDescriptorConfig Stored typed configuration for the spectral band family. family_name : str Stable family identifier used in metadata and failure records. Notes ----- The extractor always computes descriptor values per sensor first. Public sensor-level naming is applied afterward through :meth:`BaseDescriptorExtractor._finalize_descriptor`. When the pipeline provides a precomputed PSD batch through :meth:`extract_psd`, the extractor reuses that shared spectral input instead of computing its own PSD. Corrected spectral outputs additionally require a shared parametric fit batch and are therefore only available through the shared planner path or an explicit `fit_batch`. """ family_name = "bands" def __init__(self, config: BandDescriptorConfig, fit_config=None): super().__init__(config) self.config = config self.fit_config = fit_config @property def capabilities(self) -> dict[str, Any]: """Return static spectral extractor capability metadata. Returns ------- dict[str, Any] Capability metadata describing sampling-rate requirements and the optional backend used by the spectral family. Notes ----- Spectral band extraction always requires an explicit sampling rate because the PSD frequency axis depends on it. """ return { **super().capabilities, "requires_sfreq": True, "optional_dependencies": ["mne"], }
[docs] def psd_request(self) -> dict[str, Any]: """Describe the PSD requirements for the shared planner. Returns ------- dict[str, Any] Minimal PSD request containing the PSD method and the required frequency range for this family. Notes ----- `DescriptorPipeline` uses this request to group compatible PSD consumers and decide when one batch-scoped PSD can be reused across families. """ if self.needs_parametric_fit(): if self.fit_config is None: raise ValueError( "Corrected band outputs require parametric fit settings." ) fit_low, fit_high = self.fit_config.freq_range return { "method": self.config.psd_method, "fmin": min(self.config.fmin, fit_low), "fmax": max(self.config.fmax, fit_high), } return { "method": self.config.psd_method, "fmin": self.config.fmin, "fmax": self.config.fmax, }
[docs] def needs_parametric_fit(self) -> bool: """Whether corrected spectral outputs require a shared parametric fit.""" return any( output in { "corrected_absolute_power", "corrected_log_absolute_power", "corrected_relative_power", "corrected_ratios", } for output in self.config.outputs )
[docs] def parametric_fit_requirements(self) -> dict[str, Any]: """Describe whether this family needs shared parametric-fit outputs.""" return { "needed": self.needs_parametric_fit(), "metrics": False, "periodic_psds": self.needs_parametric_fit(), "config": self.fit_config, }
[docs] def extract_psd( self, psds: np.ndarray, freqs: np.ndarray, channel_names: list[str] | None, ids: np.ndarray | None, runtime, obs_offset: int = 0, fit_batch: _ParametricFitBatch | None = None, ) -> _DescriptorBlock: """Extract band descriptors from a precomputed PSD batch. Parameters ---------- psds : np.ndarray Power spectral density array with shape ``(n_obs, n_channels, n_freqs)``. freqs : np.ndarray Frequency grid aligned with the last axis of ``psds``. channel_names : list of str, optional Explicit channel labels aligned with axis 1 of ``psds``. If omitted, fallback names ``"ch-0"``, ``"ch-1"``, ... are used internally. ids : np.ndarray, optional Observation identifiers aligned with axis 0 of ``psds``. runtime : DescriptorRuntimeConfig Runtime execution controls shared across descriptor families. obs_offset : int, default=0 Global observation offset added to any collected failure records when this extractor is called on one observation batch. Returns ------- _DescriptorBlock Spectral-family descriptor block aligned with the input observation axis. Raises ------ ValueError If a configured band has no overlap with the computed PSD range and runtime error handling is configured to raise. Also raised when corrected outputs are requested without a supplied parametric ``fit_batch``. Notes ----- The extractor first restricts the incoming PSD to the configured frequency window, then integrates one power value per configured band and sensor. Those band integrals are reused for all enabled outputs, such as absolute power, log absolute power, relative power, and ratios. Ratios are always derived from absolute band powers, not from relative or log-transformed outputs. Example ------- With ``channel_names=["Fz", "Cz"]``, an absolute alpha-band request yields names such as ``band_abs_alpha_ch-Fz`` and ``band_abs_alpha_ch-Cz``. """ channel_names = channel_names or [f"ch-{idx}" for idx in range(psds.shape[1])] eps = np.finfo(float).eps corrected_failed_pairs: set[tuple[int, int]] = set() freq_mask = (freqs >= self.config.fmin) & (freqs <= self.config.fmax) local_freqs = freqs[freq_mask] local_psds = psds[..., freq_mask] total_power = None if "relative_power" in self.config.outputs: if local_freqs.size == 0: total_power = np.full(psds.shape[:-1], np.nan, dtype=float) else: total_power = np.trapezoid(local_psds, local_freqs, axis=-1) band_power: dict[str, np.ndarray] = {} missing_bands: set[str] = set() failures: list[dict[str, Any]] = [] descriptor_names: list[str] = [] chunk_features: list[np.ndarray] = [] def integrate_band_power( spectra: np.ndarray, band_freqs: np.ndarray, band_label_prefix: str, ) -> tuple[dict[str, np.ndarray], set[str]]: computed_band_power: dict[str, np.ndarray] = {} computed_missing_bands: set[str] = set() range_label = ( "computed PSD range" if band_label_prefix == "Raw" else "parametric fit range" ) for band_name, (low, high) in self.config.bands.items(): mask = (band_freqs >= low) & (band_freqs <= high) if not np.any(mask): message = ( f"{band_label_prefix} band '{band_name}' does not overlap " f"the {range_label}." ) if runtime.on_error == "raise": raise ValueError(message) computed_missing_bands.add(band_name) computed_band_power[band_name] = np.full( spectra.shape[:-1], np.nan, dtype=float, ) for obs_rel, ch_idx in np.argwhere( ~np.isfinite(computed_band_power[band_name]) ): failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + int(obs_rel), obs_id=None if ids is None else ids[obs_rel], channel_index=int(ch_idx), channel_name=channel_names[ch_idx], exception_type="BandResolutionError", message=message, ) ) continue computed_band_power[band_name] = np.trapezoid( spectra[..., mask], band_freqs[mask], axis=-1, ) return computed_band_power, computed_missing_bands def append_band_outputs( band_power_dict: dict[str, np.ndarray], total_power_array: np.ndarray | None, missing_band_names: set[str], output_prefix: str | None, enabled_absolute_output: str, enabled_log_output: str, enabled_relative_output: str, enabled_ratio_output: str, relative_message_prefix: str, ratio_message_prefix: str, failed_pairs_to_skip: set[tuple[int, int]] | None = None, ) -> None: metric_prefix = [] if output_prefix is None else [output_prefix] denominator_floor = self.config.min_denominator_power if enabled_absolute_output in self.config.outputs: for band_name, values in band_power_dict.items(): feature, names = self._finalize_descriptor( values, family_prefix="band", metric_name="_".join(metric_prefix + ["abs", band_name]), channel_names=channel_names, ) chunk_features.append(feature) descriptor_names.extend(names) if enabled_log_output in self.config.outputs: for band_name, values in band_power_dict.items(): log_values = np.log10(np.clip(values, eps, None)) feature, names = self._finalize_descriptor( log_values, family_prefix="band", metric_name="_".join(metric_prefix + ["log", "abs", band_name]), channel_names=channel_names, ) chunk_features.append(feature) descriptor_names.extend(names) if enabled_relative_output in self.config.outputs: for band_name, values in band_power_dict.items(): relative = np.divide( values, total_power_array, out=np.full_like(values, np.nan, dtype=float), where=total_power_array > denominator_floor, ) if band_name not in missing_band_names: for obs_rel, ch_idx in np.argwhere(~np.isfinite(relative)): if ( failed_pairs_to_skip and ( int(obs_rel), int(ch_idx), ) in failed_pairs_to_skip ): continue failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + int(obs_rel), obs_id=None if ids is None else ids[obs_rel], channel_index=int(ch_idx), channel_name=channel_names[ch_idx], exception_type="NumericalIssue", message=( f"{relative_message_prefix} for band " f"'{band_name}' became non-finite." ), ) ) feature, names = self._finalize_descriptor( relative, family_prefix="band", metric_name="_".join(metric_prefix + ["rel", band_name]), channel_names=channel_names, ) chunk_features.append(feature) descriptor_names.extend(names) if enabled_ratio_output in self.config.outputs: for numerator, denominator in self.config.ratio_pairs: ratio = np.divide( band_power_dict[numerator], band_power_dict[denominator], out=np.full_like( band_power_dict[numerator], np.nan, dtype=float, ), where=band_power_dict[denominator] > denominator_floor, ) if ( numerator not in missing_band_names and denominator not in missing_band_names ): for obs_rel, ch_idx in np.argwhere(~np.isfinite(ratio)): if ( failed_pairs_to_skip and ( int(obs_rel), int(ch_idx), ) in failed_pairs_to_skip ): continue failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + int(obs_rel), obs_id=None if ids is None else ids[obs_rel], channel_index=int(ch_idx), channel_name=channel_names[ch_idx], exception_type="NumericalIssue", message=( f"{ratio_message_prefix} " f"'{numerator}/{denominator}' " "became non-finite." ), ) ) feature, names = self._finalize_descriptor( ratio, family_prefix="band", metric_name="_".join( metric_prefix + ["ratio", numerator, denominator] ), channel_names=channel_names, ) chunk_features.append(feature) descriptor_names.extend(names) band_power, missing_bands = integrate_band_power(local_psds, local_freqs, "Raw") corrected_band_power: dict[str, np.ndarray] = {} corrected_missing_bands: set[str] = set() corrected_total_power = None corrected_outputs_requested = self.needs_parametric_fit() if corrected_outputs_requested: if fit_batch is None: raise ValueError( "Corrected band outputs require a supplied parametric " "fit_batch in extract_psd()." ) for obs_rel, ch_idx, exception_type, message in fit_batch.errors: corrected_failed_pairs.add((obs_rel, ch_idx)) failures.append( make_failure_record( family=self.family_name, obs_index=obs_offset + obs_rel, obs_id=None if ids is None else ids[obs_rel], channel_index=ch_idx, channel_name=channel_names[ch_idx], exception_type=exception_type, message=f"Corrected band estimation unavailable: {message}", ) ) corrected_freq_mask = (fit_batch.freqs >= self.config.fmin) & ( fit_batch.freqs <= self.config.fmax ) corrected_freqs = fit_batch.freqs[corrected_freq_mask] corrected_psds = fit_batch.periodic_psds[..., corrected_freq_mask] if "corrected_relative_power" in self.config.outputs: if corrected_freqs.size == 0: corrected_total_power = np.full( psds.shape[:-1], np.nan, dtype=float, ) else: corrected_total_power = np.trapezoid( corrected_psds, corrected_freqs, axis=-1, ) corrected_band_power, corrected_missing_bands = integrate_band_power( corrected_psds, corrected_freqs, "Corrected", ) append_band_outputs( band_power, total_power, missing_bands, None, "absolute_power", "log_absolute_power", "relative_power", "ratios", "Relative power", "Band ratio", ) append_band_outputs( corrected_band_power, corrected_total_power, corrected_missing_bands, "corr", "corrected_absolute_power", "corrected_log_absolute_power", "corrected_relative_power", "corrected_ratios", "Corrected relative power", "Corrected band ratio", corrected_failed_pairs, ) if chunk_features: X_out = np.concatenate(chunk_features, axis=1) else: X_out = np.empty((psds.shape[0], 0), dtype=float) return _DescriptorBlock( family=self.family_name, X=X_out, descriptor_names=descriptor_names, meta={ "psd_method": self.config.psd_method, "bands": self.config.bands, "freq_range": [self.config.fmin, self.config.fmax], "n_freqs": int(local_freqs.size), "corrected_outputs": [ output for output in self.config.outputs if output.startswith("corrected_") ], }, failures=failures, )
[docs] def extract( self, X: np.ndarray, sfreq: float | None, channel_names: list[str] | None, ids: np.ndarray | None, runtime, obs_offset: int = 0, ) -> _DescriptorBlock: """Extract band descriptors from segmented multi-channel data. Parameters ---------- X : np.ndarray Input array with shape ``(n_obs, n_channels, n_times)``. Each row already represents one observation segment produced upstream. sfreq : float, optional Sampling frequency in Hertz. channel_names : list of str, optional Explicit channel labels aligned with axis 1 of ``X``. ids : np.ndarray, optional Observation identifiers aligned with axis 0 of ``X``. runtime : DescriptorRuntimeConfig Runtime execution controls shared across descriptor families. obs_offset : int, default=0 Global observation offset added to any collected failure records. Returns ------- _DescriptorBlock Spectral-family descriptor block aligned with the input observation axis. Notes ----- This is the standalone extraction path for raw spectral outputs. It computes a PSD for the provided batch and then delegates to :meth:`extract_psd`. Corrected spectral outputs are not supported here because they depend on an explicit shared parametric fit batch. When the family is executed through `DescriptorPipeline`, the shared planner provides that batch automatically. """ if self.needs_parametric_fit(): raise ValueError( "Corrected band outputs are only available through " "DescriptorPipeline or extract_psd(..., fit_batch=...)." ) psds, freqs = compute_psd( X, sfreq=sfreq, method=self.config.psd_method, fmin=self.config.fmin, fmax=self.config.fmax, n_jobs=None, ) return self.extract_psd( psds, freqs, channel_names=channel_names, ids=ids, runtime=runtime, obs_offset=obs_offset, )