Why CIFAR-10 images are not displayed properly using matplotlib?

27,117

Solution 1

Following prints 5X5 grid of random Cifar10 images. It isn't blurry, though not perfect either. Any suggestions welcome.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from six.moves import cPickle 

f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='latin1')
f.close()
X = datadict["data"] 
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
Y = np.array(Y)

#Visualizing CIFAR 10
fig, axes1 = plt.subplots(5,5,figsize=(3,3))
for j in range(5):
    for k in range(5):
        i = np.random.choice(range(len(X)))
        axes1[j][k].set_axis_off()
        axes1[j][k].imshow(X[i:i+1][0])

Solution 2

Make sure you don't normalize your dataset when you want to display the image.

Example :

The loader...

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                        #  transforms.Normalize(
                        #      (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                     ])),
    batch_size=64, shuffle=True)

The code that shows the image...

img = next(iter(train_loader))[0][0]
plt.imshow(transforms.ToPILImage()(img))

Normalized

Normalized

Wihtout normalization

Not normalized

Solution 3

This file reads the cifar10 dataset and plots individual images using matplotlib.

import _pickle as pickle
import argparse
import numpy as np
import os
import matplotlib.pyplot as plt

cifar10 = "./cifar-10-batches-py/"

parser = argparse.ArgumentParser("Plot training images in cifar10 dataset")
parser.add_argument("-i", "--image", type=int, default=0, 
                    help="Index of the image in cifar10. In range [0, 49999]")
args = parser.parse_args()


def unpickle(file):
    with open(file, 'rb') as fo:
        data = pickle.load(fo, encoding='bytes')
    return data

def cifar10_plot(data, meta, im_idx=0):
    im = data[b'data'][im_idx, :]
    
    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    img = np.dstack((im_r, im_g, im_b))

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])         
    
    plt.imshow(img) 
    plt.show()


def main():
    batch = (args.image // 10000) + 1
    idx = args.image - (batch-1)*10000

    data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch)))
    meta = unpickle(os.path.join(cifar10, "batches.meta"))

    cifar10_plot(data, meta, im_idx=idx)

    
if __name__ == "__main__":
    main()

Solution 4

I have used the following code to show all CIFAR data as one big image. The code show the image, but if you want to save it and not be blurtry i sugest using plt.savefig(fname, format='png', dpi=1000)

import numpy as np
import matplotlib.pyplot as plt

def reshape_and_print(self, cifar_data):
    # number of images in rows and columns
    rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
    # Image hight and width. Divide by 3 because of 3 color channels
    imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
    # reshape to number of images X color channels X image size
    # transpose to color channels X number of images X image size
    timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
    # reshape to color channels X rows X cols X image hight X image with
    # swap axis to color channels X rows X image hight X cols X image with
    timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
    # reshape to color channels X combined image hight X combined image with
    # transpose to combined image hight X combined image with X color channels
    timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)

    plt.imshow(timg)
    plt.show()

I made a quick data helper class that i used for a small test project, I hope is can be useful:

import gzip
import pickle
import numpy as np
import matplotlib.pyplot as plt


class DataSet(object):

    def __init__(self, seed=42, setsize=10000):
        self.seed = seed
        # set the seed for reproducability
        np.random.seed(seed)
        # load the data
        train_set, test_set = self.load_data()
        # self.split_data(train_set, valid_set, test_set)
        self.split_data(train_set, test_set, setsize)

    def split_data(self, data_set, test_set, split_size):
        permutation = np.random.permutation(data_set.shape[0])
        self.train = data_set[permutation[:split_size]]
        self.valid = data_set[permutation[split_size:split_size * 2]]
        self.test = test_set[:split_size]

    def reshape_for_print(self, data):
        raise NotImplemented

    def load_data(self):
        raise NotImplemented

    def show_all_imgs(self, data):
        raise NotImplemented


class CIFAR(DataSet):

    def load_data(self):
        # try to load data
        with open('./data/cifar-100-python/train', 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        train_set = data['data'].astype(np.float32) / 255.0

        with open('./data/cifar-100-python/test', 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        test_set = data['data'].astype(np.float32) / 255.0

        return train_set, test_set

    def reshape_for_print(self, data):
        gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
        imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
        timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
        timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
        timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
        return timg

    def show_all_imgs(self, data):
        timg = self.reshape_for_print(data)
        plt.imshow(timg)
        plt.show()


class MNIST(DataSet):

    def load_data(self):
        # try to load data
        with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
            train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
        return train_set[0], test_set[0]

    def reshape_for_print(self, data):
        gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
        imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
        timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
        timg = timg.reshape(gh * imh, gw * imw)
        return timg

    def show_all_imgs(self, data):
        timg = self.reshape_for_print(data)
        plt.imshow(timg, cmap=plt.cm.gray)
        plt.show()

Solution 5

I made a function to plot the RGB image from a row in the CIFAR10 dataset.The image will be blurry at best since the original size of the image is very small (32px X 32px).

sample image

def unpickle(file):
    with open(file, 'rb') as fo:
        dict1 = pickle.load(fo, encoding='bytes')
    return dict1

pd_tr = pd.DataFrame()
tr_y = pd.DataFrame()

for i in range(1,6):
    data = unpickle('data/data_batch_' + str(i))
    pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
    tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
    pd_tr['labels'] = tr_y

tr_x = np.asarray(pd_tr.iloc[:, :3072])
tr_y = np.asarray(pd_tr['labels'])
ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
labels = unpickle('data/batches.meta')[b'label_names']

def plot_CIFAR(ind):
    arr = tr_x[ind]
    sc_dpi = 157.35
    R = arr[0:1024].reshape(32,32)/255.0
    G = arr[1024:2048].reshape(32,32)/255.0
    B = arr[2048:].reshape(32,32)/255.0

    img = np.dstack((R,G,B))
    title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111)
    ax.imshow(img,interpolation='bicubic')
    ax.set_title('Category = '+ title,fontsize =15)

plot_CIFAR(4)
Share:
27,117
Siddharth
Author by

Siddharth

Updated on July 09, 2022

Comments

  • Siddharth
    Siddharth almost 2 years

    From the training set I took a image('img') of size (3,32,32). I have used plt.imshow(img.T). The image is not clear. Now changes I have to make to image('img') to make it more clearly visible. Thanks.

    This is the image I got

  • Siddharth
    Siddharth about 8 years
    It is still not working.Just to verify the img in last line has shape (3, 32, 32).
  • bogatron
    bogatron about 8 years
    If the dimension with length 3 is the RGB axis, then you want it at the end. If your array has shape (3, R, C), then np.transpose(img, (1, 2, 0)) will have shape (R, C, 3).
  • Siddharth
    Siddharth about 8 years
    Fine, still I am not able to see a good image.Any suggestions ?
  • bogatron
    bogatron about 8 years
    If the imshow image still appears blurry as in your linked image, then it is not using "nearest" interpolation. You should be able to use the object returned by imshow to verify what interpolation it is using. Depending on your matplotlib version, you could also try "none" for the interpolation value.
  • Wolf
    Wolf about 7 years
    Please elaborate your suggestion.
  • Hizqeel
    Hizqeel about 7 years
    Please elaborate what this is trying to do.
  • saurabh kumar
    saurabh kumar about 7 years
    I added a little explanation, hope it helps
  • John
    John about 7 years
    This is good information but astype("float") will show the image as a negative setting it to astype("uint8") is normal.
  • Pietro Marchesi
    Pietro Marchesi over 6 years
    The indexing in the last line could just be X[i, :] right?
  • ceaserg
    ceaserg over 3 years
    I liked this approach :)
  • jared3412341
    jared3412341 about 3 years
    That was exactly my problem, I was looking at the images after normalizing them. Thanks!