ECG Classification Using LSTMs | AI Techniques for ECG Classification, Part 3 - MATLAB
Video Player is loading.
Current Time 0:00
Duration 15:30
Loaded: 1.07%
Stream Type LIVE
Remaining Time 15:30
 
1x
  • Chapters
  • descriptions off, selected
  • en (Main), selected
    Video length is 15:30

    ECG Classification Using LSTMs | AI Techniques for ECG Classification, Part 3

    From the series: AI Techniques for ECG Classification

    See how you can automatically extract features to reduce the dimensionality of ECG signals using wavelet scattering networks. Then learn how to use these features to train recurrent neural networks such as LSTM networks for classification.

    Published: 28 Jan 2021

    Hello, and welcome to this webinar on ECG classification using LSTMs. My name is Kirthy Devleker, and I will be your presenter for this video. This video is part of a larger series called Developing AI Models for Biomedical ECG signals. So with that, let's get started.

    So the goal of the video is to develop a classifier that can classify ECG signals into three distinct categories. Now, the approach we are going to use is we will use long short term memory networks, or LSTM networks, to build a predictive model that can classify these signals. Now we have three classes that we need to build a model to classify the signals into. The three classes are here, as you can see-- ARR, CHF, and NSR.

    In each class, I have roughly-- the first class, I have 96 signals. The second class, I have 30 signals. And the third class, which is the normal sinus rhythm, I have 36 signals. Each signal is roughly around 65,000 samples long. Now, the goal of this example is to build this model using LSTM networks that can classify EKG signals into one of these three distinct categories.

    So typically, for simple problems, when you have a signal, and if you want to build a model using LSTM, it turns out that you can just directly feed in the signal to the LSTM network. While this simple approach may work-- while this approach may work for simple signals, in reality, in many situations, directly feeding the data to LSTM networks may not work. And there are some reasons for that.

    So one of the example I can cite is, if, let's say, if you're looking at the ECG for arrhythmia, typically doctors look at the ECG signal for a longer duration, like 30 minutes-- 30 seconds or something, in order to make a diagnosis. So similarly, you may need-- there are situations where you would need a long signal sample, and just looking at shorter signal chunks is not going to help. And that is where you will probably run into challenges with working directly, or feeding these signals directly to LSTM.

    So in such situations, what you could do is, instead of feeding the signals directly to LSTM networks, you can actually maybe do some kind of little bit of feature extraction on your signals before you feed the signals into LSTM network. So in this video, I'm going to show you how you can use the wavelet scattering network to actually reduce the dimensionality of your signal and to obtain some features automatically. And once you get those features, now you can train your LSTM networks on those features to build your model. So this arrangement works very well in situations where feeding raw data directly into LSTMs does not work, or if you have less data to begin with, which is typically the case for many AI or machine learning or deep learning problems, or in situations where data augmentation can be very challenging.

    So what is wavelet scattering, and how can it help? So typically, I can maybe give you an explanation of wavelet scattering in the context of deep neural networks. So the idea is if you-- the idea for deep neural network is, let's say if you have some data, the deep neural network has a set of layers, as you can see here-- typically, multiple layers. And the idea is these layers have some kind of convolution followed by some kind of non-linearity and pooling.

    And this kind of pattern repeats, or depending on the kind of network. I mean, you can have different patterns. But the main idea is that certainly we're going to be some kind of convolution. Now, when you train your data, when you have enough data-- say, like 15 million images or something-- these convolutional networks, the filters initially-- initialized to random values. But during the training process, the filter weights are learned, and that kind of keeps changing over a period of time.

    So what people found is, eventually, there is a time when once the network is fully learned, there are some cases where, if you look at the filter, filters start looking or resembling wavelet-like filters. So the idea is, why not use-- instead of using such deep networks with variable convolution filters, why not just use a filter that is kind of constant. And that's where wavelet filters were used, and that brings rise to wavelet scattering network.

    The signal-- so you have this wavelet network, which has some wavelet convolution followed by non-linearity and averaging. So the rates are fixed in this network. And you can-- once you, maybe with a couple of layers, you can get all sorts of features from your signal, and you can take those features in and directly feed into a classifier.

    So this is actually a compact network for extracting all your features. There are basically just two layers to start off with. And this network can actually help you relieve requirements on amount of data and-- on the data and the model complexity. And you can use this for automatically extracting relevant and compact features. So to summarize, wavelet scattering is a nice technique that you can use to actually reduce the dimensionality of your signal, that's automatically extract features without losing information about the signal itself, OK?

    So before I proceed, I just want to give you a quick overview of wavelet scattering network, and the framework itself, and how it works. So typically, you have three layers, as I mentioned earlier. So you have a signal-- let's say, a signal. Then, what you do is you take the signal, and for the first layer you average the signal, or just make it pass through a low-pass filter, and you get a set of coefficients. You downsample those, and those are your scattering coefficients.

    Now, since you've low passed-- or you have applied a scaling filter or smoothing filter to your signal, you've kind of lost all the fine details. And you capture those details by making the signal go through-- or applying wavelet filters here. So on the right, you see, basically, the continuous wavelet transform like operation. And you take those scalogram coefficients, and then you convert those scalogram coefficients with the scaling filter, which is, again-- you're smoothing it out. And then, you get the layer 2 features.

    So you repeat this process for the third layer, and you get another set of features here. Now, the good thing is-- so the features that you see on your left, they're all kind put together. And you call those-- those are the features that you get out of your wavelet scattering network. So that's your output of your network. And so it's just one function, and I'm going to show you an example as to how you can use this functionality.

    Now, there are benefits of using this network in, let's say, EKG classification. Or, as a matter of fact, any classification here is. So let's say, on the left, you have-- I have tried to train my raw ECG data with LSTM. And on the right, I just added this wavelet scattering network, extracted the features using this network, and I took those features and trained the LSTM network.

    So as you can see, my-- on the left, my network had a hard time learning those features. And it was training for like 10 minutes and still didn't go anywhere. My training didn't go anywhere because you can notice that training hits 100%, and then it goes to 0%. And after some time, this pattern kind of repeats, and you spent like 10 minutes and we haven't even gotten anywhere yet.

    Now, with the case-- in the case of wavelet scattering network-- in just a matter of 20 seconds, I was able to just train this network very well. And I also got some good testing accuracy. So I will show that in the example. So the main idea here is, it is always useful to think about how to reduce your signal dimensionality while preserving the information. And that could be a key aspect when you're working with LSTM networks.

    OK, so now, let's jump into MATLAB and look at the example in action. So I have this example here, which is ECG classification using wavelet scattering and LSTM. So I have the ECG signal here, and I have all the three different classes that I mentioned, so overall, I have 113 training records. So-- and I have 49 test records, so I've already separated the data out in that manner.

    So if I want to first few samples for first few thousand samples of randomly selected records, so this is how my data looks like. So you look, you can see that I have normal sinus rhythm, arrhythmia-- congestive heart failure and arrhythmia here, right? So I'm just showing how the signals kind of look like.

    So the first step is to extract features automatically from the signal. So what you do is, to use the wavelet scattering network, you can just use this command here-- wavelet scattering. You provide signal length as an input. And specify the length of the signal. So all your signals have to be this length. And once you specify that this is the only required parameter, every other parameter here, that is all optional. And what you get is a filter bank.

    Now, you take the filter bank, you apply it to your signal using this Function Feature matrix, and you get a set of features. So when I execute this, you will notice that my sample signal that I used earlier has 65,000 samples. But my features, that were extracted automatically, I have roughly 499 times 8 features. So roughly, around 4,000 features here. So this represent almost, like, 95% reduction in size of the features, compared to the original signal of course.

    Now, the good thing is we can actually look at those features. So let me go back to the PowerPoint Slide here. You can actually look at these features here in every layer. And you can actually visualize those features, just to see if you can gain some insight into how the spectral components evolve as a function of time.

    So that is something you can do with this function called scatteringTransform. So you can take the same filter bank, provide the signal here, and then you get the kilogram coefficients, in U, and then the scattering coefficients. So we can-- so the scalogram and scattering coefficients, I mean, they're kind of related, meaning you take the scalogram coefficients, apply the scaling filter or the low-pass filter, you get the scattering coefficient.

    So the scattering level 2 filters-- so this is one example for how the signal looks like for the second level, or for the first class-- for the first signal, for the signal in the class I. And then here is how the signal that belongs to class II looks like for the second filter-- for the second layer. So notice, these signals kind of start looking a little different, or at least there are some differences that start showing up.

    So although we are not using just level 2 here, we are taking all the level 1, level 2, and level 3 data, meaning all the features that we've got, which is 499 times 8. You're taking that as is, the whole matrix, and then we are going to train the LSTM network.

    So now, I have some code here that will extract all the features for the 113 signals in training and 49 signals in test. So I have this code here that is going to basically loop through the whole data set, and extract all the features for me. So I have the training and the test features ready, OK?

    So now, since I have my features ready, the next step is to actually take an LSTM and train it on the scattering features. So in this situation, in this case, I have actually built a very small LSTM network. As you can see, there are only like five layers. And here, basically, the input side is going to be 499, because those are the number of rows I have in my fee-- for every signal.

    So what I'm going to do is just create this small network-- LSTM network-- and I'm going to set some training options. You know, I can just pick some options here. It's not trivial to just pick any options. I had to try it out with some couple of trial and error approaches. But then, once I figured it out, then I was ready to train my network with the scattering features that I've extracted for all the signals. So for each signal, as you can see here, I have 499 by 8. This is the dimension of my-- reduced dimension of my signal.

    So one question you may ask is, how do you know if this 499 by 8 matrix actually has all the information that's in the 655 other-- sorry, the information contained in the 65,000 sample vector? So the answer is I don't know. Unless I train a network, test it out, I wouldn't know if I have all the right features here.

    So what I'm going to do is first train the network here. So as you can see, I have completely trained my network. My LSTM network is trained, roughly in 12 seconds. And I was able to do that right on a CPU itself, in my case.

    So then, the next step was to actually evaluate the LSTM model. So to evaluate the LSTM model, I extract all the features from the test signals. And then, I just use this model that I trained earlier, and it yields like 96% accuracy. So, and just to compare-- so this is how, if I were to just use the raw data and feed it to an LSTM-- I have some code here-- you can definitely take this code, and try and run this, and see what you get. You'll get the same plot that I got last time, that I showed in my slides.

    So the idea here is, once you have the model, you train the model. So it looks like your model is 96% correct or accurate, so, which is probably good for some people. Or if it's not good, or if you want to increase the model accuracy, you can actually-- or the classification accuracy, there are some optional parameters of the scattering filter bank for which you can look up the documentation. And you can fine tune those couple of parameters we have to get-- achieve desired results, right? So one parameter there to look up is the default-- well, the invariant scale parameter. And you can probably tweak that and see if you can get the results you're looking for.

    So that concludes this example. So in summary, we've used the wavelet time scattering network and LSTM network to classify ECG waveforms. So I hope you enjoyed this video. And thank you very much, and I will see you in another video. Bye bye.

    Up Next:

    View full series (5 Videos)

    View more related videos