(03๊ฐ) Recurrent Neural Network and Language Modeling
210907
1. Basics of Recurrent Neural Networks (RNNs)

์ด์ time step ์์ ๊ณ์ฐํ ์ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ ํ์ฌ time step์ ๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ด์ด์ฃผ๋ ๊ตฌ์กฐ์ด๋ค. ์ด ๋ ๋งค time step์์ ๋์ผํ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋์ผํ ๋ชจ๋์ ์ฌ์ฉํ๋ค๋ ๋ป์์ Recurrent ๊ฐ ๋์๋ค.
RNN์ ๊ธฐํธ๋ค์ด ์๋ฏธํ๋ ๋ฐ๋ ๋ค์๊ณผ ๊ฐ๋ค.

ํนํ y๋ h์์ ์์ฑ๋๋ ๊ฐ์ผ๋ก ๋งค time step๋ง๋ค ์์ฑ๋ ์๋ ์๊ณ ๋ง์ง๋ง์๋ง ์์ฑ๋ ์๋ ์๋ค.
machine translation์ ๋งค๋ฒ ์์ฑ๋๊ณ ๋ฌธ์ฅ์ ๊ธ์ ํํ์ ๋ง์ง๋ง์๋ง ์์ฑ๋๋ค.
๋ณดํต, h๋ฅผ ๊ตฌํ ๋ hyper tangent๋ฅผ ์ฌ์ฉํ๋ค.

์ฌ๊ธฐ์ ์ค์ ๋ก๋ W๊ฐ x์ h(t-1)์ ๋ํด์ ๋ ๊ฐ ์๋ค๊ณ ๊ฐ์ ํ ์๋ ์์ง๋ง, ๊ฒฐ๊ตญ์ ๋ด์ ํด์ ๋ํ๋ฏ๋ก ์ค์ ๋ก๋ ํ ๊ฐ์ W๊ฐ ์๋ค๊ณ ์๊ฐํ๊ณ x์ h(t-1)์ ์ธ๋ก๋ก ๊ฒฐํฉํ ์ํ์์ ๊ณฑํ ์๋ ์๋ค.

2. Types of RNNs
One-to-One

์ฌ๋์ ํค๋ฅผ ์ ๋ ฅ ๋ฐ์ ๋ชธ๋ฌด๊ฒ๋ฅผ ์์ธกํ๋ ๋ชจ๋ธ
time step์ด๋ sequence๊ฐ ์๋ ์ผ๋ฐ์ ์ธ ํํ๋ฅผ ๋์ํํ ๋ชจ์ต
One-to-Many

Image Captioning ๊ฐ์ Task์ ๋ง์ด ์ฌ์ฉ๋๋ค.
time step์ผ๋ก ์ด๋ฃจ์ด์ง์ง ์์ ์ ๋ ฅ์ ์ ๊ณตํ๋ฉฐ ์ด๋ฏธ์ง์ ํ์ํ ๋จ์ด๋ค์ ์์ฐจ์ ์ผ๋ก ์์ฑํ๋ค.
์ ๊ทธ๋ฆผ์์๋ ์ ๋ ฅ์ด ์ฒซ๋ฒ์งธ time step์์๋ง ๋ค์ด๊ฐ๋ฏ๋ก ๊ทธ๋ฆผ์ ์ ๋ ๊ฒ ๊ทธ๋ ธ์ง๋ง ์ค์ ๋ก๋ RNN ๋ชจ๋ธ์ ์ด์ฉํ๋ฏ๋ก ๋๋ฒ์งธ time step๋ถํฐ๋ 0์ผ๋ก ์ฑ์์ง ํ๋ ฌ ๋๋ ํ ์๊ฐ ์ ๋ ฅ๋๋ค.
Many-to-One

Sentiment Classification Task์ ์ด์ฉ
๊ฐ ์ ๋ ฅ๋ง๋ค Sequence Words๋ฅผ ๋ฐ๊ณ , ๊ฐ ๋จ์ด๋ฅผ ํตํด ์ ์ฒด ๋ฌธ์ฅ์ ๊ฐ์ ์ ๋ถ์ํ๊ฒ ๋๋ค.
Many-to-Many

Machine Translation์ ์ด์ฉ
๋ง์ง๋ง sequence๊น์ง ์ ๋ ฅ์ ๋ฐ๊ณ ์ด ๋ ๊น์ง๋ ์ถ๋ ฅ์ ๋ด๋์ง ์๋ค๊ฐ ๋ง์ง๋ง step ์์ ์ ๋ ฅ์ ๋ฐ์ ๋ค ์ถ๋ ฅ์ ์ค๋ค.
๋ ๋ค๋ฅธ ๊ตฌ์กฐ๋ก์, ์ ๋ ฅ์ด ์ฃผ์ด์ง ๋ ๋ง๋ค ์ถ๋ ฅ์ด ๋๋ ๋๋ ์ด๊ฐ ์กด์ฌํ์ง ์๋ Task๊ฐ ์กด์ฌํ๋ค.
Video classification on frame level ์ด๋ POS tagging Task์ ์ด์ฉํ๋ค.
๋น๋์ค ๋ถ๋ฅ๋ ๊ฐ ๋น๋์ค์ ์ด๋ฏธ์ง ํ์ฅ ํ์ฅ์ด ์ด๋ค ์๋ฏธ๋ฅผ ๊ฐ๋ ์ง ๋ถ์ํ๋ค. ์๋ฅผ ๋ค๋ฉด ๊ฐ ์ ์ด ์ ์์ด ์ผ์ด๋๋ ์ ์ด๋ค ๋ผ๋๊ฐ. ์ฃผ์ธ๊ณต์ด ๋ฑ์ฅํ์ง ์๋ ์ ์ด๋ค ๋ผ๋๊ฐ. ๋ฑ๋ฑ
3. Character-level Language Model
์ธ์ด ๋ชจ๋ธ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ฃผ์ด์ง ๋ฌธ์์ด์ ์์๋ฅผ ๋ฐํ์ผ๋ก ๋ค์ ๋จ์ด๊ฐ ๋ฌด์์ธ์ง ์์๋ด๋ Task์ด๋ค. ๋ณด๋ค ์ฌํํ ์์ ๋ก์ Character๋ก ๋ค๋ฃฌ๋ค.
์ฒ์์๋ ์ค๋ณต์ ์ ๊ฑฐํด์ ์ฌ์ ์ ๊ตฌ์ถํ๋ค.
์ ์ฒด ๊ธธ์ด๋งํผ์ ์ฐจ์์ ๊ฐ์ง๋ ์ํซ๋ฒกํฐ๋ก ์ํ๋ฒณ์ ๋ํ๋ธ๋ค.

์ด ๋, bias๋ ์ฌ์ฉํ๋ฉฐ h0 ๋ ์๋ฒกํฐ๋ก ์ด๊ธฐํํ๋ค.
Output ๋ฒกํฐ๋ ๋ค์๊ณผ ๊ฐ์ด ๊ตฌํ ์ ์๋ค.
์ดํ, ์ํํธ ๋งฅ์ค๋ฅผ ๊ฑฐ์ณ ๊ฐ์ฅ ํฐ ๊ฐ์ผ๋ก output์ ๊ฒฐ์ ํ๊ฒ ๋๋ฉฐ ํน์ ๋ฌธ์์ 1์ ๋ชฐ์์ค Ground Truth์์ ์ค์ฐจ๋ฅผ ํตํด back propagation์ด ์ด๋ฃจ์ด์ง๊ฒ ๋๋ค.

์ด ๋ inference ํ๋ ๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ๋ค.
์ฒซ๋ฒ์งธ ๊ธ์๋ฅผ ์ฃผ๊ณ ์ป์ output์ ๋๋ฒ์งธ ์ ๋ ฅ์ผ๋ก ์ค์ ํ๋ค. ์ด๋ฅผ ๋ฐ๋ณตํด์ ๋ชจ๋ ๋จ์ด๋ฅผ ์ป๋๋ค.

์ ๊ธ์ ์ ฐ์ต์คํผ์ด์ ํฌ๊ณก ์ค ํ๋์ด๋ค. ์์ธํ ๋ณด๋ฉด ๋จ์ํ ์ํ๋ฒณ ๋ฟ๋ง ์๋๋ผ ๊ณต๋ฐฑ๊ณผ ๋ฌธ์ฅ๋ถํธ๊น์ง๋ ์ด์ด์ง๋ ๊ฒ์ ์ ์ ์๋ค. ์ค์ ๋ก ์ด๋ฌํ ๊ฒ๊น์ง ๊ณ ๋ คํด์ผํ๋ค.

์ฒ์์๋ ์ ํ์ตํ์ง ๋ชปํด ๋ง๋ ๋์ง ์๋ ์ํฐ๋ฆฌ ๋จ์ด๋ค์ ๋ด๋๋ค๊ฐ ํ์ต์ ๊ฑฐ๋ญํ ์๋ก ๋ง์ด ๋๋ ๋ฌธ์ฅ์ด ๋๋ ๊ฒ์ ์ ์ ์๋ค.
RNN์ผ๋ก ๋ ผ๋ฌธ์ ์์ฑํ๊ฑฐ๋, ์ฐ๊ทน ๋๋ณธ ๋๋ ํ๋ก๊ทธ๋๋ฐ ์ฝ๋๊น์ง๋ ์์ฑํ ์ ์๋ค.
BackPropagation through time, BPTT

์ ์ฒ, ์ ๋ง ์ด์์ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๋ค๋ณด๋ฉด ์ ๋ ฅ์ผ๋ก ์ ๊ณต๋๋ Sequence๊ฐ ๋งค์ฐ ๊ธธ ์ ์๊ณ ์ด์ ๋ฐ๋ผ ๋ชจ๋ output์ ์ข ํฉํด์ Loss ๊ฐ์ ์ป๊ณ BackPropagtaion์ ์งํํด์ผ ํ๋ค. ํ์ค์ ์ผ๋ก ์ด ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ฉด ํ๊บผ๋ฒ์ ์ฒ๋ฆฌํ ์ ์๋ ์ ๋ณด๋ ๋ฐ์ดํฐ์ ์์ด ํ์ ๋ ๋ฆฌ์์ค ์์ ๋ด๊ธฐ์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ ์ ํ๋ ๊ธธ์ด์ Sequence ๋ง์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์ฑํํ๋ค.

๋ค์ ์ด๋ฏธ์ง๋ Hidden state์ ํน์ ํ dimension์ ๊ด์ฐฐํด์ ํฌ๊ธฐ๊ฐ ์ปค์ง๋ฉด ํธ๋ฅธ์์ผ๋ก ํฌ๊ธฐ๊ฐ ์์๋ก ์์์ง๋ฉด ๋ถ์์์ผ๋ก ํํํ๋ค.

์ด๋ ๊ฒ ํ ๊ฐ์ dimension์ ์ฌ๋ฌ๊ฐ ๊ด์ฐฐํ๋ค๊ฐ ํน์ ์์น์์ ํฅ๋ฏธ๋ก์ด ํจํด์ ๋ณด๊ฒ๋๋ค ํฐ ๋ฐ์ดํ๋ถํฐ ๋ค์ ํฐ ๋ฐ์ดํ๊น์ง ํธ๋ฅธ์์ ์ ์งํ๋ค๊ฐ ๋ฐ์ดํ๊ฐ ๋ซํ ์ดํ๋ถํฐ๋ ๋ถ์์์ ์ ์งํ๊ณ ๋ค์ ๋ฐ์ดํ๋ฅผ ๋ง๋๋ฉด ํธ๋ฅธ์์ผ๋ก ๋ฐ๋๋ ๋ชจ์ต์ ๋ณผ ์ ์๋ค.

์ฆ, ์ด ์ฐจ์์์๋ ํฐ ๋ฐ์ดํ์ ์์๊ณผ ๋์ ๊ธฐ์ตํ๋ ์ฉ๋๋ก ์ฌ์ฉ๋์์์ ์ ์ ์๋ค.
๋, ๋ค์ ์ด๋ฏธ์ง๋ ํ๋ก๊ทธ๋จ ์ฝ๋๋ฅผ ๋ํ๋ธ๋ค.

์ด ์ ์ ํด๋น ๊ตฌ๋ฌธ์ด if๋ฌธ์ด๋ผ๋ ๊ฒ์ ๊ธฐ์ตํ๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
์ฌ์ค ์ด๋ฌํ ํน์ง์ ๋ฐ๋๋ผ RNN์ด ์๋ LSTM์ด๋ GRU๋ฅผ ์ฌ์ฉํ์ ๋์ ๊ฒฐ๊ณผ์ด๋ค. ๊ฐ๋จํ ๊ตฌ์กฐ์ RNN์ ์ ์ ๋ง์ด ์ฌ์ฉํ์ง ์๋๋ฐ, ๋ค์๊ณผ ๊ฐ์ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ธฐ ๋๋ฌธ์ด๋ค.

๋งค ํ์์คํ ๋ง๋ค ํ๋ ์คํ ์ดํธ์ ๋์ผํ W๊ฐ ๊ณฑํด์ง๋ค๋ณด๋ ๋ฑ๋น์์ด์ ๊ผด๋ก ๋ํ๋์ง๊ฒ ๋๊ณ ์ฌ๊ธฐ์ ๊ณต๋น๊ฐ 1๋ณด๋ค ์์ผ๋ฉด Vanishing Gradient ๋ฌธ์ ๊ฐ, 1๋ณด๋ค ํฌ๋ฉด Exploding Gradient ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ฒ ๋๋ค.

์์ ์ซ์๋ backpropagtaion ๋ ๋ W์ gradient๊ฐ์ ๋ํ๋ด๋ฉฐ ์ ์ ์์์ง๊ณ ์๋ค. 0์ ๊ฐ๊น์์ง์๋ก ์ ์๋ฏธํ signal์ ๋ค์ชฝ์ผ๋ก ์ ๋ฌํ ์ ์๊ฒ๋๋ค. ํ์์ 0์ ์๋ฏธํ๊ณ ๊ทธ ์ธ์ ๊ฐ์ 0๋ณด๋ค ํฌ๊ฑฐ๋ ์์ ๊ฐ์ ์๋ฏธํ๋ค. RNN์ ์ฝ๊ฒ gradient๊ฐ 0์ด ๋๋ ๋ฐ๋ฉด์ LSTM์ ๊ฝค ๊ธด ํ์ ์คํ ๊น์ง๋ gradient๊ฐ ์ด์์๋ ๋ชจ์ต์ด๋ค.
Long Term Dependency๋ฅผ ํด๊ฒฐํด์ฃผ๋ ๋ชจ์ต.
์ค์ต
ํ์ ํจํค์ง
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
ํ์ฌ๋ ๋ฐ์ดํฐ์ ๊ธธ์ด๊ฐ ๋ชจ๋ ๋ค๋ฅธ ๋ชจ์ต
๊ฐ์ฅ ๊ธด ๋ฐ์ดํฐ์ ๊ธธ์ด๋ก ํต์ผํด์ค๋ค.
์ด ๋ ์ด ๊ธธ์ด๋ณด๋ค ์งง์ ๋ฐ์ดํฐ๋
pad_id==0์ผ๋ก ์ฑ์์ค๋ค.
๊ธธ์ด๊ฐ ๋ชจ๋ 20์ผ๋ก ํต์ผ๋ ๋ชจ์ต
๋, ์๋ ๊ธธ์ด๋ฅผ ์๊ธฐ์ํ valid_lens ๋ฅผ ์ ์ธํ๋ค.
๋ฐ์ดํฐ๋ฅผ ํ๋์ batch๋ก ๋ง๋ ๋ค.
๋จ์ํ Tensorํ ํด์ฃผ๋ ๊ณผ์ ์ผ๋ก ์๊ฐํ๋ฉด ๋๋ค.
RNN ์ฌ์ฉ
RNN์ ๋ฃ๊ธฐ์ ์ ์๋ ์๋ฒ ๋ฉ์ ํด์ผํ๋ค.
์ฒ์์ vocab_size๋ 100์ผ๋ก ์ ํด์ฃผ์๋ค.
์ค์ ๋ก ๊ฐ๊ฐ์ ๋ฐ์ดํฐ์ ์์๋ 0๋ถํฐ 99๊น์ง์ ์๋ก ์ด๋ฃจ์ด์ ธ์๋ค.
embedding ์ฐจ์์ 256์ผ๋ก ์์๋ก ์ ํ๋ค.
hidden_laye๋ฅผ ์ ์ํ๋ค.
๋จ์ด๋ 100์ฐจ์->256์ฐจ์ -> 512์ฐจ์ -> 256์ฐจ์->100์ฐจ์์ผ๋ก ๋ณํ๋๋ค.
100์ฐจ์์ ์ํซ ์ธ์ฝ๋ฉ
256์ฐจ์์ ์๋ ์๋ฒ ๋ฉ
512์ฐจ์์ RNN network๋ฅผ ํตํด ๋ณํ
์ด๊ธฐ h0์ 0์ผ๋ก ์ด๊ธฐํ๋๋ค. ํฌ๊ธฐ๋ (1, 10, 512) ์ด๋ค.
์ดํ, RNN์ batch data๋ฅผ ๋ฃ๋๋ค. ๋ ๊ฐ์ง output์ ์ป๋๋ค.
transpose๋ ์ ์นํ๋ ฌ์ด๋ฉฐ ์ธ์๋ก ๋ฐ์ ์ฐจ์๋ผ๋ฆฌ ๋ณ๊ฒฝ์์ผ์ค๋ค.ํ์ฌ๋ 0๊ณผ 1์ด๋ฏ๋ก X_ij ์์ i๊ฐ j๋ก, j๊ฐ i๋ก ๋ฐ๋๋ค.
hidden_states: ๊ฐ time step์ ํด๋นํ๋ hidden state๋ค์ ๋ฌถ์.h_n: ๋ชจ๋ sequence๋ฅผ ๊ฑฐ์น๊ณ ๋์จ ๋ง์ง๋ง hidden state.
RNN ํ์ฉ
๋ง์ง๋ง hidden state๋ฅผ ์ด์ฉํ๋ฉด text classification task์ ์ ์ฉํ ์ ์๋ค.
๊ฐ time step์ ๋ํ hidden state๋ฅผ ์ด์ฉํ๋ฉด token-level์ task๋ฅผ ์ํํ ์๋ ์๋ค.
PackedSequence ์ฌ์ฉ
์ฃผ์ด์ง data์์ ๋ถํ์ํ pad ๊ณ์ฐ์ด ํฌํจ๋๊ธฐ ๋๋ฌธ์ ์ ๋ ฌ์ ํด์ผํ๋ค.
์ ๋ถํ์ํ๋๋ฉด 0์๋ค๊ฐ ์ด๋ค ์๋ฅผ ๊ณฑํด๋ ๋ 0์ด๋ฏ๋ก ์ด ๋ถ๋ถ์ ๊ณ์ฐํ ํ์๊ฐ ์๋ ๊ฒ์ด๋ค.
์ด๊ฒ ๋ฌธ์ ๊ฐ ๋ผ? ๋ฌธ์ ๊ฐ ๋๋ค. ๊ฐ ํ์์คํ ๋ง๋ค ๊ณ์ฐ์ด ํ์ํ๋ฐ ์ด ๋ ๋ง์ 0์ ๊ณ์ฐํ๋ ๊ฒ๋ณด๋ค ์๋ตํ๋ ๊ฒ์ด ๋ ๋น ๋ฅธ ์ฐ์ฐ์ ์ฒ๋ฆฌํ ์ ์๋ค.
์ ๋ ฌ์ ํ๋ฉด ํด๊ฒฐ๋ผ? ํด๊ฒฐ์ด ๋๋ค. ์ ๋ ฌ์ ํ๊ณ ๊ฐ ํ์์คํ ๋ณ๋ก ๋ฌธ์ฅ์ ์ต๋ ๊ธธ์ด๋ฅผ ๊ธฐ์ตํ๊ณ ์์ผ๋ฉด ๋๋ค.
์ด ๋ ์ด ๊ธฐ๋ฅ์
torch.nn.utils.rnn์์ ์ง์ํ๋pack_padded_sequence์pad_packed_sequence๋ฅผ ์ด์ฉํ๋ฉด ๋๋ค.์ฌ๊ธฐ๋ฅผ ๋ณด๋ฉด ์ดํด๊ฐ ์ฝ๋ค


์ฃผ์ด์ง data๋ฅผ ์ ๋ ฌํ๊ณ embeddingํ ๋ค
์ ๋ ฌํ data๋ฅผ ์๋ฒ ๋ฉํ๊ณ PackSequence ๋ชจ์์ผ๋ก ๋ฐ๊ฟ์ rnn์ ์ ๋ ฅํ๋ค.
์ดํ ์ป์ output์ ์๋ outputํํ์ ๋ค๋ฅด๋ฏ๋ก
pad_packed_sequence๋ฅผ ์ด์ฉํ์ฌ ์๋ ํํ๋ก ๋๋๋ ค ์ค๋ค.
Last updated
Was this helpful?