Automatically remove hot/dead pixels from an image in python

12,175

Basically, I think that the fastest way to deal with hot pixels is just to use a size=2 median filter. Then, poof, your hot pixels are gone and you also kill all sorts of other high-frequency sensor noise from your camera.

If you really want to remove ONLY the hot pixels, then substituting you can subtract the median filter from the original image, as I did in the question, and replace only these values with the values from the median filtered image. This doesn't work well at the edges, so if you can ignore the pixels along the edge, then this will make things a lot easier.

If you want to deal with the edges, you can use the code below. However, it is not the fastest:

import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage

plt.figure(figsize=(10,5))
ax1 = plt.subplot(121)
ax2 = plt.subplot(122)

#make some sample data
x = np.linspace(-5,5,200)
X,Y = np.meshgrid(x,x)
Z = 100*np.cos(np.sqrt(x**2 + Y**2))**2 + 50

np.random.seed(1)
for i in range(0,11):
    #Add some hot pixels
    Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=200,high=255)
    #and dead pixels
    Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=0,high=10)

#And some hot pixels in the corners and edges
Z[0,0]   =255
Z[-1,-1] =255
Z[-1,0]  =255
Z[0,-1]  =255
Z[0,100] =255
Z[-1,100]=255
Z[100,0] =255
Z[100,-1]=255

#Then plot it
ax1.set_title('Raw data with hot pixels')
ax1.imshow(Z,interpolation='nearest',origin='lower')

def find_outlier_pixels(data,tolerance=3,worry_about_edges=True):
    #This function finds the hot or dead pixels in a 2D dataset. 
    #tolerance is the number of standard deviations used to cutoff the hot pixels
    #If you want to ignore the edges and greatly speed up the code, then set
    #worry_about_edges to False.
    #
    #The function returns a list of hot pixels and also an image with with hot pixels removed

    from scipy.ndimage import median_filter
    blurred = median_filter(Z, size=2)
    difference = data - blurred
    threshold = 10*np.std(difference)

    #find the hot pixels, but ignore the edges
    hot_pixels = np.nonzero((np.abs(difference[1:-1,1:-1])>threshold) )
    hot_pixels = np.array(hot_pixels) + 1 #because we ignored the first row and first column

    fixed_image = np.copy(data) #This is the image with the hot pixels removed
    for y,x in zip(hot_pixels[0],hot_pixels[1]):
        fixed_image[y,x]=blurred[y,x]

    if worry_about_edges == True:
        height,width = np.shape(data)

        ###Now get the pixels on the edges (but not the corners)###

        #left and right sides
        for index in range(1,height-1):
            #left side:
            med  = np.median(data[index-1:index+2,0:2])
            diff = np.abs(data[index,0] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[index],[0]]  ))
                fixed_image[index,0] = med

            #right side:
            med  = np.median(data[index-1:index+2,-2:])
            diff = np.abs(data[index,-1] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[index],[width-1]]  ))
                fixed_image[index,-1] = med

        #Then the top and bottom
        for index in range(1,width-1):
            #bottom:
            med  = np.median(data[0:2,index-1:index+2])
            diff = np.abs(data[0,index] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[0],[index]]  ))
                fixed_image[0,index] = med

            #top:
            med  = np.median(data[-2:,index-1:index+2])
            diff = np.abs(data[-1,index] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[height-1],[index]]  ))
                fixed_image[-1,index] = med

        ###Then the corners###

        #bottom left
        med  = np.median(data[0:2,0:2])
        diff = np.abs(data[0,0] - med)
        if diff>threshold: 
            hot_pixels = np.hstack(( hot_pixels, [[0],[0]]  ))
            fixed_image[0,0] = med

        #bottom right
        med  = np.median(data[0:2,-2:])
        diff = np.abs(data[0,-1] - med)
        if diff>threshold: 
            hot_pixels = np.hstack(( hot_pixels, [[0],[width-1]]  ))
            fixed_image[0,-1] = med

        #top left
        med  = np.median(data[-2:,0:2])
        diff = np.abs(data[-1,0] - med)
        if diff>threshold: 
            hot_pixels = np.hstack(( hot_pixels, [[height-1],[0]]  ))
            fixed_image[-1,0] = med

        #top right
        med  = np.median(data[-2:,-2:])
        diff = np.abs(data[-1,-1] - med)
        if diff>threshold: 
            hot_pixels = np.hstack(( hot_pixels, [[height-1],[width-1]]  ))
            fixed_image[-1,-1] = med

    return hot_pixels,fixed_image


hot_pixels,fixed_image = find_outlier_pixels(Z)

for y,x in zip(hot_pixels[0],hot_pixels[1]):
    ax1.plot(x,y,'ro',mfc='none',mec='r',ms=10)

ax1.set_xlim(0,200)
ax1.set_ylim(0,200)

ax2.set_title('Image with hot pixels removed')
ax2.imshow(fixed_image,interpolation='nearest',origin='lower',clim=(0,255))

plt.show()

Output: enter image description here

Share:
12,175

Related videos on Youtube

DanHickstein
Author by

DanHickstein

I am currently a postdoc studying laser physics at the University of Colorado in Boulder, CO.

Updated on June 17, 2022

Comments

  • DanHickstein
    DanHickstein almost 2 years

    I am using numpy and scipy to process a number of images taken with a CCD camera. These images have a number of hot (and dead) pixels with very large (or small) values. These interfere with other image processing, so they need to be removed. Unfortunately, though a few of the pixels are stuck at either 0 or 255 and are always at the same value in all of the images, there are some pixels that are temporarily stuck at other values for a period of a few minutes (the data spans many hours).

    I am wondering if there is a method for identifying (and removing) the hot pixels already implemented in python. If not, I am wondering what would be an efficient method for doing so. The hot/dead pixels are relatively easy to identify by comparing them with neighboring pixels. I could see writing a loop that looks at each pixel, compares its value to that of its 8 nearest neighbors. Or, it seems nicer to use some kind of convolution to produce a smoother image and then subtract this from the image containing the hot pixels, making them easier to identify.

    I have tried this "blurring method" in the code below, and it works okay, but I doubt that it is the fastest. Also, it gets confused at the edge of the image (probably since the gaussian_filter function is taking a convolution and the convolution gets weird near the edge). So, is there a better way to go about this?

    Example code:

    import numpy as np
    import matplotlib.pyplot as plt
    import scipy.ndimage
    
    plt.figure(figsize=(8,4))
    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    
    #make a sample image
    x = np.linspace(-5,5,200)
    X,Y = np.meshgrid(x,x)
    Z = 255*np.cos(np.sqrt(x**2 + Y**2))**2
    
    
    for i in range(0,11):
        #Add some hot pixels
        Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=200,high=255)
        #and dead pixels
        Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=0,high=10)
    
    #Then plot it
    ax1.set_title('Raw data with hot pixels')
    ax1.imshow(Z,interpolation='nearest',origin='lower')
    
    #Now we try to find the hot pixels
    blurred_Z = scipy.ndimage.gaussian_filter(Z, sigma=2)
    difference = Z - blurred_Z
    
    ax2.set_title('Difference with hot pixels identified')
    ax2.imshow(difference,interpolation='nearest',origin='lower')
    
    threshold = 15
    hot_pixels = np.nonzero((difference>threshold) | (difference<-threshold))
    
    #Don't include the hot pixels that we found near the edge:
    count = 0
    for y,x in zip(hot_pixels[0],hot_pixels[1]):
        if (x != 0) and (x != 199) and (y != 0) and (y != 199):
            ax2.plot(x,y,'ro')
            count += 1
    
    print 'Detected %i hot/dead pixels out of 20.'%count
    ax2.set_xlim(0,200); ax2.set_ylim(0,200)
    
    
    plt.show()
    

    And the output: enter image description here

    • Eddy_Em
      Eddy_Em over 10 years
      Try a more simple case: make another image with median filtering (for example, by a pattern 3x3) and compute absolute value of differense between your image and filtered image. Substitute pixels of original image with big values of that difference (lets say, 100) by filtered values. Value of threshold you can get automatically by statistics of difference.
  • packoman
    packoman over 3 years
    Can you use edge-mirroring here, to avoid the complex handling of the edges and corners? I.e.: blurred = median_filter(image_data, size=2, mode="mirror")

Related