Source code for plot_best_model

# plot_best_model.py
"""
Module for analyzing and visualizing the single best performing model for each outcome.
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Optional, Tuple, Dict, Any
import warnings
import textwrap
import ast

from ml_grid.results_processing.core import get_clean_data
from ml_grid.results_processing.plot_hyperparameters import HyperparameterAnalysisPlotter # To reuse parsing logic

# Limit on how many outcomes to plot automatically to avoid generating too many figures.
[docs] MAX_OUTCOMES_TO_PLOT = 10
[docs] class BestModelAnalyzerPlotter: """ Analyzes and plots the characteristics of the best performing model for each outcome variable. """ def __init__(self, data: pd.DataFrame): """ Initialize the plotter. Args: data: Aggregated results DataFrame. Must contain 'outcome_variable'. """ if 'outcome_variable' not in data.columns: raise ValueError("Data must contain an 'outcome_variable' column for this analysis.")
[docs] self.data = data
[docs] self.clean_data = get_clean_data(data)
# Define feature categories and pipeline parameters from other modules for consistency
[docs] self.feature_categories = [ 'age', 'sex', 'bmi', 'ethnicity', 'bloods', 'diagnostic_order', 'drug_order', 'annotation_n', 'meta_sp_annotation_n', 'meta_sp_annotation_mrc_n', 'annotation_mrc_n', 'core_02', 'bed', 'vte_status', 'hosp_site', 'core_resus', 'news', 'date_time_stamp' ]
[docs] self.pipeline_params = ['resample', 'scale', 'param_space_size', 'percent_missing']
plt.style.use('default') sns.set_palette("muted") def _get_best_models(self, metric: str) -> pd.DataFrame: """Finds the single best model for each outcome variable. Args: metric (str): The performance metric to use for determining the best model. Returns: pd.DataFrame: A DataFrame containing the single best run for each outcome, sorted by the specified metric in descending order. """ if metric not in self.clean_data.columns: raise ValueError(f"Metric '{metric}' not found in the data.") # Find the index of the maximum metric value for each outcome group best_indices = self.clean_data.loc[self.clean_data.groupby('outcome_variable')[metric].idxmax()] return best_indices.sort_values(by=metric, ascending=False)
[docs] def plot_best_model_summary(self, metric: str = 'auc', outcomes_to_plot: Optional[List[str]] = None, figsize: Tuple[int, int] = (14, 9)): """Generates a summary plot for the best model of each outcome. This method finds the best performing model for each outcome and creates a detailed 2x2 plot summarizing its algorithm, performance, hyperparameters, and pipeline settings. Args: metric (str, optional): The metric to determine the "best" model. Defaults to 'auc'. outcomes_to_plot (Optional[List[str]], optional): A specific list of outcomes to analyze. If None, analyzes all outcomes up to a limit. Defaults to None. figsize (Tuple[int, int], optional): The figure size for each summary plot. Defaults to (14, 9). """ best_models_df = self._get_best_models(metric) if outcomes_to_plot: # Filter to only the requested outcomes best_models_df = best_models_df[best_models_df['outcome_variable'].isin(outcomes_to_plot)] if best_models_df.empty: print(f"Warning: No data found for the specified outcomes: {outcomes_to_plot}") return elif len(best_models_df) > MAX_OUTCOMES_TO_PLOT: warnings.warn( f"Found {len(best_models_df)} unique outcomes. To avoid excessive plotting, " f"showing summaries for the top {MAX_OUTCOMES_TO_PLOT} outcomes based on the '{metric}' metric. " "Use the 'outcomes_to_plot' parameter to specify which outcomes to analyze.", stacklevel=2 ) best_models_df = best_models_df.head(MAX_OUTCOMES_TO_PLOT) print(f"--- Generating Best Model Summaries (Metric: {metric.upper()}) ---") for _, model_series in best_models_df.iterrows(): self._plot_single_model_summary(model_series, metric, figsize)
def _plot_single_model_summary(self, model_series: pd.Series, metric: str, figsize: Tuple[int, int]): """Generates a single 2x2 summary plot for one model series. Args: model_series (pd.Series): A row from the DataFrame representing the best model for a single outcome. metric (str): The primary performance metric being used. figsize (Tuple[int, int]): The figure size for the plot. """ fig, axes = plt.subplots(2, 2, figsize=figsize) fig.suptitle(f"Best Model Analysis for: {model_series['outcome_variable']}", fontsize=16, fontweight='bold') # Subplot 1: Key Information (Text) self._plot_key_info(axes[0, 0], model_series, metric) # Subplot 2: Hyperparameters self._plot_hyperparameters(axes[0, 1], model_series) # Subplot 3: Feature Categories Used self._plot_feature_categories(axes[1, 0], model_series) # Subplot 4: Pipeline Parameters self._plot_pipeline_parameters(axes[1, 1], model_series) plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show() def _plot_key_info(self, ax: plt.Axes, model_series: pd.Series, metric: str): """Plots key model and performance info on a given axis. Args: ax (plt.Axes): The matplotlib axis to plot on. model_series (pd.Series): The data for the best model. metric (str): The name of the primary metric. """ ax.set_title("Model & Performance Summary", fontsize=12, fontweight='bold') ax.axis('off') score = model_series.get(metric, 'N/A') score_str = f"{score:.4f}" if isinstance(score, (int, float)) else str(score) info_text = ( f"Algorithm: {model_series.get('method_name', 'N/A')}\n" f"Best Score ({metric.upper()}): {score_str}\n" f"Number of Features: {model_series.get('nb_size', 'N/A')}\n" f"Run Timestamp: {model_series.get('run_timestamp', 'N/A')}\n\n" f"Other Metrics:\n" f" - F1: {model_series.get('f1', 'N/A'):.4f}\n" f" - MCC: {model_series.get('mcc', 'N/A'):.4f}\n" f" - Accuracy: {model_series.get('accuracy', 'N/A'):.4f}\n" ) ax.text(0.05, 0.95, info_text, transform=ax.transAxes, ha='left', va='top', fontsize=11, bbox=dict(boxstyle='round,pad=0.5', fc='aliceblue', ec='grey', lw=1)) def _plot_hyperparameters(self, ax: plt.Axes, model_series: pd.Series): """Plots the hyperparameters of the model on a given axis. Args: ax (plt.Axes): The matplotlib axis to plot on. model_series (pd.Series): The data for the best model. """ ax.set_title("Hyperparameters", fontsize=12, fontweight='bold') ax.axis('off') params = {} if 'algorithm_implementation' in model_series and pd.notna(model_series['algorithm_implementation']): # Reuse parsing logic from HyperparameterAnalysisPlotter params = HyperparameterAnalysisPlotter._parse_model_string_to_params(model_series['algorithm_implementation']) if not params: ax.text(0.5, 0.5, "Hyperparameters not available\nor could not be parsed.", transform=ax.transAxes, ha='center', va='center', fontsize=10) return param_str = "" for key, val in params.items(): val_str = str(val) if len(val_str) > 40: val_str = textwrap.fill(val_str, width=40, subsequent_indent=' ') param_str += f"{key}: {val_str}\n" ax.text(0.05, 0.95, param_str.strip(), transform=ax.transAxes, ha='left', va='top', fontsize=9, family='monospace', bbox=dict(boxstyle='round,pad=0.5', fc='lightyellow', ec='grey', lw=1)) def _plot_feature_categories(self, ax: plt.Axes, model_series: pd.Series): """Plots which feature categories were used on a given axis. Args: ax (plt.Axes): The matplotlib axis to plot on. model_series (pd.Series): The data for the best model. """ ax.set_title("Feature Categories Used", fontsize=12, fontweight='bold') used_categories = {} for cat in self.feature_categories: if cat in model_series and pd.notna(model_series[cat]): val = model_series[cat] try: is_used = ast.literal_eval(str(val).capitalize()) if isinstance(val, str) else bool(val) except (ValueError, SyntaxError): is_used = False if is_used: used_categories[cat.replace("_", " ").title()] = 1 if not used_categories: ax.text(0.5, 0.5, "No feature category information available.", transform=ax.transAxes, ha='center', va='center', fontsize=10) ax.set_xticks([]) ax.set_yticks([]) return cat_df = pd.DataFrame.from_dict(used_categories, orient='index', columns=['Used']).sort_index() sns.barplot(x=cat_df.index, y=cat_df['Used'], ax=ax, palette='viridis', hue=cat_df.index, legend=False) ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') ax.set_xlabel("") ax.set_ylabel("Enabled") ax.set_yticks([0, 1]) ax.set_yticklabels(['', 'Yes']) ax.grid(axis='y', linestyle='--', alpha=0.7) def _plot_pipeline_parameters(self, ax: plt.Axes, model_series: pd.Series): """Plots the pipeline settings on a given axis. Args: ax (plt.Axes): The matplotlib axis to plot on. model_series (pd.Series): The data for the best model. """ ax.set_title("Pipeline Settings", fontsize=12, fontweight='bold') ax.axis('off') pipeline_settings = {} for param in self.pipeline_params: if param in model_series and pd.notna(model_series[param]): pipeline_settings[param.replace("_", " ").title()] = model_series[param] if not pipeline_settings: ax.text(0.5, 0.5, "No pipeline setting information available.", transform=ax.transAxes, ha='center', va='center', fontsize=10) return settings_str = "" for key, val in pipeline_settings.items(): settings_str += f"{key}: {val}\n" ax.text(0.05, 0.95, settings_str.strip(), transform=ax.transAxes, ha='left', va='top', fontsize=11, bbox=dict(boxstyle='round,pad=0.5', fc='honeydew', ec='grey', lw=1))