Source code for coco_pipe.io.dataset

"""
coco_pipe/io/dataset.py
-----------------------
Specialized Dataset classes that produce standardized DataContainer objects.
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd

from .structures import DataContainer
from .utils import (
    _get_bids_path,
    default_id_extractor,
    detect_runs,
    detect_sessions,
    detect_subjects,
    load_participants_tsv,
    read_bids_entry,
    smart_reader,
    split_column,
)

logger = logging.getLogger(__name__)


[docs] class BaseDataset(ABC):
[docs] @abstractmethod def load(self) -> DataContainer: pass
[docs] class TabularDataset(BaseDataset): """ Dataset for loading tabular feature data (CSV, TSV, Excel). This class handles loading, optional clearing, and reshaping of 2D tabular data into multi-dimensional DataContainers. Parameters ---------- path : str or Path Path to the tabular file (csv, tsv, txt, xls, xlsx). target_col : str, optional Name of the column to extract as target `y`. Removed from features `X`. index_col : str or int, optional Column to use as index (observation IDs). sep : str, default='\\t' Separator for text files. header : int or list of int, default=0 Row number(s) to use as column names. sheet_name : str or int, default=0 Sheet name or index for Excel files. columns_to_dims : list of str, optional If provided, attempts to reshape the 2D feature columns into N-D dimensions. Columns must follow the naming convention: `dim1_dim2_..._feature`. col_sep : str, default='_' Separator used in column names for reshaping. meta_columns : list of str, optional List of columns to extract as metadata coordinates instead of features. clean : bool, default=False Whether to perform automated cleaning (drop NaNs/Infs). clean_kwargs : dict, optional Arguments passed to `TabularDataset.clean`. select_kwargs : dict, optional Arguments for feature selection (not yet implemented in load directly). Examples -------- >>> # Load a simple CSV >>> ds = TabularDataset("data.csv", target_col="label") >>> container = ds.load() >>> # Load and reshape wide data (e.g. time series in columns) >>> # Columns: T0_F1, T0_F2, T1_F1... -> dims=('time', 'freq') >>> ds = TabularDataset("wide.csv", columns_to_dims=['time', 'freq'], col_sep='_') """ def __init__( self, path: Union[str, Path], target_col: Optional[str] = None, index_col: Optional[Union[str, int]] = None, sep: str = "\t", header: Optional[Union[int, List[int]]] = 0, sheet_name: Optional[Union[str, int]] = 0, columns_to_dims: Optional[List[str]] = None, col_sep: str = "_", meta_columns: Optional[List[str]] = None, clean: bool = False, clean_kwargs: Optional[Dict[str, Any]] = None, select_kwargs: Optional[Dict[str, Any]] = None, ): self.path = Path(path) self.target_col = target_col self.index_col = index_col self.sep = sep self.header = header self.sheet_name = sheet_name self.columns_to_dims = columns_to_dims self.col_sep = col_sep self.meta_columns = meta_columns or [] self.do_clean = clean self.clean_kwargs = clean_kwargs or {} self.select_kwargs = select_kwargs or {} self.strict_reshaping = True
[docs] def load(self) -> DataContainer: if not self.path.exists(): raise FileNotFoundError(f"File not found: {self.path}") # 1. Load DataFrame if self.path.suffix in [".csv", ".tsv", ".txt"]: df = pd.read_csv( self.path, sep=self.sep, index_col=self.index_col, header=self.header ) elif self.path.suffix in [".xls", ".xlsx"]: df = pd.read_excel( self.path, index_col=self.index_col, header=self.header, sheet_name=self.sheet_name, ) else: raise ValueError(f"Unsupported file extension: {self.path.suffix}") # Dtype Check non_numeric = df.select_dtypes(include=["object", "string", "category"]) if not non_numeric.empty and self.target_col not in non_numeric.columns: logger.warning( f"Tabular data contains non-numeric columns: " f"{non_numeric.columns.tolist()}. " "These might cause issues during processing if not intended as " "metadata." ) if self.target_col and self.target_col in df.columns: y = df[self.target_col].values X_df = df.drop(columns=[self.target_col]) else: y = None X_df = df # 3. Handle Demographics / Meta Columns covariates = {} if self.meta_columns: found_meta = [c for c in self.meta_columns if c in X_df.columns] if found_meta: meta_df = X_df[found_meta] for col in meta_df.columns: covariates[col] = meta_df[col].values X_df = X_df.drop(columns=found_meta) # 4. Cleaning if self.do_clean: X_df, report = self.clean(X_df, **self.clean_kwargs) else: report = None ids = df.index.astype(str).values if self.index_col is not None else None # 5. Reshaping Logic (2D -> ND) coords = {} if ids is not None: coords["obs"] = np.array(ids) # Add covariates to coords (Auxiliary Coords) coords.update(covariates) dims = ["obs", "feature"] X_final = X_df.values if self.columns_to_dims: parsed_cols = [] valid_cols = [] failed_cols = [] for col in X_df.columns: parts = str(col).split(self.col_sep) if len(parts) == len(self.columns_to_dims): parsed_cols.append(tuple(parts)) valid_cols.append(col) else: failed_cols.append(col) if failed_cols: msg = ( f"{len(failed_cols)} columns failed reshaping pattern " f"(sep='{self.col_sep}', expected {len(self.columns_to_dims)} " f"parts). Examples: {failed_cols[:5]}" ) if self.strict_reshaping: logger.error(msg) # raise ValueError(msg) # Optional: strict mode could raise logger.warning(msg) if not parsed_cols: raise ValueError( f"No columns matched the reshaping pattern with sep=" f"'{self.col_sep}'." ) # Create MultiIndex to sort and structure mi = pd.MultiIndex.from_tuples(parsed_cols, names=self.columns_to_dims) # Reorder columns to match sorted MultiIndex (Cartesian Product order) # This ensures reshape works correctly X_subset = X_df[valid_cols] X_subset.columns = mi X_sorted = X_subset.sort_index(axis=1) # Sorts lexically by levels # Extract Levels to Coords for i, dim_name in enumerate(self.columns_to_dims): unique_vals = X_sorted.columns.unique(level=i) coords[dim_name] = unique_vals.values # Verify Shape n_obs = X_sorted.shape[0] dim_sizes = [len(coords[d]) for d in self.columns_to_dims] expected_total = np.prod(dim_sizes) if X_sorted.shape[1] != expected_total: raise ValueError( f"Reshaping failed. Found {X_sorted.shape[1]} columns, expected " f"full product {expected_total} ({dim_sizes}). Missing " f"combinations?" ) # Reshape: (N_obs, Dim1, Dim2, ...) new_shape = (n_obs,) + tuple(dim_sizes) X_final = X_sorted.values.reshape(new_shape) dims = tuple(["obs"] + self.columns_to_dims) else: # Default 2D coords["feature"] = np.array(X_df.columns.tolist()) return DataContainer( X=X_final, y=y, ids=np.array(ids) if ids is not None else None, dims=tuple(dims), coords=coords, meta={"filename": self.path.name, "cleaning_report": report}, )
[docs] @staticmethod def clean( X: pd.DataFrame, mode: str = "any", sep: str = "_", reverse: bool = False, verbose: bool = False, min_abs_value: Optional[float] = None, min_abs_fraction: float = 0.0, ) -> Tuple[pd.DataFrame, Dict[str, List[str]]]: """ Remove invalid feature columns containing NaN, ±Inf, and optionally very small values. """ if X.shape[1] == 0: return X.copy(), { "dropped_columns": [], "dropped_features": [], "mode": mode, "n_before": 0, "n_after": 0, } # Identify columns with NaN/Inf num = X.select_dtypes(include=[np.number]) other = X.drop(columns=num.columns, errors="ignore") bad_cols = [] if not num.empty: arr = num.to_numpy() with np.errstate(divide="ignore", invalid="ignore"): inf_mask = np.isinf(arr) bad_mask = num.isna().to_numpy() | inf_mask bad_any = bad_mask.any(axis=0) bad_cols.extend(num.columns[bad_any].tolist()) if min_abs_value is not None: with np.errstate(invalid="ignore"): tiny_mask = np.abs(arr) < float(min_abs_value) if min_abs_fraction <= 0.0: tiny_cols = num.columns[tiny_mask.any(axis=0)].tolist() else: frac = tiny_mask.mean(axis=0) tiny_cols = num.columns[(frac >= min_abs_fraction)].tolist() bad_cols.extend(tiny_cols) if not other.empty: obj_bad = other.isna().all(axis=0) bad_cols.extend(other.columns[obj_bad].tolist()) dropped_columns = [] dropped_features = [] if mode == "any": dropped_columns = sorted(set(bad_cols)) X_clean = X.drop(columns=dropped_columns, errors="ignore") elif mode == "sensor_wide": feature_to_cols = {} for col in X.columns: _, feat = split_column(col, sep=sep, reverse=reverse) feature_to_cols.setdefault(feat, []).append(col) bad_features = set() for col in bad_cols: _, feat = split_column(col, sep=sep, reverse=reverse) bad_features.add(feat) for feat in sorted(bad_features): dropped_columns.extend(feature_to_cols.get(feat, [])) dropped_columns = sorted(set(dropped_columns)) dropped_features = sorted(bad_features) X_clean = X.drop(columns=dropped_columns, errors="ignore") else: raise ValueError("mode must be one of {'any','sensor_wide'}") report = { "mode": mode, "dropped_columns": dropped_columns, "dropped_features": dropped_features, "n_before": X.shape[1], "n_after": X_clean.shape[1], } return X_clean, report
[docs] class EmbeddingDataset(BaseDataset): """ Generic Dataset for loading embedding files (Pickle, NPY, JSON, H5). This class decouples file discovery (via patterns and IDs) from content reading. It supports structured formats (e.g., Layers x Features) and user-supplied metadata coordinates. Parameters ---------- path : str or Path Root directory containing the embedding files. pattern : str, default='*.pkl' Glob pattern to match files (e.g., "*.npy", "sub-*_emb.pkl"). dims : tuple of str, default=('obs', 'feature') Dimension labels for the data arrays (excluding the observation dimension if implicit). Typically ('feature',) or ('layer', 'feature'). coords : dict, optional Dictionary of coordinates for dimensions. E.g., {'layer': ['L1', 'L2']}. reader : callable, optional Custom function to read a Path and return a numpy array or dict. If None, uses `smart_reader` based on file extension. id_fn : callable, optional Custom function to extract subject ID from a Path. If None, uses `default_id_extractor`. task : str, optional (Legacy BIDS) Task name to construct search pattern. run : str, optional (Legacy BIDS) Run name to construct search pattern. processing : str, optional (Legacy BIDS) Processing label. subjects : int or list, optional If int, loads first N subjects. If list, loads specific subjects (matched by `id_fn`). Examples -------- >>> # Load loose numpy files >>> ds = EmbeddingDataset("./embeddings", pattern="*.npy", dims=('feature',)) >>> container = ds.load() """ def __init__( self, path: Union[str, Path], pattern: str = "*.pkl", dims: Tuple[str, ...] = ("obs", "feature"), coords: Optional[Dict[str, Union[List, np.ndarray]]] = None, reader: Optional[Any] = None, id_fn: Optional[Any] = None, task: Optional[str] = None, run: Optional[str] = None, processing: Optional[str] = None, subjects: Optional[Union[int, List[int]]] = None, ): self.path = Path(path) self.subjects = subjects self.dims = dims self.coords_in = coords or {} # 1. Determine Pattern if any([task, run, processing]): # Legacy BIDS-like construction p_parts = ["sub-*"] if task: p_parts.append(f"task-{task}") if run: p_parts.append(f"run-{run}") p_parts.append(f"embeddings{processing or ''}.pkl") self.pattern = "_".join(p_parts) else: self.pattern = pattern # 2. Set Reader self.reader = reader if reader else smart_reader self.id_fn = id_fn if id_fn else default_id_extractor
[docs] def load(self) -> DataContainer: # Find files files = sorted(list(self.path.rglob(self.pattern))) if not files: raise FileNotFoundError( f"No files matched pattern '{self.pattern}' in {self.path}" ) # Filter by subjects if self.subjects is not None: if isinstance(self.subjects, int): files = files[: self.subjects] else: target_ids = set(str(s) for s in self.subjects) files = [f for f in files if self.id_fn(f) in target_ids] data_list = [] ids_list = [] logger.info(f"Loading {len(files)} embedding files...") for fpath in files: try: # Reader returns (N, ...) or Dict content = self.reader(fpath) sid = self.id_fn(fpath) if isinstance(content, dict): # Dict {segment_id: array} for seg_k in sorted(content.keys()): arr = np.array(content[seg_k]) # Ensure array if arr.ndim == len(self.dims) + 1: data_list.append(arr) ids_list.extend( [f"{sid}_{seg_k}_{i}" for i in range(len(arr))] ) elif arr.ndim == len(self.dims): data_list.append(arr[np.newaxis, ...]) ids_list.append(f"{sid}_{seg_k}") else: logger.warning( f"Shape mismatch in {fpath.name} key {seg_k}: " f"{arr.shape} vs dims {self.dims}" ) else: # Single Array or List arr = np.array(content) if arr.ndim == len(self.dims) + 1: data_list.append(arr) ids_list.extend([f"{sid}_{i}" for i in range(len(arr))]) elif arr.ndim == len(self.dims): data_list.append(arr[np.newaxis, ...]) ids_list.append(sid) else: logger.warning( f"Shape mismatch in {fpath.name}: {arr.shape} vs dims " f"{self.dims}" ) except Exception as e: logger.error(f"Failed to read {fpath.name}: {e}") continue if not data_list: raise RuntimeError("No valid data loaded.") try: X_final = np.concatenate(data_list, axis=0) except ValueError as e: shapes = [d.shape for d in data_list[:5]] raise ValueError(f"Concatenation failed. Shapes vary? {shapes}") from e # Obs Dim + User Dims final_dims = ("obs",) + self.dims # Build coordinates coords = {} if ids_list: coords["obs"] = np.array(ids_list) # Add user-provided coords (e.g. {'layer': ['L1', 'L2']}) coords.update(self.coords_in) return DataContainer( X=X_final, dims=final_dims, coords=coords, ids=np.array(ids_list) )
[docs] class BIDSDataset(BaseDataset): """ Dataset for loading M/EEG data formatted according to the BIDS standard. This class supports loading valid BIDS structures, handling multiple subjects, sessions, and data types (Raw, Epoched, Evoked). It automatically extracts metadata from `participants.tsv` and aligns it with the loaded data. Parameters ---------- root : str or Path The root directory of the BIDS dataset. task : str, optional The task name (e.g., 'rest', 'audiovisual'). session : str or List[str], optional The session ID(s) to load. If None, detects all available sessions. datatype : str, default='eeg' The data type to load (e.g., 'eeg', 'meg'). suffix : str, optional The suffix of the files to load. - If None, defaults to `datatype`. - Use 'epo' to load pre-computed epochs. - Use 'ave' to load evoked data. mode : str, default='epochs' The loading mode: - 'epochs': Splices raw continuous data into fixed-length windows. - 'continuous': Loads raw data as single continuous segments (1 epoch per run). - 'load_existing': treated as pre-computed epochs (requires `suffix='epo'`). window_length : float, optional Length of window in seconds for 'epochs' mode. stride : float, optional Stride between windows in seconds. If None, defaults to `window_length` (no overlap). subjects : str or List[str], optional Specific subject IDs to load (without 'sub-' prefix). If None, detects all subjects. Examples -------- >>> # Load resting state EEG for all subjects, sliced into 1s windows >>> ds = BIDSDataset(root="/data/bids", task="rest", window_length=1.0) >>> container = ds.load() """ def __init__( self, root: Union[str, Path], task: Optional[str] = None, session: Optional[Union[str, List[str]]] = None, datatype: str = "eeg", suffix: Optional[str] = None, mode: str = "epochs", target_col: Optional[str] = None, window_length: Optional[float] = None, stride: Optional[float] = None, subjects: Optional[Union[str, List[str]]] = None, runs: Optional[Union[str, List[str]]] = None, event_id: Optional[Union[Dict[str, int], str, List[str]]] = None, subject_metadata_df: Optional[pd.DataFrame] = None, subject_key: Optional[str] = None, tmin: float = -0.2, tmax: float = 0.5, baseline: Optional[Tuple[Optional[float], Optional[float]]] = None, ): self.root = Path(root) self.task = task self.session = session self.datatype = datatype self.suffix = suffix self.mode = mode self.target_col = target_col self.window_length = window_length self.stride = stride self.subjects = subjects self.runs = runs self.event_id = event_id self.subject_metadata_df = subject_metadata_df self.subject_key = subject_key self.tmin = tmin self.tmax = tmax self.baseline = baseline
[docs] def load(self) -> DataContainer: """ Load the BIDS dataset into a DataContainer. Returns ------- DataContainer A container with: - X: Data array of shape (N_obs, N_channels, N_time). - ids: Unique identifiers for each observation. - coords: Dictionary containing 'channel', 'time', 'obs', and metadata. - dims: ('obs', 'channel', 'time'). """ # Resolve subjects if self.subjects is None: subjects = detect_subjects(self.root) subjects = sorted(subjects) elif isinstance(self.subjects, str): subjects = [self.subjects] else: subjects = sorted(self.subjects) # Load participants.tsv metadata meta_lookup = load_participants_tsv(self.root) if self.subject_metadata_df is not None: if self.subject_key is None: raise ValueError( "subject_key must be provided when subject_metadata_df is used." ) if self.subject_key not in self.subject_metadata_df.columns: raise ValueError( f"subject_key '{self.subject_key}' not found " "in subject_metadata_df." ) for _, row in self.subject_metadata_df.iterrows(): sub = str(row[self.subject_key]).replace("sub-", "") meta_lookup.setdefault(sub, {}).update(row.to_dict()) data_list = [] ids_list = [] meta_columns = ( {k: [] for k in next(iter(meta_lookup.values())).keys()} if meta_lookup else {} ) labels_list = [] ch_names = None times = None sfreq = None # Determine Loading Strategy # If suffix implies pre-computed data is_pre_epoched = (self.suffix and "epo" in self.suffix) or ( self.mode == "load_existing" ) is_evoked = (self.suffix and "ave" in self.suffix) or (self.datatype == "ave") for sub in subjects: # Resolve sessions if self.session is None: sessions = detect_sessions(self.root, sub) if not sessions: sessions = [None] elif isinstance(self.session, str): sessions = [self.session] else: sessions = self.session sub_meta = meta_lookup.get(sub, {}) for ses in sessions: # Resolve runs if self.runs is None: runs = detect_runs( self.root, sub, ses, task=self.task, datatype=self.datatype ) if not runs: runs = [None] elif isinstance(self.runs, str): runs = [self.runs] else: runs = self.runs for run in runs: if is_pre_epoched: pre_epoched_dir = self.root / f"sub-{sub}" if ses: pre_epoched_dir = pre_epoched_dir / f"ses-{ses}" pre_epoched_dir = pre_epoched_dir / self.datatype stem_parts = [f"sub-{sub}"] if ses: stem_parts.append(f"ses-{ses}") if self.task: stem_parts.append(f"task-{self.task}") if run: stem_parts.append(f"run-{run}") pre_epoched_suffix = self.suffix or "epo" stem = "*_".join(stem_parts) matches = sorted( pre_epoched_dir.glob(f"{stem}*_{pre_epoched_suffix}.fif") ) bids_path = SimpleNamespace( fpath=( matches[0] if matches else pre_epoched_dir / f"{stem}_{pre_epoched_suffix}.fif" ), match=lambda matches=matches: matches, ) else: bids_path = _get_bids_path()( subject=sub, session=ses, task=self.task, run=run, datatype=self.datatype, root=self.root, suffix=self.suffix or self.datatype, ) try: # --- LOAD STRATEGY (Delegated) --- data, current_times, current_ch, current_sfreq, current_y = ( read_bids_entry( bids_path, is_pre_epoched=is_pre_epoched, is_evoked=is_evoked, mode=self.mode, window_length=self.window_length, stride=self.stride, event_id=self.event_id, tmin=self.tmin, tmax=self.tmax, baseline=self.baseline, ) ) # --- CONSISTENCY CHECKS --- if ch_names is None: ch_names = current_ch sfreq = current_sfreq times = current_times else: # 1. Channel Consistency if list(current_ch) != list(ch_names): diff = set(current_ch) ^ set(ch_names) logger.warning( f"Channel mismatch for sub-{sub} ses-{ses}. " f"Expected {len(ch_names)}, got {len(current_ch)}. " f"Differing channels: {list(diff)[:5]}..." ) # 2. Time/Length Consistency if len(current_times) != len(times): logger.warning( f"Time length mismatch for sub-{sub} ses-{ses}. " f"Expected {len(times)}, got {len(current_times)}. " "This may cause concatenation failure." ) elif not np.allclose(current_times, times, atol=1e-5): # Often simple jitter in start times, but important if # rigorous pass # --- APPEND DATA --- data_list.append(data) if current_y is not None: labels_list.append(current_y) # --- GENERATE IDs & METADATA --- # data shape is (N_epochs, C, T) n_epochs = data.shape[0] sid_base = f"{sub}" if ses: sid_base += f"_{ses}" if run: sid_base += f"_run-{run}" new_ids = [f"{sid_base}_{i}" for i in range(n_epochs)] ids_list.extend(new_ids) # Repeatedly append subject metadata for each epoch for k, v in sub_meta.items(): meta_columns.setdefault(k, []).extend([v] * n_epochs) except Exception as e: logger.debug(f"Failed to load subject {sub} session {ses}: {e}") continue if not data_list: raise RuntimeError(f"No valid data found in {self.root}") # --- CONCATENATE --- try: # data_list contains (N_i, C, T) X_out = np.concatenate(data_list, axis=0) y_out = np.concatenate(labels_list, axis=0) if labels_list else None except ValueError as e: shapes = [d.shape for d in data_list[:5]] raise ValueError(f"Concatenation failed. Shapes vary? {shapes}") from e coords = {} if ch_names is not None and len(ch_names) > 0: coords["channel"] = np.array(ch_names) if times is not None: coords["time"] = times if ids_list: coords["obs"] = np.array(ids_list) # Add metadata coords for k, v in meta_columns.items(): if len(v) == len(ids_list): coords[k] = np.array(v) if self.target_col is not None: if self.target_col not in coords: raise ValueError( f"target_col '{self.target_col}' not found in BIDS coords." ) if len(coords[self.target_col]) != len(ids_list): raise ValueError( f"target_col '{self.target_col}' length " f"{len(coords[self.target_col])} does not match " f"the number of observations {len(ids_list)}." ) y_out = np.array(coords[self.target_col]) dims = ("obs", "channel", "time") return DataContainer( X=X_out, y=y_out, ids=np.array(ids_list), dims=dims, coords=coords, meta={"sfreq": sfreq, "source": str(self.root)}, )