"""
Descriptor Configuration
========================
Strict Pydantic configuration models for the descriptors module.
This module defines the static, typed configuration surface for descriptor
extraction:
- explicit runtime input requirements
- family-specific configs for bands, parametric fitting, and complexity
- final output precision control
- runtime execution controls
These models validate local field structure and family-local constraints. The
remaining cross-family compatibility rule for corrected spectral outputs is
enforced by :class:`coco_pipe.descriptors.core.DescriptorPipeline` after config
parsing, because it depends on how multiple family configs interact.
Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
__all__ = [
"DescriptorInputConfig",
"BandDescriptorConfig",
"ParametricDescriptorConfig",
"ComplexityDescriptorConfig",
"DescriptorFamiliesConfig",
"DescriptorRuntimeConfig",
"DescriptorConfig",
]
CANONICAL_BANDS = {
"delta": (1.0, 4.0),
"theta": (4.0, 8.0),
"alpha": (8.0, 13.0),
"beta": (13.0, 30.0),
"gamma": (30.0, 45.0),
}
_BAND_OUTPUTS = (
"absolute_power",
"log_absolute_power",
"relative_power",
"ratios",
"corrected_absolute_power",
"corrected_log_absolute_power",
"corrected_relative_power",
"corrected_ratios",
)
_PARAM_OUTPUTS = ("aperiodic", "fit_quality", "peak_summary")
_COMPLEXITY_MEASURES = (
"sample_entropy",
"perm_entropy",
"spectral_entropy",
"approx_entropy",
"svd_entropy",
"petrosian_fd",
"katz_fd",
"higuchi_fd",
"shannon_entropy",
"fuzzy_entropy",
"dispersion_entropy",
"hurst_exponent",
"zero_crossings",
"kurtosis",
"rms",
"hjorth_mobility",
"hjorth_complexity",
"lziv_complexity",
)
class _StrictConfigModel(BaseModel):
"""Shared strict Pydantic behavior."""
model_config = ConfigDict(extra="forbid")
[docs]
class BandDescriptorConfig(_StrictConfigModel):
"""
Configuration for PSD-based band summary descriptors.
Parameters
----------
enabled : bool, default=False
Whether the band family is enabled.
psd_method : {"welch", "multitaper"}, default="welch"
PSD estimator used before computing band summaries.
fmin, fmax : float
Global frequency window within which PSDs and bands are evaluated.
bands : dict of str to tuple of float, default=canonical EEG bands
Mapping from band name to ``(low, high)`` boundaries.
outputs : list of {"absolute_power", "log_absolute_power", \
"relative_power", "ratios", "corrected_absolute_power", \
"corrected_log_absolute_power", "corrected_relative_power", \
"corrected_ratios"}
Band descriptors to emit.
ratio_pairs : list of tuple of str, default=[]
Explicit numerator and denominator band names for ratio outputs.
min_denominator_power : float, default=0.0
Minimum denominator power required for relative-power and ratio
outputs. Any denominator at or below this threshold is treated as
undefined and yields ``NaN`` instead of an unstable division result.
Notes
-----
Corrected band outputs are configured here, but their cross-family
compatibility with the parametric fit range is checked later by the
descriptor pipeline because that rule depends on both the band and
parametric family configs together.
"""
enabled: bool = False
psd_method: Literal["welch", "multitaper"] = "welch"
fmin: float = Field(1.0, ge=0.0)
fmax: float = Field(45.0, gt=0.0)
bands: dict[str, tuple[float, float]] = Field(
default_factory=lambda: dict(CANONICAL_BANDS)
)
outputs: list[
Literal[
"absolute_power",
"log_absolute_power",
"relative_power",
"ratios",
"corrected_absolute_power",
"corrected_log_absolute_power",
"corrected_relative_power",
"corrected_ratios",
]
] = Field(default_factory=lambda: ["absolute_power"])
ratio_pairs: list[tuple[str, str]] = Field(default_factory=list)
min_denominator_power: float = Field(0.0, ge=0.0)
[docs]
@field_validator("bands", mode="before")
@classmethod
def _coerce_bands(cls, value: Any) -> dict[str, tuple[float, float]]:
if value is None:
return dict(CANONICAL_BANDS)
return {str(key): tuple(bounds) for key, bounds in dict(value).items()}
[docs]
@field_validator("outputs", mode="before")
@classmethod
def _validate_outputs(cls, value: list[str]) -> list[str]:
if len(set(value)) != len(value):
raise ValueError("Band outputs must not contain duplicates.")
invalid = sorted(set(value) - set(_BAND_OUTPUTS))
if invalid:
raise ValueError(f"Unknown band outputs: {invalid}")
return value
[docs]
@field_validator("ratio_pairs", mode="before")
@classmethod
def _coerce_ratio_pairs(cls, value: Any) -> list[tuple[str, str]]:
if value is None:
return []
return [tuple(pair) for pair in value]
[docs]
@model_validator(mode="after")
def _validate_model(self) -> "BandDescriptorConfig":
if self.fmin >= self.fmax:
raise ValueError("Band descriptor config requires fmin < fmax.")
for name, (low, high) in self.bands.items():
if low >= high:
raise ValueError(f"Band '{name}' requires low < high.")
if low < self.fmin or high > self.fmax:
raise ValueError(
"Band "
f"'{name}' must stay within the configured "
f"[{self.fmin}, {self.fmax}] range."
)
if (
"ratios" in self.outputs or "corrected_ratios" in self.outputs
) and not self.ratio_pairs:
raise ValueError("Band ratios require explicit ratio_pairs.")
return self
[docs]
class ParametricDescriptorConfig(_StrictConfigModel):
"""
Configuration for specparam-based spectral summary descriptors.
Parameters
----------
enabled : bool, default=False
Whether the parametric family is enabled.
backend : {"specparam"}, default="specparam"
Parametric modeling backend.
psd_method : {"welch", "multitaper"}, default="welch"
PSD estimator used before fitting the parametric model.
freq_range : tuple of float, default=(1.0, 45.0)
Frequency range passed to the parametric model.
peak_width_limits : tuple of float, default=(1.0, 12.0)
Peak width bounds forwarded to the model backend.
max_n_peaks : int, default=6
Maximum number of periodic peaks to fit.
aperiodic_mode : {"fixed", "knee"}, default="fixed"
Aperiodic model form used by specparam.
outputs : list of {"aperiodic", "fit_quality", "peak_summary"}
Parametric descriptor groups to emit.
Notes
-----
This config describes how the shared parametric fit is produced. The same
fit can be reused by the parametric family itself and by corrected spectral
outputs when the planner detects compatible requests.
"""
enabled: bool = False
backend: Literal["specparam"] = "specparam"
psd_method: Literal["welch", "multitaper"] = "welch"
freq_range: tuple[float, float] = (1.0, 45.0)
peak_width_limits: tuple[float, float] = (1.0, 12.0)
max_n_peaks: int = Field(6, ge=0)
aperiodic_mode: Literal["fixed", "knee"] = "fixed"
outputs: list[Literal["aperiodic", "fit_quality", "peak_summary"]] = Field(
default_factory=lambda: ["aperiodic", "fit_quality", "peak_summary"]
)
[docs]
@field_validator("outputs", mode="before")
@classmethod
def _validate_outputs(cls, value: list[str]) -> list[str]:
if len(set(value)) != len(value):
raise ValueError("Parametric outputs must not contain duplicates.")
invalid = sorted(set(value) - set(_PARAM_OUTPUTS))
if invalid:
raise ValueError(f"Unknown parametric outputs: {invalid}")
return value
[docs]
@model_validator(mode="after")
def _validate_model(self) -> "ParametricDescriptorConfig":
if self.freq_range[0] >= self.freq_range[1]:
raise ValueError("Parametric freq_range requires low < high.")
if self.peak_width_limits[0] >= self.peak_width_limits[1]:
raise ValueError("peak_width_limits requires low < high.")
return self
[docs]
class ComplexityDescriptorConfig(_StrictConfigModel):
"""
Configuration for signal-complexity descriptors.
Parameters
----------
enabled : bool, default=False
Whether the complexity family is enabled.
backend : {"antropy", "neurokit2", "auto"}, default="antropy"
Complexity backend used for supported measures.
measures : list of str
Complexity measures to compute.
measure_kwargs : dict of str to dict, default={}
Per-measure keyword arguments forwarded to the backend implementation.
Notes
-----
Complexity measures are signal-domain descriptors. Unlike the PSD-based
families, they do not participate in shared PSD planning.
"""
enabled: bool = False
backend: Literal["antropy", "neurokit2", "auto"] = "antropy"
measures: list[str] = Field(default_factory=lambda: list(_COMPLEXITY_MEASURES))
measure_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
[docs]
@field_validator("measures", mode="before")
@classmethod
def _validate_measures(cls, value: list[str]) -> list[str]:
if len(set(value)) != len(value):
raise ValueError("Complexity measures must not contain duplicates.")
invalid = sorted(set(value) - set(_COMPLEXITY_MEASURES))
if invalid:
raise ValueError(f"Unknown complexity measures: {invalid}")
return value
[docs]
class DescriptorFamiliesConfig(_StrictConfigModel):
"""
Group descriptor-family configuration under one top-level field.
Attributes
----------
bands : BandDescriptorConfig
Configuration for PSD-based band summaries.
parametric : ParametricDescriptorConfig
Configuration for specparam-based summaries.
complexity : ComplexityDescriptorConfig
Configuration for complexity measures.
"""
bands: BandDescriptorConfig = Field(default_factory=BandDescriptorConfig)
parametric: ParametricDescriptorConfig = Field(
default_factory=ParametricDescriptorConfig
)
complexity: ComplexityDescriptorConfig = Field(
default_factory=ComplexityDescriptorConfig
)
[docs]
class DescriptorRuntimeConfig(_StrictConfigModel):
"""
Runtime execution controls for descriptor extraction.
Parameters
----------
execution_backend : {"joblib", "sequential"}, default="joblib"
Execution backend used by the pipeline.
n_jobs : int, default=1
Number of worker slots requested for supported parallel paths.
``-1`` means "use as much useful parallelism as the current stage can
use", while positive integers request an explicit worker count.
obs_chunk : int, default=128
Number of observations processed per batch.
on_error : {"raise", "warn", "collect"}, default="collect"
Failure policy applied during extraction.
Notes
-----
Runtime config controls execution only. It does not add provenance,
reporting, or persistence metadata to the returned descriptor result.
"""
execution_backend: Literal["joblib", "sequential"] = "joblib"
n_jobs: int = 1
obs_chunk: int = Field(128, gt=0)
on_error: Literal["raise", "warn", "collect"] = Field(
"collect",
description=(
"Policies: "
"'raise' re-raises the first exception immediately; "
"'warn' collects all failures and emits one aggregate warning; "
"'collect' stores failures silently for inspection in result['failures']."
),
)
[docs]
@field_validator("n_jobs")
@classmethod
def _validate_n_jobs(cls, value: int) -> int:
if value == 0 or value < -1:
raise ValueError("n_jobs must be -1 or a positive integer.")
return value
[docs]
class DescriptorConfig(_StrictConfigModel):
"""
Top-level descriptors configuration object.
Attributes
----------
input : DescriptorInputConfig
Runtime input requirements for explicit array extraction.
families : DescriptorFamiliesConfig
Enabled descriptor families and their typed configs.
precision : {"float32", "float64"}
Output dtype used for the final descriptor matrix.
runtime : DescriptorRuntimeConfig
Runtime execution and error-handling settings.
Notes
-----
This object is the stable config boundary for
:class:`coco_pipe.descriptors.core.DescriptorPipeline`. Parsing this config
validates local structure here, then the pipeline applies the remaining
cross-family compatibility checks when it builds the execution plan.
"""
input: DescriptorInputConfig = Field(default_factory=DescriptorInputConfig)
families: DescriptorFamiliesConfig = Field(default_factory=DescriptorFamiliesConfig)
precision: Literal["float32", "float64"] = "float32"
runtime: DescriptorRuntimeConfig = Field(default_factory=DescriptorRuntimeConfig)