# L-9 MCS 507 Mon 11 Sep 2023 : laguerre.py
"""
Plots the basins of attraction of the method of Laguerre,
a method to compute a root of a polynomial.
"""
import numpy as np
from numpy.polynomial.polynomial import polyval, polyder
from numpy.random import random
from numpy.lib.scimath import sqrt

def laguerre(p,d1p,d2p,z0,dxtol=1.0e-8,pxtol=1.0e-8,maxit=20,verbose=True):
    """
    Applies the method of Laguerre to compute a root of p.

    ON ENTRY : 
      p        coefficients of a polynomial in one variable 
      d1p      coefficients of the first derivative of p 
      d2p      coefficients of the second derivative of p 
      z0       an approximation for the root 
      dxtol    the tolerance on the forward error 
      pxtol    the tolerance on the backward error 
      maxit    the maximum number of iterations
      verbose  the verbose flag, if true, writes one line each step
                                                                 
    ON RETURN :
      root     an approximation for the root 
      absdx    the estimated forward error 
      abspx    the estimated backward error 
      nbrit    the number of iterations
      fail     true if tolerances not reached,
               false otherwise.
    """
    root = z0
    dx = 1.0
    pval = 1.0
    deg = len(p)-1
    degm1 = deg-1
    if verbose:
        title = "        real(root)               imag(root)"
        print("step :" + title + "           |dx|     |p(x)|")
        stri = '%3d' % 0
        strx = '%23.16e  %23.16e' % (root.real, root.imag)
        print(stri, " : ", strx)
    for i in range(1, maxit+1):
        pval = polyval(root, p)
        if(abs(pval) < pxtol):
            if verbose:
                print('succeeded after', i-1, 'step(s)')
            return (root, abs(dx), abs(pval), i-1, False)
        d1val = polyval(root, d1p)
        d2val = polyval(root, d2p)
        Lroot = d1val/pval
        Mroot = Lroot**2 - d2val/pval
        valsqrt = sqrt(degm1*(deg*Mroot - Lroot**2))
        yplus = Lroot + valsqrt
        yminus = Lroot - valsqrt
        if(abs(yplus) > abs(yminus)):
            dx = deg/yplus
        else:
            dx = deg/yminus
        root = root - dx
        pval = polyval(root, p)
        if verbose:
            stri = '%3d' % i
            strx = '%23.16e  %23.16e' % (root.real, root.imag)
            strdx = ' %.2e' % abs(dx)
            strpx = ' %.2e' % abs(pval)
            print(stri, " : ", strx, strdx, strpx)
        if(abs(dx) < dxtol):
            if verbose:
                print('succeeded after', i, 'step(s)')
            return (root, abs(dx), abs(pval), i, False)
    return (root, abs(dx), abs(pval), maxit, True)

def rank(roots, z, tol=1.0e-4):
    """
    Returns the position of z in the list roots.
    Two complex numbers x and y are considered equal
    if abs(x - y) < tol.
    If z does not appear in roots, then -1 is returned.
    """
    for idx in range(len(roots)):
        if abs(roots[idx] - z) < tol: return idx
    return -1

def test():
    """
    Runs a simple test on the method of Laguerre.
    Depending on the size of the initial approximation z0,
    the method either converges to 1 or 2,
    the roots of (x-1)*(x-2).
    """
    cff = np.array([2.0, -3.0, 1.0])
    roots = [complex(1), complex(2)]
    dp1 = polyder(cff)
    dp2 = polyder(dp1)
    rnd = random()
    size = 1.0 + int(rnd > 0.5)
    z0 = size*complex(random(), random())
    result = laguerre(cff, dp1, dp2, z0)
    print(result)
    print('rank :', rank(roots, result[0]))

def matrixplot(A):
    """
    Uses matshow to make a plot of the matrix A.
    """
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import matshow
    fig = plt.figure()
    ax = fig.add_subplot(111)
    # ax.matshow(A, cmap='Set1')
    ax.matshow(A, cmap='Paired')
    plt.show()

def main():
    """
    For p = x^8 - 1, makes a matrix plot of the attraction basins
    of the method of Laguerre.
    """
    cff = np.array([-1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
    sq2 = sqrt(2)/2
    roots = [complex(1), complex(sq2, sq2), complex(0, 1), \
             complex(-sq2, sq2), complex(-1), complex(-sq2, -sq2), \
             complex(0, -1), complex(sq2, -sq2)]
    dp1 = polyder(cff)
    dp2 = polyder(dp1)
    dim = 501
    (left, right) = (-1.1, +1.1)
    dz = (right - left)/(dim-1)
    mat = np.zeros((dim, dim), dtype=int)
    from os import times
    start = times();
    for i in range(dim):
        for j in range(dim):
            z0 = complex(left + i*dz, left + j*dz)
            if(rank(roots, z0, tol=0.01) != -1):
                mat[i,j] = len(roots)
            else:
                result = laguerre(cff, dp1, dp2, z0, verbose=False)
                mat[i,j] = rank(roots, result[0])
    stop = times();
    print('user cpu time : %.4f' % (stop[0] - start[0]))
    print('  system time : %.4f' % (stop[1] - start[1]))
    print(' elapsed time : %.4f' % (stop[4] - start[4]))
    print(mat)
    matrixplot(mat)

if __name__ == "__main__":
    main()
