Lagrange interpolation in Python

35,497

Solution 1

Check the indices, Wikipedia says "k+1 data points", but you're setting k = len(x_values) where it should be k = len(x_values) - 1 if you followed the formula exactly.

Solution 2

Try

def interpolate(x, x_values, y_values):
    def _basis(j):
        p = [(x - x_values[m])/(x_values[j] - x_values[m]) for m in xrange(k) if m != j]
        return reduce(operator.mul, p)
    assert len(x_values) != 0 and (len(x_values) == len(y_values)), 'x and y cannot be empty and must have the same length'
    k = len(x_values)
    return sum(_basis(j)*y_values[j] for j in xrange(k))

You can confirm it as follows:

>>> interpolate(1,[1,2,4],[1,0,2])
1.0
>>> interpolate(2,[1,2,4],[1,0,2])
0.0
>>> interpolate(4,[1,2,4],[1,0,2])
2.0
>>> interpolate(3,[1,2,4],[1,0,2])
0.33333333333333331

So the result is the interpolated value based on the polynomial that goes through the points given. In this case, the 3 points define a parabola and the first 3 tests show that the stated y_value is returned for the given x_value.

Solution 3

I'm almost a decade late to the party, but I found this searching for a simple implementation of Lagrange interpolation. @smichr's answer is great, but the Python is a little outdated, and I also wanted something that would work nicely with np.ndarrays so I could do easy plotting. Maybe others will find this useful:

import numpy as np
import matplotlib.pyplot as plt


class LagrangePoly:

    def __init__(self, X, Y):
        self.n = len(X)
        self.X = np.array(X)
        self.Y = np.array(Y)

    def basis(self, x, j):
        b = [(x - self.X[m]) / (self.X[j] - self.X[m])
             for m in range(self.n) if m != j]
        return np.prod(b, axis=0) * self.Y[j]

    def interpolate(self, x):
        b = [self.basis(x, j) for j in range(self.n)]
        return np.sum(b, axis=0)


X  = [-9, -4, -1, 7]
Y  = [5, 2, -2, 9]

plt.scatter(X, Y, c='k')

lp = LagrangePoly(X, Y)

xx = np.arange(-100, 100) / 10

plt.plot(xx, lp.basis(xx, 0))
plt.plot(xx, lp.basis(xx, 1))
plt.plot(xx, lp.basis(xx, 2))
plt.plot(xx, lp.basis(xx, 3))
plt.plot(xx, lp.interpolate(xx), linestyle=':')
plt.show()
Share:
35,497
rubik
Author by

rubik

Code complexity and other software metrics for Python! Github project: https://github.com/rubik/radon My blog: https://signal-to-noise.xyz/

Updated on July 28, 2022

Comments

  • rubik
    rubik almost 2 years

    I want to interpolate a polynomial with the Lagrange method, but this code doesn't work:

    def interpolate(x_values, y_values):
        def _basis(j):
            p = [(x - x_values[m])/(x_values[j] - x_values[m]) for m in xrange(k + 1) if m != j]
            return reduce(operator.mul, p)
    
        assert len(x_values) != 0 and (len(x_values) == len(y_values)), 'x and y cannot be empty and must have the same length'
    
        k = len(x_values)
        return sum(_basis(j) for j in xrange(k))
    

    I followed Wikipedia, but when I run it I receive an IndexError at line 3!

    Thanks

  • rubik
    rubik over 13 years
    Ok, and why if I do: interpolate([1, 2, 3], [1, 4, 9]) it returns -0.5x^2 + 1.5x ? Take a look at this: i.imgur.com/MkATz.gif
  • AndiDog
    AndiDog over 13 years
    @rubik: Sorry, but I can't help you with such a specific problem without knowing the interpolation algorithm (and I won't read up on it). Check your logic again or search for an existing implementation. If you post more code on how you apply the interpolation (e.g. the definition/initial value of x is missing in your question), then somebody might be able to help you further.
  • rubik
    rubik over 13 years
  • Devansh Mishra
    Devansh Mishra about 4 years
    can someone please explain what the the plot of xx vs lp.basis(xx,0),(xx,1).... (xx,3) signifies?
  • jds
    jds about 4 years
    It's plotting the basis functions of the Lagrange polynomial, reconstructing a figure like this: en.wikipedia.org/wiki/File:Lagrange_polynomial.svg
  • SurpriseDog
    SurpriseDog about 3 years
    To get this to work in python3: pip3 install future then you can add the code: from past.builtins import reduce, xrange; import operator