"""
Data Structures
===============
Standardized containers for passing data between Datasets, Preprocessing, and main
modules.
This module provides the `DataContainer`, an N-dimensional tensor wrapper that manages
metadata, coordinates, and labels alongside the raw data matrix. It serves as the
common currency for the entire pipeline.
Examples
--------
>>> import numpy as np
>>> from coco_pipe.io import DataContainer
# 1. Creating a container for EEG Epochs (N_epochs, N_channels, N_time)
>>> X = np.random.randn(10, 64, 500)
>>> container = DataContainer(
... X=X,
... dims=('obs', 'channel', 'time'),
... coords={
... 'channel': ['Fz', 'Cz', 'Pz'], # ... etc
... 'time': np.linspace(0, 1.0, 500)
... },
... y=np.random.randint(0, 2, 10),
... ids=[f'sub-01_trial-{i}' for i in range(10)]
... )
# 2. Creating a container for simple Tabular Features (N_subjects, N_features)
>>> X_tab = np.random.randn(20, 5)
>>> container_tab = DataContainer(
... X=X_tab,
... dims=('obs', 'feature'),
... coords={'feature': ['age', 'IQ', 'response_time', 'power_alpha', 'power_beta']}
... )
"""
import difflib
import fnmatch
import itertools
import logging
import re
import warnings
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
from .utils import make_strata, sample_indices
logger = logging.getLogger(__name__)
[docs]
@dataclass
class DataContainer:
"""
Generic container for N-dimensional neurophysiological data.
Acts as a lightweight labelled array (like xarray but simpler), managing
dimensions, coordinates, and associated target labels (y) and IDs.
Attributes
----------
X : np.ndarray
The primary data tensor. Shape must match `dims`.
dims : Tuple[str, ...]
Labels for each dimension of X.
Examples: ('obs', 'feature'), ('obs', 'channel', 'time').
Note: The 'obs' dimension is special and typically represents independent
samples.
coords : Dict[str, Union[List, np.ndarray]]
Coordinates/Labels for dimensions. Keys must be in `dims`.
Values must match the length of the corresponding dimension in X.
y : Optional[np.ndarray], optional
Target labels corresponding to the 'obs' dimension.
Used for supervised learning or coloring plots.
ids : Optional[np.ndarray], optional
Identifiers for observations (e.g., subject IDs, trial names).
Should correspond to 'obs' dim in coords if provided.
Kept separate from coords for convenient tracking.
meta : Dict[str, Any]
Arbitrary metadata (sfreq, units, source path, etc).
Examples
--------
Accessing data:
>>> container.X.shape
(10, 64, 500)
Accessing coordinates:
>>> container.coords['channel'][:3]
['Fz', 'Cz', 'Pz']
"""
X: np.ndarray
dims: Tuple[str, ...]
coords: Dict[str, Union[List, np.ndarray, Sequence]] = field(default_factory=dict)
y: Optional[np.ndarray] = None
ids: Optional[np.ndarray] = None
meta: Dict[str, Any] = field(default_factory=dict)
[docs]
def __post_init__(self):
# Validation
if self.X.ndim != len(self.dims):
raise ValueError(
f"Shape mismatch: X has {self.X.ndim} dims {self.X.shape}, "
f"but `dims` has {len(self.dims)} labels {self.dims}."
)
# Check coords lengths
for dim, labels in self.coords.items():
if dim in self.dims:
axis = self.dims.index(dim)
if self.X.shape[axis] != len(labels):
logger.debug(
f"Coord '{dim}' length ({len(labels)}) does not match "
f"X dimension {axis} ({self.X.shape[axis]})."
)
@property
def shape(self) -> Tuple[int, ...]:
return self.X.shape
[docs]
def save(self, path: Union[str, Any]) -> None:
"""
Save the DataContainer to disk using joblib.
Parameters
----------
path : str or Path
Destination file path.
"""
from pathlib import Path
import joblib
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(self, p)
logger.info(f"DataContainer saved to {p}")
[docs]
@classmethod
def load(cls, path: Union[str, Any]) -> "DataContainer":
"""
Load a DataContainer from disk.
Parameters
----------
path : str or Path
Source file path.
Returns
-------
DataContainer
"""
from pathlib import Path
import joblib
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"File not found: {p}")
obj = joblib.load(p)
if not isinstance(obj, cls):
raise TypeError(f"Loaded object is {type(obj)}, expected {cls.__name__}")
return obj
[docs]
def __repr__(self) -> str:
dim_strs = [f"{d}={s}" for d, s in zip(self.dims, self.X.shape)]
return (
f"<DataContainer [{' x '.join(dim_strs)}], "
f"coords={list(self.coords.keys())}>"
)
[docs]
def obs_table(
self,
include_ids: bool = False,
id_col: str = "obs_id",
include_y: bool = False,
y_col: str = "y",
include_obs_coord: bool = False,
) -> pd.DataFrame:
"""
Return one-dimensional coordinates aligned to the observation axis.
This helper is useful when exporting a row-wise table from a container.
It only materializes metadata that can map cleanly to one row per
observation, skipping coordinates that belong to other axes such as
``channel``, ``time``, ``feature``, or ``stat``.
Parameters
----------
include_ids : bool, default=False
If True, include ``self.ids`` as the first column.
id_col : str, default="obs_id"
Column name used when exporting ``self.ids``.
include_y : bool, default=False
If True, include ``self.y`` as a column when present.
y_col : str, default="y"
Column name used when exporting ``self.y``.
include_obs_coord : bool, default=False
If True, include ``coords["obs"]`` when present.
Returns
-------
pandas.DataFrame
DataFrame containing only one-dimensional observation-aligned
metadata columns.
Raises
------
ValueError
If the container has no ``obs`` dimension, or if ``include_ids`` is
requested when ``self.ids`` is missing.
"""
if "obs" not in self.dims:
raise ValueError("Observation metadata export requires an 'obs' dimension.")
obs_len = self.X.shape[self.dims.index("obs")]
data: Dict[str, np.ndarray] = {}
if include_ids:
if self.ids is None:
raise ValueError("`include_ids=True` requires `DataContainer.ids`.")
ids = np.asarray(self.ids, dtype=object)
if ids.ndim != 1 or len(ids) != obs_len:
raise ValueError("`DataContainer.ids` must be 1D and aligned to 'obs'.")
data[id_col] = ids
if include_obs_coord and "obs" in self.coords:
obs_coord = np.asarray(self.coords["obs"], dtype=object)
if obs_coord.ndim == 1 and len(obs_coord) == obs_len:
data["obs"] = obs_coord
for key, values in self.coords.items():
if key == "obs":
continue
arr = np.asarray(values, dtype=object)
if arr.ndim == 1 and len(arr) == obs_len:
data[key] = arr
if include_y and self.y is not None:
y = np.asarray(self.y, dtype=object)
if y.ndim != 1 or len(y) != obs_len:
raise ValueError("`DataContainer.y` must be 1D and aligned to 'obs'.")
if y_col not in data:
data[y_col] = y
return pd.DataFrame(data)
[docs]
def isel(self, **indexers) -> "DataContainer":
"""
Select data by integer indices on specified dimensions.
This method is the integer-index equivalent of `select`. It operates
directly on the dimensions of the data tensor `X`. It is robust and
handles metadata splitting/alignment automatically.
Parameters
----------
**indexers : dict
Key: Dimension name (e.g., 'obs', 'channel', 'time').
Value: Integer indices to select. Can be:
- List or numpy array of integers: [0, 1, 5]
- Slice object: slice(0, 10)
- Single integer: 0
Note: If you provide a list of indices with repeats (e.g., [0, 0, 1]),
the output will be oversampled accordingly.
Returns
-------
DataContainer
A new DataContainer instance with the sliced data and coordinates.
Examples
--------
>>> # Select first 10 observations
>>> subset = container.isel(obs=slice(0, 10))
>>> # Select specific channels by index
>>> subset = container.isel(channel=[0, 5, 12])
>>> # Select time range by index
>>> subset = container.isel(time=slice(100, 200))
>>> # Bootstrap/Resample (Select index 0 five times)
>>> bootstrap = container.isel(obs=[0, 0, 0, 0, 0])
"""
if not indexers:
return self
slices = [slice(None)] * self.X.ndim
self.X.shape[0] if "obs" in self.dims else 0
obs_dim_idx = self.dims.index("obs") if "obs" in self.dims else -1
new_coords = self.coords.copy()
# Apply slicers
for dim_name, indices in indexers.items():
if dim_name not in self.dims:
logger.warning(
f"Dimension {dim_name} not in {self.dims}, skipping isel."
)
continue
d_idx = self.dims.index(dim_name)
# Normalize int to list to preserve dimension
if isinstance(indices, int):
indices = [indices]
# Update specific dim slice
slices[d_idx] = indices
# Handle metadata alignment
dim_len_old = self.X.shape[d_idx]
# We must be careful not to update coords twice if orthogonal slicing
# But here we just prepare new_coords values
for k, v in self.coords.items():
if dim_name in self.dims and k == dim_name:
# This IS the coordinate for this dimension
new_coords[k] = np.array(v)[indices]
elif (
len(v) == dim_len_old and k not in self.dims
): # Don't overwrite other dim labels
# Heuristic match
new_coords[k] = np.array(v)[indices]
# Orthogonal Application
try:
new_X = self.X
for axis, sl in enumerate(slices):
if isinstance(sl, slice) and sl == slice(None):
continue
indexer = [slice(None)] * new_X.ndim
indexer[axis] = sl
new_X = new_X[tuple(indexer)]
except Exception as e:
logger.error(f"Slicing failed with slices {slices}: {e}")
raise
new_y = self.y
new_ids = self.ids
if obs_dim_idx != -1:
obs_sl = slices[obs_dim_idx]
# Slicing y/ids if they exist
if not (isinstance(obs_sl, slice) and obs_sl == slice(None)):
if self.y is not None:
new_y = self.y[obs_sl]
if self.ids is not None:
new_ids = self.ids[obs_sl] # ids is numpy array
return replace(
self,
X=new_X,
y=new_y,
ids=new_ids,
coords=new_coords,
meta=deepcopy(self.meta),
)
[docs]
def balance(
self,
target: str = "y",
strategy: str = "undersample",
covariates: Optional[List[str]] = None,
random_state: int = 42,
**kwargs,
) -> "DataContainer":
"""
Balance the dataset classes using undersampling or oversampling.
This method adjusts the number of observations (rows) in the container
so that class counts in `target` are equalized. It supports simple
random sampling and stratified sampling based on covariates.
Parameters
----------
target : str, default='y'
Name of the target variable.
- 'y': Uses `self.y`.
- Any other string: Looks for the variable in `self.coords`.
strategy : {'undersample', 'oversample', 'auto'}, default='undersample'
- 'undersample': Downsample majority classes to match the minority
class count.
- 'oversample': Upsample minority classes (with replacement) to match
the majority class.
- 'auto': Heuristic choice. Uses undersampling if total size remains >
50% of original, else oversampling.
covariates : list of str, optional
List of covariate names in `self.coords` to preserve distribution of.
If provided, the balancing is performed *within* strata defined by these
covariates.
random_state : int, default=42
Seed for the random number generator.
Change this value to produce different random subsets (e.g., for bagging).
**kwargs : dict
Additional arguments passed to internal logic:
- n_bins (int): Number of bins for continuous covariates (default 5).
- binning (str): 'quantile' (default) or 'uniform' binning.
- prefer_clean_rows (bool): If True, weighs sampling to prefer rows
with fewer NaNs/artifacts.
Returns
-------
DataContainer
A new DataContainer instance with balanced classes.
Examples
--------
>>> # 1. Simple Undersampling of 'y'
>>> balanced = container.balance(strategy='undersample')
>>> # 2. Balance based on a metadata column 'condition'
>>> balanced = container.balance(target='condition')
>>> # 3. Stratified Balancing (Balance 'y' while preserving 'sex' and 'age'
>>> # ratios)
>>> balanced = container.balance(target='y', covariates=['sex', 'age'])
>>> # 4. Iterative Bootstrapping (Different seeds)
>>> for seed in [1, 2, 3]:
... subset = container.balance(strategy='undersample', random_state=seed)
... # process subset...
"""
# 1. Construct temporary DataFrame for Metadata
data_dict = {}
# Target
if target == "y":
if self.y is None:
raise ValueError("Container has no y data.")
data_dict["y"] = self.y
elif target in self.coords:
data_dict[target] = self.coords[target]
else:
raise ValueError(f"Target '{target}' not found in y or coords.")
# Covariates
if covariates:
for c in covariates:
if c not in self.coords:
raise ValueError(f"Covariate '{c}' not found in coords.")
data_dict[c] = self.coords[c]
df_meta = pd.DataFrame(data_dict)
# Index is implicitly RangeIndex (0..N-1), which matches DataContainer
# positional indices
# 2. Get Indices
if target not in df_meta.columns:
raise ValueError(f"Target '{target}' missing")
strategy = strategy.lower()
counts = df_meta[target].value_counts()
min_c, max_c = int(counts.min()), int(counts.max())
if strategy == "auto":
strategy = (
"undersample"
if (min_c * counts.size) >= len(df_meta) / 2
else "oversample"
)
exclude = [target] + (covariates or [])
rng = np.random.default_rng(random_state)
indices_val = None
# 1. Simple Case (No Covariates)
if not covariates:
size = {
c: min_c if strategy == "undersample" else max_c for c in counts.index
}
indices_val = sample_indices(
df_meta,
target,
size,
rng,
strategy != "undersample",
kwargs.get("prefer_clean_rows", False),
exclude,
)
else:
# 2. Covariate Balancing (Stratified)
# kwargs.get('n_bins', 5), kwargs.get('binning', 'quantile')
strata_s = make_strata(
df_meta,
covariates,
kwargs.get("n_bins", 5),
kwargs.get("binning", "quantile"),
)
tmp = df_meta.assign(__strata__=strata_s)
indices_parts = []
for _, g in tmp.groupby("__strata__"):
sc = g[target].value_counts()
if len(sc) <= 1:
# Cannot balance within a single-class stratum
if strategy != "undersample":
M = int(counts.max())
size_map = {c: M for c in sc.index}
indices_parts.append(
sample_indices(
g,
target,
size_map,
rng,
True,
kwargs.get("prefer_clean_rows", False),
exclude,
)
)
continue
# Balance locally within stratum
sz = {
c: int(sc.min()) if strategy == "undersample" else int(sc.max())
for c in sc.index
}
indices_parts.append(
sample_indices(
g,
target,
sz,
rng,
strategy != "undersample",
kwargs.get("prefer_clean_rows", False),
exclude,
)
)
if not indices_parts:
# Fallback: global balance
sz = {
c: min_c if strategy == "undersample" else max_c
for c in counts.index
}
indices_val = sample_indices(
df_meta,
target,
sz,
rng,
strategy != "undersample",
kwargs.get("prefer_clean_rows", False),
exclude,
)
else:
combined = pd.concat([pd.Series(i) for i in indices_parts]).sample(
frac=1.0, random_state=rng
)
indices_val = pd.Index(combined.values)
# 3. Apply Indexing
# indices is a pandas Index (Int64 or similar). Convert to values for safe
# numpy indexing.
return self.isel(obs=indices_val.values)
[docs]
def select(
self, ignore_case: bool = False, fuzzy: bool = False, **selections
) -> "DataContainer":
"""
Select data subsets based on coordinates, ids, or y.
This method supports exact matching, wildcard matching, operator-based
filtering, and custom callable filters.
Parameters
----------
ignore_case : bool, default=False
If True, string matching is case-insensitive (e.g., 'fz' matches 'Fz').
fuzzy : bool, default=False
If True, uses `difflib` to find closest matches for string queries
(e.g., 'Alpha' matches 'alpha'). Useful for handling typos.
**selections : dict
Key is the dimension name (or special keys 'y', 'ids').
Value is the query. Supported query types:
1. **List/Array (Exact or Wildcard)**:
Matches values present in the list. Strings can use shell-style
wildcards ('*', '?').
2. **Dictionary (Operator Queries)**:
Filters numerical or string values using operators.
Keys: '>', '<', '>=', '<=', '==', '!=', 'in'.
3. **Callable**:
A function taking the coordinate array and returning a boolean mask.
Returns
-------
DataContainer
A new DataContainer instance containing the selected subset.
Examples
--------
>>> # 1. Exact Selection (Sensors)
>>> sub = container.select(channel=['Fz', 'Cz'])
>>> # 2. Wildcard Selection (All Alpha features)
>>> sub = container.select(feature='*alpha*')
>>> # 3. Range Selection (Time)
>>> sub = container.select(time={'>=': 0.1, '<': 0.5})
>>> # 4. Case-Insensitive Fuzzy Matching
>>> sub = container.select(channel=['fz'], ignore_case=True)
>>> # 5. Filter by Target (y)
>>> sub = container.select(y=['Patient'])
>>> # 6. Complex Logic (Subjects 1-5 via Operator)
>>> sub = container.select(subject_id={'>=': 1, '<=': 5})
>>> # 7. Stratified Selection (First 2 epochs per subject via Callable)
>>> def first_n(ids, n=2):
... # ... logic ...
... return mask
>>> sub = container.select(ids=first_n)
"""
slices = [slice(None)] * self.X.ndim
self.coords.copy()
obs_dim_idx = self.dims.index("obs") if "obs" in self.dims else -1
for key, query in selections.items():
# Determine target array and axis
target_arr = None
axis = -1
target_dim = None
if key == "y" and self.y is not None:
target_arr = self.y
target_dim = "obs"
elif key == "ids" and self.ids is not None:
target_arr = self.ids
target_dim = "obs"
elif key in self.dims:
target_dim = key
target_arr = np.array(self.coords.get(key, []))
elif key in self.coords:
target_arr = np.array(self.coords[key])
matched_dim = None
for d_i, d_name in enumerate(self.dims):
if self.X.shape[d_i] == len(target_arr):
matched_dim = d_name
break
if matched_dim:
target_dim = matched_dim
else:
logger.warning(
f"Aux coordinate '{key}' len={len(target_arr)} matches no "
f"dimension shape {self.X.shape}. Ignoring selection."
)
continue
if target_arr is None:
if key in self.dims and key not in self.coords:
logger.warning(
f"Selection on dim '{key}' ignored (no coordinates)."
)
continue
logger.warning(
f"Selection key '{key}' not found in logical dims, y, ids, or "
f"aux coords. Ignoring."
)
continue
if target_dim in self.dims:
axis = self.dims.index(target_dim)
if len(target_arr) == 0:
logger.warning(
f"Target array for '{key}' is empty. Skipping selection."
)
continue
mask = np.zeros(len(target_arr), dtype=bool)
# Handle different query types
if callable(query):
mask = query(target_arr)
if not isinstance(mask, (np.ndarray, list)) or len(mask) != len(
target_arr
):
raise ValueError(
f"Callable query for '{key}' must return boolean array of "
f"shape {target_arr.shape}."
)
mask = np.array(mask, dtype=bool)
elif isinstance(query, dict):
# Operator mode: {'>': 5}
mask = np.ones(len(target_arr), dtype=bool)
ops = {
">": lambda a, b: a > b,
"<": lambda a, b: a < b,
">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b,
"==": lambda a, b: a == b,
"!=": lambda a, b: a != b,
"in": lambda a, b: np.isin(a, b),
}
for op, val in query.items():
if op not in ops:
raise ValueError(
f"Unknown operator '{op}'. Supported: {list(ops.keys())}"
)
mask &= ops[op](target_arr, val)
else:
# Standard List/Value Match (Exact / Wildcard / Fuzzy)
query_arr = np.array(query, ndmin=1)
# Pre-processing for String Matching
is_str_target = target_arr.dtype.kind in ("U", "S", "O")
if is_str_target:
# String Matching
target_list = target_arr.tolist()
query_list = query_arr.tolist()
if ignore_case:
target_lookup = [str(x).lower() for x in target_list]
query_list_proc = [str(q).lower() for q in query_list]
else:
target_lookup = [str(x) for x in target_list]
query_list_proc = [str(q) for q in query_list]
# 1. Fuzzy Choice?
final_queries = set()
if fuzzy:
for q in query_list_proc:
matches = difflib.get_close_matches(
q, target_lookup, n=3, cutoff=0.6
)
final_queries.update(matches)
if not matches:
logger.warning(f"No fuzzy match found for '{q}'.")
else:
final_queries = set(query_list_proc)
# 2. Pattern vs Exact
patterns = [q for q in final_queries if "*" in q or "?" in q]
exacts = [q for q in final_queries if q not in patterns]
target_lookup_arr = np.array(target_lookup)
if exacts:
mask |= np.isin(target_lookup_arr, exacts)
# Wildcards
for pat in patterns:
regex = re.compile(fnmatch.translate(pat))
matches = [bool(regex.match(x)) for x in target_lookup]
mask |= np.array(matches)
else:
# Numeric Exact Match
mask = np.isin(target_arr, query_arr)
indices = np.where(mask)[0]
if len(indices) == 0:
raise ValueError(
f"Selection for '{key}' resulted in empty set. Query: {query}"
)
# Apply Slicing Logic (Intersect with current)
existing_slice = slices[axis]
if isinstance(existing_slice, slice) and existing_slice == slice(None):
slices[axis] = indices
else:
common = np.intersect1d(existing_slice, indices)
if len(common) == 0:
raise ValueError(
f"Conflicting selections for axis {axis} ({key}) resulted "
f"in empty set."
)
slices[axis] = common
# Final Application (Orthogonal Indexing)
# Apply slices sequentially to avoid broadcasting issues
X_new = self.X
for axis, sl in enumerate(slices):
if isinstance(sl, slice) and sl == slice(None):
continue
indexer = [slice(None)] * X_new.ndim
indexer[axis] = sl
X_new = X_new[tuple(indexer)]
# Update coordinates to match new X
final_coords = {}
for coord_name, labels in self.coords.items():
# Check if coordinate aligns with any dimension
aligned_dim_idx = -1
if coord_name in self.dims:
aligned_dim_idx = self.dims.index(coord_name)
else:
# Heuristic: Find matching dimension length
# Note: Ambiguity if multiple dims have same length.
# We prioritize 'obs' if length matches, then others.
# Check obs first
if obs_dim_idx != -1 and len(labels) == self.X.shape[obs_dim_idx]:
aligned_dim_idx = obs_dim_idx
else:
for d_i, d_len in enumerate(self.X.shape):
if len(labels) == d_len:
aligned_dim_idx = d_i
break
if aligned_dim_idx != -1:
sl = slices[aligned_dim_idx]
if isinstance(sl, slice):
final_coords[coord_name] = np.array(labels)[sl]
else:
final_coords[coord_name] = np.array(labels)[sl]
else:
# Coordinate didn't match any dimension? Drop it to be safe, or keep?
# If validation passes, this shouldn't happen unless corrupt.
pass
# Update y/ids
y_new = self.y
ids_new = self.ids
# If obs was sliced (even indirectly via y/ids)
if obs_dim_idx != -1:
obs_sl = slices[obs_dim_idx]
if not isinstance(obs_sl, slice) or obs_sl != slice(None):
if y_new is not None:
y_new = y_new[obs_sl]
if ids_new is not None:
ids_new = ids_new[obs_sl]
return replace(self, X=X_new, coords=final_coords, y=y_new, ids=ids_new)
[docs]
def flatten(self, preserve: Union[str, List[str]] = "obs") -> "DataContainer":
"""
Flatten dimensions NOT in `preserve` into a single 'feature' dimension.
This is useful for preparing N-dimensional data for standard 2D machine
learning algorithms (scikit-learn). It automatically generates composite
feature names (e.g., 'Fz_0.1s') for tracking.
Parameters
----------
preserve : str or List[str], default='obs'
Dimensions to keep. All other dimensions will be collapsed into a
single 'feature' dimension.
- 'obs': Result shape (N_obs, N_features). Standard specifiction.
- ['obs', 'time']: Result shape (N_obs, N_time, N_features).
Useful for time-resolved decoding distributions.
Returns
-------
DataContainer
A new DataContainer with reshaped X and generated 'feature' coordinates.
Examples
--------
>>> # Flatten (10, 64, 500) -> (10, 32000)
>>> flat = container.flatten(preserve='obs')
>>> flat.shape
(10, 32000)
>>> flat.coords['feature'][0]
'Fz_0.0'
>>> # Flatten spatial only, keep time (10, 64, 500) -> (10, 500, 64)
>>> time_resolved = container.flatten(preserve=['obs', 'time'])
"""
if isinstance(preserve, str):
preserve = [preserve]
# Verify dims exist
for p in preserve:
if p not in self.dims:
raise ValueError(f"Dimension '{p}' not found in {self.dims}.")
# Dimensions to flatten
to_flatten = [d for d in self.dims if d not in preserve]
if not to_flatten:
return self # Nothing to do
# Move preserved dims to front
# Current indices
permute_order = [self.dims.index(p) for p in preserve] + [
self.dims.index(d) for d in to_flatten
]
X_trans = np.transpose(self.X, axes=permute_order)
# New shape: (*preserved_shapes, product(flattened_shapes))
preserved_shape = [self.X.shape[self.dims.index(p)] for p in preserve]
flattened_len = int(
np.prod([self.X.shape[self.dims.index(d)] for d in to_flatten])
)
new_shape = tuple(preserved_shape) + (flattened_len,)
X_flat = X_trans.reshape(new_shape)
# New Dims
new_dims = tuple(preserve) + ("feature",)
# New Coords
# We keep coords for preserved dimensions.
new_coords = {k: v for k, v in self.coords.items() if k in preserve}
if "obs" in preserve and "obs" in self.dims:
n_obs = self.X.shape[self.dims.index("obs")]
for k, v in self.coords.items():
if k not in self.dims and len(v) == n_obs:
new_coords[k] = v
flat_coords_list = []
for d in to_flatten:
c = self.coords.get(d)
if c is not None:
flat_coords_list.append(c)
else:
flat_coords_list.append(np.arange(self.X.shape[self.dims.index(d)]))
# Create Cartesian product
if flat_coords_list:
# Check size first to avoid memory explosion?
total_size = np.prod([len(x) for x in flat_coords_list])
if total_size < 200000: # Limit to ~200k features strings
combo_labels = [
"_".join(map(str, x)) for x in itertools.product(*flat_coords_list)
]
new_coords["feature"] = combo_labels
return replace(
self,
X=X_flat,
dims=new_dims,
coords=new_coords,
meta={**self.meta, "flattened_from": self.dims},
)
[docs]
def stack(self, dims: Sequence[str], new_dim: str = "obs") -> "DataContainer":
"""
Stack multiple dimensions into a single new dimension.
This reshapes N-dimensional data into (N-K) dimensions by combining
specified dimensions. It is useful for transforming spatiotemporal data
(Trials, Channels, Time) -> (Trials*Time, Channels) for trajectory analysis.
Parameters
----------
dims : sequence of str
Dimensions to stack. The order determines the nesting (slowest to fastest).
e.g., ('obs', 'time') means 'obs' changes slowly, 'time' cycles fast.
new_dim : str, default='obs'
Name of the resulting stacked dimension.
Returns
-------
DataContainer
New container with stacked dimension. Metadata (coords/ids) are
expanded/tiled to match the new shape.
Examples
--------
>>> # Stack time into observations:
>>> # (10 obs, 64 ch, 500 time) -> (5000 obs, 64 ch)
>>> stacked = container.stack(dims=('obs', 'time'), new_dim='obs')
>>> stacked.shape
(5000, 64)
"""
for d in dims:
if d not in self.dims:
raise ValueError(f"Dimension '{d}' not found in {self.dims}")
# 1. Permute
preserved = [d for d in self.dims if d not in dims]
stack_indices = [self.dims.index(d) for d in dims]
preserve_indices = [self.dims.index(d) for d in preserved]
permute_order = stack_indices + preserve_indices
X_trans = np.transpose(self.X, axes=permute_order)
# 2. Reshape
stack_shape = [self.X.shape[i] for i in stack_indices]
prod_len = int(np.prod(stack_shape))
preserved_shape = [self.X.shape[i] for i in preserve_indices]
new_shape = (prod_len,) + tuple(preserved_shape)
X_new = X_trans.reshape(new_shape)
# 3. Handle Metadata Expansion (if new_dim is 'obs' or overrides it)
new_ids = None
new_y = None
new_coords = self.coords.copy()
# Drop old coords keys that are being stacked
for d in dims:
if d in new_coords:
del new_coords[d]
# Logic for IDs/Y expansion if 'obs' is involved
if "obs" in dims and new_dim == "obs":
obs_idx = dims.index("obs")
n_obs = self.X.shape[self.dims.index("obs")]
# Repeats (inner) and Tiles (outer) logic
# product(dims after obs) -> repeats
# product(dims before obs) -> tiles
n_repeats = int(
np.prod([self.X.shape[self.dims.index(d)] for d in dims[obs_idx + 1 :]])
)
n_tiles = int(
np.prod([self.X.shape[self.dims.index(d)] for d in dims[:obs_idx]])
)
# Expand Y
if self.y is not None:
new_y = np.tile(np.repeat(self.y, n_repeats), n_tiles)
for k, v in self.coords.items():
if k not in self.dims and len(v) == n_obs:
new_coords[k] = np.tile(np.repeat(np.array(v), n_repeats), n_tiles)
# Expand IDs
if self.ids is not None:
# We want composite IDs: "sub-0_t-0", "sub-0_t-1"
# Construct MultiIndex details
idx_components = []
for d in dims:
if d == "obs":
idx_components.append(self.ids)
else:
# Use coordinate labels if available, else range
c = self.coords.get(d)
if c is None:
c = np.arange(self.X.shape[self.dims.index(d)])
idx_components.append(c)
# Cartesian Product
# Use pandas for robust string joining
mi = pd.MultiIndex.from_product(idx_components, names=dims)
new_ids = (
mi.to_frame(index=False).astype(str).agg("_".join, axis=1).values
)
new_dims_final = (new_dim,) + tuple(preserved)
return replace(
self,
X=X_new,
dims=new_dims_final,
ids=new_ids,
y=new_y,
coords=new_coords,
meta={
**self.meta,
"stacked_from": dims,
"stacked_shapes": tuple(stack_shape),
},
)
[docs]
def unstack(self, dim: str) -> "DataContainer":
"""
Unstack a dimension into multiple dimensions.
Inverse operation of `stack`. Reshapes the data tensor by splitting one
dimension into multiple using metadata stored during the `stack` operation.
Parameters
----------
dim : str
Dimension to unstack (e.g. 'obs').
Returns
-------
DataContainer
New container with unstacked dimensions.
Raises
------
ValueError
If the container was not previously stacked (missing metadata).
Examples
--------
>>> # Stack 'trials' and 'time' -> 'obs'
>>> stacked = container.stack(('trials', 'time'), new_dim='obs')
>>> # Unstack 'obs' -> ('trials', 'time') (automatically inferred)
>>> unstacked = stacked.unstack('obs')
"""
if dim not in self.dims:
raise ValueError(f"Dimension '{dim}' not found in {self.dims}")
# Strict Metadata Check
if "stacked_from" not in self.meta or "stacked_shapes" not in self.meta:
raise ValueError(
"Cannot unstack: Metadata 'stacked_from' or 'stacked_shapes' not found."
"Ensure data was processed with .stack() or metadata is preserved."
)
new_dims = self.meta["stacked_from"]
new_shapes = self.meta["stacked_shapes"]
dim_idx = self.dims.index(dim)
current_len = self.X.shape[dim_idx]
target_len = int(np.prod(new_shapes))
if target_len != current_len:
raise ValueError(
f"Shape mismatch: {dim} has length {current_len}, "
f"but product of new_shapes {new_shapes} is {target_len}"
)
# 1. Reshape: Move target dim to front, reshape, then permute back
# Move 'dim' to axis 0: (dim, ...)
X_moved = np.moveaxis(self.X, dim_idx, 0)
# Reshape to (new_d1, new_d2, ..., other_dims...)
X_reshaped = X_moved.reshape(*new_shapes, *X_moved.shape[1:])
# Permute to insert new dimensions at original position
# new dims are at [0...k-1]. We want them at [dim_idx...dim_idx+k-1]
k = len(new_dims)
# Construct permutation:
# [k...k+dim_idx-1] + [0...k-1] + [k+dim_idx...]
# axes before dim + new axes + axes after
perm = (
list(range(k, k + dim_idx))
+ list(range(k))
+ list(range(k + dim_idx, X_reshaped.ndim))
)
X_final = np.transpose(X_reshaped, perm)
# 2. Update Metadata
final_dims = []
for d in self.dims:
if d == dim:
final_dims.extend(new_dims)
else:
final_dims.append(d)
new_coords = {k: v for k, v in self.coords.items() if k != dim}
# Drop y/ids if they matched the unstacked dimension length
new_y = self.y if (self.y is None or len(self.y) != current_len) else None
new_ids = (
self.ids if (self.ids is None or len(self.ids) != current_len) else None
)
return replace(
self,
X=X_final,
dims=tuple(final_dims),
y=new_y,
ids=new_ids,
coords=new_coords,
meta={**self.meta, "unstacked_from": dim},
)
[docs]
def center(self, dim: str = "time", inplace: bool = False) -> "DataContainer":
"""
Remove mean along a specified dimension (Centering/Baseline Correction).
This operation computes the mean along `dim` (ignoring NaNs) and subtracts it.
Commonly used in EEG for baseline correction (subtracting mean of
pre-stimulus interval) or centering features before covariance calculation.
Parameters
----------
dim : str, default='time'
Dimension name to center over (e.g., 'time', 'channel', 'obs').
inplace : bool, default=False
If True, modifies X in-place to save memory.
Returns self.
Returns
-------
DataContainer
Container with centered data.
Examples
--------
>>> # Baseline correction over time
>>> container.center(dim='time')
"""
if dim not in self.dims:
raise ValueError(f"Dimension '{dim}' not found in {self.dims}")
axis = self.dims.index(dim)
X = self.X if inplace else self.X.copy()
mean = np.nanmean(X, axis=axis, keepdims=True)
X -= mean
if inplace:
return self
else:
return replace(self, X=X)
[docs]
def zscore(
self, dim: str = "time", eps: float = 1e-8, inplace: bool = False
) -> "DataContainer":
"""
Standardize (Z-score) along a specified dimension.
Computes `(X - mean) / std` along the given dimension. Robust to NaNs.
Useful for normalizing features or standardizing temporal dynamics.
Parameters
----------
dim : str
Dimension to standardize.
eps : float
Stability epsilon to avoid division by zero.
inplace : bool
Returns
-------
DataContainer
Examples
--------
>>> # Standardize each channel's timecourse
>>> container.zscore(dim='time')
"""
if dim not in self.dims:
raise ValueError(f"Dimension '{dim}' not found in {self.dims}")
axis = self.dims.index(dim)
X = self.X if inplace else self.X.copy()
mean = np.nanmean(X, axis=axis, keepdims=True)
std = np.nanstd(X, axis=axis, keepdims=True)
X -= mean
X /= std + eps
if inplace:
return self
else:
return replace(self, X=X)
[docs]
def rms_scale(
self, dim: str = "time", eps: float = 1e-8, inplace: bool = False
) -> "DataContainer":
"""
Scale by Root Mean Square (RMS) amplitude along a dimension.
Divides data by `sqrt(mean(X**2))` along the dimension.
Preserves relative shape but normalizes energy.
Parameters
----------
dim : str
Dimension to scale.
eps : float
Stability epsilon.
inplace : bool
Returns
-------
DataContainer
"""
if dim not in self.dims:
raise ValueError(f"Dimension '{dim}' not found in {self.dims}")
axis = self.dims.index(dim)
X = self.X if inplace else self.X.copy()
mean_sq = np.nanmean(X**2, axis=axis, keepdims=True)
rms = np.sqrt(mean_sq)
X /= rms + eps
if inplace:
return self
else:
return replace(self, X=X)
[docs]
def baseline_correction(
self, dim: str = "time", inplace: bool = False
) -> "DataContainer":
"""Alias for center(). Common in EEG."""
return self.center(dim=dim, inplace=inplace)
[docs]
def aggregate(
self,
by: Union[str, np.ndarray, List[Any]],
stats: Union[str, Sequence[str]] = "mean",
min_count: int = 1,
on_insufficient: str = "raise",
) -> "DataContainer":
"""
Aggregate observations into grouped summaries along the ``obs`` axis.
Parameters
----------
by : str or array-like
Group definition for the observation axis.
- If str: resolve the key from ``self.coords`` or from ``self.y``
when ``by == "y"``.
- If array-like: explicit group labels aligned with ``obs``.
stats : str or sequence of str, default="mean"
Aggregation statistic or ordered list of statistics. Supported
tokens are ``"mean"``, ``"median"``, ``"std"``, ``"var"``,
``"sem"``, ``"mad"``, ``"iqr"``, ``"min"``, ``"max"``,
``"count"``, and ``"first"``. Legacy ``"obs-*"`` aliases are
accepted and normalized.
min_count : int, default=1
Minimum number of valid observations required per group. A valid
observation is one with at least one finite value across the
non-observation axes.
on_insufficient : {"raise", "warn", "collect"}, default="raise"
Policy applied when a group has fewer than ``min_count`` valid
observations.
Returns
-------
DataContainer
Aggregated container with grouped observations on the ``obs`` axis.
When multiple stats are requested, a ``stat`` dimension is inserted
immediately after ``obs``.
Raises
------
ValueError
If the container has no ``obs`` dimension, grouping is invalid,
requested stats are unsupported, or ``min_count`` /
``on_insufficient`` are invalid.
"""
if "obs" not in self.dims:
raise ValueError("Aggregation requires 'obs' dimension.")
obs_idx = self.dims.index("obs")
n_obs = self.X.shape[obs_idx]
if min_count < 1:
raise ValueError("`min_count` must be at least 1.")
if on_insufficient not in {"raise", "warn", "collect"}:
raise ValueError("`on_insufficient` must be one of: raise, warn, collect.")
stat_aliases = {
"obs-mean": "mean",
"obs-median": "median",
"obs-std": "std",
"obs-var": "var",
"obs-sem": "sem",
"obs-mad": "mad",
"obs-iqr": "iqr",
"obs-min": "min",
"obs-max": "max",
"obs-count": "count",
}
supported_stats = {
"mean",
"median",
"std",
"var",
"sem",
"mad",
"iqr",
"min",
"max",
"count",
"first",
}
if isinstance(stats, str):
stats_out = [stat_aliases.get(stats, stats)]
else:
stats_out = [stat_aliases.get(str(stat), str(stat)) for stat in stats]
if not stats_out:
raise ValueError("`stats` must not be empty.")
invalid_stats = sorted(set(stats_out) - supported_stats)
if invalid_stats:
raise ValueError(
f"Unknown stats: {invalid_stats}. Supported stats are: "
f"{sorted(supported_stats)}"
)
if isinstance(by, str):
if by == "y" and self.y is not None:
groups_raw = self.y
elif by in self.coords:
groups_raw = self.coords[by]
else:
raise ValueError(f"Grouping key '{by}' not found in coords or y.")
else:
groups_raw = by
labels_list = list(groups_raw)
groups = np.empty(len(labels_list), dtype=object)
groups[:] = labels_list
if len(groups) != n_obs:
raise ValueError(
f"Grouping array length {len(groups)} must match obs length {n_obs}."
)
if obs_idx != 0:
X_moved = np.moveaxis(self.X, obs_idx, 0)
else:
X_moved = self.X
other_dims = tuple(dim for dim in self.dims if dim != "obs")
group_positions: Dict[Any, List[int]] = {}
ordered_groups: List[Any] = []
for obs_position, group_id in enumerate(groups.tolist()):
if group_id not in group_positions:
ordered_groups.append(group_id)
group_positions[group_id] = []
group_positions[group_id].append(obs_position)
def _reshape_reduced(values_flat: np.ndarray) -> np.ndarray | np.float64:
if rest_shape:
return np.asarray(values_flat, dtype=np.float64).reshape(rest_shape)
return np.asarray(values_flat, dtype=np.float64)[0]
def _reduce_group(
group_X: np.ndarray,
group_X_flat: np.ndarray,
counts_flat: np.ndarray,
stat: str,
) -> np.ndarray | np.float64:
if stat == "count":
return _reshape_reduced(counts_flat)
if stat == "first":
return np.asarray(group_X[0], dtype=np.float64)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
if stat == "mean":
values_flat = np.nanmean(group_X_flat, axis=0)
elif stat == "median":
values_flat = np.nanmedian(group_X_flat, axis=0)
elif stat == "std":
values_flat = np.nanstd(group_X_flat, axis=0)
elif stat == "var":
values_flat = np.nanvar(group_X_flat, axis=0)
elif stat == "sem":
values_flat = np.nanstd(group_X_flat, axis=0) / np.sqrt(
counts_flat.astype(np.float64)
)
elif stat == "mad":
medians_flat = np.nanmedian(group_X_flat, axis=0)
values_flat = np.nanmedian(
np.abs(group_X_flat - medians_flat),
axis=0,
)
elif stat == "iqr":
values_flat = np.nanpercentile(
group_X_flat,
75,
axis=0,
) - np.nanpercentile(
group_X_flat,
25,
axis=0,
)
elif stat == "min":
values_flat = np.nanmin(group_X_flat, axis=0)
elif stat == "max":
values_flat = np.nanmax(group_X_flat, axis=0)
else: # pragma: no cover - guarded above
raise ValueError(f"Unknown stat '{stat}'")
values_flat = np.asarray(values_flat, dtype=np.float64)
if counts_flat.size:
values_flat = np.where(counts_flat == 0, np.nan, values_flat)
return _reshape_reduced(values_flat)
def _failure_record(
group_id: Any,
group_index: int,
row_count: int,
valid_row_count: int,
message: str,
) -> Dict[str, Any]:
return {
"group_id": group_id,
"group_index": group_index,
"row_count": row_count,
"valid_row_count": valid_row_count,
"exception_type": "InsufficientObservations",
"message": message,
}
n_groups = len(ordered_groups)
rest_shape = X_moved.shape[1:]
reduced_shape = (n_groups, len(stats_out)) + rest_shape
agg_moved = np.empty(reduced_shape, dtype=np.float64)
epoch_counts = np.empty(n_groups, dtype=np.int64)
failures: List[Dict[str, Any]] = []
for group_index, group_id in enumerate(ordered_groups):
obs_positions = np.asarray(group_positions[group_id], dtype=int)
group_X = X_moved[obs_positions]
row_count = int(obs_positions.size)
epoch_counts[group_index] = row_count
if rest_shape:
group_X_flat = group_X.reshape(row_count, -1)
else:
group_X_flat = group_X.reshape(row_count, 1)
if group_X_flat.shape[1] == 0:
valid_row_count = row_count
else:
valid_row_count = int(np.isfinite(group_X_flat).any(axis=1).sum())
if valid_row_count < min_count:
message = (
f"Group {group_id!r} has {valid_row_count} valid rows, "
f"requires at least {min_count}."
)
failure = _failure_record(
group_id=group_id,
group_index=group_index,
row_count=row_count,
valid_row_count=valid_row_count,
message=message,
)
if on_insufficient == "raise":
raise ValueError(message)
if on_insufficient == "warn":
warnings.warn(message, stacklevel=2)
failures.append(failure)
agg_moved[group_index] = np.full((len(stats_out),) + rest_shape, np.nan)
continue
counts_flat = np.isfinite(group_X_flat).sum(axis=0, dtype=np.int64)
for stat_index, stat in enumerate(stats_out):
agg_moved[group_index, stat_index] = _reduce_group(
group_X=group_X,
group_X_flat=group_X_flat,
counts_flat=counts_flat,
stat=stat,
)
if len(stats_out) == 1:
moved_dims = ("obs",) + other_dims
final_dims = self.dims
agg_values = agg_moved[:, 0, ...]
else:
moved_dims = ("obs", "stat") + other_dims
final_dims_list: List[str] = []
for dim in self.dims:
final_dims_list.append(dim)
if dim == "obs":
final_dims_list.append("stat")
final_dims = tuple(final_dims_list)
agg_values = agg_moved
permutation = [moved_dims.index(dim) for dim in final_dims]
X_agg = np.transpose(agg_values, axes=permutation)
unique_groups = np.empty(n_groups, dtype=object)
unique_groups[:] = ordered_groups
new_y = None
if self.y is not None:
grouped_y: List[Any] = []
y_consistent = True
for group_id in ordered_groups:
values = np.asarray(self.y)[group_positions[group_id]]
if len(set(values.tolist())) != 1:
y_consistent = False
break
grouped_y.append(values[0])
if y_consistent:
new_y = np.asarray(grouped_y)
new_coords = {
dim: deepcopy(values)
for dim, values in self.coords.items()
if dim in self.dims and dim != "obs"
}
new_coords["obs"] = unique_groups
if len(stats_out) > 1:
new_coords["stat"] = np.asarray(stats_out, dtype=object)
new_coords["epoch_count"] = epoch_counts
for key, values in self.coords.items():
if key == "obs" or key in self.dims:
continue
if len(values) != n_obs:
continue
grouped_values: List[Any] = []
consistent = True
values_array = np.asarray(values, dtype=object)
for group_id in ordered_groups:
group_values = values_array[group_positions[group_id]]
if len(set(group_values.tolist())) != 1:
consistent = False
break
grouped_values.append(group_values[0])
if consistent:
coord_out = np.empty(n_groups, dtype=object)
coord_out[:] = grouped_values
new_coords[key] = coord_out
meta = deepcopy(self.meta)
meta.update(
{
"aggregated": True,
"agg_by": by if isinstance(by, str) else None,
"agg_stats": list(stats_out),
"min_count": int(min_count),
}
)
if failures:
meta["aggregate_failures"] = failures
return replace(
self,
X=X_agg,
y=new_y,
dims=final_dims,
ids=unique_groups,
coords=new_coords,
meta=meta,
)
[docs]
def aggregate_groups(
self,
by: Union[str, np.ndarray, List[Any]],
groups: Sequence[Dict[str, Any]],
min_count: int = 1,
on_insufficient: str = "raise",
skip_empty: bool = True,
) -> "DataContainer":
"""
Aggregate selected feature groups with different statistics.
This is a thin wrapper around :meth:`aggregate` for tabular feature
containers. Each group spec selects a subset of feature columns and
applies one or more stats to that subset. The outputs are concatenated
along the ``feature`` dimension, and each resulting feature name is
prefixed with its stat (for example ``"mean_band_log_abs_alpha"``).
Parameters
----------
by : str or array-like
Group definition for the observation axis. Passed through to
:meth:`aggregate`.
groups : sequence of dict
Ordered group specifications. Each group must provide ``"stats"``
and may optionally provide include/exclude selectors:
- ``names`` / ``exclude_names``
- ``prefixes`` / ``exclude_prefixes``
- ``suffixes`` / ``exclude_suffixes``
- ``contains`` / ``exclude_contains``
- ``regex`` / ``exclude_regex``
If a group provides no include selectors, it starts from all
features and then applies exclusions.
min_count : int, default=1
Minimum number of valid observations required per group. Passed
through to :meth:`aggregate`.
on_insufficient : {"raise", "warn", "collect"}, default="raise"
Policy applied when a group has fewer than ``min_count`` valid
observations. Passed through to :meth:`aggregate`.
skip_empty : bool, default=True
If True, silently skip group specs that match no features. If
False, raise a ``ValueError`` when a group matches nothing.
Returns
-------
DataContainer
Aggregated container with dims ``("obs", "feature")`` and
stat-prefixed feature names.
Raises
------
ValueError
If the container lacks a ``feature`` dimension or coord, no groups
are provided, a group spec is invalid, multiple groups would emit
the same output feature name, or no non-empty grouped outputs are
produced.
"""
if "feature" not in self.dims:
raise ValueError("aggregate_groups requires a 'feature' dimension.")
if "feature" not in self.coords:
raise ValueError("aggregate_groups requires a 'feature' coordinate.")
if not groups:
raise ValueError("`groups` must not be empty.")
feature_names = np.asarray(self.coords["feature"], dtype=object)
feature_axis = self.dims.index("feature")
include_keys = ("names", "prefixes", "suffixes", "contains", "regex")
exclude_keys = tuple(f"exclude_{key}" for key in include_keys)
allowed_group_keys = {"name", "stats", *include_keys, *exclude_keys}
def _normalize_patterns(value: Any) -> Tuple[str, ...]:
if value is None:
return ()
if isinstance(value, str):
return (value,)
return tuple(str(item) for item in value)
def _selector_mask(spec: Dict[str, Any], *, exclude: bool) -> np.ndarray:
selector_keys = exclude_keys if exclude else include_keys
mask = np.zeros(feature_names.size, dtype=bool)
for key in selector_keys:
patterns = _normalize_patterns(spec.get(key))
if not patterns:
continue
base_key = key.removeprefix("exclude_")
if base_key == "names":
mask |= np.isin(feature_names.astype(str), patterns)
elif base_key == "prefixes":
mask |= np.array(
[
any(str(name).startswith(pattern) for pattern in patterns)
for name in feature_names
],
dtype=bool,
)
elif base_key == "suffixes":
mask |= np.array(
[
any(str(name).endswith(pattern) for pattern in patterns)
for name in feature_names
],
dtype=bool,
)
elif base_key == "contains":
mask |= np.array(
[
any(pattern in str(name) for pattern in patterns)
for name in feature_names
],
dtype=bool,
)
elif base_key == "regex":
compiled = [re.compile(pattern) for pattern in patterns]
mask |= np.array(
[
any(pattern.search(str(name)) for pattern in compiled)
for name in feature_names
],
dtype=bool,
)
return mask
combined_parts: List[np.ndarray] = []
combined_feature_names: List[str] = []
aggregate_failures: List[Dict[str, Any]] = []
base_agg: Optional["DataContainer"] = None
for group_index, group in enumerate(groups):
if not isinstance(group, dict):
raise ValueError("Each entry in `groups` must be a dict.")
unknown_keys = sorted(set(group) - allowed_group_keys)
if unknown_keys:
raise ValueError(
f"Unknown aggregate_groups keys: {unknown_keys}. "
f"Supported keys are: {sorted(allowed_group_keys)}"
)
if "stats" not in group:
raise ValueError("Each aggregate_groups spec must include `stats`.")
stats_spec = group["stats"]
if isinstance(stats_spec, str):
stats_out = [stats_spec]
else:
stats_out = [str(stat) for stat in stats_spec]
if not stats_out:
raise ValueError(
"Each aggregate_groups spec must include at least one stat."
)
include_mask = _selector_mask(group, exclude=False)
has_include_selectors = any(
group.get(key) is not None for key in include_keys
)
if not has_include_selectors:
include_mask = np.ones(feature_names.size, dtype=bool)
exclude_mask = _selector_mask(group, exclude=True)
selected_mask = include_mask & ~exclude_mask
feature_indices = np.flatnonzero(selected_mask)
if feature_indices.size == 0:
if skip_empty:
continue
group_name = group.get("name", f"index {group_index}")
raise ValueError(
f"aggregate_groups spec {group_name!r} matched no features."
)
subset = self.isel(feature=feature_indices.tolist())
for stat in stats_out:
grouped = subset.aggregate(
by=by,
stats=stat,
min_count=min_count,
on_insufficient=on_insufficient,
)
prefixed_names = [
f"{stat}_{name}"
for name in np.asarray(grouped.coords["feature"], dtype=object)
]
duplicate_names = sorted(
set(prefixed_names).intersection(combined_feature_names)
)
if duplicate_names:
raise ValueError(
"aggregate_groups would emit duplicate feature names: "
f"{duplicate_names}"
)
if base_agg is None:
base_agg = grouped
else:
if grouped.dims != base_agg.dims:
raise ValueError(
"aggregate_groups requires all grouped outputs to "
"share the same dimensions."
)
if not np.array_equal(grouped.ids, base_agg.ids):
raise ValueError(
"aggregate_groups requires all grouped outputs to "
"share the same grouped observation ids."
)
combined_parts.append(np.asarray(grouped.X, dtype=np.float64))
combined_feature_names.extend(prefixed_names)
failures = grouped.meta.get("aggregate_failures", [])
for failure in failures:
failure_out = deepcopy(failure)
failure_out["aggregate_group_index"] = group_index
if "name" in group:
failure_out["aggregate_group_name"] = group["name"]
failure_out["aggregate_stat"] = stat
aggregate_failures.append(failure_out)
if base_agg is None: # pragma: no cover - guarded above
raise ValueError("aggregate_groups produced no aggregated outputs.")
new_coords = deepcopy(base_agg.coords)
new_coords["feature"] = np.asarray(combined_feature_names, dtype=object)
meta = deepcopy(self.meta)
unique_stats = list(
dict.fromkeys(name.split("_", 1)[0] for name in combined_feature_names)
)
meta.update(
{
"aggregated": True,
"agg_by": by if isinstance(by, str) else None,
"agg_stats": unique_stats,
"agg_groups": deepcopy(list(groups)),
"min_count": int(min_count),
}
)
if aggregate_failures:
meta["aggregate_failures"] = aggregate_failures
elif "aggregate_failures" in meta:
del meta["aggregate_failures"]
X_out = np.concatenate(combined_parts, axis=feature_axis)
return replace(
base_agg,
X=X_out,
coords=new_coords,
meta=meta,
)