Source code for ml_grid.model_classes.quadratic_discriminant_class

"""Defines the QuadraticDiscriminantAnalysis model class."""

import logging
from typing import Any, Dict, Optional

import pandas as pd
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from skopt.space import Categorical, Real

from ml_grid.util import param_space
from ml_grid.util.global_params import global_parameters

logging.getLogger("ml_grid").debug("Imported QuadraticDiscriminantAnalysis class")


[docs] class QuadraticDiscriminantAnalysisClass: """QuadraticDiscriminantAnalysis with support for hyperparameter tuning.""" def __init__( self, X: Optional[pd.DataFrame] = None, y: Optional[pd.Series] = None, parameter_space_size: Optional[str] = None, ): """Initializes the QuadraticDiscriminantAnalysisClass. Args: X (Optional[pd.DataFrame]): Feature matrix for training. Defaults to None. y (Optional[pd.Series]): Target vector for training. Defaults to None. parameter_space_size (Optional[str]): Size of the parameter space for optimization. Defaults to None. Raises: ValueError: If `parameter_space_size` is not a valid key (though current implementation does not explicitly raise this). """ global_params = global_parameters
[docs] self.X: Optional[pd.DataFrame] = X
[docs] self.y: Optional[pd.Series] = y
[docs] self.algorithm_implementation: QuadraticDiscriminantAnalysis = ( QuadraticDiscriminantAnalysis() )
[docs] self.method_name: str = "QuadraticDiscriminantAnalysis"
[docs] self.parameter_vector_space: param_space.ParamSpace = param_space.ParamSpace( parameter_space_size )
[docs] self.parameter_space: Dict[str, Any]
if global_params.bayessearch: self.parameter_space = { "priors": Categorical([None]), # Categorical: single option, None "reg_param": Real(1e-5, 1e-2, prior="log-uniform"), "store_covariance": Categorical( [False] ), # Categorical: single option, False "tol": Real(1e-5, 1e-2, prior="log-uniform"), } else: self.parameter_space = { "priors": [None], "reg_param": self.parameter_vector_space.param_dict.get("log_small"), "store_covariance": [False], "tol": self.parameter_vector_space.param_dict.get("log_small"), }