import numpy as np
import matplotlib.pyplot as plt
# credits: Jiahui Lin 

def data_ex31_3(P,a,b,th1,th2,th3,st1,st2,st3):
    R1 = np.array([[np.cos(th1), -np.sin(th1)], [np.sin(th1), np.cos(th1)]])
    R2 = np.array([[np.cos(th2), -np.sin(th2)], [np.sin(th2), np.cos(th2)]])
    R3 = np.array([[np.cos(th3), -np.sin(th3)], [np.sin(th3), np.cos(th3)]])
    a2 = a**2
    b2 = b**2
    ku = 0
    X1 = np.array([[],[]])
    e = np.ones((1, 100))
    np.random.seed(st1)
    while ku < P:
        xw = 2*np.random.randn(2,1)
        x1 = xw[0]**2
        x2 = xw[1]**2
        if x1/a2 + x2/b2 <= 1:
            X1 = np.hstack((X1, xw)) 
            ku = ku + 1

    X1 = R1 @ X1 + np.array([[0.8], [0.7]]) @ e
    ku = 0
    X2 = np.array([[],[]])
    np.random.seed(st2)
    while ku < P:
        xw = 2*np.random.randn(2,1)
        x1 = xw[0]**2
        x2 = xw[1]**2
        if x1/a2 + x2/b2 <= 1:
            X2 = np.hstack((X2, xw)) 
            ku = ku + 1

    X2 = R2 @ X2 + np.array([[2.8], [2.9]]) @ e
    ku = 0
    X3 = np.array([[],[]])
    np.random.seed(st3)
    while ku < P:
        xw = 2*np.random.randn(2,1)
        x1 = xw[0]**2
        x2 = xw[1]**2
        if x1/a2 + x2/b2 <= 1:
            X3 = np.hstack((X3, xw)) 
            ku = ku + 1
    
    X3 = R3 @ X3 + np.array([[-2.4], [-2]]) @ e
    C1 = np.mean(X1,axis=1)
    C2 = np.mean(X2,axis=1)
    C3 = np.mean(X3,axis=1)
    X = [X1, X2, X3]
    plt.figure()
    plt.gca().tick_params(labelsize=14)
    plt.gca().xaxis.label.set_size(14)
    plt.gca().yaxis.label.set_size(14) 
    plt.gca().title.set_size(14) 

    plt.plot(X1[0,:], X1[1,:], '+')
    plt.plot(X2[0,:], X2[1,:], 'o')
    plt.plot(X3[0,:], X3[1,:], '.')
    plt.xticks(np.linspace(-5, 6, 12))
    plt.yticks(np.linspace(-5, 5, 11))
    plt.grid()
    plt.axis([-5, 6, -5, 5])
    plt.show()

    return X,C1,C2,C3

[X,C1,C2,C3] = data_ex31_3(100,1,3,np.pi/6,np.pi/2.5,2.5*np.pi/3,6,19,28)
print(X)
print(C1)
print(C2)
print(C3)
