Source code for plot_pipeline_parameters

# plot_pipeline_parameters.py
"""
Pipeline parameter analysis plotting module for ML results analysis.
Focuses on visualizing the impact of data transformations and pipeline settings on model performance.
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional

from ml_grid.results_processing.core import get_clean_data

[docs] class PipelineParameterPlotter: """Visualizes the impact of pipeline parameters on model performance.""" def __init__(self, data: pd.DataFrame): """Initializes the PipelineParameterPlotter. Args: data (pd.DataFrame): Results DataFrame, must contain columns for pipeline parameters and performance metrics. Raises: ValueError: If no pipeline parameter columns are found in the data. """
[docs] self.data = data
[docs] self.clean_data = get_clean_data(data)
# Define which parameters are categorical vs continuous based on the request
[docs] self.categorical_params = [ 'resample', 'scale', 'param_space_size', 'percent_missing' ]
# For continuous params, we map the column name to a user-friendly title
[docs] self.continuous_params = { 'nb_size': 'Number of Features Used', # Using nb_size for n_features as it represents selected features 'X_train_size': 'Train Set Size', 'X_test_orig_size': 'Original Test Set Size', 'X_test_size': 'Test Set Size', 'n_fits': 'Number of Fits (Random Search)', 't_fits': 'Total Fits (Grid Search)' }
# Filter to only parameters present in the data
[docs] self.available_categorical = [p for p in self.categorical_params if p in self.clean_data.columns]
[docs] self.available_continuous = {k: v for k, v in self.continuous_params.items() if k in self.clean_data.columns}
if not self.available_categorical and not self.available_continuous: raise ValueError("No pipeline parameter columns (e.g., 'resample', 'scale') found in the provided data.") plt.style.use('default') sns.set_palette("viridis")
[docs] def plot_categorical_parameters( self, metric: str = 'auc', figsize: Optional[Tuple[int, int]] = None ) -> None: """Creates box plots for categorical pipeline parameters. Args: metric (str, optional): The performance metric to plot. Defaults to 'auc'. figsize (Optional[Tuple[int, int]], optional): Figure size for the plot. Defaults to None. Raises: ValueError: If the specified metric is not found in the data. """ if not self.available_categorical: print("Info: No categorical pipeline parameters found to plot.") return if metric not in self.clean_data.columns: raise ValueError(f"Metric '{metric}' not found in data.") n_params = len(self.available_categorical) cols = min(4, n_params) rows = (n_params + cols - 1) // cols fig_size = figsize or (cols * 5, rows * 4) fig, axes = plt.subplots(rows, cols, figsize=fig_size, squeeze=False) axes = axes.flatten() for i, param in enumerate(self.available_categorical): ax = axes[i] # Handle potential NaN values by filling them with a string representation plot_data = self.clean_data.copy() if plot_data[param].isnull().any(): plot_data[param] = plot_data[param].fillna('None') # For boolean 'scale', map to more descriptive labels if param == 'scale': plot_data[param] = plot_data[param].astype(str).map({'True': 'Scaled', 'False': 'Not Scaled', 'None': 'None'}) order = sorted(plot_data[param].unique()) sns.boxplot(data=plot_data, x=param, y=metric, ax=ax, palette="Set3", order=order, hue=param, legend=False) ax.set_title(f'{param.replace("_", " ").title()}', fontsize=11, fontweight='bold') ax.set_xlabel('Parameter Value') ax.set_ylabel(metric.upper() if i % cols == 0 else '') ax.tick_params(axis='x', rotation=45) ax.grid(True, alpha=0.3) for j in range(i + 1, len(axes)): axes[j].set_visible(False) plt.suptitle(f'Impact of Categorical Pipeline Parameters on {metric.upper()}', fontsize=16, fontweight='bold') plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show()
[docs] def plot_continuous_parameters( self, metric: str = 'auc', figsize: Optional[Tuple[int, int]] = None ) -> None: """Creates scatter plots for continuous pipeline parameters. Args: metric (str, optional): The performance metric to plot. Defaults to 'auc'. figsize (Optional[Tuple[int, int]], optional): Figure size for the plot. Defaults to None. Raises: ValueError: If the specified metric is not found in the data. """ if not self.available_continuous: print("Info: No continuous pipeline parameters found to plot.") return if metric not in self.clean_data.columns: raise ValueError(f"Metric '{metric}' not found in data.") n_params = len(self.available_continuous) cols = min(3, n_params) rows = (n_params + cols - 1) // cols fig_size = figsize or (cols * 6, rows * 5) fig, axes = plt.subplots(rows, cols, figsize=fig_size, squeeze=False) axes = axes.flatten() for i, (param, title) in enumerate(self.available_continuous.items()): ax = axes[i] plot_data = self.clean_data.copy() # Ensure data is numeric and drop NaNs for plotting plot_data[param] = pd.to_numeric(plot_data[param], errors='coerce') plot_data.dropna(subset=[param, metric], inplace=True) if plot_data.empty: ax.text(0.5, 0.5, 'No Valid Data', ha='center', va='center', transform=ax.transAxes) ax.set_title(title, fontsize=11, fontweight='bold') continue sns.scatterplot(data=plot_data, x=param, y=metric, ax=ax, alpha=0.5, edgecolor=None, s=15) sns.regplot(data=plot_data, x=param, y=metric, ax=ax, scatter=False, color='red', line_kws={'linewidth': 2}) ax.set_title(title, fontsize=11, fontweight='bold') ax.set_xlabel(title) ax.set_ylabel(metric.upper() if i % cols == 0 else '') ax.grid(True, alpha=0.3) for j in range(i + 1, len(axes)): axes[j].set_visible(False) plt.suptitle(f'Correlation of Continuous Pipeline Parameters with {metric.upper()}', fontsize=16, fontweight='bold') plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show()