(05๊ฐ) Sequence to Sequence with Attention
210908
Last updated
Was this helpful?
210908
Last updated
Was this helpful?
์์ ๋ฐฐ์ด RNN์ ๊ตฌ์กฐ ์ค Many to Many์ ํด๋นํ๋ ๋ชจ๋ธ์ด๋ค. ๋ณดํต ์ ๋ ฅ์ word ๋จ์์ ๋ฌธ์ฅ์ด๊ณ ์ถ๋ ฅ๋ ๋์ผํ๋ค.
์ด ๋, ์ ๋ ฅ ๋ฌธ์ฅ์ ๋ฐ๋ ๋ชจ๋ธ์ ์ธ์ฝ๋๋ผ๊ณ ํ๊ณ ํ๋ํ๋ ๋ต์ ๋ด๋๋ ๋ถ๋ถ์ ๋์ฝ๋๋ผ๊ณ ํ๋ค. ์ธ์ฝ๋์ ๋์ฝ๋๋ ์๋ก ๋ค๋ฅธ RNN ๋ชจ๋ธ์ด๋ค. ๊ทธ๋์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ณต์ ํ๊ฑฐ๋ ํ์ง ์๋๋ค. (์ธ์ฝ๋์ ๋์ฝ๋ ๊ฐ๊ฐ์ ๋ด๋ถ์ ์ผ๋ก ๊ณต์ ํ๋ค)
๋ํ, ๋ด๋ถ ๊ตฌ์กฐ๋ฅผ ์์ธํ ๋ณด๋ฉด LSTM์ ์ฑ์ฉํ ๊ฒ์ ์ ์ ์๋ค. ์ธ์ฝ๋์ ๋ง์ง๋ง ๋จ์ด๊น์ง ์ฝ์ ํ ์์ฑ๋๋ ๋ง์ง๋ง ์คํ ์ Hidden state๋ ๋์ฝ๋์ h0๋ก์์ ์ญํ ์ ํ๋ค. ์ด hidden state๋ ์ ๋ ฅ์ ๋ํ ์ ๋ณด๋ฅผ ์ ๊ฐ์ง๊ณ ์๋ค๊ณ ๋ณผ ์ ์๊ณ ์ด๋ฅผ ๋ฐํ์ผ๋ก ๋์ฝ๋์์ ์ฌ์ฉํ๋ค๊ณ ๋ณผ ์ ์๋ค.
<Start> ํ ํฐ ๋๋ <SoS> (Start of Sentence) ํ ํฐ์ด ์ ๋ ฅ๋๋ฉด์ ๋์ฝ๋๊ฐ ์๋๋๊ธฐ ์์ํ๋ฉฐ <End> ํ ํฐ ๋๋ <EoS> (End of Sentence) ํ ํฐ์ด ๋์ฌ ๋ ๊น์ง ๋์ฝ๋ RNN์ ๊ตฌ๋ํ๋ค.
Hidden state์ ํฌ๊ธฐ๋ ์ฒ์์ ๊ณ ์ ํ๊ธฐ ๋๋ฌธ์ ์๋ฌด๋ฆฌ ์งง์ ๋ฌธ์ฅ์ด๋ผ๋ hidden dimension๋งํผ์ ์ ๋ณด๋ฅผ ์ ์ฅํด์ผ ํ๊ณ , ์๋ฌด๋ฆฌ ๊ธด ๋ฌธ์ฅ์ด๋ผ๋ hidden dimnesion ๋งํผ์ผ๋ก ์ ๋ณด๋ฅผ ์์ถํด์ผ ํ๋ค.
๋, LSTM์ด Long Term Dependency๋ฅผ ํด๊ฒฐํ๋ค๊ณ ํ๋๋ผ๋ ํจ์ฌ ์ด์ ์ ๋ํ๋ ์ ๋ณด๋ ๋ณ์ง๋๊ฑฐ๋ ์์ค๋๋ค. ๊ทธ๋์ ๋ฌธ์ฅ์ด ๊ธธ๋ค๋ณด๋ฉด ์ฒซ๋ฒ์งธ ๋จ์ด์ ๋ํ ์ ๋ณด๊ฐ ์ ๊ธฐ ๋๋ฌธ์ ๋์ฝ๋์ ์์๋ถํฐ ํ์ง์ด ๋๋น ์ง๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ค. ์ด์ ๋ํ ํ ํฌ๋์ผ๋ก "I go home" ์ผ๋ก ์ ๋ ฅํ๋ ๊ฒ์ด ์๋ "home go I"๋ก ์ ๋ ฅํด์ ๋ฌธ์ฅ์ ์ด๋ฐ ์ ๋ณด๋ฅผ ์ ์ ์งํ ์ ์๋๋ก ํ๋ค.
๋์ฝ๋๋ ์ธ์ฝ๋์์ ๋ง์ง๋ง์ผ๋ก ๋์จ hIdden state๋ฅผ h0์ผ๋ก ์ฌ์ฉํ์ง๋ง ์ด๊ฒ๋ง์ ์ฌ์ฉํ์ง ์๋๋ค. ์ธ์ฝ๋์ ๊ฐ time step์์ ๋์จ hidden state๋ฅผ ๋ชจ๋ ์ ๊ณต๋ฐ๊ณ ์ด ์ค ์ ๋ณ์ ์ผ๋ก ์ฌ์ฉํด์ ์์ธก์ ๋์์ ์ฃผ๋ ํํ๋ก ํ์ฉํ๋ค. ์ด๊ฒ์ด attention ๋ชจ๋์ ๊ธฐ๋ณธ์ ์ธ ์์ด๋์ด์ด๋ค.
hidden state๊ฐ 4๊ฐ์ ์ฐจ์์ผ๋ก ๊ตฌ์ฑ๋์๊ณ ํ๋์ค์ด๋ฅผ ์์ด๋ก ๋ณํํ๋ ๊ณผ์ ์ ์์๋ก ๋ ์ด๋ฏธ์ง์ด๋ค. ๋ค์๊ณผ ๊ฐ์ ์์๋ก ๊ตฌ์ฑ๋๋ค.
์ธ์ฝ๋์์ ์ ๋ ฅ๋ณ๋ก hidden state๊ฐ ์์ฑ๋๋ฉฐ ์ต์ข hidden state๊ฐ ๋์ฝ๋์ ์ ๊ณต๋๋ค.
๋์ฝ๋๋ h0์ <sos> ํ ํฐ์ ๊ฐ์ง๊ณ ์ฒซ๋ฒ์งธ h state๋ฅผ ์์ฑํ๋ค.
์ฒซ๋ฒ์งธ h state๋ ์ธ์ฝ๋์ ๊ฐ๊ฐ์ h state์ ๋ด์ ์ ํ๊ฒ ๋๋ค.
๋ด์ ์ ํ๋ค๋ ๊ฒ์ ์ ์ฌ๋๋ฅผ ๋น๊ตํ๊ฒ ๋ค๋ ์๋ฏธ.
์ดํ, ๊ฐ ์ ์ฌ๋๋ฅผ sofrmaxํ ๊ฐ์ ๊ฐ์ค์น๋ก ์ป๊ฒ๋๋ค.
์ด ๋ attention output ๋ฒกํฐ๋ ๊ฐ์คํ๊ท ๋ ๋ฒกํฐ์ด๋ฉฐ context ๋ฒกํฐ๋ผ๊ณ ๋ ๋ถ๋ฅธ๋ค.
์ดํ ๋์ฝ๋๋ ๋์ฝ๋์ h state์ attention output ์ concat ํ๋ฉฐ ์์ธก๊ฐ์ ๋ฐํํ๊ฒ๋๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก, ๋์ฝ๋์ ๋๋ฒ์งธ step์์๋ ๋์ผํ ๋ฉ์ปค๋์ฆ์ด ์ ์ฉ๋๋ค.
<eos> ํ ํฐ์ด ๋์ฌ๋๊น์ง ์๋๋๋ค.
์ ๋ฆฌํ๋ฉด RNN์ ๋์ฝ๋๋ 1) ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ๊ณ 2) ์ธ์ฝ๋๋ก๋ถํฐ ํ์๋ก ํ๋ ์ ๋ณด๋ฅผ ์ทจ์ฌ์ ํํ๋๋ก, ํ์ต์ด ์งํ๋๋ค. ์ญ์ ํ์ ๊ด์ ์์๋, Attention ๋ฒกํฐ๊ฐ ๋ค์ ์ ํ๋ ์ ์๋๋ก ์ธ์ฝ๋์ hidden state๊ฐ ๊ฐฑ์ ๋๋ค. ์ธ์ฝ๋์ h state๊ฐ ๊ฐฑ์ ๋๋ฏ๋ก ๋น์ฐํ ๋์ฝ๋์ h state๋ ๊ฐฑ์ ๋๋ค.
ํ์ต์ ํ ๋์๋ ๋์ฝ๋์ ๊ฐ ํ์์คํ ์ ์์ธก๊ฐ์ด ๋ฌด์์ด๋ ๊ฐ์ Ground Truth ๊ฐ์ ๋ฃ์ด์ฃผ๊ฒ ๋์ง๋ง ์ถ๋ก ์ ํ ๋์๋ ์ด์ ํ์์คํ ์ ์์ธก๊ฐ์ ๋ค์ ํ์์คํ ์ ์ ๋ ฅ๊ฐ์ผ๋ก ๋ฃ์ด์ฃผ๊ฒ ๋๋ค.
์ด๋ ๊ฒ ํ์ต ์ค์ ์
๋ ฅ์ Ground Truth๋ก ๋ฃ์ด์ฃผ๋ ๋ฐฉ๋ฒ์ Teacher Forcing
์ด๋ผ๊ณ ํ๋ค.
๋ฌผ๋ก , ํ์ต์ ์ ๋์ง๋ง ์ค์ ๋ก ์ฐ๋ฆฌ๊ฐ ์ ์ฉํด์ผ ํ๋ ๋ฌธ์ ๋ Teacher Forcing
๊ณผ๋ ๊ดด๋ฆฌ๊ฐ ์๋ค. ๊ทธ๋์ ์ด๋ฅผ ์์ด์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ด ๋์๋๋ฐ, ํ์ต ์ด๋ฐ์๋ ๋น ๋ฅธ ํ์ต์ ์ํด์ ์ด๋ฅผ ์ ์ฉํ๋ค๊ฐ, ํ์ต์ด ์ด๋ ์ ๋ ๋๊ณ ๋์๋ ์ ์ฉํ์ง ์๋๋ก ํ๋ ๋ฐฉ๋ฒ๋ ์กด์ฌํ๋ค.
์ด์ ์๋ ์ ์ฌ๋๋ฅผ ๊ตฌํ๊ธฐ ์ํด ๋ด์ ์ ์ฌ์ฉํ๋๋ฐ, ๋ด์ ์ด์ธ์๋ ๋ค์ํ ๋ฐฉ๋ฒ์ผ๋ก attention์ ๊ตฌ์ฑํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด๋๋ก ํ๋ค.
h_t : ๋์ฝ๋์์ ์ฃผ์ด์ง๋ ํ๋ ๋ฒกํฐ
h_s : ์ธ์ฝ๋์์ ๊ฐ ์๋๋ณ๋ก์ ํ๋ ๋ฒกํฐ
๊ทธ๋ฅ ๋ด์ ์ ํ ์๋ ์์ง๋ง generalized dot product
๋ผ๋ attention ๋ฐฉ๋ฒ๋ ์๋ค.
W๋ ๋๊ฐํ๋ ฌ์ ๋ชจ์์ด๋ค. ๊ฐ dimension ๋ณ๋ก ์ ์ฉํ๋ ๊ฐ์ค์น์ ์ญํ ์ ํ๋ค.
๋, concat
ํ๋ ๋ฐฉ๋ฒ์ด ์๋๋ฐ, ์ด์ ์ ๋ด์ ๋ค๊ณผ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ด๋ค. ์ ์ฌ๋๋ฅผ ๋ด์ ์ด ์๋๋ผ ์ ๊ฒฝ๋ง์ ํตํด์ ๊ตฌํ๋ ๋ฐฉ๋ฒ์ด๋ค.
2-layer์ ์ ๊ฒฝ๋ง์ผ๋ก ๊ตฌ์ฑํ ์ ์๋ค.
์ด์ ์ attention์ ํ๋ผ๋ฏธํฐ๊ฐ ํ์์๋ ๋ด์ ์ฐ์ฐ์ ๋ชจ๋์ด์๋๋ฐ, ํ๋ผ๋ฏธํฐ๊ฐ ํ์ํ ํ์ต์ด ๋๋ฉด์ ์ข ๋ ์ต์ ํ ํ ์ ์๊ฒ๋๋ค.
๋์ฝ๋์ ๋งค ์คํ ๋ง๋ค ํน์ ์ ๋ณด๋ฅผ ์ ๊ณตํ๋ฉด์ ์ฑ๋ฅ์ด ๋งค์ฐ ํฅ์๋์๋ค.
attention์ ํ๋ฉด์ ๊ธด ๋ฌธ์ฅ์ ๋ฒ์ญ์ด ์ด๋ ค์ด ์ , bottleneck problem์ ํด๊ฒฐํ๋ค.
์ญ์ ํ ๊ณผ์ ์์ ๋์ฝ๋ ์คํ ๊ณผ ์ธ์ฝ๋ ์คํ ์ ๊ฑฐ์ณ๊ฐ๋ฉด์ ๋งค์ฐ ๊ธด ํ์์คํ ์ ์ง๋๊ฒ๋๊ณ ์ด ๋ gradient ์์ค ๋๋ ์ฆํญ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์๊ฒ๋๋๋ฐ attention์ ์ฌ์ฉํ๋ฉด์ gradient๊ฐ ์ง์ ์ ์ผ๋ก ์ ๋ฌํ ์ ์๋ ๋ฐฉ๋ฒ์ด ์ถ๊ฐ๋๋ฉด์ gradient๊ฐ ๋ณ์ง์์ด ์ ๋ฌ๋ ์ ์๊ฒ๋์๋ค.
ํฅ๋ฏธ๋ก์ด ํด์๊ฐ๋ฅ์ฑ์ ์ ๊ณตํด์ค๋ค.
attention์ ์กฐ์ฌํด์ h state๊ฐ ๊ฐ ๋จ์ด์ ์ด๋ค ๋ถ๋ถ์ ์ง์คํ๋์ง ๊ด์ฐฐํ ์ ์๊ฒ๋์๋ค.
๋งค ์ค์ต๋ง๋ค ๋์ผํ ๋ถ๋ถ์ด ํต์ฌ ํด๋์ค๋ง ๋ค๋ฃน๋๋ค.
3, 4๊ฐ์ ๋ฑ์ฅํ ์ธ์ฝ๋์ ๋์ผํ๋ค. ๋ค๋ง, ๋ด๋ถ ์ธ์๊ฐ ์ด์ง ๋ฌ๋ผ์ ์ถ๊ฐ๋ ์ฝ๋๊ฐ ์๋ค. ์ฌ๊ธฐ์๋, layer์ ์๊ฐ 2๊ฐ์ด๊ณ ๋ฐฉํฅ๋ ์๋ฐฉํฅ์ด๋ค.
๊ทธ๋์ hidden state์ 3์ฐจ์ ๊ฐ์๊ฐ 1์์ 4๋ก ์ฆ๊ฐํ๋ค.
๋ํ, layer๊ฐ 2๊ฐ์ด๋ฏ๋ก forward_hidden
์ ์ฒซ๋ฒ์งธ layer๋ก, backward_hidden
์ ๋๋ฒ์งธ layer๋ก ์ ํ๊ณ ์ค์ hidden state๋ฅผ ๋ฐํํ ๋๋ ์ด ๋์ cat
ํด์ ๋ฐํํ๋ค.
๋์ฝ๋๋ ์ด์ ๊ณผ ๋์ผํ๋ฏ๋ก ์๋ตํ๋ค.
encoder์ output์ ์ฌ์ฉํ์ง ์๋ ๋ชจ์ต.
๋ํ, decoder์ output์ encoder์ฒ๋ผ ํ๋ฒ์ ๋์ค์ง ์์ผ๋ฏ๋ก for๋ฌธ์ผ๋ก ์๋์ํจ๋ค. ๊ทธ๋์ ์ด๋ฅผ ๋ด์์ฃผ๊ธฐ ์ํ outputs๋ฅผ ์ ์ธํด์ค๋ค.
์ฌ๊ธฐ์ W2์ ํด๋นํ๋ ๋ถ๋ถ์ด ๊ฐ ๋๋ค.