"""H2O RuleFit Classifier Wrapper.
This module provides a scikit-learn compatible wrapper for H2O's RuleFitEstimator.
"""
import pandas as pd
from h2o.estimators import H2ORuleFitEstimator
from .H2OBaseClassifier import H2OBaseClassifier
[docs]
class H2ORuleFitClassifier(H2OBaseClassifier):
"""A scikit-learn compatible wrapper for H2O's RuleFit."""
def __init__(self, **kwargs):
"""Initializes the H2ORuleFitClassifier.
Args:
**kwargs: Keyword arguments passed directly to the
`H2ORuleFitEstimator`. Common arguments include
`max_rule_length=3` and `model_type='rules_and_linear'`.
"""
# Remove estimator_class from kwargs if present (happens during sklearn clone)
kwargs.pop("estimator_class", None)
# Pass the specific estimator class
super().__init__(estimator_class=H2ORuleFitEstimator, **kwargs)
def _prepare_fit(self, X: pd.DataFrame, y: pd.Series):
"""Overrides the base _prepare_fit to add RuleFit-specific validation.
This method checks for invalid parameter combinations and datasets that
are known to be unstable for RuleFit, such as a single constant feature.
"""
# Call the base class's _prepare_fit to get the initial setup
train_h2o, x_vars, outcome_var, model_params = super()._prepare_fit(X, y)
# --- CRITICAL FIX for server crashes on single constant feature ---
# H2O's RuleFit fails if the only feature is constant.
if len(x_vars) == 1 and X[x_vars[0]].nunique() <= 1:
raise ValueError(
"H2ORuleFitClassifier: Dataset has a single constant feature, which is "
"unstable for RuleFit. Halting execution."
)
# --- FIX for invalid parameter combinations ---
min_len = model_params.get("min_rule_length")
max_len = model_params.get("max_rule_length")
if min_len is not None and max_len is not None and min_len > max_len:
self.logger.warning(
f"Warning: Invalid H2ORuleFit params detected: min_rule_length ({min_len}) > max_rule_length ({max_len}). "
f"Correcting min_rule_length to {max_len} to proceed."
)
model_params["min_rule_length"] = max_len
return train_h2o, x_vars, outcome_var, model_params
# The fit() method is now inherited from H2OBaseClassifier and will use the
# parameters returned by our overridden _prepare_fit().