import os
import numpy as np
# credits: Ruijie Tao

def bt_lsearch2021(x, d, g, fname, p1=None, p2=None):
    """
    Implements line search by backtracking.

    Parameters:
        x (numpy array): Initial point.
        d (numpy array): Search direction.
        g (numpy array): Gradient at x.
        fname (function): Objective function to be minimized along the direction of d.
        p1 (optional): User-defined parameter vector.
        p2 (optional): User-defined parameter vector.

    Returns:
        a (float): Acceptable value of alpha.
    """
    rho = 0.1
    gma = 0.5

    x = np.asarray(x).reshape(-1, 1)
    d = np.asarray(d).reshape(-1, 1)
    gk = np.asarray(g).reshape(-1, 1)

    a = 1
    xw = x + a * d

    # Compute f0 and f1 with optional parameters
    if p1 is not None and p2 is not None:
        f0 = fname(x, p1, p2)
        f1 = fname(xw, p1, p2)
    elif p1 is not None:
        f0 = fname(x, p1)
        f1 = fname(xw, p1)
    else:
        f0 = fname(x)
        f1 = fname(xw)

    t0 = rho * np.dot(gk.T, d)[0, 0]
    f2 = f0 + a * t0
    er = f1 - f2

    while er > 0:
        a *= gma
        xw = x + a * d

        if p1 is not None and p2 is not None:
            f1 = fname(xw, p1, p2)
        elif p1 is not None:
            f1 = fname(xw, p1)
        else:
            f1 = fname(xw)

        f2 = f0 + a * t0
        er = f1 - f2

    if a < 1e-5:
        a = min(1e-5, 0.1 / np.linalg.norm(d))

    return a


def sample_function(x, p1=None, p2=None):
    return np.sum(x ** 2)  # Example function: sum of squares


def main():
    x = np.array([1.0, 2.0])
    d = np.array([-1.0, -1.0])
    g = np.array([2.0, 4.0])
    alpha = bt_lsearch2021(x, d, g, sample_function)
    print(f"Computed eg. step size alpha: {alpha}")


if __name__ == "__main__":
    main()
