import numpy as np
import matplotlib.pyplot as plt
# credits: Michael Yang

def sgd_rlog(X, y, Xte1, Xte2, mu, bt, gm, m, st, iter, f_rlog, g_rlog):
    """
    Input:
     - f_rlog : loss function
     - g_rlog : gradient of loss function
    Usage:
    >>> ws, f, rtm = sgd_rlog(Xhog,y,T2_hog,T7_hog,0.002,13,0.01,8,9,1176,f_rlog, g_rlog)

    """
    N, P = X.shape
    Xh = np.vstack([X, np.ones(P)])
    t1, t2 = Xte1.shape[1], Xte2.shape[1]
    A1 = np.vstack([Xte1, np.ones(t1)])
    A2 = np.vstack([Xte2, np.ones(t2)])
    w0 = np.zeros(N + 1)
    wk = w0
    print(wk.shape)
    print(Xh.shape)
    print(y.shape)
    f = [f_rlog(wk, Xh, y, 0)] 
    y1 = np.dot(wk.T, A1)
    y2 = np.dot(wk.T, A2)
    
    C = np.zeros((2, 2))
    C[0, 0] = np.sum(y1 > 0)
    C[1, 0] = t1 - C[0, 0]
    C[1, 1] = np.sum(y2 < 0)
    C[0, 1] = t2 - C[1, 1]
    
    rtm = [(1 - ((C[0, 0] + C[1, 1]) / np.sum(C))) * 100]
    aw = np.array([bt / (1 + gm * k) for k in range(1, iter + 1)])
    
    np.random.seed(st)
    for k in range(iter):
        r = np.random.permutation(P)
        Xw = Xh[:, r[:m]]
        yw = y[r[:m]]
        
        gk = g_rlog(wk, Xw, yw, mu)
        dk = -gk
        adk = aw[k] * dk
        wk = wk + adk
        
        fk = f_rlog(wk, Xh, y, 0)
        f.append(fk)
        
        y1 = np.dot(wk.T, A1)
        y2 = np.dot(wk.T, A2)
        
        C[0, 0] = np.sum(y1 > 0)
        C[1, 0] = t1 - C[0, 0]
        C[1, 1] = np.sum(y2 < 0)
        C[0, 1] = t2 - C[1, 1]
        
        rk = (1 - ((C[0, 0] + C[1, 1]) / np.sum(C))) * 100
        rtm.append(rk)
    
    print(fk)
    ws = wk
    print(C)
    print("rate of misclassification in percentage: ", rk)
    
    plt.figure(figsize=(8, 6))
    plt.semilogy(range(iter + 1), f, 'b-', linewidth=1.4)
    plt.xlabel('Iterations $k$', fontsize=13)
    plt.ylabel('Objective $E_L(w_k)$', fontsize=13)
    plt.grid(True)
    plt.show()
    
    return ws, f, rtm
