"""Functions to validate model-specific hyperparameters before grid search."""
from typing import Any, Dict
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from ml_grid.model_classes.knn_gpu_classifier_class import knn__gpu_wrapper_class
from ml_grid.model_classes.knn_wrapper_class import KNNWrapper
[docs]
def validate_knn_parameters(
parameters: Dict[str, Any], ml_grid_object: Any
) -> Dict[str, Any]:
"""Validates the `n_neighbors` parameter for KNN classifiers.
This function ensures that the values for `n_neighbors` do not exceed the
number of samples in the training data. If a value is too large, it is
capped at `n_samples - 1`.
Args:
parameters (Dict[str, Any]): The dictionary of parameters to validate.
ml_grid_object (Any): The main pipeline object containing the training
data (`X_train`).
Returns:
Dict[str, Any]: The validated parameters dictionary.
"""
# Get the number of samples in the training data
print("Validating KNN parameters")
X_train = ml_grid_object.X_train
n_samples = X_train.shape[0]
print(f" n_samples: {n_samples}")
# Get the maximum number of neighbors
max_neighbors = n_samples - 1
print(f" max_neighbors: {max_neighbors}")
# Get the n_neighbors values from the parameters
n_neighbors = parameters.get("n_neighbors")
print(f" n_neighbors: {n_neighbors}")
# Check if any n_neighbors values are too large
if n_neighbors is not None:
for i in range(len(n_neighbors)):
if n_neighbors[i] > max_neighbors:
print(f" n_neighbors[{i}] is greater than max_neighbors")
# If so, reduce the value to be within the allowed range
print(
f" Reducing n_neighbors[{i}] from {n_neighbors[i]} to {max_neighbors}"
)
n_neighbors[i] = max_neighbors
parameters["n_neighbors"] = n_neighbors
# Return the validated parameters
return parameters
[docs]
def validate_XGB_parameters(
parameters: Dict[str, Any], ml_grid_object: Any
) -> Dict[str, Any]:
"""Validates the `max_bin` parameter for XGBoost.
This function checks that the max_bin values are greater than or equal to 2,
and if not, it sets them to 2.
Args:
parameters (Dict[str, Any]): The dictionary of parameters to validate.
ml_grid_object (Any): The main pipeline object (currently unused).
Returns:
Dict[str, Any]: The validated parameters dictionary.
"""
max_bin_array = parameters.get("max_bin")
# Iterate over each value in the max_bin array
for i in range(len(max_bin_array)):
# Check if the value is less than 2
if max_bin_array[i] < 2:
# If so, set it to 2
max_bin_array[i] = 2
# Update the max_bin array in the parameter combination
parameters["max_bin"] = max_bin_array
return parameters
[docs]
def validate_parameters_helper(
algorithm_implementation: Any, parameters: Dict[str, Any], ml_grid_object: Any
) -> Dict[str, Any]:
"""Dispatches to the correct parameter validation function based on algorithm type.
Args:
algorithm_implementation (Any): The scikit-learn estimator instance.
parameters (Dict[str, Any]): The dictionary of parameters to validate.
ml_grid_object (Any): The main pipeline object containing training data.
Returns:
Dict[str, Any]: The validated parameters dictionary.
"""
if type(algorithm_implementation) == KNeighborsClassifier:
parameters = validate_knn_parameters(parameters, ml_grid_object)
return parameters
elif type(algorithm_implementation) == KNNWrapper:
parameters = validate_knn_parameters(parameters, ml_grid_object)
return parameters
elif type(algorithm_implementation) == knn__gpu_wrapper_class:
parameters = validate_knn_parameters(parameters, ml_grid_object)
return parameters
elif type(algorithm_implementation) == XGBClassifier:
parameters = validate_XGB_parameters(parameters, ml_grid_object)
return parameters
else:
return parameters