- Compute loss/gradient for the Discriminator
- Update the Discriminator
- Compute loss/gradient for the Generator (using updated Discriminator)
- Update the Generator

11 views (last 30 days)

Show older comments

Hi there,

I am trying to train a GAN. By exploring MATLAB's official example, I realised the following

gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables,'RetainData',true);

gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);

And after reading the help of dlgradient(...), I have the following questions:

- What is derivative trace in dlgradient function? Consider a two-layered dlnetwork, in which

z=W*input+B;

output = sigmoid(z);

targetOutput = 1 * ones(size(z));

Cost = 0.5*mean(targetOutput-output).^2;

So my guess is that the derivative trace is del(Cost)/del(z)=-(targetOutput-output).*sigmoid(z).*(1-simoid(z), del(Cost)/del(input)=W'*del(Cost)/del(z), etc., is that correct? Or dose it indicate something else? May anyone tell me?

2. If my guess is correct, when I train a GAN and perform dlgradient for the discriminator and the generator in the same dlfeval, will it be the same if I calculate derivatives of the discriminator first? For example

gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables,'RetainData',true);

gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator.Learnables);

Because when calculating gradients in the generator, the W's and B's in the discriminator remain unchanged.

3. As I can see in many GAN papers, the key to a successful training of an GAN is that the generator and the discriminator are trained separately, that is, the first set of synthetic (fake) images goes through both the discriminator and the generator, and the discriminator is trained by its cost together with the cost caused by real images so that W's and B's in the discriminator get updated. Then the second set of synthetic images goes through both, and the generator is trained by its cost so that ONLY W's and B's in the generator are updated. In Keras of Python, parameters of a model can be set not trainable explictly. In MATLAB, how can I make sure that it is EXACTLY what happens?

Thanks a lot.

Gautam Pendse
on 14 Jan 2020

Hi Theron,

Re: 1. What is derivative trace in dlgradient function?

** Derivative trace is essentially the history containing a sequence of operations that were executed when computing a given set of values. See this doc page for more info (middle of the page): https://www.mathworks.com/help/deeplearning/ug/include-automatic-differentiation.html

Re: 2. If my guess is correct, when I train a GAN and perform dlgradient for the discriminator and the generator in the same dlfeval, will it be the same if I calculate derivatives of the discriminator first?

** Yes, switching the order of the two dlgradient calls should give the same gradients.

Re: 3. As I can see in many GAN papers, the key to a successful training of an GAN is that the generator and the discriminator are trained separately, that is, the first set of synthetic (fake) images goes through both the discriminator and the generator, and the discriminator is trained by its cost together with the cost caused by real images so that W's and B's in the discriminator get updated. Then the second set of synthetic images goes through both, and the generator is trained by its cost so that ONLY W's and B's in the generator are updated.

** The MATLAB GAN example uses simultaneous gradient descent for optimization. I think your description above refers to alternating gradient descent - another optimization method for GANs. Both methods are described in this paper: https://arxiv.org/abs/1705.10461.

To implement alternating gradient descent, the modelGradients function in MATLAB GAN example can be split into two functions - one computing the loss/gradient for the Discriminator only and the other computing the loss/gradient for the Generator only. Then the following gradient calculation/update sequence can be used:

- Compute loss/gradient for the Discriminator
- Update the Discriminator
- Compute loss/gradient for the Generator (using updated Discriminator)
- Update the Generator

Hope that helps,

Gautam

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!