Physical Informed Neural Network - Identify coefficient of loss function

15 views (last 30 days)
Is it possible in MATLAB to train a PINN to find also unknown parameters of the physical loss function.
In this case it is presented how to manage a PINN for projectile motion and drag coefficient is defined as trainable variable even if it is coefficient of the the loss function.

Answers (1)

Ben
Ben on 18 Sep 2023
Yes this is possible, you can make the coefficient μ into a dlarray and train it alongside the dlnetwork or other dlarray-s as in https://uk.mathworks.com/help/deeplearning/ug/solve-partial-differential-equations-using-deep-learning.html
There's also some discussion here where I previously gave some details on this: https://uk.mathworks.com/matlabcentral/answers/1783690-physics-informed-nn-for-parameter-identification
Here's a simple example to use an inverse PINN to find μ in from solution data.
% Inverse PINN for d2x/dt2 = mu*x
%
% For mu<0 it's known the solution is
% x(t) = a*cos(sqrt(-mu)*t) + b*sin(sqrt(-mu)*t);
%
% where a, b are determined by initial conditions.
%
% Let's fix a = 1, b = 0 and train to learn mu from the solution data x
% Set some value for mu and specify the true solution function to generate data to train on.
% In practice these values are unknown.
muActual = -rand;
x = @(t) cos(sqrt(-muActual)*t);
% Create training data - in practice you might get this data elsewhere, e.g. from sensors on a physical system or model.
% Evaluate x(t) at uniformly spaced t in [0,maxT].
% Choose maxT such that a full wavelength occurs for x(t).
maxT = 2*pi/sqrt(-muActual);
t = dlarray(linspace(0,maxT,batchSize),"CB");
xactual = dlarray(x(t),"CB");
% Specify a network and initial guess for mu as parameters to train
net = [
featureInputLayer(1)
fullyConnectedLayer(100)
tanhLayer
fullyConnectedLayer(100)
tanhLayer
fullyConnectedLayer(1)];
params.net = dlnetwork(net);
params.mu = dlarray(-0.5);
% Specify training configuration
numEpochs = 5000;
avgG = [];
avgSqG = [];
batchSize = 100;
lossFcn = dlaccelerate(@modelLoss);
clearCache(lossFcn);
lr = 1e-3;
% Train
for i = 1:numEpochs
[loss,grad] = dlfeval(lossFcn,t,xactual,params);
[params,avgG,avgSqG] = adamupdate(params,grad,avgG,avgSqG,i,lr);
if mod(i,1000)==0
fprintf("Epoch: %d, Predicted mu: %.3f, Actual mu: %.3f\n",i,extractdata(params.mu),muActual);
end
end
function [loss,grad] = modelLoss(t,x,params)
% Implement the PINN loss by predicting x(t) via params.net and computing the derivatives with dlgradient
xpred = forward(params.net,t);
% Here we sum xpred over the batch dimension to get a scalar.
% This is required since dlgradient only takes gradients of scalars.
% This works because xpred(i) depends only on t(i)
% - i.e. the network-s forward pass vectorizes over the batch dimension
dxdt = dlgradient(sum(xpred),t,EnableHigherDerivatives=true);
d2xdt2 = dlgradient(sum(dxdt),t);
odeResidual = d2xdt2 - params.mu*xpred;
% Compute the mean square error of the ODE residual.
odeLoss = mean(odeResidual.^2);
% Compute the L2 difference between the predicted xpred and the true x.
dataLoss = l2loss(x,xpred);
% Sum the losses and take gradients.
% Note that by creating the grad as a struct with fields matching params we can
% easily pass grad to adamupdate in the training loop.
loss = odeLoss + dataLoss;
[grad.net,grad.mu] = dlgradient(loss,params.net.Learnables,params.mu);
end
This script should run reasonably quickly and usually approximates μ well.
In general you will need to:
  1. Modify the PINN loss in the modelLoss function for your particular ODE or PDE.
  2. Modify the params to include both the dlnetwork and any additional coefficients.
  3. Tweak the net design and training configuration to achieve a good loss.
  3 Comments
Ben
Ben on 21 Nov 2023
@Giulio I don't get an error when running that script.
The error stack in your comment notes complex values which aren't supported by adamupdate. I'm not sure how complex values would appear in your script though. Is this reproducible every iteration or training? If so - you could remove the dlaccelerate calls, and place a breakpoint in modelLoss to try to identify where the values become complex.
Giulio
Giulio on 21 Nov 2023
Thank you for your quick reply.
Could you confirm to me that it's not expected to generate a complex gradient, is it? I think it should be related to a wrong setting of the loss function.

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!