Source code for coco_pipe.decoding.registry

"""
Decoding Registry
=================

Central registry for decoding estimators (classifiers, regressors, and FMs).
This allows instantiating models from string names in configuration files,
avoiding circular imports and simplifying the config layer.

Usage
-----
>>> from coco_pipe.decoding.registry import register_estimator, get_estimator_cls
>>>
>>> @register_estimator("MyModel")
>>> class MyModel: ...
>>>
>>> cls = get_estimator_cls("MyModel")
"""

import importlib
import pkgutil
import warnings
from importlib.metadata import entry_points
from typing import Callable, Dict, Type

# Registry Storage
# Maps string alias -> class object
_ESTIMATOR_REGISTRY: Dict[str, Type] = {}


_LAZY_MODULES = {
    # MNE
    "SlidingEstimator": "mne.decoding",
    "GeneralizingEstimator": "mne.decoding",
    # Classifiers
    "LogisticRegression": "sklearn.linear_model",
    "RandomForestClassifier": "sklearn.ensemble",
    "SVC": "sklearn.svm",
    "KNeighborsClassifier": "sklearn.neighbors",
    "GradientBoostingClassifier": "sklearn.ensemble",
    "SGDClassifier": "sklearn.linear_model",
    "MLPClassifier": "sklearn.neural_network",
    "GaussianNB": "sklearn.naive_bayes",
    "LDA": "sklearn.discriminant_analysis",
    "AdaBoostClassifier": "sklearn.ensemble",
    "DummyClassifier": "sklearn.dummy",
    # Regressors
    "LinearRegression": "sklearn.linear_model",
    "Ridge": "sklearn.linear_model",
    "Lasso": "sklearn.linear_model",
    "ElasticNet": "sklearn.linear_model",
    "RandomForestRegressor": "sklearn.ensemble",
    "SVR": "sklearn.svm",
    "ARDRegression": "sklearn.linear_model",
}


[docs] def _discover_entry_points(): """ Populate _LAZY_MODULES from 'coco_pipe.estimators' entry points. This allows plugins to register estimators without modifying code. """ eps = entry_points(group="coco_pipe.estimators") for ep in eps: if ep.name not in _LAZY_MODULES: _LAZY_MODULES[ep.name] = ep.value
[docs] def _discover_internal_modules(): """ Walk through the 'coco_pipe.decoding' subpackage and import all modules. This triggers the @register_estimator decorators. """ package = importlib.import_module("coco_pipe.decoding") if not hasattr(package, "__path__"): return for _, name, ispkg in pkgutil.walk_packages( package.__path__, package.__name__ + "." ): try: importlib.import_module(name) except ImportError: # warn but continue - we don't want to crash if deep learning libs are # missing pass
# 1. Load Entry Points on startup (lazy map update only) _discover_entry_points()
[docs] def register_estimator(name: str) -> Callable[[Type], Type]: """ Decorator to register an estimator class under a specific name. Parameters ---------- name : str The unique alias for the estimator (e.g., "RandomForestClassifier"). """ def decorator(cls: Type) -> Type: if name in _ESTIMATOR_REGISTRY: warnings.warn(f"Overwriting existing estimator registry for '{name}'") _ESTIMATOR_REGISTRY[name] = cls return cls return decorator
[docs] def get_estimator_cls(name: str) -> Type: """ Retrieve an estimator class by name. Parameters ---------- name : str Name of the estimator. Returns ------- Type The class object. Raises ------ ValueError If name is not found. """ # 1. Check if already loaded if name in _ESTIMATOR_REGISTRY: return _ESTIMATOR_REGISTRY[name] # 2. Try Lazy Loading Map if name in _LAZY_MODULES: try: mod_path = _LAZY_MODULES[name] if ":" in mod_path: mod_path = mod_path.split(":")[0] module = importlib.import_module(mod_path) except ImportError as e: raise ImportError( f"Could not load estimator '{name}' from '{_LAZY_MODULES[name]}'. " f"Ensure optional dependencies are installed." ) from e if hasattr(module, name): cls = getattr(module, name) _ESTIMATOR_REGISTRY[name] = cls return cls # Check if the import triggered a decorator registration if name in _ESTIMATOR_REGISTRY: return _ESTIMATOR_REGISTRY[name] # 3. Last Ditch: Internal Discovery if not getattr(get_estimator_cls, "_internal_scanned", False): _discover_internal_modules() setattr(get_estimator_cls, "_internal_scanned", True) if name in _ESTIMATOR_REGISTRY: return _ESTIMATOR_REGISTRY[name] if name not in _ESTIMATOR_REGISTRY: # Generate helpful error available = sorted(list(_ESTIMATOR_REGISTRY.keys())) raise ValueError( f"Estimator '{name}' not found in registry.\n" f"Available estimators: {available}\n" f"Tip: Ensure the containing module is imported or registered via " f"entry points." ) return _ESTIMATOR_REGISTRY[name]
[docs] def list_estimators() -> Dict[str, Type]: """Return a copy of the current registry.""" # Ensure everything is discovered before listing if not getattr(get_estimator_cls, "_internal_scanned", False): _discover_internal_modules() setattr(get_estimator_cls, "_internal_scanned", True) return dict(_ESTIMATOR_REGISTRY)