# plot_feature_categories.py
"""
Feature category analysis plotting module for ML results analysis.
Focuses on visualizing the impact of including different data source categories on model performance.
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Optional
import warnings
import ast
from ml_grid.results_processing.core import get_clean_data
[docs]
class FeatureCategoryPlotter:
"""Visualizes the impact of feature categories on model performance.
These categories correspond to the boolean flags that control which data
sources are included at the start of the data pipeline.
"""
def __init__(self, data: pd.DataFrame):
"""Initializes the FeatureCategoryPlotter.
Args:
data (pd.DataFrame): Results DataFrame, must contain boolean columns
for feature categories and performance metrics.
Raises:
ValueError: If no feature category columns are found in the data.
"""
[docs]
self.clean_data = get_clean_data(data)
[docs]
self.feature_categories = [
'age', 'sex', 'bmi', 'ethnicity', 'bloods', 'diagnostic_order', 'drug_order',
'annotation_n', 'meta_sp_annotation_n', 'meta_sp_annotation_mrc_n',
'annotation_mrc_n', 'core_02', 'bed', 'vte_status', 'hosp_site',
'core_resus', 'news', 'date_time_stamp'
]
# Filter to only categories present in the data
[docs]
self.available_categories = [cat for cat in self.feature_categories if cat in self.clean_data.columns]
if not self.available_categories:
raise ValueError("No feature category columns (e.g., 'bloods', 'age') found in the provided data.")
plt.style.use('default')
sns.set_palette("viridis")
[docs]
def plot_category_impact_on_metric(
self, metric: str = 'auc', figsize: Tuple[int, int] = (10, 8)
) -> None:
"""Plots the impact of including each feature category on a metric.
Impact is calculated as:
(Mean metric with category) - (Mean metric without category)
Args:
metric (str, optional): The performance metric to evaluate.
Defaults to 'auc'.
figsize (Tuple[int, int], optional): The figure size for the plot.
Defaults to (10, 8).
Raises:
ValueError: If the specified metric is not found in the data.
"""
if metric not in self.clean_data.columns:
raise ValueError(f"Metric '{metric}' not found in data.")
plot_data = self.clean_data.copy()
impact_data = []
for category in self.available_categories:
# Ensure boolean type
if plot_data[category].dtype != bool:
try:
if plot_data[category].apply(type).eq(str).all():
plot_data[category] = plot_data[category].apply(ast.literal_eval)
plot_data[category] = plot_data[category].astype(bool)
except Exception:
continue
# Check if both True and False values exist
if plot_data[category].nunique() < 2:
continue
mean_with = plot_data[plot_data[category] == True][metric].mean()
mean_without = plot_data[plot_data[category] == False][metric].mean()
impact = mean_with - mean_without
if not pd.isna(impact):
impact_data.append({
'category': category.replace("_", " ").title(),
'impact': impact,
})
if not impact_data:
print("Could not calculate impact for any feature categories. This may be because no categories had both included and excluded runs.")
return
impact_df = pd.DataFrame(impact_data).sort_values('impact', ascending=False)
plt.figure(figsize=figsize)
colors = ['#3a923a' if x > 0 else '#c14242' for x in impact_df['impact']]
ax = sns.barplot(x='impact', y='category', data=impact_df, orient='h', palette=colors, hue='category', legend=False)
ax.set_title(f'Impact of Including Feature Category on {metric.upper()}', fontsize=14, fontweight='bold')
ax.set_xlabel(f'Change in Mean {metric.upper()} (Included vs. Excluded)', fontsize=12)
ax.set_ylabel('Feature Category', fontsize=12)
ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
# Add value labels
for container in ax.containers:
ax.bar_label(container, fmt='%.4f', padding=3, fontsize=9)
plt.tight_layout()
plt.show()