import numpy as np
import matplotlib.pyplot as plt

# credits: Yupeng Liu

def data_semi_circle(r, thk, sep, N, st1, st2):
    """
    Generates two sets of points in a semi-circular pattern.

    Parameters:
        r (float): Inner radius of the semi-circle.
        thk (float): Thickness of the semi-circle.
        sep (float): Separation distance between positive and negative semi-circles.
        N (int): Number of points per semi-circle.
        st1 (int): Random seed for positive semi-circle.
        st2 (int): Random seed for negative semi-circle.

    Returns:
        xp (numpy array): Points in the positive semi-circle.
        xn (numpy array): Points in the negative semi-circle.
    """
    r1 = r
    r2 = r + thk

    def generate_semi_circle(N, r1, r2, seed, sign):
        np.random.seed(seed)
        points = []
        while len(points) < N:
            xw = 12 * np.random.randn(2)
            xm = np.linalg.norm(xw)
            if r1 < xm < r2 and (xw[1] > 0 if sign > 0 else xw[1] < 0):
                points.append(xw)
        return np.array(points).T

    xp = generate_semi_circle(N, r1, r2, st1, 1)
    xn = generate_semi_circle(N, r1, r2, st2, -1)

    ra = 0.5 * (r1 + r2)
    xn[0, :] += ra
    xn[1, :] -= sep

    plt.figure()
    plt.plot(xp[0, :], xp[1, :], 'ko', label='Positive Semi-Circle')
    plt.plot(xn[0, :], xn[1, :], 'k+', label='Negative Semi-Circle')
    plt.grid()
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')
    plt.axis([-20, 30, -20, 20])
    plt.axis('square')
    plt.legend()
    plt.show()

    return xp, xn


def main():
    r, thk, sep, N, st1, st2 = 10, 5, -1, 400, 42, 4
    xp, xn = data_semi_circle(r, thk, sep, N, st1, st2)
    print("Generated positive semi-circle points:", xp.shape)
    print("Generated negative semi-circle points:", xn.shape)


if __name__ == "__main__":
    main()
