coco_pipe.decoding.utils

Helper functions and classes for the decoding module, primarily focused on Cross-Validation (CV) strategy management.

This module provides: - get_cv_splitter: A factory function to instantiate Scikit-Learn cross-validators

from a Pydantic CVConfig.

  • SimpleSplit: A custom validator for a single train/test split.

  • _CVWithGroups: A wrapper to ensure group constraints are respected even when Scikit-Learn’s cross_val_score internals might obscure them.

Classes

_CVWithGroups

Internal wrapper to bind specific groups to a CV splitter.

SimpleSplit

A unified 1-fold CV strategy wrapping train_test_split.

Functions

get_cv_splitter(...)

Factory function to create a Scikit-Learn compliant cross-validator.

get_scorer(→ Callable)

Retrieve or construct a Scikit-Learn compliant scorer by name.

cross_validate_score(→ float)

Compute one mean cross-validated score for an estimator.

Module Contents

class coco_pipe.decoding.utils._CVWithGroups(cv, groups)[source]

Bases: sklearn.model_selection.BaseCrossValidator

Internal wrapper to bind specific groups to a CV splitter.

This ensures that .split(X, y) always uses the strict groups provided at initialization, ignoring any groups passed at runtime. This is critical for preventing data leakage when complex grouping logic is defined upstream.

Parameters:
  • cv (BaseCrossValidator) – The underlying Scikit-Learn cross-validator (e.g., GroupKFold).

  • groups (array-like) – The group labels to enforce for all splits.

cv
groups
split(X, y=None, groups=None)[source]

Generate indices to split data into training and test set.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training data, where n_samples is the number of samples and n_features is the number of features.

  • y (array-like of shape (n_samples,)) – The target variable for supervised learning problems.

  • groups (array-like of shape (n_samples,), default=None) – Group labels for the samples used while splitting the dataset into train/test set.

Yields:
  • train (ndarray) – The training set indices for that split.

  • test (ndarray) – The testing set indices for that split.

get_n_splits(X=None, y=None, groups=None)[source]

Returns the number of splitting iterations in the cross-validator.

class coco_pipe.decoding.utils.SimpleSplit(test_size: float = 0.2, shuffle: bool = True, random_state: int | None = None, stratify: pandas.Series | numpy.ndarray | None = None)[source]

Bases: sklearn.model_selection.BaseCrossValidator

A unified 1-fold CV strategy wrapping train_test_split.

This allows “hold-out” validation to be treated as a Cross-Validation strategy with n_splits=1, integrating seamlessly into loops that expect a generator of indices.

Parameters:
  • test_size (float, default=0.2) – Proportion of the dataset to include in the test split.

  • shuffle (bool, default=True) – Whether to shuffle the data before splitting.

  • random_state (int, optional) – Controls the shuffling applied to the data before applying the split.

  • stratify (array-like, optional) – If not None, data is split in a stratified fashion, using this array as the class labels.

test_size = 0.2
shuffle = True
random_state = None
stratify = None
split(X: pandas.DataFrame | numpy.ndarray, y: pandas.Series | numpy.ndarray | None = None, groups: Sequence | None = None)[source]

Yield a single (train_index, test_index) tuple.

get_n_splits(X: Any = None, y: Any = None, groups: Any = None) int[source]

Always returns 1 split.

coco_pipe.decoding.utils.get_cv_splitter(config: coco_pipe.decoding.configs.CVConfig, groups: Sequence | None = None) sklearn.model_selection.BaseCrossValidator[source]

Factory function to create a Scikit-Learn compliant cross-validator.

Constructs the appropriate splitter based on the provided CVConfig strategy. If groups are provided, they are bound to the splitter using _CVWithGroups to guarantee consistent grouping across pipeline steps.

Parameters:
  • config (CVConfig) – Validated configuration object specifying: - strategy: ‘stratified’, ‘kfold’, ‘group_kfold’, ‘leave_p_out’, etc. - n_splits: Number of folds (where applicable). - shuffle: Whether to shuffle data (where applicable). - random_state: Seed for reproducibility.

  • groups (sequence, optional) – Group labels for the samples. Required for ‘group_kfold’, ‘leave_p_out’, and ‘stratified_group_kfold’. If provided, the returned validator will ignore any groups passed to its .split() method and use these instead.

Returns:

An initialized cross-validator instance.

Return type:

BaseCrossValidator

Raises:

ValueError – If an unknown CV strategy is specified or if required parameters (like n_groups for leave_p_out) are missing from the configuration.

coco_pipe.decoding.utils.get_scorer(name: str) Callable[source]

Retrieve or construct a Scikit-Learn compliant scorer by name.

Parameters:

name (str) – The name of the metric (e.g., ‘accuracy’, ‘f1_macro’, ‘neg_mean_squared_error’).

Returns:

A scoring function with signature (y_true, y_pred) -> float.

Return type:

Callable

Raises:

ValueError – If the metric name is unknown.

coco_pipe.decoding.utils.cross_validate_score(estimator: sklearn.base.BaseEstimator, X: numpy.ndarray, y: Sequence, *, groups: Sequence | None = None, cv_config: coco_pipe.decoding.configs.CVConfig | None = None, metric: str = 'balanced_accuracy', use_scaler: bool = False) float[source]

Compute one mean cross-validated score for an estimator.

Parameters:
  • estimator (BaseEstimator) – Estimator to fit inside each fold.

  • X (np.ndarray) – Input features with shape (n_samples, n_features).

  • y (sequence) – Target labels aligned with X.

  • groups (sequence, optional) – Group labels aligned with X.

  • cv_config (CVConfig, optional) – Cross-validation configuration. Defaults to a 5-fold stratified strategy, or 5-fold stratified-group strategy when groups are provided.

  • metric (str, default="balanced_accuracy") – Metric name resolved through get_scorer().

  • use_scaler (bool, default=False) – When True, wraps the estimator in a StandardScaler pipeline.

Returns:

Mean cross-validated score.

Return type:

float