Main Content

Tips on Importing Models from TensorFlow, PyTorch, and ONNX

This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.

Import Functions of Deep Learning Toolbox

This table lists the Deep Learning Toolbox import functions. Use these functions to import networks from TensorFlow, PyTorch, and ONNX.

You must have the relevant support package to run these import functions. If the support package is not installed, each function provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.

Autogenerated Custom Layers

  • The importNetworkFromTensorFlow, importNetworkFromPyTorch, and importNetworkFromONNX functions can automatically generate custom layers, or custom layers with placeholder functions, when you import TensorFlow, PyTorch, or ONNX layers that the software cannot convert into equivalent built-in MATLAB functions or layers.

  • The importNetworkFromTensorFlow, importNetworkFromPyTorch, and importNetworkFromONNX functions import an external platform layer into MATLAB by trying these steps in order:

    1. The function imports the external layer as a built-in MATLAB layer.

    2. The function imports the external layer as a built-in MATLAB function (for TensorFlow and PyTorch only).

    3. The function imports the external layer as a custom layer.

    4. The function imports the external layer as a custom layer with a placeholder function.

For more information about custom layer generation, see the Algorithms section of each function: Algorithms (TensorFlow), Algorithms (PyTorch), and Algorithms (ONNX).

Input Dimension Ordering

The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX. This table compares input dimension ordering between platforms for different input types.

Input TypeDimension Ordering
MATLABTensorFlowPyTorchONNX
FeaturesCNNCNCNC
2-D imageHWCNNHWCNCHWNCHW
3-D imageHWDCNNHWDCNCDHWNCHWD
Vector sequenceCSNNSCSNCNSC
2-D image sequenceHWCSNNSWHCNCSHWNSCHW
3-D image sequenceHWDCSNNSWHDCNCSDHWNSCHWD

Variable names in the table:

  • N — Number of observations

  • C — Number of features or channels

  • H — Height of images

  • W — Width of images

  • D — Depth of images

  • S — Sequence length

Data Formats for Prediction with dlnetwork

The importNetworkFromTensorFlow function imports a TensorFlow network as an initialized dlnetwork object. For an example, see Import TensorFlow Network and Classify Image. If the network does not have fixed input size, the function imports the model as an uninitialized dlnetwork object without an input layer. For an example about how to initialize this network, see Import and Initialize TensorFlow Network.

The importNetworkFromPyTorch function imports a PyTorch network as an uninitialized or initialized dlnetwork object. If the imported network is uninitialized, before you use the network, do one of the following:

A PyTorch network can be imported as an initialized dlnetwork object by using the PyTorchInputSizes name-value argument. For an example, see Import Network from PyTorch using PyTorchInputSizes.

The importNetworkFromONNX function imports an ONNX network as an initialized dlnetwork object. For an example, see Import ONNX Network and Classify Image.

To predict using a dlnetwork object, you must convert the input data to a dlarray object with the appropriate data format. For an example, see Import TensorFlow Network and Classify Image. Use this table to choose the right data format for each input type and layer.

Input TypeInput Layer **Input Format *
FeaturesfeatureInputLayerCB
2-D imageimageInputLayerSSCB
3-D imageimage3dInputLayerSSSCB
Vector sequencesequenceInputLayerCBT
2-D image sequencesequenceInputLayerSSCBT
3-D image sequencesequenceInputLayerSSSCBT

* In Deep Learning Toolbox, each data format must be one of these labels:

  • S — Spatial

  • C — Channel

  • B — Batch observations

  • T — Time or sequence

  • U — Unspecified

** A dlnetwork object does not require an input layer. The network can infer the input layer type from the input data format.

For more information on data formats, see dlarray.

Input Data Preprocessing

To use a pretrained network for prediction or transfer learning on new images, you must preprocess your images in the as same way the images that were used to train the imported model. The most common preprocessing steps are resizing images, subtracting image average values, and converting the images from BGR format to RGB format.

  • To resize images, use imresize. For example, imresize(image,[227 227 3]).

  • To convert images from RGB to BGR format, use flip. For example, flip(image,3).

For more information about preprocessing images for training and prediction, see Preprocess Images for Deep Learning.

See Also

| | |

Related Topics

External Websites