import numpy as np
import matplotlib.pyplot as plt
# credits: Michael Yang

def hog20new(Im, d, B):
    """
    Input:
        Im: input image. 
        d: signifies size of basic image patch to perform HOG.
           Typically d is between 3 and 7.
        B: number of histogram bins. Typically B is set between 7 and 9. 
    Output:
        h: HOG feature vector of input image.

    Usage:
    >>> X2 = scipy.io.loadmat('X2.mat')['X2']
    >>> hog20new(X2[:, 0].reshape(28, 28).T, 7, 7).sum()

    Verify on matlab:
    >>> hog20new( reshape(X2(:, 1),28,28), 7, 7)

    The result should be exactly the same.
    """
    Im = Im.T
    t = d // 2
    N, M = Im.shape

    k1 = (M - d) / t
    c1 = np.ceil(k1)
    k2 = (N - d) / t
    c2 = np.ceil(k2)

    if c1 - k1 > 0:
        M1 = int(d + t * c1)
        Im = np.hstack((Im, np.fliplr(Im[:, (2 * M - M1):M])))

    if c2 - k2 > 0:
        N1 = int(d + t * c2)
        Im = np.vstack((Im, np.flipud(Im[(2 * N - N1):N, :])))

    N, M = Im.shape
    nx1 = np.arange(0, M - d + 1, t)
    nx2 = nx1 + d - 1
    ny1 = np.arange(0, N - d + 1, t)
    ny2 = ny1 + d - 1

    Lx = len(nx1)
    Ly = len(ny1)
    h = np.zeros(Lx * Ly * B)

    Im = Im.astype(float)

    Gx = np.hstack((Im[:, 1:], np.zeros((N, 1)))) - np.hstack((np.zeros((N, 1)), Im[:, :-1]))
    Gy = np.vstack((np.zeros((1, M)), Im[:-1, :])) - np.vstack((Im[1:, :], np.zeros((1, M))))

    mag = np.sqrt(Gx**2 + Gy**2)
    ang = np.arctan2(Gy, Gx)

    c3 = (B - 1e-6) / (2 * np.pi)
    I = np.round((ang + np.pi) * c3 + 0.5).astype(int)
    I[I == 0] = 1

    k = 0
    zb = np.zeros(B)
    Lt = d**2

    for m in range(Lx):
        for n in range(Ly):
            ht = np.copy(zb)
            mag_patch = mag[ny1[n]:ny2[n]+1, nx1[m]:nx2[m]+1].flatten()
            ang_patch = I[ny1[n]:ny2[n]+1, nx1[m]:nx2[m]+1].flatten()

            for i in range(Lt):
                ai = ang_patch[i] - 1
                ht[ai] += mag_patch[i]

            norm_ht = np.linalg.norm(ht)
            if norm_ht != 0:
                ht /= (norm_ht + 0.01)

            h[k*B:(k+1)*B] = ht
            k += 1

    return h

