Return State and Return Sequence of LSTM in Keras
LSTM looks easy but is too difficult for us to keep an eye on it.
Using LSTM in keras is easy, I mean:
LSTM(input_dim,return_sequence=False,return_state=False).
This medium post will be about return_sequence and return_state only as that is what brings confusion to you.
How does LSTM looks like?
return_sequence=True:
You know LSTM(dim_number)(input) gives us? It gives us the final hidden state value(ht in above figure) from LSTM. So, if we have dim_number as 40 suppose, LSTM will be 40 in number right? So, the first input maybe x, and it may give output as y0. y0 would be input for next LSTM layer and so on. Continuing this way, it would go for y39. So, without return_sequence the output would be y39th value.
What if we have return_sequence = True? LSTM(dim_number,return_sequence = True)(input).
It would give us every y0,y1,y2 to y39. So hidden state value from each timestep.
return_state= True
So what is return_state? This LSTM(dim_number,return_state = True)(input). It returns 3 values.
First is final hidden_state, yes, remember the y39 value? It is what would be the first value returned by LSTM with return_state of True
Second is final hidden_state. Oh hold on, isn’t this same as the first thing returned? Yes it is. I will tell you the difference right afterwards!
Third value returned is cell_state. You see in the figure above we have c(t-1) to c(t) running across one end to another, it is cell_state. So, the cell_state after processing the final output is returned here.
The question of why there are two hidden state values is return_state? Well, what if we want all the hidden_state values and and also the cell_state(the third value returned by lstm with return_state=True)?
We can return the list of hidden_states, the final_hidden state and cell_state by doing something like this.
LSTM(dim_number,return_state = True,return_sequence=True)(input).
So the first value here returns hidden_state and each time step.
Second value returned is hidden_state at final time_step. So it is equal to final value of array of values received from first value.
Third value is cell_state as usual.