Converting Pytorch model .pth into onnx model
14,887
try changing your code to this
from torch.autograd import Variable
import torch.onnx
import torchvision
import torch
dummy_input = Variable(torch.randn(1, 3, 256, 256))
state_dict = torch.load('./my_model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")
Related videos on Youtube
Author by
Urvish
Updated on June 04, 2022Comments
-
Urvish over 1 year
I have one pre-trained model into format of .pth extension. I want to convert that into Tensorflow protobuf. But I am not finding any way to do that. I have seen onnx can convert models from pytorch into onnx and then from onnx to Tensorflow. But with that approach I got following error in the first stage of conversion.
from torch.autograd import Variable import torch.onnx import torchvision import torch dummy_input = Variable(torch.randn(1, 3, 256, 256)) model = torch.load('./my_model.pth') torch.onnx.export(model, dummy_input, "moment-in-time.onnx")`
It gives error like this.
File "t.py", line 9, in <module> torch.onnx.export(model, dummy_input, "moment-in-time.onnx") File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 75, in export _export(model, args, f, export_params, verbose, training) File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 108, in _export orig_state_dict_keys = model.state_dict().keys() AttributeError: 'dict' object has no attribute 'state_dict'
What is possible solution ?
-
layog over 5 yearsYour
.pth
file is a state dictionary and not the complete model. You will first need to create a model and then load that state dictionary and then start your conversion process. Check this answer -
Urvish over 5 yearsThe approach shown in that requires to write the model. but I am having pre-trained model and I do not know the exact architecture of it. so I can not define model as done in that answer. what should I do?
-
layog over 5 yearsThen it gets really hard to determine the architecture. you can guess the architecture by seeing the parameters size, but guessing the correct architecture is really difficult even after looking at the size since residual networks will have same sized parameters as non-residual ones. Your best bet is to get the architecture definition from your pretrained weights source
-
Urvish over 5 yearsOkay. Let's see if I can get that. Thank you for the help.also if I have .pth.tar file then also will the process be the same or changed ?
-
-
Urvish over 5 yearsHow to do that? Please provide some guidance.
-
user129916 over 5 yearsWhen you define your class edit the header to: class ModelName(nn.Module)