Main Content

Classify Out-of-Memory Text Data Using Deep Learning

This example shows how to classify out-of-memory text data with a deep learning network using a transformed datastore.

A transformed datastore transforms or processes data read from an underlying datastore. You can use a transformed datastore as a source of training, validation, test, and prediction data sets for deep learning applications. Use transformed datastores to read out-of-memory data or to perform specific preprocessing operations when reading batches of data.

When training the network, the software creates mini-batches of sequences of the same length by padding, truncating, or splitting the input data. The trainingOptions function provides options to pad and truncate input sequences, however, these options are not well suited for sequences of word vectors. Furthermore, this function does not support padding data in a custom datastore. Instead, you must pad and truncate the sequences manually. If you left-pad and truncate the sequences of word vectors, then the training might improve.

The Classify Text Data Using Deep Learning (Text Analytics Toolbox) example manually truncates and pads all the documents to the same length. This process adds lots of padding to very short documents and discards lots of data from very long documents.

Alternatively, to prevent adding too much padding or discarding too much data, create a transformed datastore that inputs mini-batches into the network. The datastore created in this example converts mini-batches of documents to sequences or word indices and left-pads each mini-batch to the length of the longest document in the mini-batch.

Load Pretrained Word Embedding

The datastore requires a word embedding to convert documents to sequences of vectors. Load a pretrained word embedding using fastTextWordEmbedding. This function requires Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding support package. If this support package is not installed, then the function provides a download link.

emb = fastTextWordEmbedding;

Load Data

Create a tabular text datastore from the data in factoryReports.csv. Specify to read the data from the "Description" and "Category" columns only.

filenameTrain = "factoryReports.csv";
textName = "Description";
labelName = "Category";
ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);

View a preview of the datastore.

ans=8×2 table
                                  Description                                         Category       
    _______________________________________________________________________    ______________________

    {'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}
    {'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}
    {'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}
    {'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}
    {'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}
    {'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }
    {'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}
    {'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}

Transform Datastore

Create a custom transform function that converts data read from the datastore to a table containing the predictors and the responses. The transformText function takes the data read from a tabularTextDatastore object and returns a table of predictors and responses. The predictors are numFeatures-by-lenSequence arrays of word vectors given by the word embedding emb, where numFeatures is the embedding dimension and lenSequence is the sequence length. The responses are categorical labels over the classes.

To get the class names, read the labels from the training data using the readLabels function, listed and the end of the example, and find the unique class names.

labels = readLabels(ttdsTrain,labelName);
classNames = unique(labels);
numObservations = numel(labels);

Because tabular text datastores can read multiple rows of data in a single read, you can process a full mini-batch of data in the transform function. To ensure that the transform function processes a full mini-batch of data, set the read size of the tabular text datastore to the mini-batch size that will be used for training.

miniBatchSize = 64;
ttdsTrain.ReadSize = miniBatchSize;

To convert the output of the tabular text data to sequences for training, transform the datastore using the transform function.

tdsTrain = transform(ttdsTrain, @(data) transformText(data,emb,classNames))
tdsTrain = 
  TransformedDatastore with properties:

      UnderlyingDatastores: {}
    SupportedOutputFormats: ["txt"    "csv"    "dat"    "asc"    "xlsx"    "xls"    "parquet"    "parq"    "png"    "jpg"    "jpeg"    "tif"    "tiff"    "wav"    "flac"    "ogg"    "opus"    "mp3"    "mp4"    "m4a"]
                Transforms: {@(data)transformText(data,emb,classNames)}
               IncludeInfo: 0

Preview the transformed datastore. The predictors are numFeatures-by-lenSequence arrays, where lenSequence is the sequence length and numFeatures is the number of features (the embedding dimension). The responses are the categorical labels.

ans=8×2 table
      predictors           responses     
    _______________    __________________

    {300×11 single}    Mechanical Failure
    {300×11 single}    Mechanical Failure
    {300×11 single}    Electronic Failure
    {300×11 single}    Electronic Failure
    {300×11 single}    Electronic Failure
    {300×11 single}    Leak              
    {300×11 single}    Electronic Failure
    {300×11 single}    Mechanical Failure

Create and Train LSTM Network

Define the LSTM network architecture. To input sequence data into the network, include a sequence input layer and set the input size to the embedding dimension. Next, include an LSTM layer with 180 hidden units. To use the LSTM layer for a sequence-to-label classification problem, set the output mode to 'last'. Finally, add a fully connected layer with output size equal to the number of classes, and a softmax layer.

numFeatures = emb.Dimension;
numHiddenUnits = 180;
numClasses = numel(classNames);
layers = [ ...

Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.

  • Train using the Adam optimizer.

  • Set the input data format to 'CTB' (channel, time, batch).

  • Specify the mini-batch size.

  • Set the gradient threshold to 2.

  • The datastore does not support shuffling, so set 'Shuffle' to 'never'.

  • Display the training progress in a plot and monitor the accuracy.

  • Disable the verbose output.

By default, trainnet uses a GPU if one is available. To specify the execution environment manually, use the 'ExecutionEnvironment' name-value pair argument of trainingOptions. Training on a CPU can take significantly longer than training on a GPU. Training using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

numIterationsPerEpoch = floor(numObservations / miniBatchSize);

options = trainingOptions('adam', ...
    'MaxEpochs',15, ...
    'InputDataFormats','CTB', ...
    'MiniBatchSize',miniBatchSize, ...
    'GradientThreshold',2, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Metrics','accuracy', ...

Train the neural network using the trainnet function. For classification, use cross-entropy loss.

net = trainnet(tdsTrain,layers,"crossentropy",options);

Predict Using New Data

Classify the event type of three new reports. Create a string array containing the new reports.

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Preprocess the text data using the preprocessing steps as the training documents.

documentsNew = preprocessText(reportsNew);

Convert the text data to sequences of embedding vectors using doc2sequence.

XNew = doc2sequence(emb,documentsNew);

Classify the new sequences using the trained LSTM network.

scores = minibatchpredict(net,XNew,InputDataFormats="CTB");
Y = scores2label(scores,classNames)
Y = 3×1 categorical
     Electronic Failure 
     Mechanical Failure 

Transform Text Function

The transformText function takes the data read from a tabularTextDatastore object and returns a table of predictors and responses. The predictors are numFeatures-by-lenSequence arrays of word vectors given by the word embedding emb, where numFeatures is the embedding dimension and lenSequence is the sequence length. The responses are categorical labels over the classes in classNames.

function dataTransformed = transformText(data,emb,classNames)

% Preprocess documents.
textData = data{:,1};
documents = preprocessText(textData);

% Convert to sequences.
predictors = doc2sequence(emb,documents);

% Read labels.
labels = data{:,2};
responses = categorical(labels,classNames);

% Convert data to table.
dataTransformed = table(predictors,responses);


Preprocessing Function

The function preprocessText performs these steps:

  1. Tokenize the text using tokenizedDocument.

  2. Convert the text to lowercase using lower.

  3. Erase the punctuation using erasePunctuation.

function documents = preprocessText(textData)

documents = tokenizedDocument(textData);
documents = lower(documents);
documents = erasePunctuation(documents);


Read Labels Function

The readLabels function creates a copy of the tabularTextDatastore object ttds and reads the labels from the labelName column.

function labels = readLabels(ttds,labelName)

ttdsNew = copy(ttds);
ttdsNew.SelectedVariableNames = labelName;
tbl = readall(ttdsNew);
labels = tbl.(labelName);


See Also

| | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | |

Related Topics