Main Content

Visualize Correlations Between LDA Topics and Document Labels

This example shows how to fit a Latent Dirichlet Allocation (LDA) topic model and visualize correlations between the LDA topics and document labels.

A Latent Dirichlet Allocation (LDA) model is a topic model which discovers underlying topics in a collection of documents and infers the word probabilities in topics. Fitting an LDA model does not require labeled data. However, you can visualize correlations between the fitted LDA topics and the document labels using a parallel coordinates plot.

This example fits an LDA model to the Factory Reports data set which is a collection of factory reports detailing different failure events and identifies correlations between the LDA topics and the report category.

Load and Extract Text Data

Load the example data. The file factoryReports.csv contains factory reports, including a text description and categorical labels for each event.

data = readtable("factoryReports.csv",TextType="string");
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Extract the text data from the field Description.

textData = data.Description;
textData(1:10)
ans = 10×1 string
    "Items are occasionally getting stuck in the scanner spools."
    "Loud rattling and banging sounds are coming from assembler pistons."
    "There are cuts to the power when starting the plant."
    "Fried capacitors in the assembler."
    "Mixer tripped the fuses."
    "Burst pipe in the constructing agent is spraying coolant."
    "A fuse is blown in the mixer."
    "Things continue to tumble off of the belt."
    "Falling items from the conveyor belt."
    "The scanner reel is split, it will soon begin to curve."

Extract the labels from the field Category.

labels = data.Category;

Prepare Text Data for Analysis

Create a function which tokenizes and preprocesses the text data so it can be used for analysis. The function preprocessText, listed in the Preprocessing Function section of the example, performs the following steps in order:

  1. Tokenize the text using tokenizedDocument.

  2. Lemmatize the words using normalizeWords.

  3. Erase punctuation using erasePunctuation.

  4. Remove a list of stop words (such as "and", "of", and "the") using removeStopWords.

  5. Remove words with 2 or fewer characters using removeShortWords.

  6. Remove words with 15 or more characters using removeLongWords.

Prepare the text data for analysis using the preprocessText function.

documents = preprocessText(textData);
documents(1:5)
ans = 5×1 tokenizedDocument array with properties:
    6 tokens: item occasionally get stuck scanner spool
    7 tokens: loud rattling bang sound come assembler piston
    4 tokens: cut power start plant
    3 tokens: fry capacitor assembler
    3 tokens: mixer trip fuse

Create a bag-of-words model from the tokenized documents.

bag = bagOfWords(documents)
bag = bagOfWords with properties:
          Counts: [480×338 double]
      Vocabulary: [1×338 string]
        NumWords: 338
    NumDocuments: 480

Remove words from the bag-of-words model that have do not appear more than two times in total. Remove any documents containing no words from the bag-of-words model.

bag = removeInfrequentWords(bag,2);
bag = removeEmptyDocuments(bag)
bag = bagOfWords with properties:
          Counts: [480×158 double]
      Vocabulary: [1×158 string]
        NumWords: 158
    NumDocuments: 480

Fit LDA Model

Fit an LDA model with 7 topics. For an example showing how to choose the number of topics, see Choose Number of Topics for LDA Model. To suppress verbose output, set the Verbose option to 0. For reproducibility, set rng to "default".

rng("default")
numTopics = 7;
mdl = fitlda(bag,numTopics,Verbose=0);

If you have a large dataset, then the stochastic approximate variational Bayes solver is usually better suited as it can fit a good model in fewer passes of the data. The default solver for fitlda (collapsed Gibbs sampling) can be more accurate at the cost of taking longer to run. To use stochastic approximate variational Bayes, set the Solver option to "savb". For an example showing how to compare LDA solvers, see Compare LDA Solvers.

Visualize the topics using word clouds.

figure
t = tiledlayout("flow");
title(t,"LDA Topics")

for i = 1:numTopics
    nexttile
    wordcloud(mdl,i);
    title("Topic " + i)
end

Visualize Correlations Between Topics and Document Labels

Visualize the correlations between the LDA topics and the document labels by plotting the mean topic probabilities against each document label.

Extract the document topic mixtures from the DocumentTopicProbabilities property of the LDA model.

topicMixtures = mdl.DocumentTopicProbabilities;

For the documents with each label, calculate the mean topic probabilities.

[groups,groupNames] = findgroups(labels);
numGroups = numel(groupNames);

for i = 1:numGroups
    idx = groups == i;
    meanTopicProbabilities(i,:) = mean(topicMixtures(idx,:));
end

For each topic, find the top three words.

numTopics = mdl.NumTopics;
for i = 1:numTopics
    top = topkwords(mdl,3,i);
    topWords(i) = join(top.Word,", ");
end

Plot the per-category mean topic probabilities using a parallel coordinates plot. For readability, create a figure and increase the figure width using the Position property.

f = figure;
f.Position(3) = 2*f.Position(3);

Plot the per-category mean topic probabilities using the parallelplot function. Do not normalize the input data and specify the categories as the groups. Set the coordinate tick labels to the top three words of each topic.

p = parallelplot(meanTopicProbabilities, ...
    GroupData=groupNames, ...
    DataNormalization="none");

p.CoordinateTickLabels = topWords;

xlabel("LDA Topic")
ylabel("Mean Topic Probability")
title("LDA Topic and Document Label Correlations")

The parallel plot highlights the correlations between the LDA topics and the document labels. High peaks indicate a strong correlation between the corresponding topic and document label.

Preprocessing Function

The function preprocessText, performs the following steps in order:

  1. Tokenize the text using tokenizedDocument.

  2. Lemmatize the words using normalizeWords.

  3. Erase punctuation using erasePunctuation.

  4. Remove a list of stop words (such as "and", "of", and "the") using removeStopWords.

  5. Remove words with 2 or fewer characters using removeShortWords.

  6. Remove words with 15 or more characters using removeLongWords.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Lemmatize the words.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','lemma');

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove a list of stop words.
documents = removeStopWords(documents);

% Remove words with 2 or fewer characters, and words with 15 or greater
% characters.
documents = removeShortWords(documents,2);
documents = removeLongWords(documents,15);

end

See Also

| | |

Related Topics