Source code for ml_grid.model_classes.tabpfn_classifier_class

"""Defines the TabPFN Classifier model class."""

import logging
from typing import Any, Dict, Optional

import numpy as np
import pandas as pd
import torch
from sklearn.base import BaseEstimator, ClassifierMixin
from skopt.space import Categorical, Integer

from ml_grid.util import param_space
from ml_grid.util.global_params import global_parameters

try:
    from tabpfn import TabPFNClassifier
    from tabpfn.constants import ModelVersion

[docs] TABPFN_AVAILABLE = True
except ImportError: TABPFN_AVAILABLE = False logging.getLogger("ml_grid").warning( "TabPFN not available. Install with: pip install tabpfn" ) logging.getLogger("ml_grid").debug("Imported TabPFNClassifier class")
[docs] class TabPFNClassifierClass(BaseEstimator, ClassifierMixin): """TabPFN Classifier with support for hyperparameter tuning. TabPFN is a foundation model for tabular data that performs well on small to medium-sized datasets (up to 50,000 rows). It requires GPU for optimal performance on datasets larger than ~1000 samples. Note: TabPFN-2.5 model weights require accepting license terms at: https://huggingface.co/Prior-Labs/tabpfn_2_5 """ def __init__( self, parameter_space_size: Optional[str] = None, # Hyperparameters for scikit-learn compatibility model_version: str = "v2.5_default", device: str = "cpu", n_estimators: int = 4, subsample_samples: Optional[int] = None, random_state: int = 42, ): """Initializes the TabPFNClassifierClass. Args: parameter_space_size (Optional[str]): Size of the parameter space for optimization. Defaults to None. model_version (str): The version of the TabPFN model to use. device (str): The device to run the model on ('cpu' or 'cuda'). n_estimators (int): Number of ensemble members. subsample_samples (Optional[int]): Subsample size for large datasets. random_state (int): Random state for reproducibility. Raises: ImportError: If TabPFN is not installed. """ if not TABPFN_AVAILABLE: raise ImportError( "TabPFN is not installed. Install with: pip install tabpfn" ) # Store scikit-learn hyperparameters
[docs] self.model_version = model_version
[docs] self.device = device
[docs] self.n_estimators = n_estimators
[docs] self.subsample_samples = subsample_samples
[docs] self.random_state = random_state
global_params = global_parameters
[docs] self.parameter_space_size = parameter_space_size
[docs] self.algorithm_implementation = self # The instance itself is the estimator
[docs] self.method_name: str = "TabPFNClassifier"
[docs] self.parameter_vector_space: param_space.ParamSpace = param_space.ParamSpace( parameter_space_size )
[docs] self.parameter_space: Dict[str, Any]
if global_params.bayessearch: self.parameter_space = { # Model version selection "model_version": Categorical( [ "v2.5_default", # Default: finetuned on real data "v2.5_synthetic", # Trained on synthetic data only "v2", # TabPFN v2 ] ), # Device selection - can be optimized based on availability "device": Categorical(["cuda", "cpu"]), # Number of ensemble members (more = better but slower) "n_estimators": Integer(1, 8), # Training subsample size (for large datasets) "subsample_samples": Categorical([None, 5000, 10000, 20000]), # Random state for reproducibility "random_state": Categorical([42]), } else: self.parameter_space = { "model_version": ["v2.5_default", "v2.5_synthetic", "v2"], "device": ["cuda", "cpu"], "n_estimators": [1, 2, 4, 8], "subsample_samples": [None, 5000, 10000, 20000], "random_state": [42], }
[docs] def fit(self, X: pd.DataFrame, y: pd.Series): """Fits the TabPFN model. This method uses the hyperparameters set on the instance to create and fit the underlying TabPFNClassifier. """ # Apply subsampling if configured if self.subsample_samples is not None and len(X) > self.subsample_samples: # Use numpy for stable random sampling rng = np.random.RandomState(self.random_state) indices = rng.choice(len(X), self.subsample_samples, replace=False) # Handle DataFrame/Series or numpy arrays if isinstance(X, pd.DataFrame): X = X.iloc[indices] else: X = X[indices] if isinstance(y, pd.Series): y = y.iloc[indices] else: y = y[indices] # Get the hyperparameters from the instance itself params = self.get_params() # Check for GPU availability and fallback if necessary if params.get("device") == "cuda" and not torch.cuda.is_available(): logging.getLogger("ml_grid").warning( "TabPFN device set to 'cuda' but no CUDA GPU found. Falling back to 'cpu'." ) params["device"] = "cpu" # This logic was originally in create_model model_version = params.pop("model_version", "v2.5_default") # Filter out non-TabPFN params that might be in get_params() valid_tabpfn_params = ["device", "n_estimators", "random_state"] params_copy = {k: v for k, v in params.items() if k in valid_tabpfn_params} if model_version == "v2.5_synthetic": params_copy["model_path"] = "tabpfn-v2.5-classifier-v2.5_default-2.ckpt" if model_version == "v2": self._estimator = TabPFNClassifier.create_default_for_version( ModelVersion.V2, **params_copy ) else: self._estimator = TabPFNClassifier(**params_copy) self._estimator.fit(X, y) self.classes_ = self._estimator.classes_ return self
[docs] def predict(self, X: pd.DataFrame) -> pd.Series: """Makes predictions using the fitted model.""" return self._estimator.predict(X)
[docs] def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame: """Returns probability estimates for predictions.""" return self._estimator.predict_proba(X)