Source code for ml_grid.model_classes.catboostClassifier

from typing import Any, Union

import numpy as np
import pandas as pd
from catboost import CatBoostClassifier
from sklearn.base import BaseEstimator, ClassifierMixin


[docs] class CatBoostSKLearnWrapper(BaseEstimator, ClassifierMixin): """A scikit-learn compatible wrapper for the CatBoostClassifier.""" def __init__(self, **kwargs: Any): """Initializes the CatBoostSKLearnWrapper. Args: **kwargs (Any): Keyword arguments passed directly to the `catboost.CatBoostClassifier`. """
[docs] self.model = CatBoostClassifier(**kwargs)
[docs] def fit( self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray] ) -> "CatBoostSKLearnWrapper": """Fits the CatBoost model. Args: X (Union[pd.DataFrame, np.ndarray]): The training input samples. y (Union[pd.Series, np.ndarray]): The target values. Returns: CatBoostSKLearnWrapper: The fitted estimator. """ self.model.fit(X, y) return self
[docs] def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: """Predicts class labels for samples in X. Args: X (Union[pd.DataFrame, np.ndarray]): The input samples to predict. Returns: np.ndarray: The predicted class labels. """ return self.model.predict(X)
[docs] def predict_proba(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: """Predicts class probabilities for samples in X. Args: X (Union[pd.DataFrame, np.ndarray]): The input samples. Returns: np.ndarray: The class probabilities of the input samples. """ return self.model.predict_proba(X)