Source code for coco_pipe.descriptors.extractors.base
"""
Base interfaces for descriptor extraction backends.
This module defines the internal contracts shared by built-in descriptor
extractors. The module exposes:
- `BaseDescriptorExtractor` for families that consume validated raw signal
batches
- `BasePSDDescriptorExtractor` for families that consume shared PSD batches
- `_DescriptorBlock` as the private family output payload
- `make_failure_record` as the shared normalized failure-record helper
The surrounding descriptors stack uses these interfaces to provide:
- explicit runtime dispatch from `DescriptorPipeline`
- deterministic sensor-level descriptor naming
- family-wise metadata and failure collection
- safe merging of family outputs into one stable result dictionary
Notes
-----
`BaseDescriptorExtractor` is an internal extension point for descriptor
families. Unlike dim-reduction reducers, descriptor extractors are stateless at
runtime and do not expose `fit`, persistence, or model objects.
Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from ..configs import DescriptorRuntimeConfig
__all__ = [
"BaseDescriptorExtractor",
"BasePSDDescriptorExtractor",
"make_failure_record",
]
@dataclass(slots=True)
class _DescriptorBlock:
"""Private in-memory descriptor payload for one family.
Attributes
----------
family : str
Canonical family name that produced the block.
X : np.ndarray
Family-specific descriptor matrix aligned on the observation axis.
descriptor_names : list of str
Deterministic column names aligned with the columns of ``X``.
meta : dict
Family-specific metadata to preserve under the merged result.
failures : list of dict
Normalized failure records collected during extraction.
Notes
-----
``X.shape[0]`` must always match the input observation count seen by the
extractor, and ``len(descriptor_names)`` must always match ``X.shape[1]``.
The pipeline depends on these alignment guarantees when merging family
outputs.
"""
family: str
X: np.ndarray
descriptor_names: list[str]
meta: dict[str, Any] = field(default_factory=dict)
failures: list[dict[str, Any]] = field(default_factory=list)
[docs]
def make_failure_record(
family: str,
obs_index: int,
obs_id: Any = None,
channel_index: int | None = None,
channel_name: str | None = None,
exception_type: str | None = None,
message: str | None = None,
) -> dict[str, Any]:
"""Create one normalized extractor failure record."""
return {
"family": family,
"obs_index": obs_index,
"obs_id": obs_id,
"channel_index": channel_index,
"channel_name": channel_name,
"exception_type": exception_type,
"message": message,
}
[docs]
class BaseDescriptorExtractor(ABC):
"""
Abstract base class for descriptor extraction families.
Subclasses receive already validated NumPy inputs and must return one
`_DescriptorBlock` aligned on the observation axis. The base class keeps
the extractor API narrow and provides a shared helper for sensor-level
finalization and deterministic descriptor naming.
Parameters
----------
config : Any
Typed family configuration parsed by `DescriptorConfig`.
Attributes
----------
config : Any
Stored family-specific configuration object.
family_name : str
Stable family identifier used in failure records and merged metadata.
Notes
-----
Extractors are stateless at runtime. They do not learn parameters across
calls; all runtime state is provided explicitly through `extract()`.
Concrete extractors are expected to:
1. compute family-specific values with shape ``(n_obs, n_channels)`` for
each metric
2. pass those values through :meth:`_finalize_descriptor`
3. return one `_DescriptorBlock` with aligned names, metadata, and failures
Examples
--------
A minimal concrete extractor typically looks like:
>>> class MeanOverTimeExtractor(BaseDescriptorExtractor):
... family_name = "toy"
...
... def extract(
... self,
... X,
... sfreq,
... channel_names,
... ids,
... runtime,
... ):
... values = X.mean(axis=-1)
... X_out, names = self._finalize_descriptor(
... values,
... family_prefix="toy",
... metric_name="mean",
... channel_names=channel_names,
... )
... return _DescriptorBlock(
... family=self.family_name,
... X=X_out,
... descriptor_names=names,
... )
"""
family_name = "base"
def __init__(self, config: Any):
"""Store the typed family configuration."""
self.config = config
@property
def capabilities(self) -> dict[str, Any]:
"""Return static extractor capability metadata.
Returns
-------
dict[str, Any]
Static metadata describing optional dependencies and general
execution properties for the extractor.
Notes
-----
The descriptors pipeline currently uses this mapping only as lightweight
backend metadata. It is intentionally much smaller than the reducer
capability surface in `dim_reduction`.
"""
return {
"requires_sfreq": False,
"supports_batching": True,
"supports_channelwise": True,
"deterministic": True,
"optional_dependencies": [],
}
[docs]
@abstractmethod
def extract(
self,
X: np.ndarray,
sfreq: float | None,
channel_names: list[str] | None,
ids: np.ndarray | None,
runtime: DescriptorRuntimeConfig,
obs_offset: int = 0,
) -> _DescriptorBlock:
"""Extract descriptors from a validated input array.
Parameters
----------
X : np.ndarray
Input array with shape ``(n_obs, n_channels, n_times)``.
sfreq : float, optional
Sampling frequency in Hertz.
channel_names : list of str, optional
Explicit channel labels aligned with axis 1 of ``X``.
ids : np.ndarray, optional
Observation identifiers aligned with axis 0 of ``X``.
runtime : DescriptorRuntimeConfig
Runtime execution controls shared across extractors.
obs_offset : int, default=0
Global observation offset applied to any collected failure records.
Returns
-------
_DescriptorBlock
Family-specific descriptor matrix plus metadata and failures.
Raises
------
ImportError
If an optional backend required by the extractor is unavailable.
ValueError
If the extractor encounters an invalid runtime condition and the
configured error policy requires raising.
Notes
-----
The recommended pattern is to keep family-specific computation local to
the extractor and delegate sensor-level naming behavior to
:meth:`_finalize_descriptor`.
"""
[docs]
def _finalize_descriptor(
self,
values: np.ndarray,
family_prefix: str,
metric_name: str,
channel_names: list[str] | None,
) -> tuple[np.ndarray, list[str]]:
"""Build deterministic sensor-level descriptor names.
Parameters
----------
values : np.ndarray
Family metric values with shape ``(n_obs, n_channels)`` or
``(n_obs,)``.
family_prefix : str
Stable family prefix, for example ``"band"`` or ``"param"``.
metric_name : str
Family-local metric identifier used in the descriptor name.
channel_names : list of str, optional
Channel labels used when building channel-resolved descriptor names.
Returns
-------
tuple
``(X_metric, names)`` where ``X_metric`` is the finalized metric
matrix and ``names`` is the aligned list of descriptor names.
Notes
-----
This helper assumes ``values`` already represents descriptor values, not
raw signals. It therefore only handles the stable sensor-level naming
convention used by the public extract result.
Examples
--------
Given ``channel_names=["Fz", "Cz", "Pz"]`` and
``metric_name="abs_alpha"``:
- yields
``["band_abs_alpha_ch-Fz", "band_abs_alpha_ch-Cz", "band_abs_alpha_ch-Pz"]``
"""
if values.ndim == 1:
values = values[:, None]
channel_names = channel_names or [f"ch-{idx}" for idx in range(values.shape[1])]
scopes = [
channel_name if channel_name.startswith("ch-") else f"ch-{channel_name}"
for channel_name in channel_names
]
names = ["_".join((family_prefix, metric_name, scope)) for scope in scopes]
return values, names
[docs]
class BasePSDDescriptorExtractor(BaseDescriptorExtractor):
"""
Abstract base class for descriptor families that consume PSD batches.
PSD-consuming families still participate in the shared descriptor contract,
but they expose one additional explicit entry point:
- `extract_psd(...)` consumes precomputed `psds, freqs`
- `psd_request()` tells the planner which PSD range and method is needed
This keeps the generic raw-signal interface narrow while still giving the
planner one formal PSD-consumer contract shared by spectral and parametric
families.
Notes
-----
PSD consumers may still expose `extract()` to satisfy the generic family
interface, but the shared planner uses `psd_request()` and `extract_psd()`
exclusively once PSD intermediates have been materialized.
"""
[docs]
@abstractmethod
def psd_request(self) -> dict[str, Any]:
"""Describe the PSD requirements for the shared planner.
Returns
-------
dict[str, Any]
Minimal request payload containing the PSD method and the required
frequency range for this family.
"""
[docs]
def parametric_fit_requirements(self) -> dict[str, Any]:
"""Describe whether this PSD consumer needs a shared parametric fit.
Returns
-------
dict[str, Any]
Shared-fit requirements with the keys:
- `needed`
- `metrics`
- `periodic_psds`
- `config`
"""
return {
"needed": False,
"metrics": False,
"periodic_psds": False,
"config": None,
}
[docs]
@abstractmethod
def extract_psd(
self,
psds: np.ndarray,
freqs: np.ndarray,
channel_names: list[str] | None,
ids: np.ndarray | None,
runtime: DescriptorRuntimeConfig,
obs_offset: int = 0,
fit_batch: Any | None = None,
) -> _DescriptorBlock:
"""Extract descriptors from explicit PSD intermediates.
Parameters
----------
psds : np.ndarray
PSD batch with shape ``(n_obs, n_channels, n_freqs)``.
freqs : np.ndarray
Frequency grid aligned with the last axis of ``psds``.
channel_names : list of str, optional
Explicit channel labels aligned with the channel axis.
ids : np.ndarray, optional
Observation identifiers aligned with the observation axis.
runtime : DescriptorRuntimeConfig
Runtime execution controls shared across extractors.
obs_offset : int, default=0
Global observation offset applied to collected failure records.
fit_batch : Any, optional
Additional shared fit payload required by some PSD consumers.
Returns
-------
_DescriptorBlock
Family-specific descriptor block aligned with the input PSD batch.
"""