(04๊ฐ) LSTM and GRU
210907
4. Long Short-Term Memory (LSTM) | Gated Recurrent Unit (GRU)
Long Short-Term Memory (LSTM)
Vanila Model์ Gradient Exploding/Vanishing ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ Long Term Dependency ๋ฌธ์ ๋ฅผ ๊ฐ์ ํ ๋ฌธ์ ์ด๋ค.
๊ธฐ์กด์ RNN ์์ ๋ค์๊ณผ ๊ฐ๋ค.
LSTM์๋ Cell state๋ผ๋ ๊ฐ์ด ์ถ๊ฐ๋๋ฉฐ ์์ ๋ค์๊ณผ ๊ฐ๋ค.
Cell state๊ฐ hidden state๋ณด๋ค ์ข ๋ ์์ฑ๋, ํ์๋ก ํ๋ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ๋ฒกํฐ์ด๋ฉฐ ์ด cell state ๋ฒกํฐ๋ฅผ ํ๋ฒ ๋ ๊ฐ๊ณตํด์ ํด๋น time step์์ ๋ ธ์ถํ ํ์๊ฐ ์๋ ์ ๋ณด๋ฅผ ํํฐ๋ง ํ ์ ์๋ ๋ฒกํฐ๋ก๋ ์๊ฐํ ์ ์๋ค.
cell state๋ฅผ ํ๋ฒ ๋ ๊ฐ๊ณตํ hidden state ๋ฒกํฐ๋ ํ์ฌ timestep ์์ ์์ธก๊ฐ์ ๊ณ์ฐํ๋ output layer์ ์ ๋ ฅ๋ฒกํฐ๋ก ์ฌ์ฉํ๋ค.

์ฌ๊ธฐ์ x๋ x_t ์ด๊ณ h๋ h_(t-1) ์ด๋ค.
Forget gate

์ด์ ํ์์คํ ์์ ์ป์ ์ ๋ณด ์ค ์ผ๋ถ๋ง์ ๋ฐ์ํ๊ฒ ๋ค.
= ์ด์ ํ์์คํ ์์ ์ป์ ์ ๋ณด ์ผ๋ถ๋ฅผ ๊น๋จน๊ฒ ๋ค = forget
Input gate

์ด๋ฒ ์ ์์ ์ป์ C tilda ๊ฐ์ input gate์ ๊ณฑํด์ฃผ๋ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ๋ค.
ํ๋ฒ์ ์ ํ๋ณํ๋ง์ผ๋ก ์ ๋ํด์ฃผ๋ ์ ๋ณด๋ฅผ ๋ง๋ค๊ธฐ๊ฐ ์ด๋ ต๋ค. ๋ฐ๋ผ์ ์ด ๋ํด์ฃผ๋ ์ ๋ณด๋ฅผ ์ผ๋จ ํฌ๊ฒ ๋ง๋ ํ์ ๊ฐ ์ฐจ์๋ณ๋ก ํน์ ๋น์จ๋งํผ ๋์ด๋ด์ ๋ํด์ฃผ๋ ์ ๋ณด๋ฅผ ๋ง๋ค๊ฒ ๋ค ๋ผ๋ ๋ชฉ์ ์ด๋ค.
์ด ๋, ๋ํด์ฃผ๋ ์ ๋ณด๋ณด๋ค ํฌ๊ฒ ๋ง๋ ์ ๋ณด๊ฐ C tilda ์ด๋ฉฐ ํน์ ๋น์จ๋งํผ ๋์ด๋ด๋ ์์ ์ด input gate์ ๊ณฑํด์ฃผ๋ ์์ ์ด๋ค.
Output gate

"He said, 'I love you.' " ๋ผ๋ ๋ฌธ์ฅ์ด ์๋ค๊ณ ํ์. ํ์ฌ sequence๊ฐ love y ๊น์ง ๋ค์ด์๊ณ y๋ค์์ o๋ฅผ ์ถ๋ ฅ์ผ๋ก ์ค์ผ ํ ์ฐจ๋ก์ด๋ค. ์ด ๋ y์ ์ ์ฅ์์๋ ๋น์ฅ์ ์์ ๋ฐ์ดํ๊ฐ ์ด๋ฆฐ ์ฌ์ค์ ์ค์ํ์ง ์์ง๋ง, ๊ณ์ ์ ๋ฌํด์ค์ผํ๋ ์ ๋ณด์ด๋ค. ๊ทธ๋์ Ct์ activate function์ ๊ฑฐ์น๊ฐ์ o_t์ ๊ณฑํด์ฃผ๋ ๊ฒ์ผ๋ก ํด์ํ ์ ์๋ค.
Gated Recurrent Unit (GRU)

LSTM์ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๊ฒฝ๋ํํด์ ์ ์ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋๊ณผ ๋น ๋ฅธ ๊ณ์ฐ์ด ๊ฐ๋ฅํ๋๋ก ๋ง๋ ๋ชจ๋ธ์ด๋ค. ๊ฐ์ฅ ํฐ ํน์ง์ LSTM์ Cell๊ณผ Hidden์ด ์๋ ๋ฐ๋ฉด์ GRU์์๋ Hidden๋ง ์กด์ฌํ๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋ฌ๋ GRU์ ๋์์๋ฆฌ๋ LSTM๊ณผ ๊ต์ฅํ ๋์ผํ๋ค.
LSTM์ Cell์ ์ญํ ์ GRU์์๋ Hidden์ด ํด์ฃผ๊ณ ์๋ค๊ณ ๋ณด๋ฉด๋๋ค.
GRU ์์๋ Input Gate๋ง์ ์ฌ์ฉํ๋ฉฐ Forget Gate ์๋ฆฌ์๋ 1 - Input Gate ๊ฐ์ ์ฌ์ฉํ๋ค.
์ค์ต
ํ์ ํจํค์ง import
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
์ด์ ์ค์ต๊ณผ ๋์ผํ๋ค.
LSTM ์ฌ์ฉ
LSTM์ Cell state๊ฐ ์ถ๊ฐ๋๋ค. shape๋ hidden state์ ๋์ผํ๋ค.
hidden state์ cell state๋ 0์ผ๋ก ์ด๊ธฐํํ๋ค.
hidden state์ cell state์ ํฌ๊ธฐ๊ฐ ๊ฐ์๊ฒ์ ๋ณผ ์ ์๋ค.
packed_outputs ์ ์ฌ์ด์ฆ๊ฐ 123์ธ ์ด์ ๋ฅผ ์๋๊ฐ? ์ฌ์ค์ 200์ด์ด์ผ ํ๋ค. ์ฌ๊ธฐ์ 0์ ๊ฐ์๋ฅผ ๋นผ๋ฉด 123์ด๋๋ค!
GPU ์ฌ์ฉ
GPU๋ Cell state๊ฐ ์๋ค. ๊ทธ ์ธ์๋ ๋์ผํ๋ค.
Teacher forcing ์์ด ์ด์ ์ ์ป์ ๊ฒฐ๊ณผ๋ฅผ ๋ค์ input์ผ๋ก ์ด์ฉํ๋ค.
Teacher forcing์ด๋, Seq2seq(Encoder-Decoder)๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ชจ๋ธ๋ค์์ ๋ง์ด ์ฌ์ฉ๋๋ ๊ธฐ๋ฒ์ด๋ค. ์๋ ์ค๋ช ๊ณผ ์ด๋ฏธ์ง๋ ์ฌ๊ธฐ๋ฅผ ์ฐธ๊ณ ํ๋ค.

t-1๋ฒ์งธ์ ๋์ฝ๋ ์ ์ด ์์ธกํ ๊ฐ์ t๋ฒ์งธ ๋์ฝ๋์ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ค๋ค. t-1๋ฒ์งธ์์ ์ ํํ ์์ธก์ด ์ด๋ฃจ์ด์ง๋ค๋ฉด ์์ฒญ๋ ์ฅ์ ์ ๊ฐ์ง๋ ๊ตฌ์กฐ์ง๋ง, ์๋ชป๋ ์์ธก ์์์๋ ์์ฒญ๋ ๋จ์ ์ด ๋์ด๋ฒ๋ฆฐ๋ค.
๋ค์์ ๋จ์ ์ด ๋์ด๋ฒ๋ฆฐ RNN์ ์๋ชป๋ ์์ธก์ด ์ ํ๋ ๊ฒฝ์ฐ

์ด๋ฌํ ๋จ์ ์ ํ์ต ์ด๊ธฐ์ ํ์ต ์๋ ์ ํ์ ์์ธ์ด ๋๋ฉฐ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋์จ ๊ธฐ๋ฒ์ด ํฐ์ณํฌ์ฑ์ด๋ค.

์์ ๊ฐ์ด ์ ๋ ฅ์ Ground Truth๋ก ๋ฃ์ด์ฃผ๊ฒ ๋๋ฉด, ํ์ต์ ๋ ์ ํํ ์์ธก์ด ๊ฐ๋ฅํ๊ฒ ๋์ด ์ด๊ธฐ ํ์ต ์๋๋ฅผ ๋น ๋ฅด๊ฒ ์ฌ๋ฆด ์ ์๋ค.
๊ทธ๋ฌ๋ ๋จ์ ์ผ๋ก๋ ๋ ธ์ถ ํธํฅ ๋ฌธ์ ๊ฐ ์๋ค. ์ถ๋ก ๊ณผ์ ์์๋ Ground Truth๋ฅผ ์ ๊ณตํ ์ ์๊ธฐ ๋๋ฌธ์ ํ์ต๊ณผ ์ถ๋ก ๋จ๊ณ์์์ ์ฐจ์ด๊ฐ ์กด์ฌํ๊ฒ ๋๊ณ ์ด๋ ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ ์์ ์ฑ์ ๋จ์ด๋จ๋ฆด ์ ์๋ค.
๋ค๋ง ๋ ธ์ถ ํธํฅ ๋ฌธ์ ๊ฐ ์๊ฐ๋งํผ ํฐ ์ํฅ์ ๋ฏธ์น์ง ์๋๋ค๋ ์ฐ๊ตฌ๊ฒฐ๊ณผ๊ฐ ์๋ค.
(T. He, J. Zhang, Z. Zhou, and J. Glass. Quantifying Exposure Bias for Neural Language Generation (2019), arXiv.)
์๋ฐฉํฅ ๋ฐ ์ฌ๋ฌ layer ์ฌ์ฉ
์ฌ๊ธฐ์๋ 2๊ฐ์ ๋ ์ด์ด ๋ฐ ์๋ฐฉํฅ์ ์ฌ์ฉํ๋ค. ๊ทธ๋์ hidden state์ ํฌ๊ธฐ๋ (4, Batchsize, hidden dimension) ์ด ๋๋ค.
์ค์ ๋ก ํ๋ ์คํ ์ดํธ์ ํฌ๊ธฐ๊ฐ 4๋ก ์์ํ๋ ๊ฒ์ ์ ์ ์๋ค. ๋ํ, packed_outputs ์ญ์ 256๊ฐ๊ฐ ์๋๋ผ 1024๊ฐ์ ์ฐจ์์ผ๋ก ์ด๋ฃจ์ด์ง ๊ฒ์ ์ ์ ์๋ค.
Last updated
Was this helpful?