Source code for coco_pipe.descriptors.extractors._psd

"""
Shared PSD computation for PSD-consuming descriptor paths.

This module holds the reusable PSD step used by the descriptors planner and by
PSD-consuming extractors when they need a standalone spectral input. It does
not define descriptor semantics. It only:

- prepare a writable runtime environment for MNE-backed PSD helpers
- lazily import the MNE PSD functions used by descriptors
- compute Welch or multitaper PSD batches on explicit NumPy inputs

Author: Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""

from __future__ import annotations

import os
import tempfile

import numpy as np

from ...utils import import_optional_dependency


[docs] def load_mne_psd_functions(): """Lazily import MNE PSD helpers with writable runtime cache locations. Returns ------- tuple ``(psd_array_welch, psd_array_multitaper)`` imported from `mne.time_frequency`. Notes ----- MNE may write cache or config files during import/use. The descriptors module keeps those paths inside the system temp directory so PSD computation remains sandbox-friendly. """ tmp_root = os.path.join(tempfile.gettempdir(), "coco_pipe_descriptors") mpl_dir = os.path.join(tmp_root, "mplconfig") mne_dir = os.path.join(tmp_root, "mne") os.makedirs(mpl_dir, exist_ok=True) os.makedirs(mne_dir, exist_ok=True) os.environ.setdefault("MPLCONFIGDIR", mpl_dir) os.environ.setdefault("MNE_HOME", mne_dir) os.environ.setdefault("MNE_DONTWRITE_HOME", "true") return import_optional_dependency( lambda: ( __import__( "mne.time_frequency", fromlist=["psd_array_welch", "psd_array_multitaper"], ).psd_array_welch, __import__( "mne.time_frequency", fromlist=["psd_array_welch", "psd_array_multitaper"], ).psd_array_multitaper, ), feature="descriptor spectral extraction", dependency="mne", install_hint="pip install coco-pipe[descriptors,eeg]", )
[docs] def compute_psd( X: np.ndarray, sfreq: float, method: str, fmin: float, fmax: float, n_jobs: int | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ Compute PSD values for one batch of segmented signals. Parameters ---------- X : np.ndarray Input array with shape ``(n_obs, n_channels, n_times)``. sfreq : float Sampling frequency in Hertz. method : {"welch", "multitaper"} PSD estimator to use. fmin : float Lower frequency bound passed to the PSD backend. fmax : float Upper frequency bound passed to the PSD backend. n_jobs : int, optional Parallel worker count forwarded to the MNE PSD backend when the caller enables PSD-level parallelism. `None` leaves the backend default in place. Returns ------- tuple[np.ndarray, np.ndarray] PSD values with shape ``(n_obs, n_channels, n_freqs)`` and the aligned frequency grid with shape ``(n_freqs,)``. Notes ----- For Welch PSDs, the descriptors module uses: - ``n_fft = min(n_times, 256)`` - ``n_per_seg = n_fft`` while enforcing a minimum of `8` for both values. This keeps Welch behavior bounded and deterministic across the current descriptor tests and examples. """ psd_array_welch, psd_array_multitaper = load_mne_psd_functions() if method == "welch": n_fft = min(int(X.shape[-1]), 256) psd, freqs = psd_array_welch( X, sfreq=sfreq, fmin=fmin, fmax=fmax, n_fft=max(n_fft, 8), n_per_seg=max(n_fft, 8), average="mean", n_jobs=n_jobs, verbose=False, ) return np.asarray(psd, dtype=float), np.asarray(freqs, dtype=float) if method == "multitaper": psd, freqs = psd_array_multitaper( X, sfreq=sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=False, ) return np.asarray(psd, dtype=float), np.asarray(freqs, dtype=float) raise ValueError(f"Unknown PSD method: {method}")