import numpy as np
import matplotlib.pyplot as plt
# credits: Lin Fang 

def f_rlog(w, X, y, mu):
    """
    logistic regression objective function value.
    - w: weight vector (including bias), shape (dim,)
    - X: augmented data matrix, shape (dim, n_samples)
    - y: labels, shape (n_samples,)
    - mu: regularization parameter.
    Uses the loss: log(1 + exp(-z)), with z = y*(w^T x).
    """
    z = y * (np.dot(w, X))
    # Use log1p for numerical stability: log1p(exp(-z)) computes log(1+exp(-z))
    loss = np.log1p(np.exp(-z))
    return np.sum(loss) + 0.5 * mu * np.sum(w**2)


def g_rlog(w, X, y, mu):
    """
    Compute the gradient for logistic regression.
    - w: weight vector (including bias), shape (dim,)
    - X: augmented data matrix, shape (dim, n_samples)
    - y: labels, shape (n_samples,)
    - mu: regularization parameter.
    The gradient is computed as:
      -Σ [ y_i * σ(z_i) * x_i ] + mu * w,
    where σ(z) = 1/(1+exp(z)) and z = y*(w^T x).
    """
    z = y * (np.dot(w, X))
    sigma = 1 / (1 + np.exp(z))
    grad = -np.dot(X, (y * sigma).T) + mu * w
    return grad

def SVRG_rlog(X, y, Xte1, Xte2, mu, a, T, st, K):
    """
    SVRG algorithm for logistic regression.
    - X: training data, shape (n_features, n_samples)
    - y: labels, shape (n_samples,), with values +1 or -1.
    - Xte1, Xte2: two test sets, with shapes (n_features, n_test1) and (n_features, n_test2)
    - mu: regularization parameter.
    - a: step size.
    - T: number of inner iterations.
    - st: initial random seed.
    - K: number of outer iterations.
    """
    n_features, P = X.shape
    # Augment training data with a row of ones for bias term
    Xh = np.vstack((X, np.ones((1, P))))
    
    t1 = Xte1.shape[1]
    t2 = Xte2.shape[1]
    A1 = np.vstack((Xte1, np.ones((1, t1))))
    A2 = np.vstack((Xte2, np.ones((1, t2))))
    
    w0 = np.zeros(n_features + 1)
    wt = w0.copy()
    
    f0 = f_rlog(wt, Xh, y, 0)  # Objective computed with mu=0
    f_vals = [f0]
    
    # Initial prediction and confusion matrix
    y1 = np.dot(wt, A1)
    y2 = np.dot(wt, A2)
    C = np.zeros((2,2), dtype=int)
    C[0,0] = np.sum(y1 > 0)
    C[1,0] = t1 - C[0,0]
    C[1,1] = np.sum(y2 < 0)
    C[0,1] = t2 - C[1,1]
    rtm = [(1 - ((C[0,0] + C[1,1]) / C.sum())) * 100]
    
    print(f'iter 0: obj = {f0:.2e}')
    
    tt = 1  # mini-batch size (selecting one sample per inner iteration)
    for k in range(1, K+1):
        gt = g_rlog(wt, Xh, y, mu)
        wk = wt.copy()
        for t in range(1, T+1):
            # Reset the random seed (note: frequent seed resetting is uncommon and mainly used here for reproducibility)
            np.random.seed(st + (k-1)*T + t)
            ind = np.random.permutation(P)
            it = ind[:tt]
            xi = Xh[:, it]   # shape: (n_features+1, tt)
            yi = y[it]       # shape: (tt,)
            gik = g_rlog(wk, xi, yi, mu)
            git = g_rlog(wt, xi, yi, mu)
            gk = gik - git + gt
            wk = wk - a * gk
        wt = wk.copy()
        fk = f_rlog(wt, Xh, y, 0)
        f_vals.append(fk)
        
        y1 = np.dot(wt, A1)
        y2 = np.dot(wt, A2)
        C[0,0] = np.sum(y1 > 0)
        C[1,0] = t1 - C[0,0]
        C[1,1] = np.sum(y2 < 0)
        C[0,1] = t2 - C[1,1]
        rk = (1 - ((C[0,0] + C[1,1]) / C.sum())) * 100
        rtm.append(rk)
        
        print(f'iter {k}: obj = {fk:.2e}')
        
    ws = wt
    print('Objective function at the solution point:')
    fs = f_vals[-1]
    print(fs)
    print('Confusion matrix:')
    print(C)
    print('Rate of misclassification in percentage:')
    print(rtm[-1])
    
    # Plot the objective function values (semilog plot)
    plt.figure(1)
    plt.semilogy(range(0, K+1), f_vals, 'b-', linewidth=1.5)
    plt.xlabel('Iterations k', fontsize=15, fontname='times')
    plt.ylabel('Objective E_L(w_k)', fontsize=15, fontname='times')
    plt.axis([0, K, 1.5e-2, 1])
    plt.grid(True)
    plt.show()
    
    return ws, f_vals, rtm
