Compare LDA Solvers

This example shows how to compare latent Dirichlet allocation (LDA) solvers by comparing the goodness of fit and the time taken to fit the model.

To reproduce the results of this example, set rng to 'default'.


Extract and Preprocess Text Data

Load the example data. The file weatherReports.csv contains weather reports, including a text description and categorical labels for each event. Extract the text data from the field event_narrative.

filename = "weatherReports.csv";
data = readtable(filename,'TextType','string');
textData = data.event_narrative;

Set aside 10% of the documents at random for validation.

numDocuments = numel(textData);
cvp = cvpartition(numDocuments,'HoldOut',0.1);
textDataTrain = textData(training(cvp));
textDataValidation = textData(test(cvp));

Tokenize and preprocess the text data using the function preprocessText which is listed at the end of this example.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Create a bag-of-words model from the training documents. Remove the words that do not appear more than two times in total. Remove any documents containing no words.

bag = bagOfWords(documentsTrain);
bag = removeInfrequentWords(bag,2);
bag = removeEmptyDocuments(bag);

Fit and Compare Models

For each of the LDA solvers, fit an LDA model with 60 topics. To distinguish the solvers when plotting the results on the same axes, specify different line properties for each solver.

numTopics = 60;
solvers = ["cgs" "avb" "cvb0" "savb"];
lineSpecs = ["+-" "*-" "x-" "o-"];

For the validation data, create a bag-of-words model from the validation documents.

validationData = bagOfWords(documentsValidation);

For each of the LDA solvers, fit the model, set the initial topic concentration to 1, and specify not to fit the topic concentration parameter. Using the data in the FitInfo property of the fitted LDA models, plot the validation perplexity and the time elapsed. Plot the time elapsed in a logarithmic scale. This can take up to an hour to run.

The code for removing NaNs is necessary because of a quirk of the stochastic solver 'savb'. For this solver, the function evaluates the validation perplexity after each pass of the data. The function does not evaluate the validation perplexity for each iteration (mini-batch) and reports NaNs in the FitInfo property. To plot the validation perplexity, remove the NaNs from the reported values.

for i = 1:numel(solvers)
    solver = solvers(i);
    lineSpec = lineSpecs(i);
    mdl = fitlda(bag,numTopics, ...
        'Solver',solver, ...
        'InitialTopicConcentration',1, ...
        'FitTopicConcentration',false, ...
        'ValidationData',validationData, ...
    history = mdl.FitInfo.History;
    timeElapsed = history.TimeSinceStart;
    validationPerplexity = history.ValidationPerplexity;
    % Remove NaNs.
    idx = isnan(validationPerplexity);
    timeElapsed(idx) = [];
    validationPerplexity(idx) = [];
    hold on
hold off
xlabel("Time Elapsed (s)")
ylabel("Validation Perplexity")

For the stochastic solver "savb", the function, by default, passes through the training data once. To process more passes of the data, set 'DataPassLimit' to a larger value (the default value is 1). For the batch solvers ("cgs", "avb", and "cvb0"), to reduce the number of iterations used to fit the models, set the 'IterationLimit' option to a lower value (the default value is 100).

A lower validation perplexity suggests a better fit. Usually, the solvers "savb" and "cgs" converge quickly to a good fit. The solver "cvb0" might converge to a better fit, but it can take much longer to converge.

For the FitInfo property, the fitlda function estimates the validation perplexity from the document probabilities at the maximum likelihood estimates of the per-document topic probabilities. This is usually quicker to compute, but can be less accurate than other methods. Alternatively, calculate the validation perplexity using the logp function. This function calculates more accurate values but can take longer to run. For an example showing how to compute the perplexity using logp, see Calculate Document Log-Probabilities from Word Count Matrix.

Preprocessing Function

The function preprocessText performs the following steps:

  1. Tokenize the text using tokenizedDocument.

  2. Lemmatize the words using normalizeWords.

  3. Erase punctuation using erasePunctuation.

  4. Remove a list of stop words (such as "and", "of", and "the") using removeStopWords.

  5. Remove words with 2 or fewer characters using removeShortWords.

  6. Remove words with 15 or more characters using removeLongWords.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Lemmatize the words.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','lemma');

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove a list of stop words.
documents = removeStopWords(documents);

% Remove words with 2 or fewer characters, and words with 15 or greater
% characters.
documents = removeShortWords(documents,2);
documents = removeLongWords(documents,15);


See Also

| | | | | | | | | | | | |

Related Topics