Source code for coco_pipe.descriptors.validation
"""Runtime input validation helpers for descriptor extraction."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import numpy as np
from .configs import DescriptorConfig
[docs]
def validate_runtime_inputs(
config: DescriptorConfig,
*,
X: Any,
ids: Sequence[Any] | np.ndarray | None = None,
channel_names: Sequence[str] | np.ndarray | None = None,
sfreq: float | None = None,
) -> dict[str, Any]:
"""Validate explicit runtime inputs against the descriptor contract.
Parameters
----------
config : DescriptorConfig
Parsed descriptor config defining the expected runtime contract.
X : Any
Candidate signal array expected to coerce to shape
``(n_obs, n_channels, n_times)``.
ids, channel_names, sfreq
Optional runtime inputs aligned with the observation or channel axes.
Returns
-------
dict[str, Any]
Normalized runtime inputs ready for pipeline and extractor dispatch.
Raises
------
ValueError
If array dimensionality, identifier alignment, sampling frequency, or
explicit channel-name requirements are violated.
"""
X_arr = np.asarray(X, dtype=float)
if X_arr.ndim != 3:
raise ValueError(
"Descriptors expect 3D input in 'obs_channel_time' layout; "
f"got shape {X_arr.shape}."
)
n_obs, n_channels, _ = X_arr.shape
sfreq_required = (
config.input.require_sfreq
or config.families.bands.enabled
or config.families.parametric.enabled
or (
config.families.complexity.enabled
and "spectral_entropy" in config.families.complexity.measures
)
)
if sfreq_required:
if sfreq is None:
raise ValueError(
"`sfreq` must be passed explicitly for the enabled descriptor families."
)
if sfreq <= 0:
raise ValueError("`sfreq` must be positive.")
channel_names_out = None
channel_names_required = config.input.require_channel_names or any(
getattr(config.families, family_name).enabled
for family_name in ("bands", "parametric", "complexity")
)
if channel_names is not None:
channel_names_out = [str(name) for name in np.asarray(channel_names).tolist()]
if len(channel_names_out) != n_channels:
raise ValueError(
f"`channel_names` must align with n_channels={n_channels}; "
f"got {len(channel_names_out)}."
)
elif channel_names_required:
raise ValueError(
"`channel_names` must be passed explicitly for channel-resolved output."
)
ids_out = None
if ids is not None:
ids_out = np.asarray(ids)
if ids_out.shape[0] != n_obs:
raise ValueError(
f"`ids` must align with n_obs={n_obs}; got shape {ids_out.shape}."
)
return {
"X": X_arr,
"ids": ids_out,
"channel_names": channel_names_out,
"sfreq": sfreq,
}