Converting Pytorch model .pth into onnx model

14,887

try changing your code to this

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

dummy_input = Variable(torch.randn(1, 3, 256, 256))
state_dict = torch.load('./my_model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")
Share:
14,887

Related videos on Youtube

Urvish
Author by

Urvish

Updated on June 04, 2022

Comments

  • Urvish
    Urvish over 1 year

    I have one pre-trained model into format of .pth extension. I want to convert that into Tensorflow protobuf. But I am not finding any way to do that. I have seen onnx can convert models from pytorch into onnx and then from onnx to Tensorflow. But with that approach I got following error in the first stage of conversion.

    from torch.autograd import Variable
    import torch.onnx
    import torchvision
    import torch 
    
    dummy_input = Variable(torch.randn(1, 3, 256, 256))
    model = torch.load('./my_model.pth')
    torch.onnx.export(model, dummy_input, "moment-in-time.onnx")`
    

    It gives error like this.

    File "t.py", line 9, in <module>
        torch.onnx.export(model, dummy_input, "moment-in-time.onnx")
      File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 75, in export
        _export(model, args, f, export_params, verbose, training)
      File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 108, in _export
        orig_state_dict_keys = model.state_dict().keys()
    AttributeError: 'dict' object has no attribute 'state_dict'
    

    What is possible solution ?

    • layog
      layog over 5 years
      Your .pth file is a state dictionary and not the complete model. You will first need to create a model and then load that state dictionary and then start your conversion process. Check this answer
    • Urvish
      Urvish over 5 years
      The approach shown in that requires to write the model. but I am having pre-trained model and I do not know the exact architecture of it. so I can not define model as done in that answer. what should I do?
    • layog
      layog over 5 years
      Then it gets really hard to determine the architecture. you can guess the architecture by seeing the parameters size, but guessing the correct architecture is really difficult even after looking at the size since residual networks will have same sized parameters as non-residual ones. Your best bet is to get the architecture definition from your pretrained weights source
    • Urvish
      Urvish over 5 years
      Okay. Let's see if I can get that. Thank you for the help.also if I have .pth.tar file then also will the process be the same or changed ?
  • Urvish
    Urvish over 5 years
    How to do that? Please provide some guidance.
  • user129916
    user129916 over 5 years
    When you define your class edit the header to: class ModelName(nn.Module)