(04๊ฐ•) LSTM and GRU

210907

4. Long Short-Term Memory (LSTM) | Gated Recurrent Unit (GRU)

Long Short-Term Memory (LSTM)

Vanila Model์˜ Gradient Exploding/Vanishing ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  Long Term Dependency ๋ฌธ์ œ๋ฅผ ๊ฐœ์„ ํ•œ ๋ฌธ์ œ์ด๋‹ค.

๊ธฐ์กด์˜ RNN ์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

ht=fw(xt,ย htโˆ’1) h_t = f_w(x_t,\ h_{t-1})

LSTM์—๋Š” Cell state๋ผ๋Š” ๊ฐ’์ด ์ถ”๊ฐ€๋˜๋ฉฐ ์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

{Ct,ย ht}=LSTM(xt,ย Ctโˆ’1,ย htโˆ’1) \{C_t,\ h_t\} = LSTM(x_t,\ C_{t-1},\ h_{t-1})

Cell state๊ฐ€ hidden state๋ณด๋‹ค ์ข€ ๋” ์™„์„ฑ๋œ, ํ•„์š”๋กœ ํ•˜๋Š” ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๋ฒกํ„ฐ์ด๋ฉฐ ์ด cell state ๋ฒกํ„ฐ๋ฅผ ํ•œ๋ฒˆ ๋” ๊ฐ€๊ณตํ•ด์„œ ํ•ด๋‹น time step์—์„œ ๋…ธ์ถœํ•  ํ•„์š”๊ฐ€ ์žˆ๋Š” ์ •๋ณด๋ฅผ ํ•„ํ„ฐ๋ง ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฒกํ„ฐ๋กœ๋„ ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ๋‹ค.

cell state๋ฅผ ํ•œ๋ฒˆ ๋” ๊ฐ€๊ณตํ•œ hidden state ๋ฒกํ„ฐ๋Š” ํ˜„์žฌ timestep ์—์„œ ์˜ˆ์ธก๊ฐ’์„ ๊ณ„์‚ฐํ•˜๋Š” output layer์˜ ์ž…๋ ฅ๋ฒกํ„ฐ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

  • ์—ฌ๊ธฐ์„œ x๋Š” x_t ์ด๊ณ  h๋Š” h_(t-1) ์ด๋‹ค.

Forget gate

์ด์ „ ํƒ€์ž„์Šคํ…์—์„œ ์–ป์€ ์ •๋ณด ์ค‘ ์ผ๋ถ€๋งŒ์„ ๋ฐ˜์˜ํ•˜๊ฒ ๋‹ค.

= ์ด์ „ ํƒ€์ž„์Šคํ…์—์„œ ์–ป์€ ์ •๋ณด ์ผ๋ถ€๋ฅผ ๊นŒ๋จน๊ฒ ๋‹ค = forget

Input gate

์ด๋ฒˆ ์…€์—์„œ ์–ป์€ C tilda ๊ฐ’์„ input gate์™€ ๊ณฑํ•ด์ฃผ๋Š” ์ด์œ ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • ํ•œ๋ฒˆ์˜ ์„ ํ˜•๋ณ€ํ™˜๋งŒ์œผ๋กœ Ctโˆ’1 C_{t-1} ์— ๋”ํ•ด์ฃผ๋Š” ์ •๋ณด๋ฅผ ๋งŒ๋“ค๊ธฐ๊ฐ€ ์–ด๋ ต๋‹ค. ๋”ฐ๋ผ์„œ ์ด ๋”ํ•ด์ฃผ๋Š” ์ •๋ณด๋ฅผ ์ผ๋‹จ ํฌ๊ฒŒ ๋งŒ๋“  ํ›„์— ๊ฐ ์ฐจ์›๋ณ„๋กœ ํŠน์ • ๋น„์œจ๋งŒํผ ๋œ์–ด๋‚ด์„œ ๋”ํ•ด์ฃผ๋Š” ์ •๋ณด๋ฅผ ๋งŒ๋“ค๊ฒ ๋‹ค ๋ผ๋Š” ๋ชฉ์ ์ด๋‹ค.

  • ์ด ๋•Œ, ๋”ํ•ด์ฃผ๋Š” ์ •๋ณด๋ณด๋‹ค ํฌ๊ฒŒ ๋งŒ๋“  ์ •๋ณด๊ฐ€ C tilda ์ด๋ฉฐ ํŠน์ • ๋น„์œจ๋งŒํผ ๋œ์–ด๋‚ด๋Š” ์ž‘์—…์ด input gate์™€ ๊ณฑํ•ด์ฃผ๋Š” ์ž‘์—…์ด๋‹ค.

Output gate

  • "He said, 'I love you.' " ๋ผ๋Š” ๋ฌธ์žฅ์ด ์žˆ๋‹ค๊ณ  ํ•˜์ž. ํ˜„์žฌ sequence๊ฐ€ love y ๊นŒ์ง€ ๋“ค์–ด์™”๊ณ  y๋‹ค์Œ์— o๋ฅผ ์ถœ๋ ฅ์œผ๋กœ ์ค˜์•ผ ํ•  ์ฐจ๋ก€์ด๋‹ค. ์ด ๋•Œ y์˜ ์ž…์žฅ์—์„œ๋Š” ๋‹น์žฅ์˜ ์ž‘์€ ๋”ฐ์˜ดํ‘œ๊ฐ€ ์—ด๋ฆฐ ์‚ฌ์‹ค์€ ์ค‘์š”ํ•˜์ง€ ์•Š์ง€๋งŒ, ๊ณ„์† ์ „๋‹ฌํ•ด์ค˜์•ผํ•˜๋Š” ์ •๋ณด์ด๋‹ค. ๊ทธ๋ž˜์„œ Ct์˜ activate function์„ ๊ฑฐ์นœ๊ฐ’์„ o_t์— ๊ณฑํ•ด์ฃผ๋Š” ๊ฒƒ์œผ๋กœ ํ•ด์„ํ•  ์ˆ˜ ์žˆ๋‹ค.

Gated Recurrent Unit (GRU)

LSTM์˜ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๊ฒฝ๋Ÿ‰ํ™”ํ•ด์„œ ์ ์€ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ๋Ÿ‰๊ณผ ๋น ๋ฅธ ๊ณ„์‚ฐ์ด ๊ฐ€๋Šฅํ•˜๋„๋ก ๋งŒ๋“  ๋ชจ๋ธ์ด๋‹ค. ๊ฐ€์žฅ ํฐ ํŠน์ง•์€ LSTM์€ Cell๊ณผ Hidden์ด ์žˆ๋Š” ๋ฐ˜๋ฉด์— GRU์—์„œ๋Š” Hidden๋งŒ ์กด์žฌํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ GRU์˜ ๋™์ž‘์›๋ฆฌ๋Š” LSTM๊ณผ ๊ต‰์žฅํžˆ ๋™์ผํ•˜๋‹ค.

  • LSTM์˜ Cell์˜ ์—ญํ• ์„ GRU์—์„œ๋Š” Hidden์ด ํ•ด์ฃผ๊ณ  ์žˆ๋‹ค๊ณ  ๋ณด๋ฉด๋œ๋‹ค.

  • GRU ์—์„œ๋Š” Input Gate๋งŒ์„ ์‚ฌ์šฉํ•˜๋ฉฐ Forget Gate ์ž๋ฆฌ์—๋Š” 1 - Input Gate ๊ฐ’์„ ์‚ฌ์šฉํ•œ๋‹ค.

์‹ค์Šต

ํ•„์š” ํŒจํ‚ค์ง€ import

๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

LSTM ์‚ฌ์šฉ

LSTM์€ Cell state๊ฐ€ ์ถ”๊ฐ€๋œ๋‹ค. shape๋Š” hidden state์™€ ๋™์ผํ•˜๋‹ค.

  • hidden state์™€ cell state๋Š” 0์œผ๋กœ ์ดˆ๊ธฐํ™”ํ•œ๋‹ค.

  • hidden state์™€ cell state์˜ ํฌ๊ธฐ๊ฐ€ ๊ฐ™์€๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

  • packed_outputs ์˜ ์‚ฌ์ด์ฆˆ๊ฐ€ 123์ธ ์ด์œ ๋ฅผ ์•„๋Š”๊ฐ€? ์‚ฌ์‹ค์€ 200์ด์–ด์•ผ ํ•œ๋‹ค. ์—ฌ๊ธฐ์„œ 0์˜ ๊ฐœ์ˆ˜๋ฅผ ๋นผ๋ฉด 123์ด๋œ๋‹ค!

GPU ์‚ฌ์šฉ

GPU๋Š” Cell state๊ฐ€ ์—†๋‹ค. ๊ทธ ์™ธ์—๋Š” ๋™์ผํ•˜๋‹ค.

Teacher forcing ์—†์ด ์ด์ „์— ์–ป์€ ๊ฒฐ๊ณผ๋ฅผ ๋‹ค์Œ input์œผ๋กœ ์ด์šฉํ•œ๋‹ค.

  • Teacher forcing์ด๋ž€, Seq2seq(Encoder-Decoder)๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ๋ชจ๋ธ๋“ค์—์„œ ๋งŽ์ด ์‚ฌ์šฉ๋˜๋Š” ๊ธฐ๋ฒ•์ด๋‹ค. ์•„๋ž˜ ์„ค๋ช…๊ณผ ์ด๋ฏธ์ง€๋Š” ์—ฌ๊ธฐ๋ฅผ ์ฐธ๊ณ ํ–ˆ๋‹ค.

  • t-1๋ฒˆ์งธ์˜ ๋””์ฝ”๋” ์…€์ด ์˜ˆ์ธกํ•œ ๊ฐ’์„ t๋ฒˆ์งธ ๋””์ฝ”๋”์˜ ์ž…๋ ฅ์œผ๋กœ ๋„ฃ์–ด์ค€๋‹ค. t-1๋ฒˆ์งธ์—์„œ ์ •ํ™•ํ•œ ์˜ˆ์ธก์ด ์ด๋ฃจ์–ด์ง„๋‹ค๋ฉด ์—„์ฒญ๋‚œ ์žฅ์ ์„ ๊ฐ€์ง€๋Š” ๊ตฌ์กฐ์ง€๋งŒ, ์ž˜๋ชป๋œ ์˜ˆ์ธก ์•ž์—์„œ๋Š” ์—„์ฒญ๋‚œ ๋‹จ์ ์ด ๋˜์–ด๋ฒ„๋ฆฐ๋‹ค.

  • ๋‹ค์Œ์€ ๋‹จ์ ์ด ๋˜์–ด๋ฒ„๋ฆฐ RNN์˜ ์ž˜๋ชป๋œ ์˜ˆ์ธก์ด ์„ ํ–‰๋œ ๊ฒฝ์šฐ

  • ์ด๋Ÿฌํ•œ ๋‹จ์ ์€ ํ•™์Šต ์ดˆ๊ธฐ์— ํ•™์Šต ์†๋„ ์ €ํ•˜์˜ ์š”์ธ์ด ๋˜๋ฉฐ ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋‚˜์˜จ ๊ธฐ๋ฒ•์ด ํ‹ฐ์ณํฌ์‹ฑ์ด๋‹ค.

  • ์œ„์™€ ๊ฐ™์ด ์ž…๋ ฅ์„ Ground Truth๋กœ ๋„ฃ์–ด์ฃผ๊ฒŒ ๋˜๋ฉด, ํ•™์Šต์‹œ ๋” ์ •ํ™•ํ•œ ์˜ˆ์ธก์ด ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋˜์–ด ์ดˆ๊ธฐ ํ•™์Šต ์†๋„๋ฅผ ๋น ๋ฅด๊ฒŒ ์˜ฌ๋ฆด ์ˆ˜ ์žˆ๋‹ค.

  • ๊ทธ๋Ÿฌ๋‚˜ ๋‹จ์ ์œผ๋กœ๋Š” ๋…ธ์ถœ ํŽธํ–ฅ ๋ฌธ์ œ๊ฐ€ ์žˆ๋‹ค. ์ถ”๋ก  ๊ณผ์ •์—์„œ๋Š” Ground Truth๋ฅผ ์ œ๊ณตํ•  ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต๊ณผ ์ถ”๋ก  ๋‹จ๊ณ„์—์„œ์˜ ์ฐจ์ด๊ฐ€ ์กด์žฌํ•˜๊ฒŒ ๋˜๊ณ  ์ด๋Š” ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ๊ณผ ์•ˆ์ •์„ฑ์„ ๋–จ์–ด๋œจ๋ฆด ์ˆ˜ ์žˆ๋‹ค.

  • ๋‹ค๋งŒ ๋…ธ์ถœ ํŽธํ–ฅ ๋ฌธ์ œ๊ฐ€ ์ƒ๊ฐ๋งŒํผ ํฐ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š๋Š”๋‹ค๋Š” ์—ฐ๊ตฌ๊ฒฐ๊ณผ๊ฐ€ ์žˆ๋‹ค.

(T. He, J. Zhang, Z. Zhou, and J. Glass. Quantifying Exposure Bias for Neural Language Generation (2019), arXiv.)

์–‘๋ฐฉํ–ฅ ๋ฐ ์—ฌ๋Ÿฌ layer ์‚ฌ์šฉ

  • ์—ฌ๊ธฐ์„œ๋Š” 2๊ฐœ์˜ ๋ ˆ์ด์–ด ๋ฐ ์–‘๋ฐฉํ–ฅ์„ ์‚ฌ์šฉํ•œ๋‹ค. ๊ทธ๋ž˜์„œ hidden state์˜ ํฌ๊ธฐ๋„ (4, Batchsize, hidden dimension) ์ด ๋œ๋‹ค.

  • ์‹ค์ œ๋กœ ํžˆ๋“  ์Šคํ…Œ์ดํŠธ์˜ ํฌ๊ธฐ๊ฐ€ 4๋กœ ์‹œ์ž‘ํ•˜๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ๋˜ํ•œ, packed_outputs ์—ญ์‹œ 256๊ฐœ๊ฐ€ ์•„๋‹ˆ๋ผ 1024๊ฐœ์˜ ์ฐจ์›์œผ๋กœ ์ด๋ฃจ์–ด์ง„ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

Last updated

Was this helpful?