"""
Descriptor extraction planner and execution pipeline.
This module owns the config-bound runtime orchestration for descriptor
extraction. It does not implement family-specific descriptor math; instead it:
- validates the explicit runtime inputs accepted by the module
- instantiates enabled descriptor families from typed config
- plans shared PSD computation for compatible PSD consumers
- executes one observation batch at a time with controlled parallelism
- merges aligned family outputs into one flat descriptor matrix
Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""
from __future__ import annotations
import warnings
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any
import numpy as np
from .configs import DescriptorConfig, ParametricDescriptorConfig
from .extractors._parametric_fit import fit_parametric_batch
from .extractors._psd import compute_psd
from .extractors.base import (
BaseDescriptorExtractor,
BasePSDDescriptorExtractor,
_DescriptorBlock,
)
from .extractors.complexity import ComplexityDescriptorExtractor
from .extractors.parametric import ParametricDescriptorExtractor
from .extractors.spectral import BandDescriptorExtractor
from .validation import (
validate_runtime_inputs,
)
__all__ = ["DescriptorPipeline"]
@dataclass(slots=True)
class _PSDGroup:
"""Plan one shared PSD compute for a compatible set of consumers.
Attributes
----------
method : str
Shared PSD estimator name.
fmin, fmax : float
Union frequency window required by all consumers in the group.
consumers : list of BasePSDDescriptorExtractor
PSD-consuming extractors that will reuse the shared PSD output.
needs_parametric_fit : bool, default=False
Whether the group also needs one shared parametric fit.
need_parametric_metrics : bool, default=False
Whether the shared fit must expose scalar parametric metrics.
need_periodic_psds : bool, default=False
Whether the shared fit must reconstruct periodic-only PSDs.
fit_config : ParametricDescriptorConfig or None, default=None
Shared parametric fit configuration when a fit is required.
"""
method: str
fmin: float
fmax: float
consumers: list[BasePSDDescriptorExtractor]
needs_parametric_fit: bool = False
need_parametric_metrics: bool = False
need_periodic_psds: bool = False
fit_config: ParametricDescriptorConfig | None = None
def _parallel_jobs(n_jobs: int, limit: int) -> int:
"""Clamp worker count to the amount of useful parallel work.
Parameters
----------
n_jobs : int
Requested worker count from runtime config. ``-1`` means "use as much
parallelism as this stage can use".
limit : int
Number of parallel tasks available in the current stage.
Returns
-------
int
Worker count capped at ``limit``.
"""
return limit if n_jobs == -1 else min(n_jobs, limit)
def _cast_precision(values: np.ndarray, precision: str) -> np.ndarray:
"""Cast the final descriptor matrix to the configured floating precision.
Parameters
----------
values : np.ndarray
Floating descriptor matrix to cast.
precision : {"float32", "float64"}
Requested output precision.
Returns
-------
np.ndarray
``values`` cast in-place when possible.
"""
return values.astype(
np.float32 if precision == "float32" else np.float64, copy=False
)
def _merge_descriptor_blocks(
blocks: list[_DescriptorBlock],
n_obs: int,
precision: str,
) -> tuple[np.ndarray, list[str], list[dict[str, Any]]]:
"""Merge family descriptor blocks column-wise on the descriptor axis.
Parameters
----------
blocks : list of _DescriptorBlock
Family-specific descriptor block objects to concatenate column-wise.
n_obs : int
Expected number of observations in each block.
precision : {"float32", "float64"}
Precision applied to the merged descriptor matrix.
Returns
-------
tuple
``(X, descriptor_names, failures)`` where ``X`` is the merged
descriptor matrix, ``descriptor_names`` is the deterministic merged
column order, and ``failures`` concatenates all family failure records.
Raises
------
ValueError
If any block is misaligned on the observation axis.
"""
if not blocks:
empty = np.empty(
(n_obs, 0),
dtype=np.float32 if precision == "float32" else np.float64,
)
return empty, [], []
matrices = []
names: list[str] = []
failures: list[dict[str, Any]] = []
for block in blocks:
if block.X.shape[0] != n_obs:
raise ValueError(
"Descriptor block "
f"'{block.family}' is misaligned: expected {n_obs} rows, "
f"got {block.X.shape[0]}."
)
matrices.append(block.X)
names.extend(block.descriptor_names)
failures.extend(block.failures)
if len(matrices) == 1:
X = _cast_precision(matrices[0], precision)
else:
X = _cast_precision(np.concatenate(matrices, axis=1), precision)
return (
X,
names,
failures,
)
def _sequential_runtime(runtime):
"""Return a sequential runtime copy for nested work.
Parameters
----------
runtime : DescriptorRuntimeConfig
Runtime configuration for the current extraction stage.
Returns
-------
DescriptorRuntimeConfig
Copy with nested parallelism disabled.
"""
return runtime.model_copy(update={"execution_backend": "sequential", "n_jobs": 1})
def _build_psd_groups(
extractors: list[BaseDescriptorExtractor],
) -> list[_PSDGroup]:
"""Plan shared PSD groups for the enabled PSD-consuming extractors.
Parameters
----------
extractors : list of BaseDescriptorExtractor
Config-bound extractors in deterministic family order.
Returns
-------
list of _PSDGroup
Shared PSD execution groups keyed by compatible PSD method and merged
fit requirements.
Raises
------
ValueError
If consumers that would share one parametric fit disagree on the fit
configuration.
"""
groups_by_method: dict[str, _PSDGroup] = {}
for extractor in extractors:
if not isinstance(extractor, BasePSDDescriptorExtractor):
continue
request = extractor.psd_request()
method = str(request["method"])
if method not in groups_by_method:
groups_by_method[method] = _PSDGroup(
method=method,
fmin=float(request["fmin"]),
fmax=float(request["fmax"]),
consumers=[extractor],
)
else:
current = groups_by_method[method]
current.fmin = min(current.fmin, float(request["fmin"]))
current.fmax = max(current.fmax, float(request["fmax"]))
current.consumers.append(extractor)
current = groups_by_method[method]
fit_req = extractor.parametric_fit_requirements()
if fit_req["needed"]:
current.needs_parametric_fit = True
current.need_parametric_metrics = current.need_parametric_metrics or bool(
fit_req["metrics"]
)
current.need_periodic_psds = current.need_periodic_psds or bool(
fit_req["periodic_psds"]
)
if current.fit_config is None:
current.fit_config = fit_req["config"]
elif current.fit_config.model_dump() != fit_req["config"].model_dump():
raise ValueError(
"PSD consumers sharing one parametric fit must use the same "
"parametric fit configuration."
)
return list(groups_by_method.values())
def _merge_family_blocks(
batch_results: list[dict[str, _DescriptorBlock]],
family_order: list[str],
) -> list[_DescriptorBlock]:
"""Merge per-batch results row-wise within each descriptor family.
Parameters
----------
batch_results : list of dict
One family-block mapping per processed observation batch.
family_order : list of str
Deterministic family order used for the final merged output.
Returns
-------
list of _DescriptorBlock
One merged block per family, still separated by family but aligned
across all processed batches.
Raises
------
ValueError
If descriptor names drift across batches for the same family.
"""
merged_blocks: list[_DescriptorBlock] = []
for family_name in family_order:
family_blocks = [
batch_result[family_name]
for batch_result in batch_results
if family_name in batch_result
]
if not family_blocks:
continue
reference_names = family_blocks[0].descriptor_names
for block in family_blocks[1:]:
if block.descriptor_names != reference_names:
raise ValueError(
"Descriptor names changed across batches for family "
f"'{family_name}'."
)
merged_blocks.append(
_DescriptorBlock(
family=family_name,
X=np.concatenate([block.X for block in family_blocks], axis=0),
descriptor_names=list(reference_names),
meta={},
failures=[
failure for block in family_blocks for failure in block.failures
],
)
)
return merged_blocks
def _process_psd_group(
group: _PSDGroup,
X_batch: np.ndarray,
sfreq: float,
channel_names: list[str] | None,
ids_batch: np.ndarray | None,
runtime,
obs_offset: int,
joblib=None,
consumer_parallel: bool = False,
psd_n_jobs: int | None = None,
) -> dict[str, _DescriptorBlock]:
"""Execute one shared PSD group for a single observation batch.
Parameters
----------
group : _PSDGroup
Planned PSD reuse group for the current batch.
X_batch : np.ndarray
Observation batch with shape ``(n_obs_batch, n_channels, n_times)``.
sfreq : float
Sampling frequency in Hertz.
channel_names : list of str or None
Runtime channel labels.
ids_batch : np.ndarray or None
Observation identifiers aligned with ``X_batch``.
runtime : DescriptorRuntimeConfig
Runtime policy used for this stage.
obs_offset : int
Absolute observation offset of the batch in the full input array.
joblib : module, optional
Imported ``joblib`` module when the selected strategy uses it.
consumer_parallel : bool, default=False
Whether compatible PSD consumers should run in parallel after the PSD
has been computed.
psd_n_jobs : int or None, default=None
Worker count forwarded to the PSD backend when the selected strategy is
PSD-level parallelism.
Returns
-------
dict[str, _DescriptorBlock]
Family-name mapping for all consumers in the PSD group.
"""
if psd_n_jobs is not None and joblib is not None:
with joblib.parallel_backend("threading", n_jobs=psd_n_jobs):
psds, freqs = compute_psd(
X_batch,
sfreq=sfreq,
method=group.method,
fmin=group.fmin,
fmax=group.fmax,
n_jobs=psd_n_jobs,
)
else:
psds, freqs = compute_psd(
X_batch,
sfreq=sfreq,
method=group.method,
fmin=group.fmin,
fmax=group.fmax,
n_jobs=psd_n_jobs,
)
fit_batch = None
if group.needs_parametric_fit:
fit_batch = fit_parametric_batch(
psds,
freqs,
group.fit_config,
runtime,
need_periodic_psd=group.need_periodic_psds,
include_metrics=group.need_parametric_metrics,
)
consumer_runtime = _sequential_runtime(runtime) if consumer_parallel else runtime
if consumer_parallel and joblib is not None and len(group.consumers) > 1:
blocks = joblib.Parallel(
n_jobs=_parallel_jobs(runtime.n_jobs, len(group.consumers)),
prefer="threads",
)(
joblib.delayed(consumer.extract_psd)(
psds,
freqs,
channel_names=channel_names,
ids=ids_batch,
runtime=consumer_runtime,
obs_offset=obs_offset,
fit_batch=fit_batch,
)
for consumer in group.consumers
)
else:
blocks = [
consumer.extract_psd(
psds,
freqs,
channel_names=channel_names,
ids=ids_batch,
runtime=consumer_runtime,
obs_offset=obs_offset,
fit_batch=fit_batch,
)
for consumer in group.consumers
]
return {block.family: block for block in blocks}
def _process_batch(
obs_slice: slice,
X: np.ndarray,
sfreq: float | None,
channel_names: list[str] | None,
ids: np.ndarray | None,
signal_extractors: list[BaseDescriptorExtractor],
psd_groups: list[_PSDGroup],
runtime,
strategy: str,
joblib=None,
) -> dict[str, _DescriptorBlock]:
"""Execute one observation batch under the selected planner strategy.
Parameters
----------
obs_slice : slice
Observation slice for the current batch.
X : np.ndarray
Full validated input array with shape ``(n_obs, n_channels, n_times)``.
sfreq : float or None
Sampling frequency in Hertz.
channel_names : list of str or None
Runtime channel labels.
ids : np.ndarray or None
Observation identifiers aligned with ``X``.
signal_extractors : list of BaseDescriptorExtractor
Non-PSD families that consume raw signal batches directly.
psd_groups : list of _PSDGroup
Planned PSD reuse groups for this pipeline instance.
runtime : DescriptorRuntimeConfig
Runtime policy for the current extraction call.
strategy : str
Selected parallelization strategy for this execution path.
joblib : module, optional
Imported ``joblib`` module when the selected strategy uses it.
Returns
-------
dict[str, _DescriptorBlock]
Family-name mapping for all blocks produced from the batch.
"""
X_batch = X[obs_slice]
ids_batch = None if ids is None else ids[obs_slice]
obs_offset = obs_slice.start or 0
family_blocks: dict[str, _DescriptorBlock] = {}
if strategy == "work-unit" and joblib is not None:
def _signal_unit(extractor):
block = extractor.extract(
X_batch,
sfreq=sfreq,
channel_names=channel_names,
ids=ids_batch,
runtime=_sequential_runtime(runtime),
obs_offset=obs_offset,
)
return {block.family: block}
def _psd_unit(group):
return _process_psd_group(
group,
X_batch,
sfreq=sfreq,
channel_names=channel_names,
ids_batch=ids_batch,
runtime=_sequential_runtime(runtime),
obs_offset=obs_offset,
joblib=None,
consumer_parallel=False,
)
work_units = [
joblib.delayed(_signal_unit)(extractor) for extractor in signal_extractors
] + [joblib.delayed(_psd_unit)(group) for group in psd_groups]
for unit_result in joblib.Parallel(
n_jobs=_parallel_jobs(
runtime.n_jobs,
len(signal_extractors) + len(psd_groups),
),
prefer="threads",
)(work_units):
family_blocks.update(unit_result)
else:
signal_runtime = _sequential_runtime(runtime)
for extractor in signal_extractors:
block = extractor.extract(
X_batch,
sfreq=sfreq,
channel_names=channel_names,
ids=ids_batch,
runtime=signal_runtime,
obs_offset=obs_offset,
)
family_blocks[block.family] = block
for group in psd_groups:
consumer_parallel = (
strategy == "psd-consumer"
and joblib is not None
and len(group.consumers) > 1
)
psd_n_jobs = None
if strategy == "psd-n_jobs":
psd_n_jobs = runtime.n_jobs
family_blocks.update(
_process_psd_group(
group,
X_batch,
sfreq=sfreq,
channel_names=channel_names,
ids_batch=ids_batch,
runtime=runtime
if strategy == "parametric-inner" and group.needs_parametric_fit
else signal_runtime,
obs_offset=obs_offset,
joblib=joblib
if consumer_parallel or strategy == "psd-n_jobs"
else None,
consumer_parallel=consumer_parallel,
psd_n_jobs=psd_n_jobs,
)
)
return family_blocks
[docs]
class DescriptorPipeline:
"""Run config-driven descriptor extraction on explicit arrays.
Parameters
----------
config : DescriptorConfig or Mapping[str, Any]
Typed descriptors configuration or a mapping accepted by
:class:`DescriptorConfig`.
Attributes
----------
config : DescriptorConfig
Parsed descriptors configuration.
extractors : list of BaseDescriptorExtractor
Enabled family extractors in deterministic family order.
signal_extractors : list of BaseDescriptorExtractor
Enabled non-PSD extractors that consume raw signal batches directly.
psd_groups : list of _PSDGroup
Planned PSD reuse groups derived once from the enabled extractors.
family_order : list of str
Deterministic family order used when merging batch-local outputs.
Notes
-----
The pipeline is config-bound but runtime-stateless. Construction performs
config parsing, corrected-band compatibility checks, and planner setup once.
Each call to :meth:`extract` then validates the explicit runtime inputs,
executes the planned families, and returns one flat descriptor matrix plus
any collected failures.
"""
def __init__(self, config: DescriptorConfig | Mapping[str, Any]):
"""Create a config-bound descriptor extraction pipeline.
Parameters
----------
config : DescriptorConfig or Mapping[str, Any]
Typed descriptors configuration or a mapping accepted by
:class:`DescriptorConfig`.
Raises
------
ValueError
If corrected band outputs are enabled but the parametric fit range
does not cover the requested band PSD window.
"""
self.config = (
config
if isinstance(config, DescriptorConfig)
else DescriptorConfig.model_validate(config)
)
corrected_outputs = {
"corrected_absolute_power",
"corrected_log_absolute_power",
"corrected_relative_power",
"corrected_ratios",
}
if any(
output in corrected_outputs for output in self.config.families.bands.outputs
):
fit_low, fit_high = self.config.families.parametric.freq_range
band_low = self.config.families.bands.fmin
band_high = self.config.families.bands.fmax
if fit_low > band_low or fit_high < band_high:
raise ValueError(
"Corrected band outputs require families.parametric.freq_range "
f"to cover the band PSD window [{band_low}, {band_high}]."
)
self.extractors: list[BaseDescriptorExtractor] = []
if self.config.families.bands.enabled:
self.extractors.append(
BandDescriptorExtractor(
self.config.families.bands,
fit_config=self.config.families.parametric,
)
)
if self.config.families.parametric.enabled:
self.extractors.append(
ParametricDescriptorExtractor(self.config.families.parametric)
)
if self.config.families.complexity.enabled:
self.extractors.append(
ComplexityDescriptorExtractor(self.config.families.complexity)
)
self.signal_extractors = [
extractor
for extractor in self.extractors
if not isinstance(extractor, BasePSDDescriptorExtractor)
]
self.psd_groups = _build_psd_groups(self.extractors)
self.family_order = [extractor.family_name for extractor in self.extractors]
[docs]
def pool_channels(
self,
result: Mapping[str, Any],
channel_groups: Mapping[str, Sequence[str]],
) -> dict[str, Any]:
"""Pool sensor-level descriptor columns into grouped channel outputs.
Parameters
----------
result : mapping
Standard descriptor result produced by :meth:`extract`.
channel_groups : mapping of str to sequence of str
Channel groups used to replace sensor-level descriptor columns with
grouped ``"chgrp-..."`` outputs.
Returns
-------
dict[str, Any]
Descriptor result with grouped channel features and unchanged
failures.
Raises
------
ValueError
If the provided result is malformed or if any requested group
cannot be formed from the sensor-level descriptor columns.
"""
if (
"X" not in result
or "descriptor_names" not in result
or "failures" not in result
):
raise ValueError(
"pool_channels() expects a result mapping with keys "
"'X', 'descriptor_names', and 'failures'."
)
X_desc = np.asarray(result["X"], dtype=float)
descriptor_names = [str(name) for name in result["descriptor_names"]]
if X_desc.ndim != 2:
raise ValueError("pool_channels() expects result['X'] to be 2D.")
if X_desc.shape[1] != len(descriptor_names):
raise ValueError(
"pool_channels() requires result['descriptor_names'] to align with "
"result['X'] columns."
)
if not channel_groups:
raise ValueError("channel_groups must define at least one group.")
base_to_channel_cols: dict[str, dict[str, int]] = {}
known_channels: set[str] = set()
for col_idx, descriptor_name in enumerate(descriptor_names):
if "_ch-" not in descriptor_name:
continue
base_name, channel_name = descriptor_name.rsplit("_ch-", 1)
base_to_channel_cols.setdefault(base_name, {})[channel_name] = col_idx
known_channels.add(channel_name)
if not base_to_channel_cols:
raise ValueError(
"pool_channels() requires sensor-level descriptor names "
"ending in '_ch-<sensor>'."
)
normalized_groups: dict[str, list[str]] = {}
assigned: dict[str, str] = {}
for raw_group_name, raw_members in channel_groups.items():
group_name = str(raw_group_name)
if not group_name:
raise ValueError("channel_groups keys must be non-empty strings.")
members = [str(member) for member in raw_members]
if not members:
raise ValueError(
f"channel_groups['{group_name}'] must define at least one channel."
)
if len(set(members)) != len(members):
raise ValueError(
f"channel_groups['{group_name}'] must not contain duplicates."
)
for member in members:
if member not in known_channels:
raise ValueError(
f"channel_groups['{group_name}'] references unknown channel "
f"'{member}'."
)
if member in assigned:
raise ValueError(
f"Channel '{member}' is assigned to multiple channel_groups: "
f"'{assigned[member]}' and '{group_name}'."
)
assigned[member] = group_name
normalized_groups[group_name] = members
output_columns: list[np.ndarray] = []
pooled_names: list[str] = []
seen_bases: set[str] = set()
for col_idx, descriptor_name in enumerate(descriptor_names):
if "_ch-" not in descriptor_name:
output_columns.append(X_desc[:, col_idx][:, None])
pooled_names.append(descriptor_name)
continue
base_name, _ = descriptor_name.rsplit("_ch-", 1)
if base_name in seen_bases:
continue
seen_bases.add(base_name)
channel_to_col = base_to_channel_cols[base_name]
for group_name, members in normalized_groups.items():
missing = [member for member in members if member not in channel_to_col]
if missing:
raise ValueError(
"pool_channels() could not form group "
f"'{group_name}' for descriptor base '{base_name}'. "
f"Missing channels: {missing}."
)
member_indices = [channel_to_col[member] for member in members]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
grouped = np.nanmean(X_desc[:, member_indices], axis=1)
output_columns.append(grouped[:, None])
pooled_names.append(f"{base_name}_chgrp-{group_name}")
X_pooled = _cast_precision(
np.concatenate(output_columns, axis=1)
if output_columns
else np.empty((X_desc.shape[0], 0), dtype=float),
self.config.precision,
)
return {
"X": X_pooled,
"descriptor_names": pooled_names,
"failures": list(result["failures"]),
}