본문으로 바로가기

RNN, LSTM, GRU

category AI/Deep Learning 2021. 4. 2. 15:51

RNN

ref : http://cs231n.stanford.edu/slides/2020/lecture_10.pdf

1) sequence data

    - sequence data is closely connected itself by an order.

    - MLP or CNN cannot handle sequence data.

 

2) advantages

    - Can process any length input

    - Computation for step t can (in theory) use information from many steps back

    - Model size doesn't increase for longer input

    - Same weights applied on every timestep, so there is symmetry in how inputs are processed.

 

3) disadvantages

    - recurrent computation is slow

    - in practice, difficult to access information from many steps back

 

4) types of RNN

ref : http://cs231n.stanford.edu/slides/2020/lecture_10.pdf

 

Structure of RNN

1) 개요

    - Basic feed-forward network just pass information (input) to hidden layer and output. But, recurrent network's hidden layer gets information from input and former time step's hidden layer.

 

2) Calculating output & gradient

    - forward

$$W_{xh}\: : \: input \: \: hidden$$

$$W_{hh}\: : \: recurrent$$

$$W_{hy}\: : \: hidden \: \: output$$

$$h_{t} = \phi_{h}(W_{xh}x_{t} + W_{hh}h_{t-1}+b_{h})$$

$$h_{t} = \phi_{h}(\begin{bmatrix} W_{xh} & W_{hh} \end{bmatrix} \begin{bmatrix} x_{t} \\ h_{t-1} \end{bmatrix})$$

$$y_{t} = \phi_{y}(W_{hy}h_{t} + b_{y})$$

    - backward

$$L = \sum_{t=1}^{T}L_{t}$$

$$\frac{\partial L_{t}}{W_{hh}} = \frac{\partial L_{t}}{y_{t}}\frac{\partial y_{t}}{h_{t}}(\sum_{k=1}^{t}\frac{\partial h_{t}}{h_{k}}\frac{\partial h_{k}}{W_{hh}})$$

$$\frac{\partial h_{t}}{\partial h_{k}} = \prod_{i=k+1}^{t}\frac{\partial h_{i}}{\partial h_{i-1}}$$

    - Old weights are multiplied t-k times.

    - |w| <1 : vanishing gradient

    - |w| > 1 : exploding gradient


LSTM

ref : http://colah.github.io/posts/2015-08-Understanding-LSTMs/

1) 개요

$$\begin{pmatrix} i \\ f \\ o \\ g \end{pmatrix} = \begin{pmatrix} \sigma \\ \sigma \\ \sigma \\ tanh \end{pmatrix} W \begin{pmatrix} h_{t-1} \\ x_{t} \end{pmatrix}$$

    - LSTM is introduced to overcome the gradient vanishing problem.

    - Its basic component is the cell state.

    - In the cell state, there are 4 gates which control memory of data whether save or not.

 

2) forget gate

    - whether to erase to cell

    - Sigmoid function's output range is (0, 1). So if a value is 1, preserve whole information.

ref : http://colah.github.io/posts/2015-08-Understanding-LSTMs/

 

3) input gate

    - whether to write to cell

    - i_t : sigmoid layer determines which value to be updated.

    - C_t : tanh layer makes vector C_t with same input value as i_t.

    - output is the element-wise multiplication i_t and C_t

ref : http://colah.github.io/posts/2015-08-Understanding-LSTMs/

 

4) cell state gate

    - how much to write to cell

    - multiply by f_t erases information.

    - add by input gate value

ref : http://colah.github.io/posts/2015-08-Understanding-LSTMs/

 

5) output gate

    - how much to reveal cell

    - sigmoid function determines information to pass.

    - put cell state gate value into tanh function -> output range (-1, 1)

    - product with sigmoid gate output in that we can pass information what we want to keep.

ref : http://colah.github.io/posts/2015-08-Understanding-LSTMs/

 

6) Backpropagation

ref : http://cs231n.stanford.edu/slides/2020/lecture_10.pdf

    - LSTM makes gradient descent easier for the RNN to preserve information over many timesteps. (doesn't guarantee)

    - Harder to learn weight in hidden state than vanilla RNN.

    - the gradient contains the forget gate's vector of activations. (Use suitable parameter updates.)


GRU

ref : http://colah.github.io/posts/2015-08-Understanding-LSTMs/

1) 개요

    - More simplified version of LSTM.

    - No cell state, only hidden state

    - Combination of forget gate and input gate

    - Reset gate is added.

 

2) forget & input gate

    - Use same forget gate as LSTM.

    - c_t and h_t = h_t

    - z_t controls forget and input gate. -> if t-1 is memorized, t is erased.


ref.

heung-bae-lee.github.io/2020/01/12/deep_learning_08/

머신러닝 교과서 with 파이썬 (길벗)

cs231n.stanford.edu/slides/2020/lecture_10.pdf

colah.github.io/posts/2015-08-Understanding-LSTMs/

ratsgo.github.io/natural%20language%20processing/2017/03/09/rnnlstm/

'AI > Deep Learning' 카테고리의 다른 글

Graph Convolutional Network (GCN)  (0) 2021.04.12
CNN(Convolutional Neural Network); 합성곱 신경망  (0) 2021.03.15
Softmax Classifier  (0) 2021.03.09
Batch Normalization  (0) 2021.03.05
Dropout  (0) 2021.03.04