# BUA302 Some functions for class
# Model Selector
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge, Lasso
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold, LeaveOneOut, train_test_split
from patsy import dmatrix
# from BUA302_functions import *
from sklearn.base import clone


class ModelSelector:
    """Base class for shared logic across all selection types"""

    def __init__(self, score_type='bic'):
        # Supported: 'bic', 'aic', 'rsquared_adj'
        self.score_type = score_type
        self.selected_features = []
        self.final_model = None

    def _get_score(self, y, X_subset):
        X_with_const = sm.add_constant(X_subset, has_constant='add')
        model = sm.OLS(y, X_with_const).fit()
        return getattr(model, self.score_type), model

    def _is_better(self, new_score, old_score):
        """Logic for 'Better': Lower is better for AIC/BIC, Higher for Adj R2"""
        if self.score_type == 'rsquared_adj':
            return new_score > old_score
        return new_score < old_score


class ForwardSelector(ModelSelector):
    def fit(self, X, y):
        remaining = list(X.columns)
        current = []
        best_overall_score = -float('inf') if self.score_type == 'rsquared_adj' else float('inf')

        while remaining:
            trial_results = []
            for candidate in remaining:
                score, _ = self._get_score(y, X[current + [candidate]])
                trial_results.append((score, candidate))

            # Sort: for R2_adj we want the max score, for others the min score
            trial_results.sort(reverse=(self.score_type == 'rsquared_adj'))
            best_trial_score, best_candidate = trial_results[0]

            if self._is_better(best_trial_score, best_overall_score):
                best_overall_score = best_trial_score
                current.append(best_candidate)
                remaining.remove(best_candidate)
            else:
                break

        self.selected_features = current
        _, self.final_model = self._get_score(y, X[current])
        return self


class BackwardSelector(ModelSelector):
    def fit(self, X, y):
        current = list(X.columns)
        best_overall_score, _ = self._get_score(y, X[current])

        while len(current) > 1:
            trial_results = []
            for candidate in current:
                subset = [f for f in current if f != candidate]
                score, _ = self._get_score(y, X[subset])
                trial_results.append((score, candidate))

            trial_results.sort(reverse=(self.score_type == 'rsquared_adj'))
            best_trial_score, worst_candidate = trial_results[0]

            if self._is_better(best_trial_score, best_overall_score):
                best_overall_score = best_trial_score
                current.remove(worst_candidate)
            else:
                break

        self.selected_features = current
        _, self.final_model = self._get_score(y, X[current])
        return self


class StepwiseSelector(ModelSelector):
    def fit(self, X, y):
        remaining = list(X.columns)
        current = []
        best_score = -float('inf') if self.score_type == 'rsquared_adj' else float('inf')

        while True:
            found_better = False

            # 1. Forward Step
            if remaining:
                f_results = [(self._get_score(y, X[current + [c]])[0], c) for c in remaining]
                f_results.sort(reverse=(self.score_type == 'rsquared_adj'))
                if self._is_better(f_results[0][0], best_score):
                    best_score, best_cand = f_results[0]
                    current.append(best_cand)
                    remaining.remove(best_cand)
                    found_better = True

            # 2. Backward Step
            if len(current) > 1:
                b_results = [(self._get_score(y, X[[f for f in current if f != c]])[0], c) for c in current]
                b_results.sort(reverse=(self.score_type == 'rsquared_adj'))
                if self._is_better(b_results[0][0], best_score):
                    best_score, worst_cand = b_results[0]
                    current.remove(worst_cand)
                    remaining.append(worst_cand)
                    found_better = True

            if not found_better:
                break

        self.selected_features = current
        _, self.final_model = self._get_score(y, X[current])
        return self


### CV Tests


class CVEvaluator:
    def __init__(self, X, y):
        # We add the constant here so it's ready for all CV methods
        self.X = sm.add_constant(X, has_constant='add')
        self.y = y

    def validation_set(self, test_size=0.2, seed=42):
        """
        Method 1: Validation Set Approach
        - test_size: the fraction of data to use for testing (e.g., 0.2 for 20%)
        - seed: the random_state for reproducibility
        """
        X_train, X_test, y_train, y_test = train_test_split(
            self.X, self.y, test_size=test_size, random_state=seed
        )

        model = sm.OLS(y_train, X_train).fit()
        preds = model.predict(X_test)
        return mean_squared_error(y_test, preds)

    def k_fold(self, k=5, seed=42):
        """
        Method 2: K-Fold Cross Validation
        """
        from sklearn.model_selection import KFold
        kf = KFold(n_splits=k, shuffle=True, random_state=seed)
        mses = []

        for train_idx, test_idx in kf.split(self.X):
            # Handle X (Predictors)
            if hasattr(self.X, 'iloc'):
                X_train, X_test = self.X.iloc[train_idx], self.X.iloc[test_idx]
            else:
                X_train, X_test = self.X[train_idx], self.X[test_idx]

            # Handle y (Target) - This is where your current error is!
            if hasattr(self.y, 'iloc'):
                y_train, y_test = self.y.iloc[train_idx], self.y.iloc[test_idx]
            else:
                y_train, y_test = self.y[train_idx], self.y[test_idx]

            model = sm.OLS(y_train, X_train).fit()
            mses.append(mean_squared_error(y_test, model.predict(X_test)))

        return np.mean(mses)

    def loocv(self):
        """
        Method 3: Leave-One-Out Cross Validation
        """
        from sklearn.model_selection import LeaveOneOut
        loo = LeaveOneOut()
        mses = []

        # LOOCV iterates through every single observation
        for train_idx, test_idx in loo.split(self.X):
            # Handle X (Predictors)
            if hasattr(self.X, 'iloc'):
                X_train, X_test = self.X.iloc[train_idx], self.X.iloc[test_idx]
            else:
                X_train, X_test = self.X[train_idx], self.X[test_idx]

            # Handle y (Target)
            if hasattr(self.y, 'iloc'):
                y_train, y_test = self.y.iloc[train_idx], self.y.iloc[test_idx]
            else:
                y_train, y_test = self.y[train_idx], self.y[test_idx]

            model = sm.OLS(y_train, X_train).fit()
            mses.append(mean_squared_error(y_test, model.predict(X_test)))

        return np.mean(mses)


# Finding the best Lambda (Alpha) for Ridge and Lasso Regressions

def find_best_alphas(X, y, test_size=0.2, seed=42, min_alpha=0.001, max_alpha=1.0, num_alphas=100):
    """
    Splits data, scales features, and iterates through alphas to find the
    best Ridge and Lasso models based on Test RMSE.
    """
    from sklearn.model_selection import train_test_split

    # 1. Split and Scale
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # 2. Define Alpha Range
    alphas = np.linspace(min_alpha, max_alpha, num_alphas)
    ridge_rmses = []
    lasso_rmses = []

    for a in alphas:
        # Ridge
        ridge = Ridge(alpha=a)
        ridge.fit(X_train_scaled, y_train)
        ridge_rmses.append(np.sqrt(mean_squared_error(y_test, ridge.predict(X_test_scaled))))

        # Lasso
        lasso = Lasso(alpha=a, max_iter=10000)
        lasso.fit(X_train_scaled, y_train)
        lasso_rmses.append(np.sqrt(mean_squared_error(y_test, lasso.predict(X_test_scaled))))

    # 3. Identify Best Alphas
    best_ridge_a = alphas[np.argmin(ridge_rmses)]
    best_lasso_a = alphas[np.argmin(lasso_rmses)]
    min_ridge_rmse = min(ridge_rmses)
    min_lasso_rmse = min(lasso_rmses)

    # 4. Plotting
    plt.figure(figsize=(10, 5))
    plt.plot(alphas, ridge_rmses, label=f'Ridge (Best α={best_ridge_a:.3f})', color='blue')
    plt.plot(alphas, lasso_rmses, label=f'Lasso (Best α={best_lasso_a:.3f})', color='red')
    plt.axvline(best_ridge_a, color='blue', linestyle='--', alpha=0.5)
    plt.axvline(best_lasso_a, color='red', linestyle='--', alpha=0.5)

    plt.title('Alpha vs. Test RMSE')
    plt.xlabel('Alpha (λ)')
    plt.ylabel('RMSE')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    print(f"--- Results for Alpha range [{min_alpha}, {max_alpha}] ---")
    print(f"Best Ridge Alpha: {best_ridge_a:.4f} (RMSE: {min_ridge_rmse:.2f})")
    print(f"Best Lasso Alpha: {best_lasso_a:.4f} (RMSE: {min_lasso_rmse:.2f})")

    return {"ridge_alpha": best_ridge_a, "lasso_alpha": best_lasso_a}


# Piecewise regression model variable create function
def create_piecewise_features(data, column, bins, labels):
    """
    Creates dummy variables and interaction terms for a piecewise linear model.
    """
    df_work = data.copy()

    # 1. Create the Tier categories
    tier_col_name = f"{column}_tier"
    df_work[tier_col_name] = pd.cut(df_work[column], bins=bins, labels=labels)

    # 2. Create Dummies (drop_first=True to avoid the Dummy Variable Trap)
    # This makes the first tier (e.g., 'Small') the baseline.
    df_dummies = pd.get_dummies(df_work[tier_col_name], prefix="tier", drop_first=True)
    tier_cols = df_dummies.columns.tolist()

    # 3. Create Interaction Terms (Slope changes)
    for col in tier_cols:
        df_dummies[f"interact_{col}"] = df_work[column] * df_dummies[col]

    # 4. Combine with original column and add constant
    X = pd.concat([df_work[column], df_dummies], axis=1)
    X = sm.add_constant(X)

    return X.astype(float)  # Ensure all numeric for Statsmodels


def plot_piecewise_regression(df, model, feature_col, bins, labels, target_col="Price"):
    """
    Generates a piecewise linear plot for any continuous variable.

    Args:
        df: Original DataFrame
        model: The fitted statsmodels OLS object
        feature_col: String name of the column (e.g., "Sqft")
        bins: List of knots (e.g., [0, 1000, 2000, np.inf])
        labels: List of tier names (e.g., ["Small", "Medium", "Large"])
        target_col: String name of the Y variable
    """
    # 1. Create a smooth grid for the prediction line
    f_min, f_max = df[feature_col].min(), df[feature_col].max()
    grid_x = np.linspace(f_min, f_max, 1000)
    grid_df = pd.DataFrame({feature_col: grid_x})

    # 2. Transform the grid using your existing feature function
    # Note: Ensure create_piecewise_features is defined in your environment
    X_grid = create_piecewise_features(grid_df, feature_col, bins, labels)

    # 3. Generate predictions from the fitted model
    predictions = model.predict(X_grid)

    # 4. Plotting
    plt.figure(figsize=(10, 6))

    # Plot raw data
    plt.scatter(df[feature_col], df[target_col], alpha=0.2, color='gray', label="Actual Data")

    # Plot piecewise line
    plt.plot(grid_x, predictions, color='firebrick', lw=3, label=f"Piecewise {feature_col} Model")

    # Add vertical lines for knots (ignoring 0 and inf)
    knots = [k for k in bins if k not in [0, np.inf]]
    for knot in knots:
        plt.axvline(knot, color='black', linestyle='--', alpha=0.4, label=f"Knot at {knot}")

    plt.title(f"Piecewise Analysis: {feature_col} vs {target_col}")
    plt.xlabel(feature_col)
    plt.ylabel(target_col)

    # Clean up legend (remove duplicate 'Knot' labels)
    handles, labels_plot = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels_plot, handles))
    plt.legend(by_label.values(), by_label.keys())

    plt.grid(axis='y', alpha=0.3)
    plt.show()


# GAM Data generation function
def create_mixed_gam_matrix(data, spline_specs):
    """
    Final, simplified version.
    Uses standard bs() arguments to ensure compatibility with your environment.
    """
    spline_list = []

    for col, (s_type, val) in spline_specs.items():
        if s_type == "ls":
            # Linear Spline: degree=1 makes it piecewise linear.
            # df=3 creates the segments and knots automatically based on data quantiles.
            formula = f"bs({col}, df={val}, degree=1) - 1"
        else:
            # Cubic Spline
            formula = f"cr({col}, df={val}) - 1"

        B = dmatrix(formula, data=data, return_type="dataframe")
        spline_list.append(B)

    # Combine everything
    X = pd.concat(spline_list, axis=1)

    # Add one global constant (Intercept) for the whole model
    return sm.add_constant(X)


# Plot the GAM results

def plot_gam_results(df, model, specs, target_col="Price"):
    """
    Plots the marginal effect of every variable included in the GAM specs.
    Holds all other variables at their median value.
    """
    # Determine the number of variables to plot
    vars_to_plot = list(specs.keys())
    n_vars = len(vars_to_plot)

    # Create a grid of subplots
    fig, axes = plt.subplots(1, n_vars, figsize=(6 * n_vars, 5))

    # If only one variable, axes isn't a list, so we wrap it
    if n_vars == 1:
        axes = [axes]

    for i, col in enumerate(vars_to_plot):
        # 1. Create a grid for the current variable
        grid_size = 200
        val_min, val_max = df[col].min(), df[col].max()
        current_grid = np.linspace(val_min, val_max, grid_size)

        # 2. Build the plotting dataframe
        # Start with the median of all variables
        plot_df = pd.DataFrame({c: [df[c].median()] * grid_size for c in vars_to_plot})

        # Update the column we are currently plotting
        plot_df[col] = current_grid

        # Add a tiny amount of 'jitter' to other columns to avoid the
        # 'distinct knots' error in the spline math
        for other_col in vars_to_plot:
            if other_col != col:
                plot_df[other_col] += np.linspace(0, 0.001, grid_size)

        # 3. Transform using your Mixed GAM function and predict
        X_plot = create_mixed_gam_matrix(plot_df, specs)
        predictions = model.predict(X_plot)

        # 4. Plotting
        axes[i].scatter(df[col], df[target_col], alpha=0.1, color='gray', label="Data")
        axes[i].plot(current_grid, predictions, color='firebrick', lw=3, label="GAM Prediction")

        axes[i].set_title(f"Marginal Effect: {col}")
        axes[i].set_xlabel(col)
        axes[i].set_ylabel(target_col)
        axes[i].legend()

    plt.tight_layout()
    plt.show()



def find_best_n_estimators(estimator,
                           X_train, X_test,
                           y_train, y_test,
                           n_min=1,
                           n_max=300,
                           metric="mse",
                           plot=True):
    scores = []
    n_values = range(n_min, n_max + 1)

    for n in n_values:

        model = clone(estimator)
        model.set_params(n_estimators=n)

        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        if metric == "mse":
            score = mean_squared_error(y_test, y_pred)

        scores.append(score)

    best_index = np.argmin(scores)
    best_n = list(n_values)[best_index]

    if plot:
        plt.figure()
        plt.plot(n_values, scores)
        plt.axvline(best_n, linestyle="--")
        plt.xlabel("Number of Trees (n_estimators)")
        plt.ylabel("Test MSE")
        plt.title("Selecting Optimal n_estimators")
        plt.show()

    return best_n, scores


def confusion_table_report(y_true, y_prob, threshold=0.5, labels=None):
    """
    Print classification metrics and plot confusion matrix heatmap.

    Parameters
    ----------
    y_true    : array — actual binary labels
    y_prob    : array — predicted probabilities
    threshold : float — classification cutoff (default 0.5)
    labels    : list  — class names e.g. ['Legit', 'Spam']
    """
    dec = (y_prob >= threshold).astype(int)
    cm  = confusion_matrix(y_true, dec)
    TN, FP, FN, TP = cm.ravel()

    accuracy    = (TP + TN) / (TP + TN + FP + FN)
    precision   = TP / (TP + FP)   if (TP + FP) > 0 else 0
    sensitivity = TP / (TP + FN)   if (TP + FN) > 0 else 0
    specificity = TN / (TN + FP)   if (TN + FP) > 0 else 0
    f1          = (2 * precision * sensitivity / (precision + sensitivity)
                   if (precision + sensitivity) > 0 else 0)

    print(f"Threshold   : {threshold}")
    print(f"Accuracy    : {accuracy:.4f}")
    print(f"Precision   : {precision:.4f}")
    print(f"Sensitivity : {sensitivity:.4f}")
    print(f"Specificity : {specificity:.4f}")
    print(f"F1 Score    : {f1:.4f}")

    tick_labels = labels if labels else ['Negative (0)', 'Positive (1)']
    fig, ax = plt.subplots(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=tick_labels, yticklabels=tick_labels,
                annot_kws={'size': 13})
    ax.set_xlabel('Predicted'); ax.set_ylabel('Actual')
    ax.set_title(f'Confusion Matrix (threshold = {threshold})')
    plt.tight_layout(); plt.show()

    return {'accuracy': round(accuracy, 4), 'precision': round(precision, 4),
            'sensitivity': round(sensitivity, 4), 'specificity': round(specificity, 4),
            'f1': round(f1, 4)}


def plot_roc(y_true, y_prob_dict, title="ROC Curve",
             mark_threshold=None, savepath=None):
    """
    Plot ROC curve(s) for one or more models.

    Parameters
    ----------
    y_true         : array — actual binary labels
    y_prob_dict    : array or dict — predicted probabilities,
                     or dict of {model_name: probabilities}
    title          : str  — plot title
    mark_threshold : float — mark a specific threshold on the curve (optional)
    savepath       : str  — file path to save figure (optional)
    """
    from sklearn.metrics import roc_curve, roc_auc_score

    if not isinstance(y_prob_dict, dict):
        y_prob_dict = {'Model': y_prob_dict}

    plt.figure(figsize=(7, 5))
    for label, y_prob in y_prob_dict.items():
        fpr, tpr, thresholds = roc_curve(y_true, y_prob)
        auc_value = roc_auc_score(y_true, y_prob)
        plt.plot(fpr, tpr, lw=2.5, label=f"{label} (AUC = {auc_value:.3f})")

        if mark_threshold is not None:
            idx = abs(thresholds - mark_threshold).argmin()
            plt.scatter(fpr[idx], tpr[idx], s=80, zorder=5,
                        label=f"Threshold = {mark_threshold}")

    plt.plot([0, 1], [0, 1], linestyle="--", color="gray",
             lw=1.5, label="Random Classifier")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.legend()
    plt.tight_layout()

    if savepath:
        plt.savefig(savepath, dpi=150, bbox_inches="tight")
    plt.show()

