Python matplotlib, invalid shape for image data

10,256

The matplotlib function 'imshow' gets 3-channel pictures as (h, w, 3) as you can see in the documentation.

It seems that you passed a "batch" of single image (the first dimention) of three channels (second dimention) of the image (h and w are the third and forth dimention).

You need to reshape or view your image (after converting to cpu, try to use:

image1.squeeze().permute(1,2,0)

The result will be an image of the desired shape (128, 128, 3).

The squeeze() function will remove the first dimention. And the premute() function will transpose the dimenstion where the first will shift to the third position and the two other will shift to the beginning.

Also, have a look here for further talk on the GPU and CPU issues: link

Hope that helps.

Share:
10,256
Hekes Pekes
Author by

Hekes Pekes

Updated on July 19, 2022

Comments

  • Hekes Pekes
    Hekes Pekes almost 2 years

    Currently I have this code to show three images:

    imshow(image1, title='1')
    imshow(image2, title='2')
    imshow(image3, title='3')
    

    And it works fine. But I am trying to put them all three in a row instead of column.

    Here is the code I have tried:

    f = plt.figure()
    f.add_subplot(1,3,1)
    plt.imshow(image1)
    f.add_subplot(1,3,2)
    plt.imshow(image2)
    f.add_subplot(1,3,3)
    plt.imshow(image3)
    

    It throws

    TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

    If I do

    f = plt.figure()
    f.add_subplot(1,3,1)
    plt.imshow(image1.cpu())
    f.add_subplot(1,3,2)
    plt.imshow(image2.cpu())
    f.add_subplot(1,3,3)
    plt.imshow(image3.cpu())
    

    It throws

    TypeError: Invalid shape (1, 3, 128, 128) for image data

    How should I fix this or is there an easier way to implement it?

  • Hekes Pekes
    Hekes Pekes about 4 years
    Thank you that helped!