Main Content

Classify Time Series Using Wavelet Analysis and Deep Learning

This example shows how to classify human electrocardiogram (ECG) signals using the continuous wavelet transform (CWT) and a deep convolutional neural network (CNN).

Training a deep CNN from scratch is computationally expensive and requires a large amount of training data. In various applications, a sufficient amount of training data is not available, and synthesizing new realistic training examples is not feasible. In these cases, leveraging existing neural networks that have been trained on large data sets for conceptually similar tasks is desirable. This leveraging of existing neural networks is called transfer learning. In this example we adapt two deep CNNs, GoogLeNet and SqueezeNet, pretrained for image recognition to classify ECG waveforms based on a time-frequency representation.

GoogLeNet and SqueezeNet are deep CNNs originally designed to classify images in 1000 categories. We reuse the network architecture of the CNN to classify ECG signals based on images from the CWT of the time series data. The data used in this example are publicly available from PhysioNet.

Data Description

In this example, you use ECG data obtained from three groups of people: persons with cardiac arrhythmia (ARR), persons with congestive heart failure (CHF), and persons with normal sinus rhythms (NSR). In total you use 162 ECG recordings from three PhysioNet databases: MIT-BIH Arrhythmia Database [3][7], MIT-BIH Normal Sinus Rhythm Database [3], and The BIDMC Congestive Heart Failure Database [1][3]. More specifically, 96 recordings from persons with arrhythmia, 30 recordings from persons with congestive heart failure, and 36 recordings from persons with normal sinus rhythms. The goal is to train a classifier to distinguish between ARR, CHF, and NSR.

Download Data

The first step is to download the data from the GitHub® repository. To download the data from the website, click Code and select Download ZIP. Save the file physionet_ECG_data-main.zip in a folder where you have write permission. The instructions for this example assume you have downloaded the file to your temporary directory, tempdir, in MATLAB®. Modify the subsequent instructions for unzipping and loading the data if you choose to download the data in folder different from tempdir.

After downloading the data from GitHub, unzip the file in your temporary directory.

unzip(fullfile(tempdir,"physionet_ECG_data-main.zip"),tempdir)

Unzipping creates the folder physionet-ECG_data-main in your temporary directory. This folder contains the text file README.md and ECGData.zip. The ECGData.zip file contains

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat holds the data used in this example. The text file, Modified_physionet_data.txt, is required by PhysioNet's copying policy and provides the source attributions for the data as well as a description of the preprocessing steps applied to each ECG recording.

Unzip ECGData.zip in physionet-ECG_data-main. Load the data file into your MATLAB workspace.

unzip(fullfile(tempdir,"physionet_ECG_data-main","ECGData.zip"), ...
    fullfile(tempdir,"physionet_ECG_data-main"))
load(fullfile(tempdir,"physionet_ECG_data-main","ECGData.mat"))

ECGData is a structure array with two fields: Data and Labels. The Data field is a 162-by-65536 matrix where each row is an ECG recording sampled at 128 hertz. Labels is a 162-by-1 cell array of diagnostic labels, one for each row of Data. The three diagnostic categories are: 'ARR', 'CHF', and 'NSR'.

To store the preprocessed data of each category, first create an ECG data directory dataDir inside tempdir. Then create three subdirectories in 'data' named after each ECG category. The helper function helperCreateECGDirectories does this. helperCreateECGDirectories accepts ECGData, the name of an ECG data directory, and the name of a parent directory as input arguments. You can replace tempdir with another directory where you have write permission. You can find the source code for this helper function in the Supporting Functions section at the end of this example.

parentDir = tempdir;
dataDir = "data";
helperCreateECGDirectories(ECGData,parentDir,dataDir)

Plot a representative of each ECG category. The helper function helperPlotReps does this. helperPlotReps accepts ECGData as input. You can find the source code for this helper function in the Supporting Functions section at the end of this example.

helperPlotReps(ECGData)

Create Time-Frequency Representations

After making the folders, create time-frequency representations of the ECG signals. These representations are called scalograms. A scalogram is the absolute value of the CWT coefficients of a signal.

To create the scalograms, precompute a CWT filter bank. Precomputing the CWT filter bank is the preferred method when obtaining the CWT of many signals using the same parameters.

Before generating the scalograms, examine one of them. Create a CWT filter bank using cwtfilterbank for a signal with 1000 samples. Use the filter bank to take the CWT of the first 1000 samples of the signal and obtain the scalogram from the coefficients.

Fs = 128;
fb = cwtfilterbank(SignalLength=1000, ...
    SamplingFrequency=Fs, ...
    VoicesPerOctave=12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;
figure
pcolor(t,frq,abs(cfs))
set(gca,"yscale","log")
shading interp
axis tight
title("Scalogram")
xlabel("Time (s)")
ylabel("Frequency (Hz)")

Use the helper function helperCreateRGBfromTF to create the scalograms as RGB images and write them to the appropriate subdirectory in dataDir. The source code for this helper function is in the Supporting Functions section at the end of this example. To be compatible with the GoogLeNet architecture, each RGB image is an array of size 224-by-224-by-3.

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

Divide into Training and Validation Data

Load the scalogram images as an image datastore. The imageDatastore function automatically labels the images based on folder names and stores the data as an ImageDatastore object. An image datastore enables you to store large image data, including data that does not fit in memory, and efficiently read batches of images during training of a CNN.

allImages = imageDatastore(fullfile(parentDir,dataDir), ...
    "IncludeSubfolders",true, ...
    "LabelSource","foldernames");

Randomly divide the images into two groups, one for training and the other for validation. Use 80% of the images for training, and the remainder for validation.

[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,"randomized");
disp("Number of training images: "+num2str(numel(imgsTrain.Files)))
Number of training images: 130
disp("Number of validation images: "+num2str(numel(imgsValidation.Files)))
Number of validation images: 32

GoogLeNet

Load

Load the pretrained GoogLeNet neural network. If Deep Learning Toolbox™ Model for GoogLeNet Network support package is not installed, the software provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.

net = googlenet;

Extract and display the layer graph from the network.

lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure("Units","normalized","Position",[0.1 0.1 0.8 0.8])
plot(lgraph)
title("GoogLeNet Layer Graph: "+num2str(numberOfLayers)+" Layers")

Inspect the first element of the network Layers property. Confirm that GoogLeNet requires RGB images of size 224-by-224-by-3.

lgraph.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [224 224 3]
        SplitComplexInputs: 0

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

Modify GoogLeNet Network Parameters

Each layer in the network architecture can be considered a filter. The earlier layers identify more common features of images, such as blobs, edges, and colors. Subsequent layers focus on more specific features in order to differentiate categories. GoogLeNet is pretrained to classify images into 1000 object categories. You must retrain GoogLeNet for our ECG classification problem.

Inspect the last five layers of the network.

lgraph.Layers(end-4:end)
ans = 
  5×1 Layer array with layers:

     1   'pool5-7x7_s1'        2-D Global Average Pooling   2-D global average pooling
     2   'pool5-drop_7x7_s1'   Dropout                      40% dropout
     3   'loss3-classifier'    Fully Connected              1000 fully connected layer
     4   'prob'                Softmax                      softmax
     5   'output'              Classification Output        crossentropyex with 'tench' and 999 other classes

To prevent overfitting, a dropout layer is used. A dropout layer randomly sets input elements to zero with a given probability. See dropoutLayer (Deep Learning Toolbox) for more information. The default probability is 0.5. Replace the final dropout layer in the network, pool5-drop_7x7_s1, with a dropout layer of probability 0.6.

newDropoutLayer = dropoutLayer(0.6,"Name","new_Dropout");
lgraph = replaceLayer(lgraph,"pool5-drop_7x7_s1",newDropoutLayer);

The convolutional layers of the network extract image features that the last learnable layer and final classification layer use to classify the input image. These two layers, loss3-classifier and output in GoogLeNet, contain information on how to combine the features that the network extracts into class probabilities, a loss value, and predicted labels. To retrain GoogLeNet to classify the RGB images, replace these two layers with new layers adapted to the data.

Replace the fully connected layer loss3-classifier with a new fully connected layer with the number of filters equal to the number of classes. To learn faster in the new layers than in the transferred layers, increase the learning rate factors of the fully connected layer.

numClasses = numel(categories(imgsTrain.Labels));
newConnectedLayer = fullyConnectedLayer(numClasses,"Name","new_fc", ...
    "WeightLearnRateFactor",5,"BiasLearnRateFactor",5);
lgraph = replaceLayer(lgraph,"loss3-classifier",newConnectedLayer);

The classification layer specifies the output classes of the network. Replace the classification layer with a new one without class labels. trainNetwork automatically sets the output classes of the layer at training time.

newClassLayer = classificationLayer("Name","new_classoutput");
lgraph = replaceLayer(lgraph,"output",newClassLayer);

Inspect the last five layers. Confirm you have replaced the dropout, convolutional, and output layers.

lgraph.Layers(end-4:end)
ans = 
  5×1 Layer array with layers:

     1   'pool5-7x7_s1'      2-D Global Average Pooling   2-D global average pooling
     2   'new_Dropout'       Dropout                      60% dropout
     3   'new_fc'            Fully Connected              3 fully connected layer
     4   'prob'              Softmax                      softmax
     5   'new_classoutput'   Classification Output        crossentropyex

Set Training Options and Train GoogLeNet

Training a neural network is an iterative process that involves minimizing a loss function. To minimize the loss function, a gradient descent algorithm is used. In each iteration, the gradient of the loss function is evaluated and the descent algorithm weights are updated.

Training can be tuned by setting various options. InitialLearnRate specifies the initial step size in the direction of the negative gradient of the loss function. MiniBatchSize specifies how large of a subset of the training set to use in each iteration. One epoch is a full pass of the training algorithm over the entire training set. MaxEpochs specifies the maximum number of epochs to use for training. Choosing the right number of epochs is not a trivial task. Decreasing the number of epochs has the effect of underfitting the model, and increasing the number of epochs results in overfitting.

Use the trainingOptions (Deep Learning Toolbox) function to specify the training options. Set MiniBatchSize to 15, MaxEpochs to 20, and InitialLearnRate to 0.0001. Visualize training progress by setting Plots to training-progress. Use the stochastic gradient descent with momentum optimizer. By default, training is done on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™. To see which GPUs are supported, see GPU Computing Requirements (Parallel Computing Toolbox).

options = trainingOptions("sgdm", ...
    MiniBatchSize=15, ...
    MaxEpochs=20, ...
    InitialLearnRate=1e-4, ...
    ValidationData=imgsValidation, ...
    ValidationFrequency=10, ...
    Verbose=1, ...
    Plots="training-progress");

Train the network. The training process usually takes 1-5 minutes on a desktop CPU. Run times will be faster if you are able to use a GPU. The command window displays training information during the run. Results include epoch number, iteration number, time elapsed, mini-batch accuracy, validation accuracy, and loss function value for the validation data.

trainedGN = trainNetwork(imgsTrain,lgraph,options);
Training on single GPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |       33.33% |       59.38% |       2.5151 |       1.1989 |      1.0000e-04 |
|       2 |          10 |       00:00:08 |       53.33% |       62.50% |       1.4702 |       0.8987 |      1.0000e-04 |
|       3 |          20 |       00:00:12 |       53.33% |       84.38% |       1.0575 |       0.5156 |      1.0000e-04 |
|       4 |          30 |       00:00:17 |       60.00% |       78.12% |       0.7547 |       0.4531 |      1.0000e-04 |
|       5 |          40 |       00:00:22 |       60.00% |       93.75% |       0.7951 |       0.3317 |      1.0000e-04 |
|       7 |          50 |       00:00:26 |      100.00% |       78.12% |       0.1318 |       0.3585 |      1.0000e-04 |
|       8 |          60 |       00:00:31 |       80.00% |       93.75% |       0.4344 |       0.2507 |      1.0000e-04 |
|       9 |          70 |       00:00:36 |      100.00% |       87.50% |       0.1106 |       0.2495 |      1.0000e-04 |
|      10 |          80 |       00:00:41 |       80.00% |      100.00% |       0.5102 |       0.1971 |      1.0000e-04 |
|      12 |          90 |       00:00:46 |      100.00% |      100.00% |       0.1180 |       0.1714 |      1.0000e-04 |
|      13 |         100 |       00:00:51 |       93.33% |      100.00% |       0.3106 |       0.1463 |      1.0000e-04 |
|      14 |         110 |       00:00:56 |      100.00% |      100.00% |       0.1004 |       0.1145 |      1.0000e-04 |
|      15 |         120 |       00:01:00 |       93.33% |      100.00% |       0.1732 |       0.1132 |      1.0000e-04 |
|      17 |         130 |       00:01:05 |       93.33% |       96.88% |       0.1391 |       0.1294 |      1.0000e-04 |
|      18 |         140 |       00:01:10 |      100.00% |      100.00% |       0.1139 |       0.1001 |      1.0000e-04 |
|      19 |         150 |       00:01:15 |      100.00% |      100.00% |       0.0320 |       0.0914 |      1.0000e-04 |
|      20 |         160 |       00:01:19 |      100.00% |       93.75% |       0.0611 |       0.1264 |      1.0000e-04 |
|======================================================================================================================|
Training finished: Max epochs completed.

Inspect the last layer of the trained network. Confirm the classification output layer includes the three classes.

trainedGN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
    ClassWeights: 'none'
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Evaluate GoogLeNet Accuracy

Evaluate the network using the validation data.

[YPred,~] = classify(trainedGN,imgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp("GoogLeNet Accuracy: "+num2str(100*accuracy)+"%")
GoogLeNet Accuracy: 93.75%

The accuracy is identical to the validation accuracy reported on the training visualization figure. The scalograms were split into training and validation collections. Both collections were used to train GoogLeNet. The ideal way to evaluate the result of the training is to have the network classify data it has not seen. Since there is an insufficient amount of data to divide into training, validation, and testing, we treat the computed validation accuracy as the network accuracy.

Explore GoogLeNet Activations

Each layer of a CNN produces a response, or activation, to an input image. However, there are only a few layers within a CNN that are suitable for image feature extraction. Inspect the first five layers of the trained network.

trainedGN.Layers(1:5)
ans = 
  5×1 Layer array with layers:

     1   'data'             Image Input                   224×224×3 images with 'zerocenter' normalization
     2   'conv1-7x7_s2'     2-D Convolution               64 7×7×3 convolutions with stride [2  2] and padding [3  3  3  3]
     3   'conv1-relu_7x7'   ReLU                          ReLU
     4   'pool1-3x3_s2'     2-D Max Pooling               3×3 max pooling with stride [2  2] and padding [0  1  0  1]
     5   'pool1-norm1'      Cross Channel Normalization   cross channel normalization with 5 channels per element

The layers at the beginning of the network capture basic image features, such as edges and blobs. To see this, visualize the network filter weights from the first convolutional layer. There are 64 individual sets of weights in the first layer.

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,8);
figure
I = imtile(wghts,GridSize=[8 8]);
imshow(I)
title("First Convolutional Layer Weights")

You can examine the activations and discover which features GoogLeNet learns by comparing areas of activation with the original image. For more information, see Visualize Activations of a Convolutional Neural Network (Deep Learning Toolbox) and Visualize Features of a Convolutional Neural Network (Deep Learning Toolbox).

Examine which areas in the convolutional layers activate on an image from the ARR class. Compare with the corresponding areas in the original image. Each layer of a convolutional neural network consists of many 2-D arrays called channels. Pass the image through the network and examine the output activations of the first convolutional layer, conv1-7x7_s2.

convLayer = "conv1-7x7_s2";

imgClass = "ARR";
imgName = "ARR_10.jpg";
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = activations(trainedGN,imarr,convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
I = imtile(rescale(trainingFeaturesARR),GridSize=[8 8]);
imshow(I)
title(imgClass+" Activations")

Find the strongest channel for this image. Compare the strongest channel with the original image.

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure
I = imtile({imarr,arrMax});
imshow(I)
title("Strongest "+imgClass+" Channel: "+num2str(maxValueIndex))

SqueezeNet

SqueezeNet is a deep CNN whose architecture supports images of size 227-by-227-by-3. Even though the image dimensions are different for GoogLeNet, you do not have to generate new RGB images at the SqueezeNet dimensions. You can use the original RGB images.

Load

Load the pretrained SqueezeNet neural network. If Deep Learning Toolbox™ Model for SqueezeNet Network support package is not installed, the software provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.

sqz = squeezenet;

Extract the layer graph from the network. Confirm SqueezeNet has fewer layers than GoogLeNet. Also confirm that SqueezeNet is configured for images of size 227-by-227-by-3.

lgraphSqz = layerGraph(sqz);
disp("Number of Layers: "+num2str(numel(lgraphSqz.Layers)))
Number of Layers: 68
lgraphSqz.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [227 227 3]
        SplitComplexInputs: 0

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [1×1×3 single]

Modify SqueezeNet Network Parameters

To retrain SqueezeNet to classify new images, make changes similar to those made for GoogLeNet.

Inspect the last six network layers.

lgraphSqz.Layers(end-5:end)
ans = 
  6×1 Layer array with layers:

     1   'drop9'                             Dropout                      50% dropout
     2   'conv10'                            2-D Convolution              1000 1×1×512 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'                       ReLU                         ReLU
     4   'pool10'                            2-D Global Average Pooling   2-D global average pooling
     5   'prob'                              Softmax                      softmax
     6   'ClassificationLayer_predictions'   Classification Output        crossentropyex with 'tench' and 999 other classes

Replace the last dropout layer in the network with a dropout layer of probability 0.6.

tmpLayer = lgraphSqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,"Name","new_dropout");
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);

Unlike GoogLeNet, the last learnable layer in SqueezeNet is a 1-by-1 convolutional layer, conv10, and not a fully connected layer. Replace the layer with a new convolutional layer with the number of filters equal to the number of classes. As was done with GoogLeNet, increase the learning rate factors of the new layer.

numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = lgraphSqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
        "Name","new_conv", ...
        "WeightLearnRateFactor",10, ...
        "BiasLearnRateFactor",10);
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);

Replace the classification layer with a new one without class labels.

tmpLayer = lgraphSqz.Layers(end);
newClassLayer = classificationLayer("Name","new_classoutput");
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);

Inspect the last six layers of the network. Confirm the dropout, convolutional, and output layers have been changed.

lgraphSqz.Layers(end-5:end)
ans = 
  6×1 Layer array with layers:

     1   'new_dropout'       Dropout                      60% dropout
     2   'new_conv'          2-D Convolution              3 1×1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'       ReLU                         ReLU
     4   'pool10'            2-D Global Average Pooling   2-D global average pooling
     5   'prob'              Softmax                      softmax
     6   'new_classoutput'   Classification Output        crossentropyex

Prepare RGB Data for SqueezeNet

The RGB images have dimensions appropriate for the GoogLeNet architecture. Create augmented image datastores that automatically resize the existing RGB images for the SqueezeNet architecture. For more information, see augmentedImageDatastore (Deep Learning Toolbox).

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

Set Training Options and Train SqueezeNet

Create a new set of training options to use with SqueezeNet, and train the network.

ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ...
    MaxEpochs=maxEpochs, ...
    InitialLearnRate=ilr, ...
    ValidationData=augimgsValidation, ...
    ValidationFrequency=valFreq, ...
    Verbose=1, ...
    Plots="training-progress");

trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);
Training on single GPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       10.00% |       56.25% |       4.0078 |       2.3776 |          0.0003 |
|       1 |          13 |       00:00:03 |       50.00% |       68.75% |       0.9363 |       0.9029 |          0.0003 |
|       2 |          26 |       00:00:05 |       70.00% |       78.12% |       0.8365 |       0.7137 |          0.0003 |
|       3 |          39 |       00:00:07 |       70.00% |       81.25% |       0.7844 |       0.5915 |          0.0003 |
|       4 |          50 |       00:00:08 |       70.00% |              |       0.5947 |              |          0.0003 |
|       4 |          52 |       00:00:09 |       70.00% |       84.38% |       0.5064 |       0.5000 |          0.0003 |
|       5 |          65 |       00:00:10 |       90.00% |       81.25% |       0.3023 |       0.3732 |          0.0003 |
|       6 |          78 |       00:00:12 |      100.00% |       87.50% |       0.0815 |       0.2651 |          0.0003 |
|       7 |          91 |       00:00:14 |      100.00% |       90.62% |       0.0644 |       0.2409 |          0.0003 |
|       8 |         100 |       00:00:15 |       80.00% |              |       0.2182 |              |          0.0003 |
|       8 |         104 |       00:00:16 |      100.00% |       96.88% |       0.1349 |       0.1818 |          0.0003 |
|       9 |         117 |       00:00:18 |       90.00% |       93.75% |       0.1663 |       0.1920 |          0.0003 |
|      10 |         130 |       00:00:19 |       90.00% |       90.62% |       0.1899 |       0.2258 |          0.0003 |
|      11 |         143 |       00:00:21 |      100.00% |       90.62% |       0.0962 |       0.1869 |          0.0003 |
|      12 |         150 |       00:00:22 |      100.00% |              |       0.0075 |              |          0.0003 |
|      12 |         156 |       00:00:23 |       90.00% |       93.75% |       0.2934 |       0.2201 |          0.0003 |
|      13 |         169 |       00:00:25 |      100.00% |       90.62% |       0.0958 |       0.3039 |          0.0003 |
|      14 |         182 |       00:00:27 |      100.00% |       93.75% |       0.0089 |       0.1430 |          0.0003 |
|      15 |         195 |       00:00:28 |      100.00% |       93.75% |       0.0068 |       0.2061 |          0.0003 |
|======================================================================================================================|
Training finished: Max epochs completed.

Inspect the last layer of the network. Confirm the classification output layer includes the three classes.

trainedSN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
    ClassWeights: 'none'
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

Evaluate SqueezeNet Accuracy

Evaluate the network using the validation data.

[YPred,probs] = classify(trainedSN,augimgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp("SqueezeNet Accuracy: "+num2str(100*accuracy)+"%")
SqueezeNet Accuracy: 93.75%

Conclusion

This example shows how to use transfer learning and continuous wavelet analysis to classify three classes of ECG signals by leveraging the pretrained CNNs GoogLeNet and SqueezeNet. Wavelet-based time-frequency representations of ECG signals are used to create scalograms. RGB images of the scalograms are generated. The images are used to fine-tune both deep CNNs. Activations of different network layers were also explored.

This example illustrates one possible workflow you can use for classifying signals using pretrained CNN models. Other workflows are possible. Deploy Signal Classifier on NVIDIA Jetson Using Wavelet Analysis and Deep Learning and Deploy Signal Classifier Using Wavelets and Deep Learning on Raspberry Pi show how to deploy code onto hardware for signal classification. GoogLeNet and SqueezeNet are models pretrained on a subset of the ImageNet database [10], which is used in the ImageNet Large-Scale Visual Recognition Challenge (ILSVRC) [8]. The ImageNet collection contains images of real-world objects such as fish, birds, appliances, and fungi. Scalograms fall outside the class of real-world objects. In order to fit into the GoogLeNet and SqueezeNet architecture, the scalograms also underwent data reduction. Instead of fine-tuning pretrained CNNs to distinguish different classes of scalograms, training a CNN from scratch at the original scalogram dimensions is an option.

References

  1. Baim, D. S., W. S. Colucci, E. S. Monrad, H. S. Smith, R. F. Wright, A. Lanoue, D. F. Gauthier, B. J. Ransil, W. Grossman, and E. Braunwald. "Survival of patients with severe congestive heart failure treated with oral milrinone." Journal of the American College of Cardiology. Vol. 7, Number 3, 1986, pp. 661–670.

  2. Engin, M. "ECG beat classification using neuro-fuzzy network." Pattern Recognition Letters. Vol. 25, Number 15, 2004, pp.1715–1722.

  3. Goldberger A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit,and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals." Circulation. Vol. 101, Number 23: e215–e220. [Circulation Electronic Pages; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (June 13). doi: 10.1161/01.CIR.101.23.e215.

  4. Leonarduzzi, R. F., G. Schlotthauer, and M. E. Torres. "Wavelet leader based multifractal analysis of heart rate variability during myocardial ischaemia." In Engineering in Medicine and Biology Society (EMBC), Annual International Conference of the IEEE, 110–113. Buenos Aires, Argentina: IEEE, 2010.

  5. Li, T., and M. Zhou. "ECG classification using wavelet packet entropy and random forests." Entropy. Vol. 18, Number 8, 2016, p.285.

  6. Maharaj, E. A., and A. M. Alonso. "Discriminant analysis of multivariate time series: Application to diagnosis based on ECG signals." Computational Statistics and Data Analysis. Vol. 70, 2014, pp. 67–87.

  7. Moody, G. B., and R. G. Mark. "The impact of the MIT-BIH Arrhythmia Database." IEEE Engineering in Medicine and Biology Magazine. Vol. 20. Number 3, May-June 2001, pp. 45–50. (PMID: 11446209)

  8. Russakovsky, O., J. Deng, and H. Su et al. "ImageNet Large Scale Visual Recognition Challenge." International Journal of Computer Vision. Vol. 115, Number 3, 2015, pp. 211–252.

  9. Zhao, Q., and L. Zhang. "ECG feature extraction and classification using wavelet transform and support vector machines." In IEEE International Conference on Neural Networks and Brain, 1089–1092. Beijing, China: IEEE, 2005.

  10. ImageNet. http://www.image-net.org

Supporting Functions

helperCreateECGDataDirectories creates a data directory inside a parent directory, then creates three subdirectories inside the data directory. The subdirectories are named after each class of ECG signal found in ECGData.

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps plots the first thousand samples of a representative of each class of ECG signal found in ECGData.

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRGBfromTF uses cwtfilterbank to obtain the continuous wavelet transform of the ECG signals and generates the scalograms from the wavelet coefficients. The helper function resizes the scalograms and writes them to disk as jpeg images.

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank(SignalLength=signalLength,VoicesPerOctave=12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(round(rescale(cfs,0,255)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = char(labels(ii))+"_"+num2str(ii)+".jpg";
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end

See Also

| (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | | (Deep Learning Toolbox)

Related Topics