Implementing Otsu binarization from scratch python

13,250

Solution 1

I dont know if my implementation is alright. But this is what I got:

def otsu(gray):
    pixel_number = gray.shape[0] * gray.shape[1]
    mean_weigth = 1.0/pixel_number
    his, bins = np.histogram(gray, np.array(range(0, 256)))
    final_thresh = -1
    final_value = -1
    for t in bins[1:-1]: # This goes from 1 to 254 uint8 range (Pretty sure wont be those values)
        Wb = np.sum(his[:t]) * mean_weigth
        Wf = np.sum(his[t:]) * mean_weigth

        mub = np.mean(his[:t])
        muf = np.mean(his[t:])

        value = Wb * Wf * (mub - muf) ** 2

        print("Wb", Wb, "Wf", Wf)
        print("t", t, "value", value)

        if value > final_value:
            final_thresh = t
            final_value = value
    final_img = gray.copy()
    print(final_thresh)
    final_img[gray > final_thresh] = 255
    final_img[gray < final_thresh] = 0
    return final_img

Otsu image

Solution 2

I used the implementation @Jose A in posted answer, which tries to maximize the interclass variance. It looks like jose has forgotten to multiply intensity level to their respective intensity pixel counts (in order to calculate mean), So I corrected the calculation of background mean mub and foreground mean muf. I am posting this as an answer and also trying to edit the accepted answer.

def otsu(gray):
    pixel_number = gray.shape[0] * gray.shape[1]
    mean_weight = 1.0/pixel_number
    his, bins = np.histogram(gray, np.arange(0,257))
    final_thresh = -1
    final_value = -1
    intensity_arr = np.arange(256)
    for t in bins[1:-1]: # This goes from 1 to 254 uint8 range (Pretty sure wont be those values)
        pcb = np.sum(his[:t])
        pcf = np.sum(his[t:])
        Wb = pcb * mean_weight
        Wf = pcf * mean_weight

        mub = np.sum(intensity_arr[:t]*his[:t]) / float(pcb)
        muf = np.sum(intensity_arr[t:]*his[t:]) / float(pcf)
        #print mub, muf
        value = Wb * Wf * (mub - muf) ** 2

        if value > final_value:
            final_thresh = t
            final_value = value
    final_img = gray.copy()
    print(final_thresh)
    final_img[gray > final_thresh] = 255
    final_img[gray < final_thresh] = 0
    return final_img
Share:
13,250
Moondra
Author by

Moondra

email: amitmoon2017[at]gmail[dot]com I like exploring different fields to see what problems haven't been solved, and to gain a better understanding of what we have accomplished so far (as well as appreciate). Currently working on a computer vision related app (slowly), exploring neural networks, trying to improve my programming skills from a scripter to a better scripter, learning about genetics, synthetic biology, and biohacking, anti-aging and always trying to improve my well-being. For those that have been programmers for a long time, I would love to hear how much exercise you guys do to stay healthy. Do you take breaks every x minutes? 1hr a day of weights and cardio? Below are some knowledgeable folks in the listed frameworks. Python Superstars Pandas : @coldspeed @maxU, @piRSquared Regex: @anubhava Numpy: @Divakar Webscraping: Tensorflow: @Maxim Keras: Swift

Updated on June 22, 2022

Comments

  • Moondra
    Moondra almost 2 years

    It seems my implementation is incorrect and not sure what exactly I'm doing wrong:

    Here is the histogram of my image: enter image description here

    So the threshold should be around 170 ish? I'm getting the threshold as 130.

    Here is my code:

    #Otsu in Python
    
    import numpy as np
    from PIL import Image
    import matplotlib.pyplot as plt  
    
    
    def load_image(file_name):
        img = Image.open(file_name)
        img.load()
        bw = img.convert('L')
        bw_data = np.array(bw).astype('int32')
        BINS = np.array(range(0,257))
        counts, pixels =np.histogram(bw_data, BINS)
        pixels = pixels[:-1]
        plt.bar(pixels, counts, align='center')
        plt.savefig('histogram.png')
        plt.xlim(-1, 256)
        plt.show()
    
        total_counts = np.sum(counts)
        assert total_counts == bw_data.shape[0]*bw_data.shape[1]
    
        return BINS, counts, pixels, bw_data, total_counts
    
    def within_class_variance():
        ''' Here we will implement the algorithm and find the lowest Within-  Class Variance:
    
            Refer to this page for more details http://www.labbookpages.co.uk
    /software/imgProc/otsuThreshold.html'''
    
        for i in range(1,len(BINS), 1):         #from one to 257 = 256 iterations
           prob_1 =    np.sum(counts[:i])/total_counts
           prob_2 = np.sum(counts[i:])/total_counts
           assert (np.sum(prob_1 + prob_2)) == 1.0
    
    
    
           mean_1 = np.sum(counts[:i] * pixels[:i])/np.sum(counts[:i])
           mean_2 = np.sum(counts[i:] * pixels[i:] )/np.sum(counts[i:])
           var_1 = np.sum(((pixels[:i] - mean_1)**2 ) * counts[:i])/np.sum(counts[:i])
           var_2 = np.sum(((pixels[i:] - mean_2)**2 ) * counts[i:])/np.sum(counts[i:])
    
    
           if i == 1:
             cost = (prob_1 * var_1) + (prob_2 * var_2)
             keys = {'cost': cost, 'mean_1': mean_1, 'mean_2': mean_2, 'var_1': var_1, 'var_2': var_2, 'pixel': i-1}
             print('first_cost',cost)
    
    
           if (prob_1 * var_1) +(prob_2 * var_2) < cost:
             cost =(prob_1 * var_1) +(prob_2 * var_2)
             keys = {'cost': cost, 'mean_1': mean_1, 'mean_2': mean_2, 'var_1': var_1, 'var_2': var_2, 'pixel': i-1}  #pixels is i-1 because BINS is starting from one
    
        return keys
    
    
    
    
    
    
    
    if __name__ == "__main__":
    
        file_name = 'fish.jpg'
        BINS, counts, pixels, bw_data, total_counts =load_image(file_name)
        keys =within_class_variance()
        print(keys['pixel'])
        otsu_img = np.copy(bw_data).astype('uint8')
        otsu_img[otsu_img > keys['pixel']]=1
        otsu_img[otsu_img < keys['pixel']]=0
        #print(otsu_img.dtype)
        plt.imshow(otsu_img)
        plt.savefig('otsu.png')
        plt.show()
    

    Resulting otsu image looks like this:

    enter image description here

    Here is the fish image (It has a shirtless guy holding a fish so may not be safe for work):

    Link : https://i.stack.imgur.com/EDTem.jpg

    EDIT:

    It turns out that by changing the threshold to 255 (The differences are more pronounced)

    enter image description here