Source code for ml_grid.pipeline.column_names

from typing import Any, Dict, List, Tuple

from fuzzysearch import find_near_matches

from ml_grid.pipeline.data_plot_split import (
    plot_candidate_feature_category_lists,
    plot_dict_values,
)
from ml_grid.util.global_params import global_parameters


[docs] def filter_substring_list(string: List[str], substr: List[str]) -> List[str]: """Filters a list of strings based on a list of substrings. Args: string (List[str]): The list of strings to filter. substr (List[str]): The list of substrings to search for. Returns: List[str]: A new list containing strings from the input list that contain any of the specified substrings, excluding those containing "bmi". """ return [ s for s in string if any(sub in s for sub in substr) and "bmi" not in s ]
[docs] def get_pertubation_columns( all_df_columns: List[str], local_param_dict: Dict[str, Any], drop_term_list: List[str], ) -> Tuple[List[str], List[str]]: """Identifies and categorizes columns for perturbation and dropping. This function processes a list of all DataFrame columns, categorizing them into groups like blood tests, diagnostic orders, etc. It also identifies columns to be dropped based on specific keywords. The selection of columns for 'perturbation' is determined by flags within `local_param_dict`. Args: all_df_columns (List[str]): A list of all column names in the DataFrame. local_param_dict (Dict[str, Any]): A dictionary containing local parameters, including 'outcome_var_n' and a 'data' sub-dictionary that specifies which column categories to include for perturbation (e.g., 'age', 'sex', 'bmi', 'bloods'). drop_term_list (List[str]): A list of strings. Any column name containing these strings (case-insensitive) will be added to the `drop_list`. Returns: Tuple[List[str], List[str]]: A tuple containing two lists: - pertubation_columns: A list of column names selected for perturbation based on the `local_param_dict` settings. - drop_list: A list of column names identified to be dropped from the DataFrame. """ global_params = global_parameters verbose = global_params.verbose orignal_feature_names = all_df_columns drop_list = [] index_level_list = list(filter(lambda k: "__index_level" in k, all_df_columns)) drop_list.extend(index_level_list) Unnamed_list = list(filter(lambda k: "Unnamed:" in k, all_df_columns)) drop_list.extend(Unnamed_list) Unnamed_list = list(filter(lambda k: "client_idcode:" in k, all_df_columns)) drop_list.extend(Unnamed_list) outcome_variable = f'outcome_var_{local_param_dict.get("outcome_var_n")}' for i in range(0, len(drop_term_list)): drop_term_string = drop_term_list[i] for elem in all_df_columns: res = find_near_matches(drop_term_string, elem.lower(), max_l_dist=0) if len(res) > 0: drop_list.append(elem) blood_test_substrings = [ "_mean", "_median", "_mode", "_std", "_num-tests", "_days-since-last-test", "_max", "_min", "_most-recent", "_earliest-test", "_days-between-first-last", "_contains-extreme-low", "_contains-extreme-high", "_basic-obs-feature", ] diagnostic_test_substrings = [ "_num-diagnostic-order", "_days-since-last-diagnostic-order", "_days-between-first-last-diagnostic", ] drug_order_substrings = [ "_num-drug-order", "_days-since-last-drug-order", "_days-between-first-last-drug", ] annotation_count_list = list( filter( lambda k: "_count" in k and "_count_subject_present" not in k, all_df_columns, ) ) appointments_substrings = ["ConsultantCode_", "ClinicCode_", "AppointmentType_"] meta_sp_annotation_count_list = list( filter(lambda k: "_count_subject_present" in k, all_df_columns) ) not_meta_sp_annotation_count_list = list( filter(lambda k: "_count_subject_not_present" in k, all_df_columns) ) meta_rp_annotation_count_list = list( filter(lambda k: "_count_relative_present" in k, all_df_columns) ) not_meta_rp_annotation_count_list = list( filter(lambda k: "_count_relative_not_present" in k, all_df_columns) ) meta_sp_annotation_count_list.extend(not_meta_sp_annotation_count_list) meta_sp_annotation_count_list.extend(meta_rp_annotation_count_list) meta_sp_annotation_count_list.extend(not_meta_rp_annotation_count_list) diagnostic_order_list = [] diagnostic_list = filter_substring_list(all_df_columns, diagnostic_test_substrings) diagnostic_order_list.extend(diagnostic_list) drug_order_list = [] drug_list = filter_substring_list(all_df_columns, drug_order_substrings) drug_order_list.extend(drug_list) appointments_list = [] appointments = filter_substring_list(all_df_columns, appointments_substrings) appointments_list.extend(appointments) bmi_list = list(filter(lambda k: "bmi_" in k, all_df_columns)) ethnicity_list = list(filter(lambda k: "census_" in k, all_df_columns)) annotation_mrc_count_list = list( filter(lambda k: "_count_mrc_cs" in k, all_df_columns) ) meta_sp_annotation_mrc_count_list = list( filter(lambda k: "_count_subject_present_mrc_cs" in k, all_df_columns) ) not_meta_sp_annotation_mrc_count_list = list( filter(lambda k: "_count_subject_not_present_mrc_cs" in k, all_df_columns) ) relative_meta_rp_annotation_mrc_count_list = list( filter(lambda k: "_count_relative_present_mrc_cs" in k, all_df_columns) ) not_relative_meta_rp_annotation_mrc_count_list = list( filter(lambda k: "_count_relative_not_present_mrc_cs" in k, all_df_columns) ) meta_sp_annotation_mrc_count_list.extend(not_meta_sp_annotation_mrc_count_list) meta_sp_annotation_mrc_count_list.extend(relative_meta_rp_annotation_mrc_count_list) meta_sp_annotation_mrc_count_list.extend( not_relative_meta_rp_annotation_mrc_count_list ) core_02_list = list(filter(lambda k: "core_02_" in k, all_df_columns)) bed_list = list(filter(lambda k: "bed_" in k, all_df_columns)) vte_status_list = list(filter(lambda k: "vte_status_" in k, all_df_columns)) hosp_site_list = list(filter(lambda k: "hosp_site_" in k, all_df_columns)) core_resus_list = list(filter(lambda k: "core_resus_" in k, all_df_columns)) news_list = list(filter(lambda k: "news_resus_" in k, all_df_columns)) bloods_list = filter_substring_list(all_df_columns, blood_test_substrings) date_time_stamp_list = list( filter(lambda k: "date_time_stamp" in k, all_df_columns) ) # Combine these into a single conceptual list for overlap check later meta_sp_annotation_all_counts = ( meta_sp_annotation_count_list + not_meta_sp_annotation_count_list + meta_rp_annotation_count_list + not_meta_rp_annotation_count_list ) # Combine these into a single conceptual list for overlap check later meta_sp_annotation_mrc_all_counts = ( meta_sp_annotation_mrc_count_list + not_meta_sp_annotation_mrc_count_list + relative_meta_rp_annotation_mrc_count_list + not_relative_meta_rp_annotation_mrc_count_list ) # --- Post-Processing: Remove overlaps from bloods_list --- # Create a set of all columns in other categories all_other_categorized_cols = set() # Add all columns from other specific lists to this set all_other_categorized_cols.update(annotation_count_list) all_other_categorized_cols.update(meta_sp_annotation_all_counts) # Use the combined list all_other_categorized_cols.update(diagnostic_order_list) all_other_categorized_cols.update(drug_order_list) all_other_categorized_cols.update(bmi_list) all_other_categorized_cols.update(ethnicity_list) all_other_categorized_cols.update(annotation_mrc_count_list) all_other_categorized_cols.update(meta_sp_annotation_mrc_all_counts) # Use the combined list all_other_categorized_cols.update(core_02_list) all_other_categorized_cols.update(bed_list) all_other_categorized_cols.update(vte_status_list) all_other_categorized_cols.update(hosp_site_list) all_other_categorized_cols.update(core_resus_list) all_other_categorized_cols.update(news_list) all_other_categorized_cols.update(date_time_stamp_list) all_other_categorized_cols.update(appointments_list) # Filter bloods_list: keep only elements NOT found in any other category to avoid vte status and others being added to bloods. bloods_list = [col for col in bloods_list if col not in all_other_categorized_cols] candidate_feature_category_lists = [ meta_sp_annotation_count_list, annotation_count_list, diagnostic_order_list, drug_order_list, bmi_list, ethnicity_list, annotation_mrc_count_list, meta_sp_annotation_mrc_count_list, core_02_list, bed_list, vte_status_list, hosp_site_list, core_resus_list, news_list, bloods_list, date_time_stamp_list, appointments_list ] if verbose >= 2: data = {} for i, lst in enumerate(candidate_feature_category_lists, start=1): var_name = [name for name, var in locals().items() if var is lst][0] data[var_name] = len(lst) plot_candidate_feature_category_lists(data) elif verbose >= 1: for i, lst in enumerate(candidate_feature_category_lists, start=1): var_name = [name for name, var in locals().items() if var is lst][0] print(f"{var_name}: {len(lst)}") pertubation_columns = [] if local_param_dict.get("data").get("age") == True: pertubation_columns.append("age") if local_param_dict.get("data").get("sex") == True: pertubation_columns.append("male") if local_param_dict.get("data").get("bmi") == True: pertubation_columns.extend(bmi_list) if local_param_dict.get("data").get("ethnicity") == True: pertubation_columns.extend(ethnicity_list) if local_param_dict.get("data").get("bloods") == True: pertubation_columns.extend(bloods_list) if local_param_dict.get("data").get("diagnostic_order") == True: pertubation_columns.extend(diagnostic_order_list) if local_param_dict.get("data").get("drug_order") == True: pertubation_columns.extend(drug_order_list) if local_param_dict.get("data").get("annotation_n") == True: pertubation_columns.extend(annotation_count_list) if local_param_dict.get("data").get("meta_sp_annotation_n") == True: pertubation_columns.extend(meta_sp_annotation_count_list) if local_param_dict.get("data").get("annotation_mrc_n") == True: pertubation_columns.extend(annotation_mrc_count_list) if local_param_dict.get("data").get("meta_sp_annotation_mrc_n") == True: pertubation_columns.extend(meta_sp_annotation_mrc_count_list) if local_param_dict.get("data").get("core_02") == True: pertubation_columns.extend(core_02_list) if local_param_dict.get("data").get("bed") == True: pertubation_columns.extend(bed_list) if local_param_dict.get("data").get("vte_status") == True: pertubation_columns.extend(vte_status_list) if local_param_dict.get("data").get("hosp_site") == True: pertubation_columns.extend(hosp_site_list) if local_param_dict.get("data").get("core_resus") == True: pertubation_columns.extend(core_resus_list) if local_param_dict.get("data").get("news") == True: pertubation_columns.extend(news_list) if local_param_dict.get("data").get("date_time_stamp") == True: pertubation_columns.extend(date_time_stamp_list) if local_param_dict.get("data").get("appointments") == True: pertubation_columns.extend(appointments_list) print(f"local_param_dict data perturbation: \n {local_param_dict.get('data')}") if verbose >= 2: plot_dict_values(local_param_dict.get("data")) return pertubation_columns, drop_list