Clear Filters
Clear Filters

Deep Neural Net for Regression

4 views (last 30 days)
Pappu Murthy
Pappu Murthy on 12 Feb 2022
Answered: Manas on 5 Oct 2023
I am training a NN with 19 inputs and two outputs. I have 1000 random observations, and I am using the first 800 for training, then 100 additional for validation and the remaining for testing. I am using three hidden layers with 15 nodes each and using tanh for actiation function. I have tried to duplicate all the parameters similar to a Python version. The Python version quickly converges to a MSE of about 0.03. However, the Matlab version is giving a MSE around 10 times the error in Python. I have used every parameter same (min batch size = 32, solver = adam, Learning Rate constant at 0.001 etc. Any help is appreciated. The question is why I am not able to train to the same degress of accuracy. I can provide the scripts and data if any one wants.

Answers (1)

Manas
Manas on 5 Oct 2023
Hello Pappu Murthy,
I understand that you are facing trouble explaining why the same neural network is converging faster when run in python and not in MATLAB.
There could be multiple reasons as to why this is happening. Since I don’t have the script, data and configurations used in MATLAB and python, I can only speculate as to why this could be happening.
One possible explanation for this behaviour could be the random network initialization before training. Another reason could be that the 5-layer network is not able to provide consistent outputs for this problem. The optimal size and training hyperparameters for a neural network depend on factors such as complexity of the problem and the amount of quality data available. An ablation study would help arrive at a concrete conclusion.
To identify the specific cause of discrepancy, the code and the configurations used in both MATLAB and python implementations need to be compared. Kindly verify if the issue persists after importing the network to MATLAB. Here is a MATLAB documentation on importing networks from external platforms: https://in.mathworks.com/help/deeplearning/networks-from-external-platforms.html?s_tid=CRUX_lftnav
Hope this helps!

Products


Release

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!