ml_grid.util.validate_parameters

Functions to validate model-specific hyperparameters before grid search.

Functions

validate_knn_parameters(→ Union[Dict[str, Any], ...)

Validates the n_neighbors parameter for KNN classifiers.

validate_XGB_parameters(→ Union[Dict[str, Any], ...)

Validates the max_bin parameter for XGBoost.

validate_parameters_helper(→ Union[Dict[str, Any], ...)

Dispatches to model-specific validation or performs generic filtering.

Module Contents

ml_grid.util.validate_parameters.validate_knn_parameters(parameters: Dict[str, Any] | List[Dict[str, Any]], ml_grid_object: Any) Dict[str, Any] | List[Dict[str, Any]][source]

Validates the n_neighbors parameter for KNN classifiers.

This function ensures that the values for n_neighbors do not exceed the number of samples in the training data. If a value is too large, it is capped at n_samples - 1.

Parameters:
  • parameters (Union[Dict[str, Any], List[Dict[str, Any]]]) – The dictionary or list of dictionaries of parameters to validate.

  • ml_grid_object (Any) – The main pipeline object containing the training data (X_train).

Returns:

The validated parameters.

Return type:

Union[Dict[str, Any], List[Dict[str, Any]]]

ml_grid.util.validate_parameters.validate_XGB_parameters(parameters: Dict[str, Any] | List[Dict[str, Any]], ml_grid_object: Any) Dict[str, Any] | List[Dict[str, Any]][source]

Validates the max_bin parameter for XGBoost.

This function checks that the max_bin values are greater than or equal to 2, and if not, it sets them to 2.

Parameters:
  • parameters (Union[Dict[str, Any], List[Dict[str, Any]]]) – The dictionary or list of dictionaries of parameters to validate.

  • ml_grid_object (Any) – The main pipeline object (currently unused).

Returns:

The validated parameters.

Return type:

Union[Dict[str, Any], List[Dict[str, Any]]]

ml_grid.util.validate_parameters.validate_parameters_helper(algorithm_implementation: Any, parameters: Dict[str, Any] | List[Dict[str, Any]], ml_grid_object: Any) Dict[str, Any] | List[Dict[str, Any]][source]

Dispatches to model-specific validation or performs generic filtering.

This function first checks for model-specific validation routines (e.g., for KNN, XGBoost). If no specific routine is found, it performs a generic validation that removes any parameters from the search space that are not valid for the given algorithm instance. This prevents TypeError exceptions from scikit-learn’s search classes.

Parameters:
  • algorithm_implementation (Any) – The scikit-learn estimator instance.

  • parameters (Union[Dict[str, Any], List[Dict[str, Any]]]) – The parameters to validate.

  • ml_grid_object (Any) – The main pipeline object containing training data.

Returns:

The validated parameters.

Return type:

Union[Dict[str, Any], List[Dict[str, Any]]]