Source code for coco_pipe.descriptors.extractors.parametric
"""
Parametric spectral descriptor extraction backend.
This module implements the built-in parametric spectral family for
`coco_pipe.descriptors`. The extractor operates on already segmented NumPy
inputs with shape ``(n_obs, n_channels, n_times)`` and computes one or more
specparam-derived summary descriptors per sensor, per observation.
Notes
-----
The parametric family is a PSD consumer. When used through
`DescriptorPipeline.extract()`, it can share one batch-scoped PSD computation
with other compatible PSD consumers such as the spectral band family. The
actual descriptor outputs are then derived from that shared `psds, freqs` pair.
Model fitting itself still happens one spectrum at a time. When runtime
parallelism is enabled and the planner allows it, those per-spectrum fits can
run in parallel across observation-channel units.
Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""
from __future__ import annotations
from typing import Any
import numpy as np
from ..configs import ParametricDescriptorConfig
from ._parametric_fit import _ParametricFitBatch, fit_parametric_batch
from ._psd import compute_psd
from .base import BasePSDDescriptorExtractor, _DescriptorBlock, make_failure_record
[docs]
class ParametricDescriptorExtractor(BasePSDDescriptorExtractor):
"""
Parametric spectral descriptor extractor.
This extractor fits one specparam model per observation and sensor in a
validated descriptor input array, then exposes scalar summaries such as
aperiodic parameters, fit quality, and dominant peak statistics.
Parameters
----------
config : ParametricDescriptorConfig
Parsed family configuration controlling the PSD method, fit range,
specparam settings, and requested output groups.
Attributes
----------
config : ParametricDescriptorConfig
Stored typed configuration for the parametric family.
family_name : str
Stable family identifier used in metadata and failure records.
Notes
-----
The extractor always computes descriptor values per sensor first. Public
sensor-level naming is applied afterward through
:meth:`BaseDescriptorExtractor._finalize_descriptor`.
When the pipeline provides a precomputed PSD batch through
:meth:`extract_psd`, the extractor reuses that shared spectral input
and expects an explicit shared `fit_batch`. Standalone :meth:`extract`
remains available for family-local PSD and fit computation.
"""
family_name = "parametric"
def __init__(self, config: ParametricDescriptorConfig):
super().__init__(config)
self.config = config
@property
def capabilities(self) -> dict[str, Any]:
"""Return static parametric extractor capability metadata.
Returns
-------
dict[str, Any]
Capability metadata describing sampling-rate requirements and the
optional backends used by the parametric family.
"""
return {
**super().capabilities,
"requires_sfreq": True,
"optional_dependencies": ["specparam", "mne"],
}
[docs]
def psd_request(self) -> dict[str, Any]:
"""Describe the PSD requirements for the shared planner."""
return {
"method": self.config.psd_method,
"fmin": self.config.freq_range[0],
"fmax": self.config.freq_range[1],
}
[docs]
def parametric_fit_requirements(self) -> dict[str, Any]:
"""Describe whether this family needs shared parametric-fit outputs."""
return {
"needed": True,
"metrics": True,
"periodic_psds": False,
"config": self.config,
}
[docs]
def extract_psd(
self,
psds: np.ndarray,
freqs: np.ndarray,
channel_names: list[str] | None,
ids: np.ndarray | None,
runtime,
obs_offset: int = 0,
fit_batch: _ParametricFitBatch | None = None,
) -> _DescriptorBlock:
"""Extract parametric descriptors from a precomputed PSD batch.
Parameters
----------
psds : np.ndarray
Power spectral density array with shape
``(n_obs, n_channels, n_freqs)``.
freqs : np.ndarray
Frequency grid aligned with the last axis of ``psds``.
channel_names : list of str, optional
Explicit channel labels aligned with axis 1 of ``psds``. If
omitted, fallback names ``"ch-0"``, ``"ch-1"``, ... are used
internally.
ids : np.ndarray, optional
Observation identifiers aligned with axis 0 of ``psds``.
runtime : DescriptorRuntimeConfig
Runtime execution controls shared across descriptor families.
obs_offset : int, default=0
Global observation offset added to any collected failure records
when this extractor is called on one observation batch.
Returns
-------
_DescriptorBlock
Parametric-family descriptor block aligned with the input
observation axis.
Raises
------
ValueError
If ``fit_batch`` is not supplied.
Notes
-----
This method consumes explicit shared intermediates. It does not compute
PSDs or fit models on its own.
"""
channel_names = channel_names or [f"ch-{idx}" for idx in range(psds.shape[1])]
if fit_batch is None:
raise ValueError("Parametric extract_psd() requires a supplied fit_batch.")
chunk_metric_arrays = fit_batch.metrics
metrics: list[str] = []
if "aperiodic" in self.config.outputs:
metrics.extend(["offset", "exponent"])
if "knee" in chunk_metric_arrays:
metrics.append("knee")
if "fit_quality" in self.config.outputs:
metrics.extend(["fit_error", "r_squared"])
if "peak_summary" in self.config.outputs:
metrics.extend(
[
"peak_count",
"peak_freq_dom",
"peak_power_dom",
"peak_bandwidth_dom",
"alpha_peak_freq",
"alpha_peak_power",
]
)
failures: list[dict[str, Any]] = []
for obs_rel, unit_idx, exception_type, message in fit_batch.errors:
if runtime.on_error == "raise":
raise RuntimeError(message)
failures.append(
make_failure_record(
family=self.family_name,
obs_index=obs_offset + obs_rel,
obs_id=None if ids is None else ids[obs_rel],
channel_index=unit_idx,
channel_name=channel_names[unit_idx],
exception_type=exception_type,
message=message,
)
)
chunk_features: list[np.ndarray] = []
descriptor_names: list[str] = []
for metric_name in metrics:
feature, names = self._finalize_descriptor(
chunk_metric_arrays[metric_name],
family_prefix="param",
metric_name=metric_name,
channel_names=channel_names,
)
chunk_features.append(feature)
descriptor_names.extend(names)
return _DescriptorBlock(
family=self.family_name,
X=np.concatenate(chunk_features, axis=1)
if chunk_features
else np.empty((psds.shape[0], 0)),
descriptor_names=descriptor_names,
meta={
**fit_batch.meta,
"psd_method": self.config.psd_method,
},
failures=failures,
)
[docs]
def extract(
self,
X: np.ndarray,
sfreq: float | None,
channel_names: list[str] | None,
ids: np.ndarray | None,
runtime,
obs_offset: int = 0,
) -> _DescriptorBlock:
"""Extract parametric descriptors from segmented multi-channel data.
Parameters
----------
X : np.ndarray
Input array with shape ``(n_obs, n_channels, n_times)``. Each row
already represents one observation segment produced upstream.
sfreq : float, optional
Sampling frequency in Hertz.
channel_names : list of str, optional
Explicit channel labels aligned with axis 1 of ``X``.
ids : np.ndarray, optional
Observation identifiers aligned with axis 0 of ``X``.
runtime : DescriptorRuntimeConfig
Runtime execution controls shared across descriptor families.
obs_offset : int, default=0
Global observation offset added to any collected failure records.
Returns
-------
_DescriptorBlock
Parametric-family descriptor block aligned with the input
observation axis.
Raises
------
ImportError
If the optional `mne` or `specparam` backend is unavailable.
ValueError
If PSD computation encounters an invalid runtime condition.
RuntimeError
If shared fit materialization encounters a runtime failure and
``runtime.on_error == "raise"``.
Notes
-----
This standalone path computes a PSD for the current batch, fits one
explicit parametric batch payload for this family, and then delegates
to :meth:`extract_psd`. When the family is executed through
`DescriptorPipeline`, the shared planner supplies the PSD and fit
payload instead.
"""
psds, freqs = compute_psd(
X,
sfreq=sfreq,
method=self.config.psd_method,
fmin=self.config.freq_range[0],
fmax=self.config.freq_range[1],
n_jobs=None,
)
fit_batch = fit_parametric_batch(
psds,
freqs,
self.config,
runtime,
need_periodic_psd=False,
include_metrics=True,
)
return self.extract_psd(
psds,
freqs,
channel_names=channel_names,
ids=ids,
runtime=runtime,
obs_offset=obs_offset,
fit_batch=fit_batch,
)