coco_pipe.io¶
Submodules¶
Classes¶
!!! abstract "Usage Documentation" |
|
Configuration for BIDSDataset. |
|
Master configuration container for IO. |
|
Configuration for EmbeddingDataset. |
|
Configuration for TabularDataset. |
|
Generic container for N-dimensional neurophysiological data. |
|
Generic wrapper for ANY scikit-learn transformer (Scaler, PCA, etc.). |
|
M/EEG Spatial Whitening using Covariance Decorrelation. |
Functions¶
|
Universal data loader factory. |
Package Contents¶
- class coco_pipe.io.BaseDatasetConfig(/, **data: Any)[source]¶
Bases:
pydantic.BaseModel- !!! abstract “Usage Documentation”
[Models](../concepts/models.md)
A base class for creating Pydantic models.
- __class_vars__¶
The names of the class variables defined on the model.
- __private_attributes__¶
Metadata about the private attributes of the model.
- __signature__¶
The synthesized __init__ [Signature][inspect.Signature] of the model.
- __pydantic_complete__¶
Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__¶
The core schema of the model.
- __pydantic_custom_init__¶
Whether the model has a custom __init__ function.
- __pydantic_decorators__¶
Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
- __pydantic_generic_metadata__¶
A dictionary containing metadata about generic Pydantic models. The origin and args items map to the [__origin__][genericalias.__origin__] and [__args__][genericalias.__args__] attributes of [generic aliases][types-genericalias], and the parameter item maps to the __parameter__ attribute of generic classes.
- __pydantic_parent_namespace__¶
Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__¶
The name of the post-init method for the model, if defined.
- __pydantic_root_model__¶
Whether the model is a [RootModel][pydantic.root_model.RootModel].
- __pydantic_serializer__¶
The pydantic-core SchemaSerializer used to dump instances of the model.
- __pydantic_validator__¶
The pydantic-core SchemaValidator used to validate instances of the model.
- __pydantic_fields__¶
A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
- __pydantic_computed_fields__¶
A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
- __pydantic_extra__¶
A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to ‘allow’.
- __pydantic_fields_set__¶
The names of fields explicitly set during instantiation.
- __pydantic_private__¶
Values of private attributes set on the model instance.
- path: pathlib.Path = None¶
- subjects: int | List[str | int] | None = None¶
- class coco_pipe.io.BIDSConfig(/, **data: Any)[source]¶
Bases:
BaseDatasetConfigConfiguration for BIDSDataset.
- mode: Literal['bids'] = 'bids'¶
- task: str | None = None¶
- session: str | List[str] | None = None¶
- datatype: str = 'eeg'¶
- suffix: str | None = None¶
- loading_mode: str = None¶
- window_length: float | None = None¶
- stride: float | None = None¶
- class coco_pipe.io.DatasetConfig(/, **data: Any)[source]¶
Bases:
pydantic.BaseModelMaster configuration container for IO.
- dataset: TabularConfig | BIDSConfig | EmbeddingConfig = None¶
- class coco_pipe.io.EmbeddingConfig(/, **data: Any)[source]¶
Bases:
BaseDatasetConfigConfiguration for EmbeddingDataset.
- mode: Literal['embedding'] = 'embedding'¶
- pattern: str = '*.pkl'¶
- dims: Tuple[str, Ellipsis] = ('obs', 'feature')¶
- coords: Dict[str, List | Any] | None = None¶
- task: str | None = None¶
- run: str | None = None¶
- processing: str | None = None¶
- class coco_pipe.io.TabularConfig(/, **data: Any)[source]¶
Bases:
BaseDatasetConfigConfiguration for TabularDataset.
- mode: Literal['tabular'] = 'tabular'¶
- target_col: str | None = None¶
- index_col: str | int | None = None¶
- sep: str = None¶
- header: int | List[int] | None = 0¶
- sheet_name: str | int = 0¶
- columns_to_dims: List[str] | None = None¶
- col_sep: str = '_'¶
- meta_columns: List[str] | None = None¶
- clean: bool = False¶
- clean_kwargs: Dict[str, Any] = None¶
- select_kwargs: Dict[str, Any] = None¶
- coco_pipe.io.load_data(path: str | pathlib.Path, mode: str = 'auto', target_col: str | None = None, index_col: str | int | None = None, sep: str = '\t', header: int | List[int] | None = 0, sheet_name: str | int | None = 0, columns_to_dims: List[str] | None = None, col_sep: str = '_', meta_columns: List[str] | None = None, clean: bool = False, clean_kwargs: Dict[str, Any] | None = None, task: str | None = None, session: str | List[str] | None = None, datatype: str = 'eeg', suffix: str | None = None, loading_mode: str = 'epochs', window_length: float | None = None, stride: float | None = None, subject_metadata_df: Any | None = None, subject_key: str | None = None, pattern: str = '*.pkl', dims: Tuple[str, Ellipsis] = ('obs', 'feature'), coords: Dict[str, List | numpy.ndarray] | None = None, reader: Any | None = None, id_fn: Any | None = None, subjects: str | List[str] | int | List[int] | None = None, **kwargs) coco_pipe.io.structures.DataContainer[source]¶
Universal data loader factory. Dispatches to BIDSDataset, TabularDataset, or EmbeddingDataset based on mode.
- Parameters:
path (str or Path) – Path to data source (file or directory).
mode ({"auto", "tabular", "bids", "embedding"}, default="auto") – Type of data to load. - “auto”: Infers type from file extension or directory structure. - “tabular”: uses TabularDataset (CSV, TSV, Excel, TXT). - “bids”: uses BIDSDataset (BIDS-compliant directories). - “embedding”: uses EmbeddingDataset (NPY, PKL, H5, JSON).
(mode="tabular") (Tabular Arguments)
----------------------------------
target_col (str, optional) – Name of the column to extract as target y. Removed from features X.
index_col (str or int, optional) – Column to use as index (observation IDs).
sep (str, default='t') – Separator for text files (e.g. ‘,’ for CSV).
header (int or list of int, default=0) – Row number(s) to use as column names.
sheet_name (str or int, default=0) – Sheet name or index for Excel files.
columns_to_dims (list of str, optional) – If provided, attempts to reshape 2D feature columns into N-D dimensions. Columns must follow: dim1_dim2_…_feature.
col_sep (str, default='_') – Separator used in column names for reshaping.
meta_columns (list of str, optional) – Columns to extract as metadata coordinates instead of features.
clean (bool, default=False) – Whether to perform automated cleaning (drop NaNs/Infs).
clean_kwargs (dict, optional) – Arguments passed to TabularDataset.clean.
(mode="bids") (BIDS Arguments)
----------------------------
task (str, optional) – BIDS task name (e.g., ‘rest’, ‘audiovisual’).
session (str or List[str], optional) – Session ID(s) to load. Defaults to all available.
datatype (str, default='eeg') – Data type folder (e.g., ‘eeg’, ‘meg’, ‘ieeg’).
suffix (str, optional) – File suffix to load (e.g., ‘eeg’, ‘epo’, ‘ave’).
loading_mode (str, default='epochs') – How to process the data. passed as mode to BIDSDataset. - ‘epochs’: Splices continuous data into fixed-length windows. - ‘continuous’: Loads as single continuous segments. - ‘load_existing’: Loads pre-computed epochs.
window_length (float, optional) – Window length in seconds (for ‘epochs’ mode).
stride (float, optional) – Stride in seconds (for ‘epochs’ mode).
subject_metadata_df (DataFrame, optional) – External subject-level metadata to merge by subject during BIDS loading.
subject_key (str, optional) – Column in subject_metadata_df containing the BIDS subject identifier.
subjects (int or list, optional) – Specific subject IDs to load (without ‘sub-‘).
(mode="embedding") (Embedding Arguments)
--------------------------------------
pattern (str, default=’*.pkl’) – Glob pattern to match files.
dims (tuple of str, default=('obs', 'feature')) – Dimension labels for the data arrays.
coords (dict, optional) – Dictionary of coordinates for dimensions.
reader (callable, optional) – Custom file reader function.
id_fn (callable, optional) – Custom subject ID extraction function.
subjects – If int, loads first N subjects. If list, filters by ID.
- Returns:
Standardized data container with attributes: - X: (N_obs, …) data array - y: Targets (if available) - ids: Observation identifiers - coords: Coordinate metadata
- Return type:
- class coco_pipe.io.DataContainer[source]¶
Generic container for N-dimensional neurophysiological data.
Acts as a lightweight labelled array (like xarray but simpler), managing dimensions, coordinates, and associated target labels (y) and IDs.
- X¶
The primary data tensor. Shape must match dims.
- Type:
np.ndarray
- dims¶
Labels for each dimension of X. Examples: (‘obs’, ‘feature’), (‘obs’, ‘channel’, ‘time’). Note: The ‘obs’ dimension is special and typically represents independent samples.
- Type:
Tuple[str, …]
- coords¶
Coordinates/Labels for dimensions. Keys must be in dims. Values must match the length of the corresponding dimension in X.
- Type:
Dict[str, Union[List, np.ndarray]]
- y¶
Target labels corresponding to the ‘obs’ dimension. Used for supervised learning or coloring plots.
- Type:
Optional[np.ndarray], optional
- ids¶
Identifiers for observations (e.g., subject IDs, trial names). Should correspond to ‘obs’ dim in coords if provided. Kept separate from coords for convenient tracking.
- Type:
Optional[np.ndarray], optional
- meta¶
Arbitrary metadata (sfreq, units, source path, etc).
- Type:
Dict[str, Any]
Examples
Accessing data: >>> container.X.shape (10, 64, 500)
Accessing coordinates: >>> container.coords[‘channel’][:3] [‘Fz’, ‘Cz’, ‘Pz’]
- X: numpy.ndarray¶
- dims: Tuple[str, Ellipsis]¶
- coords: Dict[str, List | numpy.ndarray | Sequence]¶
- y: numpy.ndarray | None = None¶
- ids: numpy.ndarray | None = None¶
- meta: Dict[str, Any]¶
- property shape: Tuple[int, Ellipsis]¶
- save(path: str | Any) None[source]¶
Save the DataContainer to disk using joblib.
- Parameters:
path (str or Path) – Destination file path.
- classmethod load(path: str | Any) DataContainer[source]¶
Load a DataContainer from disk.
- Parameters:
path (str or Path) – Source file path.
- Return type:
- obs_table(include_ids: bool = False, id_col: str = 'obs_id', include_y: bool = False, y_col: str = 'y', include_obs_coord: bool = False) pandas.DataFrame[source]¶
Return one-dimensional coordinates aligned to the observation axis.
This helper is useful when exporting a row-wise table from a container. It only materializes metadata that can map cleanly to one row per observation, skipping coordinates that belong to other axes such as
channel,time,feature, orstat.- Parameters:
include_ids (bool, default=False) – If True, include
self.idsas the first column.id_col (str, default="obs_id") – Column name used when exporting
self.ids.include_y (bool, default=False) – If True, include
self.yas a column when present.y_col (str, default="y") – Column name used when exporting
self.y.include_obs_coord (bool, default=False) – If True, include
coords["obs"]when present.
- Returns:
DataFrame containing only one-dimensional observation-aligned metadata columns.
- Return type:
pandas.DataFrame
- Raises:
ValueError – If the container has no
obsdimension, or ifinclude_idsis requested whenself.idsis missing.
- isel(**indexers) DataContainer[source]¶
Select data by integer indices on specified dimensions.
This method is the integer-index equivalent of select. It operates directly on the dimensions of the data tensor X. It is robust and handles metadata splitting/alignment automatically.
- Parameters:
**indexers (dict) –
Key: Dimension name (e.g., ‘obs’, ‘channel’, ‘time’). Value: Integer indices to select. Can be:
List or numpy array of integers: [0, 1, 5]
Slice object: slice(0, 10)
Single integer: 0
Note: If you provide a list of indices with repeats (e.g., [0, 0, 1]), the output will be oversampled accordingly.
- Returns:
A new DataContainer instance with the sliced data and coordinates.
- Return type:
Examples
>>> # Select first 10 observations >>> subset = container.isel(obs=slice(0, 10))
>>> # Select specific channels by index >>> subset = container.isel(channel=[0, 5, 12])
>>> # Select time range by index >>> subset = container.isel(time=slice(100, 200))
>>> # Bootstrap/Resample (Select index 0 five times) >>> bootstrap = container.isel(obs=[0, 0, 0, 0, 0])
- balance(target: str = 'y', strategy: str = 'undersample', covariates: List[str] | None = None, random_state: int = 42, **kwargs) DataContainer[source]¶
Balance the dataset classes using undersampling or oversampling.
This method adjusts the number of observations (rows) in the container so that class counts in target are equalized. It supports simple random sampling and stratified sampling based on covariates.
- Parameters:
target (str, default='y') – Name of the target variable. - ‘y’: Uses self.y. - Any other string: Looks for the variable in self.coords.
strategy ({'undersample', 'oversample', 'auto'}, default='undersample') –
‘undersample’: Downsample majority classes to match the minority class count.
’oversample’: Upsample minority classes (with replacement) to match the majority class.
’auto’: Heuristic choice. Uses undersampling if total size remains > 50% of original, else oversampling.
covariates (list of str, optional) – List of covariate names in self.coords to preserve distribution of. If provided, the balancing is performed within strata defined by these covariates.
random_state (int, default=42) – Seed for the random number generator. Change this value to produce different random subsets (e.g., for bagging).
**kwargs (dict) –
Additional arguments passed to internal logic: - n_bins (int): Number of bins for continuous covariates (default 5). - binning (str): ‘quantile’ (default) or ‘uniform’ binning. - prefer_clean_rows (bool): If True, weighs sampling to prefer rows
with fewer NaNs/artifacts.
- Returns:
A new DataContainer instance with balanced classes.
- Return type:
Examples
>>> # 1. Simple Undersampling of 'y' >>> balanced = container.balance(strategy='undersample')
>>> # 2. Balance based on a metadata column 'condition' >>> balanced = container.balance(target='condition')
>>> # 3. Stratified Balancing (Balance 'y' while preserving 'sex' and 'age' >>> # ratios) >>> balanced = container.balance(target='y', covariates=['sex', 'age'])
>>> # 4. Iterative Bootstrapping (Different seeds) >>> for seed in [1, 2, 3]: ... subset = container.balance(strategy='undersample', random_state=seed) ... # process subset...
- select(ignore_case: bool = False, fuzzy: bool = False, **selections) DataContainer[source]¶
Select data subsets based on coordinates, ids, or y.
This method supports exact matching, wildcard matching, operator-based filtering, and custom callable filters.
- Parameters:
ignore_case (bool, default=False) – If True, string matching is case-insensitive (e.g., ‘fz’ matches ‘Fz’).
fuzzy (bool, default=False) – If True, uses difflib to find closest matches for string queries (e.g., ‘Alpha’ matches ‘alpha’). Useful for handling typos.
**selections (dict) –
Key is the dimension name (or special keys ‘y’, ‘ids’). Value is the query. Supported query types:
List/Array (Exact or Wildcard): Matches values present in the list. Strings can use shell-style wildcards (‘*’, ‘?’).
Dictionary (Operator Queries): Filters numerical or string values using operators. Keys: ‘>’, ‘<’, ‘>=’, ‘<=’, ‘==’, ‘!=’, ‘in’.
Callable: A function taking the coordinate array and returning a boolean mask.
- Returns:
A new DataContainer instance containing the selected subset.
- Return type:
Examples
>>> # 1. Exact Selection (Sensors) >>> sub = container.select(channel=['Fz', 'Cz'])
>>> # 2. Wildcard Selection (All Alpha features) >>> sub = container.select(feature='*alpha*')
>>> # 3. Range Selection (Time) >>> sub = container.select(time={'>=': 0.1, '<': 0.5})
>>> # 4. Case-Insensitive Fuzzy Matching >>> sub = container.select(channel=['fz'], ignore_case=True)
>>> # 5. Filter by Target (y) >>> sub = container.select(y=['Patient'])
>>> # 6. Complex Logic (Subjects 1-5 via Operator) >>> sub = container.select(subject_id={'>=': 1, '<=': 5})
>>> # 7. Stratified Selection (First 2 epochs per subject via Callable) >>> def first_n(ids, n=2): ... # ... logic ... ... return mask >>> sub = container.select(ids=first_n)
- flatten(preserve: str | List[str] = 'obs') DataContainer[source]¶
Flatten dimensions NOT in preserve into a single ‘feature’ dimension.
This is useful for preparing N-dimensional data for standard 2D machine learning algorithms (scikit-learn). It automatically generates composite feature names (e.g., ‘Fz_0.1s’) for tracking.
- Parameters:
preserve (str or List[str], default='obs') –
Dimensions to keep. All other dimensions will be collapsed into a single ‘feature’ dimension. - ‘obs’: Result shape (N_obs, N_features). Standard specifiction. - [‘obs’, ‘time’]: Result shape (N_obs, N_time, N_features).
Useful for time-resolved decoding distributions.
- Returns:
A new DataContainer with reshaped X and generated ‘feature’ coordinates.
- Return type:
Examples
>>> # Flatten (10, 64, 500) -> (10, 32000) >>> flat = container.flatten(preserve='obs') >>> flat.shape (10, 32000) >>> flat.coords['feature'][0] 'Fz_0.0'
>>> # Flatten spatial only, keep time (10, 64, 500) -> (10, 500, 64) >>> time_resolved = container.flatten(preserve=['obs', 'time'])
- stack(dims: Sequence[str], new_dim: str = 'obs') DataContainer[source]¶
Stack multiple dimensions into a single new dimension.
This reshapes N-dimensional data into (N-K) dimensions by combining specified dimensions. It is useful for transforming spatiotemporal data (Trials, Channels, Time) -> (Trials*Time, Channels) for trajectory analysis.
- Parameters:
dims (sequence of str) – Dimensions to stack. The order determines the nesting (slowest to fastest). e.g., (‘obs’, ‘time’) means ‘obs’ changes slowly, ‘time’ cycles fast.
new_dim (str, default='obs') – Name of the resulting stacked dimension.
- Returns:
New container with stacked dimension. Metadata (coords/ids) are expanded/tiled to match the new shape.
- Return type:
Examples
>>> # Stack time into observations: >>> # (10 obs, 64 ch, 500 time) -> (5000 obs, 64 ch) >>> stacked = container.stack(dims=('obs', 'time'), new_dim='obs') >>> stacked.shape (5000, 64)
- unstack(dim: str) DataContainer[source]¶
Unstack a dimension into multiple dimensions.
Inverse operation of stack. Reshapes the data tensor by splitting one dimension into multiple using metadata stored during the stack operation.
- Parameters:
dim (str) – Dimension to unstack (e.g. ‘obs’).
- Returns:
New container with unstacked dimensions.
- Return type:
- Raises:
ValueError – If the container was not previously stacked (missing metadata).
Examples
>>> # Stack 'trials' and 'time' -> 'obs' >>> stacked = container.stack(('trials', 'time'), new_dim='obs') >>> # Unstack 'obs' -> ('trials', 'time') (automatically inferred) >>> unstacked = stacked.unstack('obs')
- center(dim: str = 'time', inplace: bool = False) DataContainer[source]¶
Remove mean along a specified dimension (Centering/Baseline Correction).
This operation computes the mean along dim (ignoring NaNs) and subtracts it. Commonly used in EEG for baseline correction (subtracting mean of pre-stimulus interval) or centering features before covariance calculation.
- Parameters:
dim (str, default='time') – Dimension name to center over (e.g., ‘time’, ‘channel’, ‘obs’).
inplace (bool, default=False) – If True, modifies X in-place to save memory. Returns self.
- Returns:
Container with centered data.
- Return type:
Examples
>>> # Baseline correction over time >>> container.center(dim='time')
- zscore(dim: str = 'time', eps: float = 1e-08, inplace: bool = False) DataContainer[source]¶
Standardize (Z-score) along a specified dimension.
Computes (X - mean) / std along the given dimension. Robust to NaNs. Useful for normalizing features or standardizing temporal dynamics.
- Parameters:
dim (str) – Dimension to standardize.
eps (float) – Stability epsilon to avoid division by zero.
inplace (bool)
- Return type:
Examples
>>> # Standardize each channel's timecourse >>> container.zscore(dim='time')
- rms_scale(dim: str = 'time', eps: float = 1e-08, inplace: bool = False) DataContainer[source]¶
Scale by Root Mean Square (RMS) amplitude along a dimension.
Divides data by sqrt(mean(X**2)) along the dimension. Preserves relative shape but normalizes energy.
- Parameters:
dim (str) – Dimension to scale.
eps (float) – Stability epsilon.
inplace (bool)
- Return type:
- baseline_correction(dim: str = 'time', inplace: bool = False) DataContainer[source]¶
Alias for center(). Common in EEG.
- aggregate(by: str | numpy.ndarray | List[Any], stats: str | Sequence[str] = 'mean', min_count: int = 1, on_insufficient: str = 'raise') DataContainer[source]¶
Aggregate observations into grouped summaries along the
obsaxis.- Parameters:
by (str or array-like) –
Group definition for the observation axis. - If str: resolve the key from
self.coordsor fromself.ywhen
by == "y".If array-like: explicit group labels aligned with
obs.
stats (str or sequence of str, default="mean") – Aggregation statistic or ordered list of statistics. Supported tokens are
"mean","median","std","var","sem","mad","iqr","min","max","count", and"first". Legacy"obs-*"aliases are accepted and normalized.min_count (int, default=1) – Minimum number of valid observations required per group. A valid observation is one with at least one finite value across the non-observation axes.
on_insufficient ({"raise", "warn", "collect"}, default="raise") – Policy applied when a group has fewer than
min_countvalid observations.
- Returns:
Aggregated container with grouped observations on the
obsaxis. When multiple stats are requested, astatdimension is inserted immediately afterobs.- Return type:
- Raises:
ValueError – If the container has no
obsdimension, grouping is invalid, requested stats are unsupported, ormin_count/on_insufficientare invalid.
- aggregate_groups(by: str | numpy.ndarray | List[Any], groups: Sequence[Dict[str, Any]], min_count: int = 1, on_insufficient: str = 'raise', skip_empty: bool = True) DataContainer[source]¶
Aggregate selected feature groups with different statistics.
This is a thin wrapper around
aggregate()for tabular feature containers. Each group spec selects a subset of feature columns and applies one or more stats to that subset. The outputs are concatenated along thefeaturedimension, and each resulting feature name is prefixed with its stat (for example"mean_band_log_abs_alpha").- Parameters:
by (str or array-like) – Group definition for the observation axis. Passed through to
aggregate().groups (sequence of dict) –
Ordered group specifications. Each group must provide
"stats"and may optionally provide include/exclude selectors:names/exclude_namesprefixes/exclude_prefixessuffixes/exclude_suffixescontains/exclude_containsregex/exclude_regex
If a group provides no include selectors, it starts from all features and then applies exclusions.
min_count (int, default=1) – Minimum number of valid observations required per group. Passed through to
aggregate().on_insufficient ({"raise", "warn", "collect"}, default="raise") – Policy applied when a group has fewer than
min_countvalid observations. Passed through toaggregate().skip_empty (bool, default=True) – If True, silently skip group specs that match no features. If False, raise a
ValueErrorwhen a group matches nothing.
- Returns:
Aggregated container with dims
("obs", "feature")and stat-prefixed feature names.- Return type:
- Raises:
ValueError – If the container lacks a
featuredimension or coord, no groups are provided, a group spec is invalid, multiple groups would emit the same output feature name, or no non-empty grouped outputs are produced.
- class coco_pipe.io.SklearnWrapper(transformer: sklearn.base.BaseEstimator)[source]¶
Bases:
sklearn.base.BaseEstimator,sklearn.base.TransformerMixinGeneric wrapper for ANY scikit-learn transformer (Scaler, PCA, etc.).
This wrapper applies a standard scikit-learn transformer to the .X data matrix of a DataContainer, ensuring that the resulting container has correctly updated data while checking for dimension compatibility.
- Parameters:
transformer (BaseEstimator) – An instantiated scikit-learn transformer (e.g., StandardScaler(), PCA(n_components=10)).
- estimator_¶
The fitted scikit-learn estimator.
- Type:
BaseEstimator
Examples
>>> from sklearn.preprocessing import RobustScaler >>> from coco_pipe.io import DataContainer, SklearnWrapper >>> import numpy as np
>>> # Create formatted data (100 obs, 10 features) >>> X = np.random.randn(100, 10) >>> container = DataContainer(X, dims=('obs', 'feature'))
>>> # Wrap a Scaler >>> scaler = SklearnWrapper(RobustScaler()) >>> scaled_container = scaler.fit_transform(container)
>>> # Metadata is preserved >>> scaled_container.dims == container.dims True
- transformer¶
- estimator_ = None¶
- fit(container: coco_pipe.io.structures.DataContainer, y=None)[source]¶
- transform(container: coco_pipe.io.structures.DataContainer) coco_pipe.io.structures.DataContainer[source]¶
- fit_transform(container: coco_pipe.io.structures.DataContainer, y=None)[source]¶
Fit to data, then transform it.
Fits transformer to X and y with optional parameters fit_params and returns a transformed version of X.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input samples.
y (array-like of shape (n_samples,) or (n_samples, n_outputs), default=None) – Target values (None for unsupervised transformations).
**fit_params (dict) – Additional fit parameters.
- Returns:
X_new – Transformed array.
- Return type:
ndarray array of shape (n_samples, n_features_new)
- inverse_transform(container: coco_pipe.io.structures.DataContainer) coco_pipe.io.structures.DataContainer[source]¶
- class coco_pipe.io.SpatialWhitener(method: str = 'pca', n_components: int | float | None = None)[source]¶
Bases:
sklearn.base.BaseEstimator,sklearn.base.TransformerMixinM/EEG Spatial Whitening using Covariance Decorrelation.
This transformer removes spatial correlations between channels, effectively transforming the noise covariance matrix towards the identity matrix. It supports standard PCA, ZCA (Zero-phase Component Analysis which preserves topography), and robust shrinkage covariance estimation (OAS).
It requires a dimension named ‘channel’ in the input DataContainer. The operation is performed spatially: \(X_{white} = X \cdot W^T\)
- Parameters:
method ({'pca', 'zca', 'shrinkage'}, default='pca') –
Shape of the transformation: - ‘pca’: Principal Component Analysis. Rotates data to principal axes and
scales to unit variance.
’zca’: Zero-phase Component Analysis. Rotates, scales, and rotates back. Preserves spatial topography (sensors stay in place).
’shrinkage’: Uses Oracle Approximating Shrinkage (OAS) for robust covariance estimation in high dimensions.
n_components (int or float, optional) – Number of components to keep (only for ‘pca’/’zca’ methods). If None, all matches are kept.
- whitener_¶
The estimated whitening matrix (W). Shape (n_components, n_channels).
- Type:
np.ndarray
- mean_¶
Per-channel mean vector.
- Type:
np.ndarray
- inverse_whitener_¶
The inverse matrix used to project back to sensor space.
- Type:
np.ndarray
Examples
>>> # Whitening EEG epochs (100 epochs, 64 channels, 500 times) >>> container = DataContainer( ... np.random.randn(100, 64, 500), dims=('obs', 'channel', 'time') ... )
>>> # Use Shrinkage for robust covariance >>> whitener = SpatialWhitener(method='shrinkage') >>> white_data = whitener.fit_transform(container)
>>> # Project back to sensor space for plotting >>> sensor_data = whitener.inverse_transform(white_data)
- method = 'pca'¶
- n_components = None¶
- whitener_ = None¶
- mean_ = None¶
- inverse_whitener_ = None¶
- fit(container: coco_pipe.io.structures.DataContainer, y=None)[source]¶
- transform(container: coco_pipe.io.structures.DataContainer) coco_pipe.io.structures.DataContainer[source]¶
- fit_transform(container: coco_pipe.io.structures.DataContainer, y=None)[source]¶
Fit to data, then transform it.
Fits transformer to X and y with optional parameters fit_params and returns a transformed version of X.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input samples.
y (array-like of shape (n_samples,) or (n_samples, n_outputs), default=None) – Target values (None for unsupervised transformations).
**fit_params (dict) – Additional fit parameters.
- Returns:
X_new – Transformed array.
- Return type:
ndarray array of shape (n_samples, n_features_new)
- inverse_transform(container: coco_pipe.io.structures.DataContainer) coco_pipe.io.structures.DataContainer[source]¶
- _apply_linear_op(container: coco_pipe.io.structures.DataContainer, W: numpy.ndarray, mean: numpy.ndarray | None) numpy.ndarray[source]¶