import re
from typing import List
import pandas as pd
from ml_grid.util.global_params import global_parameters
[docs]
class clean_up_class:
"""A class for cleaning and preparing DataFrame columns."""
def __init__(self):
"""Initializes the clean_up_class."""
[docs]
self.global_params = global_parameters
[docs]
self.verbose = self.global_params.verbose
[docs]
self.rename_cols = self.global_params.rename_cols
[docs]
def handle_duplicated_columns(self, X: pd.DataFrame) -> pd.DataFrame:
"""Drops duplicated columns from a DataFrame.
Args:
X (pd.DataFrame): DataFrame to drop duplicated columns from.
Returns:
pd.DataFrame: A copy of X with duplicated columns dropped.
Raises:
AssertionError: If X is None before or after processing.
"""
try:
if self.verbose > 1:
print("dropping duplicated columns")
assert X is not None, "Null pointer exception: X cannot be None."
X = X.loc[:, ~X.columns.duplicated()].copy()
assert X is not None, (
"Null pointer exception: X cannot be None after dropping "
"duplicated columns."
)
except AssertionError as e:
print(str(e))
raise
except Exception as e:
print(f"Unhandled exception: {e}")
raise
return X
[docs]
def screen_non_float_types(self, X: pd.DataFrame) -> None:
"""Screens and prints columns that are not of float or int type.
Args:
X (pd.DataFrame): The DataFrame to screen.
"""
if self.verbose > 1:
print("Screening for non float data types:")
for col in X.columns:
if X[col].dtype != int and X[col].dtype != float:
print(col)
[docs]
def handle_column_names(self, X: pd.DataFrame) -> pd.DataFrame:
"""Renames columns to remove characters unsupported by some ML models.
This function renames columns in a DataFrame (X) that contain
characters like '[', ']', or '<', which can cause issues with models
like XGBoost. These characters are replaced with underscores.
The renaming is controlled by the `self.rename_cols` attribute.
Args:
X (pd.DataFrame): DataFrame with columns to be potentially renamed.
Returns:
pd.DataFrame: A copy of X with renamed columns if applicable.
"""
if self.rename_cols:
# define a regular expression that matches "[", "]", "<" in the
# column name
regex = re.compile(r"\[|\]|<", re.IGNORECASE)
# create a new list of column names
new_col_names: List[str] = []
# loop through all the column names in X
for col in X.columns.values:
# check if the column name contains any of the characters "[", "]",
# "<" using the any() function
if any(char in str(col) for char in {"[", "]", "<"}):
# if it does, rename the column by replacing the characters
# "[", "]", "<" with "_"
new_col_names.append(regex.sub("_", col))
# if the column name does not contain any of the characters "[", "]",
# "<", keep the original column name
else:
new_col_names.append(col)
# set the column names of X to be the new list of names created above
X.columns = new_col_names
# return a copy of X with the new column names
return X