Why does Matlab work with any spatial input dimensions in an ONNX network, but other programming languages do not?

6 views (last 30 days)
Hello,
I work with an ONNX network to segment nerves in biological volume images. The model works with a 2.5-dimensional approach, i.e., to detect nerves in an image, additional information from 4 images above and below the target image is used from the volume. In other words, a stack of 9 images is used to segment the nerves in the middle image.
The images in our volumes often have more than 2000x2000 pixels. The model was trained by an external person and 512x512x9 blocks (spatial, spatial, channel) were used. This person told us that the input must also be this size. So, we would first have to split our images into 512x512 tiles, segment them with the ONNX model, and then recombine them.
However, this procedure usually results in artifacts at the edges of the tiles in the nerve segmentation.
When I tried a few things to solve this problem, I once accidentally gave the model a 1024x512x9 block as input. Surprisingly, this did not result in an error message and an output with segmented nerves was generated. It also didn't look like two blocks were created automatically because no artifact was in the middle.
Then, I also tried it on the full-size images, and it worked without any problems. However, I get an error if I use blocks with a channel size other than 9 as input.
It also doesn't look like an automatic resize, as some nerve structures are way too fine and would be lost when resizing from full size to 512x512.
We have also tested this ONNX model in Mathematica and Python. But we get an error in both if we deviate from the 512x512.
I'm also not sure if it's this ONNX model specifically, because even in the Matlab example from LINK with the peppers, I don't get any errors if I don't resize the image. (Funny note: If you resize the pepper image to 1024x1024, the predicted label will be "velvet" instead of "bell pepper.")
So my question is: What is Matlab doing there? Or is there any possibility to look at what is going on there?

Accepted Answer

Conor Daly
Conor Daly on 20 Sep 2024
Thanks Sebastian for this great question!
Here's how I think about it. There are two reasons why we might want to restrict the inputs of a neural network to a certain size, say, S:
  1. The network will error for anything other than inputs of size S, so we restrict to S to avoid hard error. For example, a multi-layer perceptron which is configured to a specific numChannels will typically error for anything other than those same numChannels because the first fully connected layer expects inputs of that size. The same is normally true for the numChannels of image inputs.
  2. The network was trained on inputs of size S -- so we expect the network to behave well for test inputs of the same size. We restrict to S to ensure the network performs well. This is related to keeping the test data in-distribution. It sounds like this is the situation you are describing. It could be that networks are mathematically capable of processing inputs with different sizes... but we expect that such inputs are essentially out-of-distribution wrt the network's training set, so we don't assume the model will behave well for these inputs. This is also probably what's happening when you resize the peppers image to 1024x1024 -- the resized image is at too high a resolution to be classified correctly by a network that was trained with a lower resolution.
The issue with 2) is that it's just not always the case. For example, FCNs for semantic segmentation are specifically designed to handle arbitrary spatial input sizes: https://paperswithcode.com/method/fcn.
This is why we don't enforce any restructions in the image input size in MATLAB. If 1) is true, the network will throw an error from wherever the network fails to process the input. But the network itself can't really know whether 2) is true without a bit more information.
It could be that the model errors in Mathematica and Python because, in those frameworks, there is a strict check on the image input size? We can always do this in MATLAB too by adding a custom layer which throws an error when the input is not a compatible size. For example:
>> layers = [ imageInputLayer([512 512 3])
functionLayer(@(x)verifySpatialDimensions(x,[512 512]), Formattable=true) ];
>> net = dlnetwork(layers);
>> xgood = rand([512 512 3]);
>> xbad = rand([1000 512 3]);
>> y = predict(net, xgood);
>> y = predict(net, xbad);
Error using dlnetwork/predict (line 710)
Execution failed during layer(s) "layer".
Error in
y = predict(net, xbad);
^^^^^^^^^^^^^^^^^^
Caused by:
Error using
verifySpatialDimensions (line 14)
Bad luck! Please insert an image with appropriate size.
function x = verifySpatialDimensions(x, expectedSpatialDimensions)
spatialDims = size(x, finddim(x,'S'));
if ~isequal(spatialDims, expectedSpatialDimensions)
error('Bad luck! Please insert an image with appropriate size.')
end
end

More Answers (1)

Alex Taylor
Alex Taylor on 21 Sep 2024
Edited: Alex Taylor on 21 Sep 2024
To add to @Conor Daly's answer, I just want to note that in the original https://arxiv.org/pdf/1505.04597 U-net paper, a tiled training, full-szed strategy was described. In the paper, they describe using multiple inference calls and a "valid" convolution strategy to avoid the introduction of seams at tile boundaries. This allows U-net to perform segmentation on arbitrarily large images.
Because FCNs like U-net are defined entirely with convolution and elementwise operations, you can also do what you described with your segmentation network where you do the entire inference in one inference call because the convolution and elementwise operations are well-defined on a larger spatial domain. The only restriction for FCNs with segmentation computationally is whether the entire inference call fits in CPU/GPU memory at one time. Otherwise you need multiple inference calls like you described.
Tiled training tends to work best when a local amount of information is sufficient to define the segmentation (e.g. cell segmenation in microscopy) as opposed to situations where knowledge of the global scene during training is useful (e.g. driving the sky is generally up, road is generally closer to the bottom of the frame).
It sounds to me like in your segmentation network it may be valid to rely on a tiled training, full-sized inference approach and I just wanted to note that this kind of workflow is well known and often useful.

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!