Source code for ml_grid.model_classes.H2OGLMClassifier


import pandas as pd
from h2o.estimators import H2OGeneralizedLinearEstimator

from .H2OBaseClassifier import H2OBaseClassifier


[docs] class H2OGLMClassifier(H2OBaseClassifier): """A scikit-learn compatible wrapper for H2O's Generalized Linear Models.""" def __init__(self, **kwargs): """Initializes the H2OGLMClassifier, handling the 'lambda' parameter. This wrapper ensures compatibility with scikit-learn's parameter naming by accepting `lambda_` and internally mapping it for the H2O backend. Args: **kwargs: Keyword arguments passed directly to the `H2OGeneralizedLinearEstimator`. Common arguments include `family='binomial'`, `alpha=0.5`, and `lambda_` (for regularization). """ # --- FIX for scikit-learn cloning and H2O's 'lambda' parameter --- # scikit-learn's get_params() will return 'lambda_', but the user might # provide 'lambda' in the parameter grid. We must handle both cases. if "lambda" in kwargs and "lambda_" not in kwargs: kwargs["lambda_"] = kwargs.pop("lambda") # Remove estimator_class from kwargs if present (happens during sklearn clone) kwargs.pop("estimator_class", None) # Pass the specific estimator class super().__init__(estimator_class=H2OGeneralizedLinearEstimator, **kwargs)
[docs] def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs) -> "H2OGLMClassifier": """Fits the H2O GLM model and corrects the 'lambda_' parameter name. This method first calls the parent `fit` method to train the model. After fitting, it ensures the internal H2O model object has the regularization parameter named 'lambda' (not 'lambda_') to prevent errors during prediction. Returns: H2OGLMClassifier: The fitted classifier instance. """ # Call the parent class's fit method to perform the actual training super().fit(X, y, **kwargs) # --- CRITICAL FIX for predict-time NullPointerException --- # The H2O backend's predict method requires the 'lambda' parameter, but the # Python object may hold it as 'lambda_'. We must ensure the final model # object has the correct 'lambda' parameter set in its internal params dict. if self.model_ and "lambda_" in self.model_.params: self.model_.params["lambda"] = self.model_.params.pop("lambda_") return self