Source code for ml_grid.util.db_backend

import sqlite3
import json
import logging
from typing import Any, Dict
import numpy as np

[docs] logger = logging.getLogger("ml_grid")
[docs] class DatabaseBackend: """Simple SQLite backend for storing ML experiment results.""" def __init__(self, db_path: str = "ml_results.db"):
[docs] self.db_path = db_path
self._init_db() def _get_connection(self): # Use a 30s timeout to handle concurrent writes during parallel runs return sqlite3.connect(self.db_path, timeout=30.0) def _init_db(self): """Initializes the results table if it doesn't exist.""" with self._get_connection() as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS ml_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_timestamp TEXT, date_time_stamp TEXT, method_name TEXT, outcome_variable TEXT, auc REAL, f1 REAL, accuracy REAL, precision REAL, recall REAL, mcc REAL, support REAL, fit_time REAL, score_time REAL, run_time REAL, nb_size INTEGER, n_features INTEGER, X_train_size INTEGER, X_test_orig_size INTEGER, X_test_size INTEGER, n_fits INTEGER, t_fits INTEGER, param_space_index INTEGER, parameters TEXT, f_list TEXT, resample TEXT, scale TEXT, percent_missing REAL, corr REAL, feature_selection_method TEXT, use_embedding TEXT, embedding_method TEXT, embedding_dim INTEGER, age TEXT, sex TEXT, bmi TEXT, ethnicity TEXT, bloods TEXT, diagnostic_order TEXT, drug_order TEXT, annotation_n TEXT, meta_sp_annotation_n TEXT, meta_sp_annotation_mrc_n TEXT, annotation_mrc_n TEXT, core_02 TEXT, bed TEXT, vte_status TEXT, hosp_site TEXT, core_resus TEXT, news TEXT, failed TEXT, timeout TEXT, run_path TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) conn.commit() def _sanitize_value(self, value: Any) -> Any: """Converts numpy types and search space objects to native Python types for SQLite.""" if isinstance(value, (np.float32, np.float64)): return float(value) if isinstance(value, (np.int32, np.int64)): return int(value) if isinstance(value, (bool, np.bool_)): return bool(value) if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, dict): return { self._sanitize_value(k): self._sanitize_value(v) for k, v in value.items() } if isinstance(value, list): return [self._sanitize_value(item) for item in value] # Handle non-serializable search space objects (skopt Integer, Real, Categorical) if value.__class__.__name__ in ["Integer", "Real", "Categorical"]: return str(value) return value
[docs] def insert_result(self, result_data: Dict[str, Any]): """Inserts a single result record into the database. Args: result_data: Dictionary containing metric names and values. """ # Define all expected columns in the database table columns = [ "run_timestamp", "date_time_stamp", "method_name", "outcome_variable", "auc", "f1", "accuracy", "precision", "recall", "mcc", "support", "fit_time", "score_time", "run_time", "nb_size", "n_features", "X_train_size", "X_test_orig_size", "X_test_size", "n_fits", "t_fits", "param_space_index", "parameters", "f_list", "resample", "scale", "percent_missing", "corr", "feature_selection_method", "use_embedding", "embedding_method", "embedding_dim", "age", "sex", "bmi", "ethnicity", "bloods", "diagnostic_order", "drug_order", "annotation_n", "meta_sp_annotation_n", "meta_sp_annotation_mrc_n", "annotation_mrc_n", "core_02", "bed", "vte_status", "hosp_site", "core_resus", "news", "failed", "timeout", "run_path", ] # Map keys from result_data to columns, handling common aliases values = [] for col in columns: val = result_data.get(col) # Key mapping for aliases used in project_score_save if val is None: if col == "run_timestamp": val = result_data.get("date_time_stamp") elif col == "date_time_stamp": val = result_data.get("run_timestamp") elif col == "parameters": val = result_data.get("parameter_sample") elif col == "fit_time": val = result_data.get("fit_time_m") elif col == "score_time": val = result_data.get("score_time_m") elif col == "param_space_index": val = result_data.get("i") elif col == "nb_size": val = result_data.get("nb_size") elif col == "run_path": val = "." # Apply recursive sanitization for numpy and search space types val = self._sanitize_value(val) # Explicitly handle boolean and flag columns to ensure they are stored as # "True"/"False" or "None" strings to match column type (TEXT). # This prevents sqlite3 from defaulting to 1/0 integers for bool types. if col in [ "failed", "timeout", "age", "sex", "bmi", "ethnicity", "bloods", "diagnostic_order", "drug_order", "annotation_n", "meta_sp_annotation_n", "meta_sp_annotation_mrc_n", "annotation_mrc_n", "core_02", "bed", "vte_status", "hosp_site", "core_resus", "news", "resample", "scale", "use_embedding", "scale_features_before_embedding", ]: val = str(val) # Robust numeric casting for metric columns to prevent string pollution if col in [ "auc", "f1", "accuracy", "precision", "recall", "mcc", "support", "fit_time", "score_time", "run_time", "percent_missing", "corr", ]: try: val = float(val) if val is not None else 0.0 except (ValueError, TypeError): val = 0.0 elif col in [ "nb_size", "n_features", "X_train_size", "X_test_orig_size", "X_test_size", "n_fits", "t_fits", "param_space_index", "embedding_dim", ]: try: val = int(val) if val is not None else 0 except (ValueError, TypeError): val = 0 # Ensure JSON columns have valid defaults if data is missing if col == "parameters" and val is None: val = {} if col == "f_list" and val is None: val = [] if col in ["parameters", "f_list"] and isinstance(val, (dict, list)): val = json.dumps(val) values.append(val) placeholders = ", ".join(["?"] * len(columns)) col_string = ", ".join(columns) query = f"INSERT INTO ml_results ({col_string}) VALUES ({placeholders})" try: with self._get_connection() as conn: conn.execute(query, values) conn.commit() except Exception as e: logger.error(f"Failed to insert result into database: {e}")