# plot_master.py
"""
Master plotting module that provides a single entry point to generate a
comprehensive set of visualizations for ML results analysis.
"""
import os
import pandas as pd
from typing import List, Optional, NoReturn
# Import all the individual plotter classes
from ml_grid.results_processing.plot_algorithms import AlgorithmComparisonPlotter
from ml_grid.results_processing.plot_distributions import DistributionPlotter
from ml_grid.results_processing.plot_features import FeatureAnalysisPlotter
from ml_grid.results_processing.plot_timeline import TimelineAnalysisPlotter
from ml_grid.results_processing.plot_hyperparameters import HyperparameterAnalysisPlotter
from ml_grid.results_processing.plot_feature_categories import FeatureCategoryPlotter
from ml_grid.results_processing.plot_pipeline_parameters import PipelineParameterPlotter
from ml_grid.results_processing.plot_global_importance import GlobalImportancePlotter
from ml_grid.results_processing.plot_interactions import InteractionPlotter
from ml_grid.results_processing.plot_best_model import BestModelAnalyzerPlotter
from ml_grid.results_processing.summarize_results import ResultsSummarizer
[docs]
class MasterPlotter:
"""
A facade that orchestrates specialized plotters to generate analysis plots.
"""
def __init__(self, data: pd.DataFrame, output_dir: str = '.'):
"""Initializes the MasterPlotter with aggregated results data.
This class acts as a facade, instantiating various specialized plotters
to generate a comprehensive suite of analysis visualizations from the
provided results DataFrame.
Args:
data (pd.DataFrame): A DataFrame containing the aggregated ML
experiment results. Must be non-empty.
output_dir (str, optional): The directory where output files (like
CSVs) will be saved. Defaults to '.'.
Raises:
ValueError: If the input `data` is not a valid, non-empty
pandas DataFrame.
"""
if not isinstance(data, pd.DataFrame) or data.empty:
raise ValueError("Input data must be a non-empty pandas DataFrame.")
[docs]
self.output_dir = output_dir
# Instantiate all the specialized plotters
[docs]
self.algo_plotter = AlgorithmComparisonPlotter(self.data)
[docs]
self.dist_plotter = DistributionPlotter(self.data)
[docs]
self.timeline_plotter = TimelineAnalysisPlotter(self.data)
[docs]
self.interaction_plotter = InteractionPlotter(self.data)
# Global importance plotter
try:
self.global_importance_plotter = GlobalImportancePlotter(self.data)
except ValueError as e:
self.global_importance_plotter = None
print(f"Warning: Could not initialize GlobalImportancePlotter. Reason: {e}")
# Pipeline parameter plotter
try:
self.pipeline_plotter = PipelineParameterPlotter(self.data)
except ValueError as e:
self.pipeline_plotter = None
print(f"Warning: Could not initialize PipelineParameterPlotter. Reason: {e}")
# Feature category plotter
try:
self.feature_cat_plotter = FeatureCategoryPlotter(self.data)
except ValueError as e:
self.feature_cat_plotter = None
print(f"Warning: Could not initialize FeatureCategoryPlotter. Reason: {e}")
# Feature plotter requires 'decoded_features' column
if 'decoded_features' in self.data.columns:
self.feature_plotter = FeatureAnalysisPlotter(self.data)
else:
self.feature_plotter = None
print("Warning: 'decoded_features' column not found. Feature-related plots will be skipped.")
# Hyperparameter plotter requires 'algorithm_implementation' column
if 'algorithm_implementation' in self.data.columns:
self.hyperparam_plotter = HyperparameterAnalysisPlotter(self.data)
else:
self.hyperparam_plotter = None
print("Warning: 'algorithm_implementation' column not found. Hyperparameter-related plots will be skipped.")
# Best model plotter
try:
self.best_model_plotter = BestModelAnalyzerPlotter(self.data)
except ValueError as e:
self.best_model_plotter = None
print(f"Warning: Could not initialize BestModelAnalyzerPlotter. Reason: {e}")
# Results summarizer
try:
self.summarizer = ResultsSummarizer(self.data)
except ValueError as e:
self.summarizer = None
print(f"Warning: Could not initialize ResultsSummarizer. Reason: {e}")
[docs]
def plot_all(self,
metric: str = 'auc_m',
stratify_by_outcome: bool = True,
top_n_features: int = 20,
top_n_algorithms: int = 10,
save_best_results: bool = True) -> None:
"""Generates a comprehensive set of standard plots from all plotters.
This method calls the main plotting functions from each specialized
plotter to provide a full overview of the results, including algorithm
comparisons, metric distributions, timeline trends, and feature
importance. It also handles saving a summary of the best models.
Args:
metric (str, optional): The primary performance metric to use for
plotting (e.g., 'auc', 'f1'). Defaults to 'auc_m'.
stratify_by_outcome (bool, optional): If True, creates plots
stratified by the 'outcome_variable' column. Defaults to True.
top_n_features (int, optional): The number of top features to show
in feature-related plots. Defaults to 20.
top_n_algorithms (int, optional): The number of top algorithms to
show in ranking plots. Defaults to 10.
save_best_results (bool, optional): If True, saves a CSV summary of
the best model per outcome. Defaults to True.
"""
print(f"--- Starting MasterPlotter.plot_all() ---", flush=True)
print(f"Parameters: metric='{metric}', stratify_by_outcome={stratify_by_outcome}, save_best_results={save_best_results}", flush=True)
# --- Step 1: Generate and Save Summary Table First ---
# This is done first to avoid being blocked by interactive plot windows.
if self.summarizer and save_best_results:
print("\n>>> 1. Generating and Saving Best Model Summary Table...", flush=True)
try:
# Use 'auc' as the default metric for the summary table, as it's most standard
summary_metric = 'auc' if 'auc' in self.data.columns else metric
print(f" - Using metric '{summary_metric}' for summary table.", flush=True)
if 'outcome_variable' not in self.data.columns:
raise ValueError("'outcome_variable' column is required to find the best model per outcome.")
if 'decoded_features' not in self.data.columns:
raise ValueError("'decoded_features' column is required to expand feature names.")
best_models_df = self.summarizer.get_best_model_per_outcome(metric=summary_metric)
# Ensure the output directory exists before saving
os.makedirs(self.output_dir, exist_ok=True)
output_path = os.path.join(self.output_dir, "best_models_summary.csv")
best_models_df.to_csv(output_path, index=False)
print(f"✅ Best models summary successfully saved to: {os.path.abspath(output_path)}", flush=True)
except Exception as e:
print(f"❌ Warning: Could not generate or save best models summary table. Reason: {e}", flush=True)
elif not self.summarizer:
print("\n>>> 1. Skipping Best Model Summary Table: ResultsSummarizer was not initialized.", flush=True)
elif not save_best_results:
print("\n>>> 1. Skipping Best Model Summary Table: 'save_best_results' is False.", flush=True)
# --- Step 2: Generate Plots ---
print("\n>>> 2. Generating Algorithm Comparison Plots...")
try:
self.algo_plotter.plot_algorithm_boxplots(metric=metric, stratify_by_outcome=stratify_by_outcome)
self.algo_plotter.plot_algorithm_performance_heatmap(metric=metric, aggregation='mean')
self.algo_plotter.plot_algorithm_ranking(metric=metric, stratify_by_outcome=stratify_by_outcome, top_n=top_n_algorithms)
self.algo_plotter.plot_algorithm_stability(metric=metric, top_n=top_n_algorithms)
self.algo_plotter.plot_performance_tradeoff(metric_y=metric, metric_x='run_time', top_n_algos=top_n_algorithms)
self.algo_plotter.plot_pareto_front(metric_y=metric, metric_x='run_time')
self.algo_plotter.plot_statistical_significance_heatmap(metric=metric)
except Exception as e:
print(f"Warning: Could not generate algorithm plots. Reason: {e}")
print("\n>>> 3. Generating Distribution Plots...")
try:
self.dist_plotter.plot_metric_distributions(metrics=[metric, 'f1', 'mcc'], stratify_by_outcome=stratify_by_outcome)
self.dist_plotter.plot_comparative_distributions(metric=metric, plot_type='violin')
except Exception as e:
print(f"Warning: Could not generate distribution plots. Reason: {e}")
print("\n>>> 4. Generating Timeline Plots...")
try:
self.timeline_plotter.plot_performance_timeline(metric=metric, stratify_by_outcome=stratify_by_outcome)
self.timeline_plotter.plot_improvement_trends(metric=metric, stratify_by_outcome=stratify_by_outcome)
self.timeline_plotter.plot_computational_cost_timeline(stratify_by_outcome=stratify_by_outcome)
except Exception as e:
print(f"Warning: Could not generate timeline plots. Reason: {e}")
if self.feature_plotter:
print("\n>>> 5. Generating Feature Analysis Plots...")
try:
self.feature_plotter.plot_feature_usage_frequency(top_n=top_n_features, stratify_by_outcome=stratify_by_outcome)
self.feature_plotter.plot_feature_performance_impact(metric=metric, top_n=top_n_features // 2)
self.feature_plotter.plot_feature_metric_correlation(metric=metric, top_n=top_n_features // 2)
self.feature_plotter.plot_feature_set_intersections(top_n_sets=15, stratify_by_outcome=stratify_by_outcome)
except Exception as e:
print(f"Warning: Could not generate feature plots. Reason: {e}")
if self.hyperparam_plotter:
print("\n>>> 6. Generating Hyperparameter Analysis Plots...")
try:
# Get algorithms that have parsable hyperparameters
available_algos = self.hyperparam_plotter.get_available_algorithms()
if not available_algos:
print("Info: No algorithms with parsable hyperparameters found. Skipping hyperparameter plots.")
else:
# Find the top overall algorithm *that has parsable hyperparameters*
plottable_data = self.data[self.data['method_name'].isin(available_algos)]
if plottable_data.empty:
print("Info: No data found for algorithms with parsable hyperparameters. Skipping.")
else:
top_algo = plottable_data.groupby('method_name')[metric].mean().idxmax()
print(f"Analyzing hyperparameters for top algorithm: {top_algo}")
self.hyperparam_plotter.plot_hyperparameter_importance(algorithm_name=top_algo, metric=metric)
# For the second plot, find a hyperparameter of the top algorithm to analyze
# This uses the parsed data within the hyperparam_plotter instance
algo_data = self.hyperparam_plotter.clean_data[self.hyperparam_plotter.clean_data['algorithm_name'] == top_algo]
# Get all hyperparameters for this algorithm from the first valid param dict
params_list = [p for p in algo_data['params_dict'] if p]
if params_list:
# Find all hyperparameters that have more than one unique value to make interesting plots
hyperparameters_to_plot = []
all_keys = set().union(*(d.keys() for d in params_list))
for param in sorted(list(all_keys)):
if algo_data['params_dict'].apply(lambda p: p.get(param) if p else None).nunique() > 1:
hyperparameters_to_plot.append(param)
if hyperparameters_to_plot:
self.hyperparam_plotter.plot_performance_by_hyperparameter(
algorithm_name=top_algo,
hyperparameters=hyperparameters_to_plot,
metric=metric
)
else:
print(f"Info: No hyperparameters with multiple values found for {top_algo} to generate performance plot.")
except Exception as e:
print(f"Warning: Could not generate hyperparameter plots. Reason: {e}")
if self.feature_cat_plotter:
print("\n>>> 7. Generating Feature Category Analysis Plots...")
try:
# Use 'auc' as requested, but fall back to the main metric if 'auc' is not present
category_metric = 'auc' if 'auc' in self.data.columns else metric
print(f"Analyzing feature category impact on metric: {category_metric.upper()}")
self.feature_cat_plotter.plot_category_performance_boxplots(metric=category_metric)
self.feature_cat_plotter.plot_category_impact_on_metric(metric=category_metric)
except Exception as e:
print(f"Warning: Could not generate feature category plots. Reason: {e}")
if self.pipeline_plotter:
print("\n>>> 8. Generating Pipeline Parameter Analysis Plots...")
try:
# Use 'auc' as requested, but fall back to the main metric if 'auc' is not present
pipeline_metric = 'auc' if 'auc' in self.data.columns else metric
print(f"Analyzing pipeline parameter impact on metric: {pipeline_metric.upper()}")
self.pipeline_plotter.plot_categorical_parameters(metric=pipeline_metric)
self.pipeline_plotter.plot_continuous_parameters(metric=pipeline_metric)
except Exception as e:
print(f"Warning: Could not generate pipeline parameter plots. Reason: {e}")
if self.global_importance_plotter:
print("\n>>> 9. Generating Global Importance Analysis Plot...")
try:
# Use 'auc' as requested, but fall back to the main metric if 'auc' is not present
global_metric = 'auc' if 'auc' in self.data.columns else metric
self.global_importance_plotter.plot_global_importance(metric=global_metric, top_n=40)
except Exception as e:
print(f"Warning: Could not generate global importance plot. Reason: {e}")
if self.interaction_plotter:
print("\n>>> 10. Generating Interaction Analysis Plots...")
try:
# Example interaction plot
self.interaction_plotter.plot_categorical_interaction(param1='resample', param2='scale', metric='auc')
except Exception as e:
print(f"Warning: Could not generate interaction plots. Reason: {e}")
if self.best_model_plotter:
print("\n>>> 11. Generating Best Model Analysis Plots...")
try:
# Use 'auc' as requested, but fall back to the main metric if 'auc' is not present
best_model_metric = 'auc' if 'auc' in self.data.columns else metric
self.best_model_plotter.plot_best_model_summary(metric=best_model_metric)
except Exception as e:
print(f"Warning: Could not generate best model plots. Reason: {e}")
print("\n--- All Plot Generation Complete ---", flush=True)