(03๊ฐ) Recurrent Neural Network and Language Modeling
210907
Last updated
Was this helpful?
210907
Last updated
Was this helpful?
์ด์ time step ์์ ๊ณ์ฐํ ์ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ ํ์ฌ time step์ ๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ด์ด์ฃผ๋ ๊ตฌ์กฐ์ด๋ค. ์ด ๋ ๋งค time step์์ ๋์ผํ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋์ผํ ๋ชจ๋์ ์ฌ์ฉํ๋ค๋ ๋ป์์ Recurrent ๊ฐ ๋์๋ค.
RNN์ ๊ธฐํธ๋ค์ด ์๋ฏธํ๋ ๋ฐ๋ ๋ค์๊ณผ ๊ฐ๋ค.
ํนํ y๋ h์์ ์์ฑ๋๋ ๊ฐ์ผ๋ก ๋งค time step๋ง๋ค ์์ฑ๋ ์๋ ์๊ณ ๋ง์ง๋ง์๋ง ์์ฑ๋ ์๋ ์๋ค.
machine translation์ ๋งค๋ฒ ์์ฑ๋๊ณ ๋ฌธ์ฅ์ ๊ธ์ ํํ์ ๋ง์ง๋ง์๋ง ์์ฑ๋๋ค.
๋ณดํต, h๋ฅผ ๊ตฌํ ๋ hyper tangent๋ฅผ ์ฌ์ฉํ๋ค.
์ฌ๊ธฐ์ ์ค์ ๋ก๋ W๊ฐ x์ h(t-1)์ ๋ํด์ ๋ ๊ฐ ์๋ค๊ณ ๊ฐ์ ํ ์๋ ์์ง๋ง, ๊ฒฐ๊ตญ์ ๋ด์ ํด์ ๋ํ๋ฏ๋ก ์ค์ ๋ก๋ ํ ๊ฐ์ W๊ฐ ์๋ค๊ณ ์๊ฐํ๊ณ x์ h(t-1)์ ์ธ๋ก๋ก ๊ฒฐํฉํ ์ํ์์ ๊ณฑํ ์๋ ์๋ค.
์ฌ๋์ ํค๋ฅผ ์ ๋ ฅ ๋ฐ์ ๋ชธ๋ฌด๊ฒ๋ฅผ ์์ธกํ๋ ๋ชจ๋ธ
time step์ด๋ sequence๊ฐ ์๋ ์ผ๋ฐ์ ์ธ ํํ๋ฅผ ๋์ํํ ๋ชจ์ต
Image Captioning ๊ฐ์ Task์ ๋ง์ด ์ฌ์ฉ๋๋ค.
time step์ผ๋ก ์ด๋ฃจ์ด์ง์ง ์์ ์ ๋ ฅ์ ์ ๊ณตํ๋ฉฐ ์ด๋ฏธ์ง์ ํ์ํ ๋จ์ด๋ค์ ์์ฐจ์ ์ผ๋ก ์์ฑํ๋ค.
์ ๊ทธ๋ฆผ์์๋ ์ ๋ ฅ์ด ์ฒซ๋ฒ์งธ time step์์๋ง ๋ค์ด๊ฐ๋ฏ๋ก ๊ทธ๋ฆผ์ ์ ๋ ๊ฒ ๊ทธ๋ ธ์ง๋ง ์ค์ ๋ก๋ RNN ๋ชจ๋ธ์ ์ด์ฉํ๋ฏ๋ก ๋๋ฒ์งธ time step๋ถํฐ๋ 0์ผ๋ก ์ฑ์์ง ํ๋ ฌ ๋๋ ํ ์๊ฐ ์ ๋ ฅ๋๋ค.
Sentiment Classification Task์ ์ด์ฉ
๊ฐ ์ ๋ ฅ๋ง๋ค Sequence Words๋ฅผ ๋ฐ๊ณ , ๊ฐ ๋จ์ด๋ฅผ ํตํด ์ ์ฒด ๋ฌธ์ฅ์ ๊ฐ์ ์ ๋ถ์ํ๊ฒ ๋๋ค.
Machine Translation์ ์ด์ฉ
๋ง์ง๋ง sequence๊น์ง ์ ๋ ฅ์ ๋ฐ๊ณ ์ด ๋ ๊น์ง๋ ์ถ๋ ฅ์ ๋ด๋์ง ์๋ค๊ฐ ๋ง์ง๋ง step ์์ ์ ๋ ฅ์ ๋ฐ์ ๋ค ์ถ๋ ฅ์ ์ค๋ค.
๋ ๋ค๋ฅธ ๊ตฌ์กฐ๋ก์, ์ ๋ ฅ์ด ์ฃผ์ด์ง ๋ ๋ง๋ค ์ถ๋ ฅ์ด ๋๋ ๋๋ ์ด๊ฐ ์กด์ฌํ์ง ์๋ Task๊ฐ ์กด์ฌํ๋ค.
Video classification on frame level ์ด๋ POS tagging Task์ ์ด์ฉํ๋ค.
๋น๋์ค ๋ถ๋ฅ๋ ๊ฐ ๋น๋์ค์ ์ด๋ฏธ์ง ํ์ฅ ํ์ฅ์ด ์ด๋ค ์๋ฏธ๋ฅผ ๊ฐ๋ ์ง ๋ถ์ํ๋ค. ์๋ฅผ ๋ค๋ฉด ๊ฐ ์ ์ด ์ ์์ด ์ผ์ด๋๋ ์ ์ด๋ค ๋ผ๋๊ฐ. ์ฃผ์ธ๊ณต์ด ๋ฑ์ฅํ์ง ์๋ ์ ์ด๋ค ๋ผ๋๊ฐ. ๋ฑ๋ฑ
์ธ์ด ๋ชจ๋ธ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฃผ์ด์ง ๋ฌธ์์ด์ ์์๋ฅผ ๋ฐํ์ผ๋ก ๋ค์ ๋จ์ด๊ฐ ๋ฌด์์ธ์ง ์์๋ด๋ Task์ด๋ค. ๋ณด๋ค ์ฌํํ ์์ ๋ก์ Character๋ก ๋ค๋ฃฌ๋ค.
์ฒ์์๋ ์ค๋ณต์ ์ ๊ฑฐํด์ ์ฌ์ ์ ๊ตฌ์ถํ๋ค.
์ ์ฒด ๊ธธ์ด๋งํผ์ ์ฐจ์์ ๊ฐ์ง๋ ์ํซ๋ฒกํฐ๋ก ์ํ๋ฒณ์ ๋ํ๋ธ๋ค.
์ด ๋, bias๋ ์ฌ์ฉํ๋ฉฐ h0 ๋ ์๋ฒกํฐ๋ก ์ด๊ธฐํํ๋ค.
Output ๋ฒกํฐ๋ ๋ค์๊ณผ ๊ฐ์ด ๊ตฌํ ์ ์๋ค.
์ดํ, ์ํํธ ๋งฅ์ค๋ฅผ ๊ฑฐ์ณ ๊ฐ์ฅ ํฐ ๊ฐ์ผ๋ก output์ ๊ฒฐ์ ํ๊ฒ ๋๋ฉฐ ํน์ ๋ฌธ์์ 1์ ๋ชฐ์์ค Ground Truth์์ ์ค์ฐจ๋ฅผ ํตํด back propagation์ด ์ด๋ฃจ์ด์ง๊ฒ ๋๋ค.
์ด ๋ inference ํ๋ ๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ๋ค.
์ฒซ๋ฒ์งธ ๊ธ์๋ฅผ ์ฃผ๊ณ ์ป์ output์ ๋๋ฒ์งธ ์ ๋ ฅ์ผ๋ก ์ค์ ํ๋ค. ์ด๋ฅผ ๋ฐ๋ณตํด์ ๋ชจ๋ ๋จ์ด๋ฅผ ์ป๋๋ค.
์ ๊ธ์ ์ ฐ์ต์คํผ์ด์ ํฌ๊ณก ์ค ํ๋์ด๋ค. ์์ธํ ๋ณด๋ฉด ๋จ์ํ ์ํ๋ฒณ ๋ฟ๋ง ์๋๋ผ ๊ณต๋ฐฑ๊ณผ ๋ฌธ์ฅ๋ถํธ๊น์ง๋ ์ด์ด์ง๋ ๊ฒ์ ์ ์ ์๋ค. ์ค์ ๋ก ์ด๋ฌํ ๊ฒ๊น์ง ๊ณ ๋ คํด์ผํ๋ค.
์ฒ์์๋ ์ ํ์ตํ์ง ๋ชปํด ๋ง๋ ๋์ง ์๋ ์ํฐ๋ฆฌ ๋จ์ด๋ค์ ๋ด๋๋ค๊ฐ ํ์ต์ ๊ฑฐ๋ญํ ์๋ก ๋ง์ด ๋๋ ๋ฌธ์ฅ์ด ๋๋ ๊ฒ์ ์ ์ ์๋ค.
RNN์ผ๋ก ๋ ผ๋ฌธ์ ์์ฑํ๊ฑฐ๋, ์ฐ๊ทน ๋๋ณธ ๋๋ ํ๋ก๊ทธ๋๋ฐ ์ฝ๋๊น์ง๋ ์์ฑํ ์ ์๋ค.
์ ์ฒ, ์ ๋ง ์ด์์ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๋ค๋ณด๋ฉด ์ ๋ ฅ์ผ๋ก ์ ๊ณต๋๋ Sequence๊ฐ ๋งค์ฐ ๊ธธ ์ ์๊ณ ์ด์ ๋ฐ๋ผ ๋ชจ๋ output์ ์ข ํฉํด์ Loss ๊ฐ์ ์ป๊ณ BackPropagtaion์ ์งํํด์ผ ํ๋ค. ํ์ค์ ์ผ๋ก ์ด ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ฉด ํ๊บผ๋ฒ์ ์ฒ๋ฆฌํ ์ ์๋ ์ ๋ณด๋ ๋ฐ์ดํฐ์ ์์ด ํ์ ๋ ๋ฆฌ์์ค ์์ ๋ด๊ธฐ์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ ์ ํ๋ ๊ธธ์ด์ Sequence ๋ง์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์ฑํํ๋ค.
๋ค์ ์ด๋ฏธ์ง๋ Hidden state์ ํน์ ํ dimension์ ๊ด์ฐฐํด์ ํฌ๊ธฐ๊ฐ ์ปค์ง๋ฉด ํธ๋ฅธ์์ผ๋ก ํฌ๊ธฐ๊ฐ ์์๋ก ์์์ง๋ฉด ๋ถ์์์ผ๋ก ํํํ๋ค.
์ด๋ ๊ฒ ํ ๊ฐ์ dimension์ ์ฌ๋ฌ๊ฐ ๊ด์ฐฐํ๋ค๊ฐ ํน์ ์์น์์ ํฅ๋ฏธ๋ก์ด ํจํด์ ๋ณด๊ฒ๋๋ค ํฐ ๋ฐ์ดํ๋ถํฐ ๋ค์ ํฐ ๋ฐ์ดํ๊น์ง ํธ๋ฅธ์์ ์ ์งํ๋ค๊ฐ ๋ฐ์ดํ๊ฐ ๋ซํ ์ดํ๋ถํฐ๋ ๋ถ์์์ ์ ์งํ๊ณ ๋ค์ ๋ฐ์ดํ๋ฅผ ๋ง๋๋ฉด ํธ๋ฅธ์์ผ๋ก ๋ฐ๋๋ ๋ชจ์ต์ ๋ณผ ์ ์๋ค.
์ฆ, ์ด ์ฐจ์์์๋ ํฐ ๋ฐ์ดํ์ ์์๊ณผ ๋์ ๊ธฐ์ตํ๋ ์ฉ๋๋ก ์ฌ์ฉ๋์์์ ์ ์ ์๋ค.
๋, ๋ค์ ์ด๋ฏธ์ง๋ ํ๋ก๊ทธ๋จ ์ฝ๋๋ฅผ ๋ํ๋ธ๋ค.
์ด ์ ์ ํด๋น ๊ตฌ๋ฌธ์ด if๋ฌธ์ด๋ผ๋ ๊ฒ์ ๊ธฐ์ตํ๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
์ฌ์ค ์ด๋ฌํ ํน์ง์ ๋ฐ๋๋ผ RNN์ด ์๋ LSTM์ด๋ GRU๋ฅผ ์ฌ์ฉํ์ ๋์ ๊ฒฐ๊ณผ์ด๋ค. ๊ฐ๋จํ ๊ตฌ์กฐ์ RNN์ ์ ์ ๋ง์ด ์ฌ์ฉํ์ง ์๋๋ฐ, ๋ค์๊ณผ ๊ฐ์ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ธฐ ๋๋ฌธ์ด๋ค.
๋งค ํ์์คํ ๋ง๋ค ํ๋ ์คํ ์ดํธ์ ๋์ผํ W๊ฐ ๊ณฑํด์ง๋ค๋ณด๋ ๋ฑ๋น์์ด์ ๊ผด๋ก ๋ํ๋์ง๊ฒ ๋๊ณ ์ฌ๊ธฐ์ ๊ณต๋น๊ฐ 1๋ณด๋ค ์์ผ๋ฉด Vanishing Gradient ๋ฌธ์ ๊ฐ, 1๋ณด๋ค ํฌ๋ฉด Exploding Gradient ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ฒ ๋๋ค.
์์ ์ซ์๋ backpropagtaion ๋ ๋ W์ gradient๊ฐ์ ๋ํ๋ด๋ฉฐ ์ ์ ์์์ง๊ณ ์๋ค. 0์ ๊ฐ๊น์์ง์๋ก ์ ์๋ฏธํ signal์ ๋ค์ชฝ์ผ๋ก ์ ๋ฌํ ์ ์๊ฒ๋๋ค. ํ์์ 0์ ์๋ฏธํ๊ณ ๊ทธ ์ธ์ ๊ฐ์ 0๋ณด๋ค ํฌ๊ฑฐ๋ ์์ ๊ฐ์ ์๋ฏธํ๋ค. RNN์ ์ฝ๊ฒ gradient๊ฐ 0์ด ๋๋ ๋ฐ๋ฉด์ LSTM์ ๊ฝค ๊ธด ํ์ ์คํ ๊น์ง๋ gradient๊ฐ ์ด์์๋ ๋ชจ์ต์ด๋ค.
Long Term Dependency๋ฅผ ํด๊ฒฐํด์ฃผ๋ ๋ชจ์ต.
ํ์ฌ๋ ๋ฐ์ดํฐ์ ๊ธธ์ด๊ฐ ๋ชจ๋ ๋ค๋ฅธ ๋ชจ์ต
๊ฐ์ฅ ๊ธด ๋ฐ์ดํฐ์ ๊ธธ์ด๋ก ํต์ผํด์ค๋ค.
์ด ๋ ์ด ๊ธธ์ด๋ณด๋ค ์งง์ ๋ฐ์ดํฐ๋ pad_id==0
์ผ๋ก ์ฑ์์ค๋ค.
๊ธธ์ด๊ฐ ๋ชจ๋ 20์ผ๋ก ํต์ผ๋ ๋ชจ์ต
๋, ์๋ ๊ธธ์ด๋ฅผ ์๊ธฐ์ํ valid_lens ๋ฅผ ์ ์ธํ๋ค.
๋ฐ์ดํฐ๋ฅผ ํ๋์ batch๋ก ๋ง๋ ๋ค.
๋จ์ํ Tensorํ ํด์ฃผ๋ ๊ณผ์ ์ผ๋ก ์๊ฐํ๋ฉด ๋๋ค.
RNN์ ๋ฃ๊ธฐ์ ์ ์๋ ์๋ฒ ๋ฉ์ ํด์ผํ๋ค.
์ฒ์์ vocab_size๋ 100์ผ๋ก ์ ํด์ฃผ์๋ค.
์ค์ ๋ก ๊ฐ๊ฐ์ ๋ฐ์ดํฐ์ ์์๋ 0๋ถํฐ 99๊น์ง์ ์๋ก ์ด๋ฃจ์ด์ ธ์๋ค.
embedding ์ฐจ์์ 256์ผ๋ก ์์๋ก ์ ํ๋ค.
hidden_laye๋ฅผ ์ ์ํ๋ค.
๋จ์ด๋ 100์ฐจ์->256์ฐจ์ -> 512์ฐจ์ -> 256์ฐจ์->100์ฐจ์์ผ๋ก ๋ณํ๋๋ค.
100์ฐจ์์ ์ํซ ์ธ์ฝ๋ฉ
256์ฐจ์์ ์๋ ์๋ฒ ๋ฉ
512์ฐจ์์ RNN network๋ฅผ ํตํด ๋ณํ
์ด๊ธฐ h0์ 0์ผ๋ก ์ด๊ธฐํ๋๋ค. ํฌ๊ธฐ๋ (1, 10, 512) ์ด๋ค.
์ดํ, RNN์ batch data๋ฅผ ๋ฃ๋๋ค. ๋ ๊ฐ์ง output์ ์ป๋๋ค.
transpose
๋ ์ ์นํ๋ ฌ์ด๋ฉฐ ์ธ์๋ก ๋ฐ์ ์ฐจ์๋ผ๋ฆฌ ๋ณ๊ฒฝ์์ผ์ค๋ค.
ํ์ฌ๋ 0๊ณผ 1์ด๋ฏ๋ก X_ij ์์ i๊ฐ j๋ก, j๊ฐ i๋ก ๋ฐ๋๋ค.
hidden_states
: ๊ฐ time step์ ํด๋นํ๋ hidden state๋ค์ ๋ฌถ์.
h_n
: ๋ชจ๋ sequence๋ฅผ ๊ฑฐ์น๊ณ ๋์จ ๋ง์ง๋ง hidden state.
๋ง์ง๋ง hidden state๋ฅผ ์ด์ฉํ๋ฉด text classification task์ ์ ์ฉํ ์ ์๋ค.
๊ฐ time step์ ๋ํ hidden state๋ฅผ ์ด์ฉํ๋ฉด token-level์ task๋ฅผ ์ํํ ์๋ ์๋ค.
์ฃผ์ด์ง data์์ ๋ถํ์ํ pad ๊ณ์ฐ์ด ํฌํจ๋๊ธฐ ๋๋ฌธ์ ์ ๋ ฌ์ ํด์ผํ๋ค.
์ ๋ถํ์ํ๋๋ฉด 0์๋ค๊ฐ ์ด๋ค ์๋ฅผ ๊ณฑํด๋ ๋ 0์ด๋ฏ๋ก ์ด ๋ถ๋ถ์ ๊ณ์ฐํ ํ์๊ฐ ์๋ ๊ฒ์ด๋ค.
์ด๊ฒ ๋ฌธ์ ๊ฐ ๋ผ? ๋ฌธ์ ๊ฐ ๋๋ค. ๊ฐ ํ์์คํ ๋ง๋ค ๊ณ์ฐ์ด ํ์ํ๋ฐ ์ด ๋ ๋ง์ 0์ ๊ณ์ฐํ๋ ๊ฒ๋ณด๋ค ์๋ตํ๋ ๊ฒ์ด ๋ ๋น ๋ฅธ ์ฐ์ฐ์ ์ฒ๋ฆฌํ ์ ์๋ค.
์ ๋ ฌ์ ํ๋ฉด ํด๊ฒฐ๋ผ? ํด๊ฒฐ์ด ๋๋ค. ์ ๋ ฌ์ ํ๊ณ ๊ฐ ํ์์คํ ๋ณ๋ก ๋ฌธ์ฅ์ ์ต๋ ๊ธธ์ด๋ฅผ ๊ธฐ์ตํ๊ณ ์์ผ๋ฉด ๋๋ค.
์ด ๋ ์ด ๊ธฐ๋ฅ์ torch.nn.utils.rnn
์์ ์ง์ํ๋ pack_padded_sequence
์ pad_packed_sequence
๋ฅผ ์ด์ฉํ๋ฉด ๋๋ค.
์ฃผ์ด์ง data๋ฅผ ์ ๋ ฌํ๊ณ embeddingํ ๋ค
์ ๋ ฌํ data๋ฅผ ์๋ฒ ๋ฉํ๊ณ PackSequence ๋ชจ์์ผ๋ก ๋ฐ๊ฟ์ rnn์ ์ ๋ ฅํ๋ค.
์ดํ ์ป์ output์ ์๋ outputํํ์ ๋ค๋ฅด๋ฏ๋ก pad_packed_sequence
๋ฅผ ์ด์ฉํ์ฌ ์๋ ํํ๋ก ๋๋๋ ค ์ค๋ค.
๋ฅผ ๋ณด๋ฉด ์ดํด๊ฐ ์ฝ๋ค