Importing pytorch models in matlab using importNetworkFromPyTorch
86 views (last 30 days)
Show older comments
Mohammed Saifur
on 10 May 2023
Answered: MathWorks Deep Learning Toolbox Team
on 16 Apr 2024
Hello,
I am trying to import the pre-trained pytorch model in matlab using the importNetworkFromPyTorch command supported by deep learning toolbox. However I am getting an error as below
Error using pytorchmex
Traced model failed to load. Trace the model in the fully supported version of PyTorch as described in Deep
Learning Toolbox Converter for PyTorch Models.
Error in nnet.internal.cnn.pytorch_importer.architecture.ModuleManager/loadModule (line 28)
PropertyCell = nnet.internal.cnn.pytorch_importer.architecture.pytorchmex(this.Filename);
Error in nnet.internal.cnn.pytorch_importer.architecture.ModuleManager (line 16)
PropertyCell = loadModule(this);
Error in nnet.internal.cnn.pytorch_importer.architecture.util.translatePyTorchFile (line 9)
nnet.internal.cnn.pytorch_importer.architecture.ModuleManager(filename);
Error in nnet.internal.cnn.pytorch_importer.architecture.importNetworkFromPyTorch (line 18)
importManager = nnet.internal.cnn.pytorch_importer.architecture.util.translatePyTorchFile(filename,
customLayerPath);
Error in importNetworkFromPyTorch (line 36)
Network = nnet.internal.cnn.pytorch_importer.architecture.importNetworkFromPyTorch(modelfile, varargin{:});
Error in mnist2mat (line 1)
net = importNetworkFromPyTorch("mnist_cnn.pt");
0 Comments
Accepted Answer
MathWorks Deep Learning Toolbox Team
on 16 Apr 2024
The model must be traced in PyTorch first before importing into MATLAB. Please see PyTorch documentation fo some details on how it's done. https://pytorch.org/docs/stable/generated/torch.jit.trace.html
You can also read this blog post for additional information: https://blogs.mathworks.com/deep-learning/2022/10/04/whats-new-in-interoperability-with-tensorflow-and-pytorch/
As a simple example, try something similar to the following in PyTorch:
# This example loads a pretrained PyTorch model from torchvision,
# traces it with example inputs, and saves the trace as a .pt file.
import torch
from torchvision import models
# Load the model with pretrained weights
model = models.mobilenet_v2(pretrained=True)
# Call "eval" to ensure that layers like batch norm and dropout are set to
# inference mode
model.eval()
# Move the model to the CPU
model.to("cpu")
# Create example inputs
X = torch.rand(1, 3, 224, 224)
# Trace model with the example input
traced_model = torch.jit.trace(model.forward, X)
# Save the traced model to a .pt file
traced_model.save('traced_mnasnet.pt')
0 Comments
More Answers (0)
See Also
Categories
Find more on Pretrained Networks from External Platforms in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!