# L-8 MCS 507 Fri 8 Sep 2023 : differential_numbers.py

"""
This script is based on section (c) on "computing with differential numbers"
running on pages 7 to 10 of "Introduction to Numerical Analysis"
by Arnold Neuimaier, Cambridge 2001.
The operator overloading encodes the elementary derivative rules
for the addition, subtraction, multiplication, and division of
any two real numbers.  As we evaluate any expression written in
these four elementary operations, we also compute its derivative.
The script illustrates the forward mode of algorithmic differentiation.
"""

class DifferentialNumber(object):
    """
    A differential number is a tuple df = (f, f') of two doubles
    where f' is the evaluated derivative of f.
    """
    def __init__(self, xv, xp=1):
        """
        Initializes a differential number
        to the tuple (xv, xp).
        """
        self.val = float(xv)
        self.der = float(xp)

    def __str__(self):
        """
        Returns the tuple representation of the differential number.
        """
        return str((self.val, self.der))

    def __div__(self, other):
        return self

    def __add__(self, other):
        """
        Defines the addition of two differential numbers.
        """
        if isinstance(other, float):
            return DifferentialNumber(self.val + other, self.der)
        else:
            return DifferentialNumber(self.val + other.val, \
                                      self.der + other.der)

    def __radd__(self, other):
        """
        Addition when operand is not a differential number
        as in 2.7 + x or 3 + x (reflected operand).
        """
        result = self + other
        return result

    def __sub__(self, other):
        """
        Defines the subtraction of two differential numbers.
        """
        if isinstance(other, float):
            return DifferentialNumber(self.val - other, self.der)
        else:
            return DifferentialNumber(self.val - other.val, \
                                      self.der - other.der)

    def __mul__(self, other):
        """
        Defines the product of two differential numbers.
        """
        if isinstance(other, float):
            return DifferentialNumber(self.val*other, self.der*other)
        else:
            return DifferentialNumber(self.val*other.val, \
                                      self.der*other.val + self.val*other.der)

    def __truediv__(self, other):
        """
        Defines the division of two differential numbers.
        """
        if isinstance(other, float):
            return DifferentialNumber(self.val/other, self.der/other)
        else:
            val = self.val/other.val
            return DifferentialNumber(val, \
                                     (self.der - val*other.der)/other.val)

def main():
    """
    Test on the expression ((x-1)*(x+3))/(x+2).
    """
    dfx = DifferentialNumber(3, 1)
    # testing reflected addition 
    f1a = lambda x: x + 1.0
    f1b = lambda x: 1.0 + x
    print('x + 1.0 at ', dfx, ':', f1a(dfx))
    print('1.0 + x at ', dfx, ':', f1b(dfx))
    fun1 = lambda x: (x-1.0)*(x+3.0)
    print('(x-1)*(x+3) at', dfx, ':', fun1(dfx))
    fun2 = lambda x: ((x-1.0)*(x+3.0))/(x+2.0)
    print('((x-1)*(x+3))/(x+2) at', dfx, ':', fun2(dfx))

if __name__ == "__main__":
    main()
