import numpy as np
import matplotlib.pyplot as plt
# credits: Ruijie Tao

def data_ex31_3(P, a, b, th1, th2, th3, st1, st2, st3):
    """
    Generates three sets of 2D points rotated and translated using given parameters.

    Parameters:
        P (int): Number of points per cluster.
        a (float): Semi-major axis.
        b (float): Semi-minor axis.
        th1, th2, th3 (float): Rotation angles for each cluster.
        st1, st2, st3 (int): Random seed states for reproducibility.

    Returns:
        X (numpy array): Combined dataset.
        C1, C2, C3 (numpy arrays): Centroids of the three clusters.
    """
    R1 = np.array([[np.cos(th1), -np.sin(th1)], [np.sin(th1), np.cos(th1)]])
    R2 = np.array([[np.cos(th2), -np.sin(th2)], [np.sin(th2), np.cos(th2)]])
    R3 = np.array([[np.cos(th3), -np.sin(th3)], [np.sin(th3), np.cos(th3)]])

    a2 = a ** 2
    b2 = b ** 2

    def generate_cluster(P, a2, b2, R, seed, shift):
        np.random.seed(seed)
        points = []
        while len(points) < P:
            xw = 2 * np.random.randn(2)
            if (xw[0] ** 2 / a2 + xw[1] ** 2 / b2) <= 1:
                points.append(xw)
        points = np.array(points).T
        return R @ points + np.array(shift).reshape(2, 1)

    X1 = generate_cluster(P, a2, b2, R1, st1, [0.8, 0.7])
    X2 = generate_cluster(P, a2, b2, R2, st2, [2.8, 2.9])
    X3 = generate_cluster(P, a2, b2, R3, st3, [-2.4, -2])

    C1 = np.mean(X1, axis=1)
    C2 = np.mean(X2, axis=1)
    C3 = np.mean(X3, axis=1)
    X = np.hstack((X1, X2, X3))

    plt.figure()
    plt.plot(X1[0, :], X1[1, :], '+k', label='Cluster 1')
    plt.plot(X2[0, :], X2[1, :], 'ok', label='Cluster 2')
    plt.plot(X3[0, :], X3[1, :], '.k', label='Cluster 3')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.grid()
    plt.axis([-5, 6, -5, 5])
    plt.legend()
    plt.show()

    return X, C1, C2, C3

def main():
    P, a, b = 100, 1, 3
    th1, th2, th3 = np.pi / 6, np.pi / 2.5, 2.5 * np.pi / 3
    st1, st2, st3 = 6, 19, 28
    X, C1, C2, C3 = data_ex31_3(P, a, b, th1, th2, th3, st1, st2, st3)
    print("Centroids:")
    print("C1:", C1)
    print("C2:", C2)
    print("C3:", C3)

if __name__ == "__main__":
    main()
