Source code for ml_grid.model_classes.TPOTClassifierWrapper

"""TPOT Classifier Wrapper.

This module provides a scikit-learn compatible wrapper for TPOTClassifier.
"""

import logging
from typing import Optional, Union

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_is_fitted

# Attempt to import TPOT
try:
    from tpot import TPOTClassifier
except ImportError:
[docs] TPOTClassifier = None
[docs] logger = logging.getLogger(__name__)
[docs] class TPOTClassifierWrapper(BaseEstimator, ClassifierMixin): """A scikit-learn compatible wrapper for TPOTClassifier.""" def __init__( self, generations: int = 5, population_size: int = 20, offspring_size: Optional[int] = None, mutation_rate: float = 0.9, crossover_rate: float = 0.1, scoring: str = "accuracy", cv: int = 5, subsample: float = 1.0, n_jobs: int = -1, max_time_mins: Optional[int] = None, max_eval_time_mins: float = 5, random_state: int = 42, verbosity: int = 2, early_stop: Optional[int] = None, ):
[docs] self.generations = generations
[docs] self.population_size = population_size
[docs] self.offspring_size = offspring_size
[docs] self.mutation_rate = mutation_rate
[docs] self.crossover_rate = crossover_rate
[docs] self.scoring = scoring
[docs] self.cv = cv
[docs] self.subsample = subsample
[docs] self.n_jobs = n_jobs
[docs] self.max_time_mins = max_time_mins
[docs] self.max_eval_time_mins = max_eval_time_mins
[docs] self.random_state = random_state
[docs] self.verbosity = verbosity
[docs] self.early_stop = early_stop
[docs] self.model_ = None
[docs] def fit( self, X: Union[np.ndarray, pd.DataFrame], y: Union[np.ndarray, pd.Series], **kwargs, ) -> "TPOTClassifierWrapper": if TPOTClassifier is None: raise ImportError( "TPOT is not installed. Please install it to use TPOTClassifierWrapper." ) self.model_ = TPOTClassifier( generations=self.generations, population_size=self.population_size, offspring_size=self.offspring_size, mutation_rate=self.mutation_rate, crossover_rate=self.crossover_rate, scoring=self.scoring, cv=self.cv, subsample=self.subsample, n_jobs=self.n_jobs, max_time_mins=self.max_time_mins, max_eval_time_mins=self.max_eval_time_mins, random_state=self.random_state, verbosity=self.verbosity, early_stop=self.early_stop, disable_update_check=True, ) # TPOT can be slow. For quick checks, it's useful to see it has started. logger.info( f"Starting TPOT fit with generations={self.generations}, population_size={self.population_size}..." ) self.model_.fit(X, y, **kwargs) # After fitting, TPOT stores the best pipeline in the `fitted_pipeline_` attribute. # We must set `classes_` for scikit-learn compatibility (e.g., for check_is_fitted). # While TPOT exposes `self.model_.classes_`, inferring from `y` is a more robust # fallback, consistent with other wrappers in this project. if hasattr(self.model_, "classes_"): self.classes_ = self.model_.classes_ else: self.classes_ = np.unique(y) logger.info("TPOT fit completed.") return self
[docs] def predict(self, X: Union[np.ndarray, pd.DataFrame]) -> np.ndarray: check_is_fitted(self, ["model_", "classes_"]) return self.model_.predict(X)
[docs] def predict_proba(self, X: Union[np.ndarray, pd.DataFrame]) -> np.ndarray: check_is_fitted(self, ["model_", "classes_"]) return self.model_.predict_proba(X)