4. Long Short-Term Memory (LSTM) | Gated Recurrent Unit (GRU)
Long Short-Term Memory (LSTM)
Vanila Model์ Gradient Exploding/Vanishing ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ Long Term Dependency ๋ฌธ์ ๋ฅผ ๊ฐ์ ํ ๋ฌธ์ ์ด๋ค.
๊ธฐ์กด์ RNN ์์ ๋ค์๊ณผ ๊ฐ๋ค.
htโ=fwโ(xtโ,ย htโ1โ)
LSTM์๋ Cell state๋ผ๋ ๊ฐ์ด ์ถ๊ฐ๋๋ฉฐ ์์ ๋ค์๊ณผ ๊ฐ๋ค.
{Ctโ,ย htโ}=LSTM(xtโ,ย Ctโ1โ,ย htโ1โ)
Cell state๊ฐ hidden state๋ณด๋ค ์ข ๋ ์์ฑ๋, ํ์๋ก ํ๋ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ๋ฒกํฐ์ด๋ฉฐ ์ด cell state ๋ฒกํฐ๋ฅผ ํ๋ฒ ๋ ๊ฐ๊ณตํด์ ํด๋น time step์์ ๋
ธ์ถํ ํ์๊ฐ ์๋ ์ ๋ณด๋ฅผ ํํฐ๋ง ํ ์ ์๋ ๋ฒกํฐ๋ก๋ ์๊ฐํ ์ ์๋ค.
cell state๋ฅผ ํ๋ฒ ๋ ๊ฐ๊ณตํ hidden state ๋ฒกํฐ๋ ํ์ฌ timestep ์์ ์์ธก๊ฐ์ ๊ณ์ฐํ๋ output layer์ ์
๋ ฅ๋ฒกํฐ๋ก ์ฌ์ฉํ๋ค.
์ฌ๊ธฐ์ x๋ x_t ์ด๊ณ h๋ h_(t-1) ์ด๋ค.
Forget gate
์ด์ ํ์์คํ
์์ ์ป์ ์ ๋ณด ์ค ์ผ๋ถ๋ง์ ๋ฐ์ํ๊ฒ ๋ค.
= ์ด์ ํ์์คํ
์์ ์ป์ ์ ๋ณด ์ผ๋ถ๋ฅผ ๊น๋จน๊ฒ ๋ค = forget
์ด๋ฒ ์
์์ ์ป์ C tilda ๊ฐ์ input gate์ ๊ณฑํด์ฃผ๋ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ๋ค.
ํ๋ฒ์ ์ ํ๋ณํ๋ง์ผ๋ก Ctโ1โ์ ๋ํด์ฃผ๋ ์ ๋ณด๋ฅผ ๋ง๋ค๊ธฐ๊ฐ ์ด๋ ต๋ค. ๋ฐ๋ผ์ ์ด ๋ํด์ฃผ๋ ์ ๋ณด๋ฅผ ์ผ๋จ ํฌ๊ฒ ๋ง๋ ํ์ ๊ฐ ์ฐจ์๋ณ๋ก ํน์ ๋น์จ๋งํผ ๋์ด๋ด์ ๋ํด์ฃผ๋ ์ ๋ณด๋ฅผ ๋ง๋ค๊ฒ ๋ค ๋ผ๋ ๋ชฉ์ ์ด๋ค.
์ด ๋, ๋ํด์ฃผ๋ ์ ๋ณด๋ณด๋ค ํฌ๊ฒ ๋ง๋ ์ ๋ณด๊ฐ C tilda ์ด๋ฉฐ ํน์ ๋น์จ๋งํผ ๋์ด๋ด๋ ์์
์ด input gate์ ๊ณฑํด์ฃผ๋ ์์
์ด๋ค.
Output gate
"He said, 'I love you.' " ๋ผ๋ ๋ฌธ์ฅ์ด ์๋ค๊ณ ํ์. ํ์ฌ sequence๊ฐ love y ๊น์ง ๋ค์ด์๊ณ y๋ค์์ o๋ฅผ ์ถ๋ ฅ์ผ๋ก ์ค์ผ ํ ์ฐจ๋ก์ด๋ค. ์ด ๋ y์ ์
์ฅ์์๋ ๋น์ฅ์ ์์ ๋ฐ์ดํ๊ฐ ์ด๋ฆฐ ์ฌ์ค์ ์ค์ํ์ง ์์ง๋ง, ๊ณ์ ์ ๋ฌํด์ค์ผํ๋ ์ ๋ณด์ด๋ค. ๊ทธ๋์ Ct์ activate function์ ๊ฑฐ์น๊ฐ์ o_t์ ๊ณฑํด์ฃผ๋ ๊ฒ์ผ๋ก ํด์ํ ์ ์๋ค.
Gated Recurrent Unit (GRU)
LSTM์ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๊ฒฝ๋ํํด์ ์ ์ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋๊ณผ ๋น ๋ฅธ ๊ณ์ฐ์ด ๊ฐ๋ฅํ๋๋ก ๋ง๋ ๋ชจ๋ธ์ด๋ค. ๊ฐ์ฅ ํฐ ํน์ง์ LSTM์ Cell๊ณผ Hidden์ด ์๋ ๋ฐ๋ฉด์ GRU์์๋ Hidden๋ง ์กด์ฌํ๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋ฌ๋ GRU์ ๋์์๋ฆฌ๋ LSTM๊ณผ ๊ต์ฅํ ๋์ผํ๋ค.
LSTM์ Cell์ ์ญํ ์ GRU์์๋ Hidden์ด ํด์ฃผ๊ณ ์๋ค๊ณ ๋ณด๋ฉด๋๋ค.
GRU ์์๋ Input Gate๋ง์ ์ฌ์ฉํ๋ฉฐ Forget Gate ์๋ฆฌ์๋ 1 - Input Gate ๊ฐ์ ์ฌ์ฉํ๋ค.
์ค์ต
ํ์ ํจํค์ง import
from tqdm import tqdm
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
vocab_size = 100
pad_id = 0
data = [
[85,14,80,34,99,20,31,65,53,86,3,58,30,4,11,6,50,71,74,13],
[62,76,79,66,32],
[93,77,16,67,46,74,24,70],
[19,83,88,22,57,40,75,82,4,46],
[70,28,30,24,76,84,92,76,77,51,7,20,82,94,57],
[58,13,40,61,88,18,92,89,8,14,61,67,49,59,45,12,47,5],
[22,5,21,84,39,6,9,84,36,59,32,30,69,70,82,56,1],
[94,21,79,24,3,86],
[80,80,33,63,34,63],
[87,32,79,65,2,96,43,80,85,20,41,52,95,50,35,96,24,80]
]
max_len = len(max(data, key=len))
valid_lens = []
for i, seq in enumerate(tqdm(data)):
valid_lens.append(len(seq))
if len(seq) < max_len:
data[i] = seq + [pad_id] * (max_len - len(seq))
# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data) # (B, L)
batch_lens = torch.LongTensor(valid_lens) # (B)
batch_lens, sorted_idx = batch_lens.sort(descending=True)
batch = batch[sorted_idx]
LSTM ์ฌ์ฉ
LSTM์ Cell state๊ฐ ์ถ๊ฐ๋๋ค. shape๋ hidden state์ ๋์ผํ๋ค.
embedding_size = 256
hidden_size = 512
num_layers = 1
num_dirs = 1
embedding = nn.Embedding(vocab_size, embedding_size)
lstm = nn.LSTM(
input_size=embedding_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=True if num_dirs > 1 else False
)
h_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size)) # (num_layers * num_dirs, B, d_h)
c_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size)) # (num_layers * num_dirs, B, d_h)
hidden state์ cell state๋ 0์ผ๋ก ์ด๊ธฐํํ๋ค.
# d_w: word embedding size
batch_emb = embedding(batch) # (B, L, d_w)
packed_batch = pack_padded_sequence(batch_emb.transpose(0, 1), batch_lens)
packed_outputs, (h_n, c_n) = lstm(packed_batch, (h_0, c_0))
print(packed_outputs)
print(packed_outputs[0].shape)
print(h_n.shape)
print(c_n.shape)
PackedSequence(data=tensor([[-0.0690, 0.1176, -0.0184, ..., -0.0339, -0.0347, 0.1103],
[-0.1626, 0.0038, 0.0090, ..., -0.1385, -0.0806, 0.0635],
[-0.0977, 0.1470, -0.0678, ..., 0.0203, 0.0201, 0.0175],
...,
[-0.1911, -0.1925, -0.0827, ..., 0.0491, 0.0302, -0.0149],
[ 0.0803, -0.0229, -0.0772, ..., -0.0706, -0.1711, -0.2128],
[ 0.1861, -0.1572, -0.1024, ..., -0.0090, -0.2621, -0.2803]],
grad_fn=<CatBackward>), batch_sizes=tensor([10, 10, 10, 10, 10, 9, 7, 7, 6, 6, 5, 5, 5, 5, 5, 4, 4, 3,
1, 1]), sorted_indices=None, unsorted_indices=None)
torch.Size([123, 512])
torch.Size([1, 10, 512])
torch.Size([1, 10, 512])
hidden state์ cell state์ ํฌ๊ธฐ๊ฐ ๊ฐ์๊ฒ์ ๋ณผ ์ ์๋ค.
packed_outputs ์ ์ฌ์ด์ฆ๊ฐ 123์ธ ์ด์ ๋ฅผ ์๋๊ฐ? ์ฌ์ค์ 200์ด์ด์ผ ํ๋ค. ์ฌ๊ธฐ์ 0์ ๊ฐ์๋ฅผ ๋นผ๋ฉด 123์ด๋๋ค!
outputs, output_lens = pad_packed_sequence(packed_outputs)
print(outputs.shape)
print(output_lens)
torch.Size([20, 10, 512])
tensor([20, 18, 18, 17, 15, 10, 8, 6, 6, 5])
GPU ์ฌ์ฉ
GPU๋ Cell state๊ฐ ์๋ค. ๊ทธ ์ธ์๋ ๋์ผํ๋ค.
gru = nn.GRU(
input_size=embedding_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=True if num_dirs > 1 else False
)
output_layer = nn.Linear(hidden_size, vocab_size)
input_id = batch.transpose(0, 1)[0, :] # (B)
hidden = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size)) # (1, B, d_h)
Teacher forcing ์์ด ์ด์ ์ ์ป์ ๊ฒฐ๊ณผ๋ฅผ ๋ค์ input์ผ๋ก ์ด์ฉํ๋ค.
t-1๋ฒ์งธ์ ๋์ฝ๋ ์
์ด ์์ธกํ ๊ฐ์ t๋ฒ์งธ ๋์ฝ๋์ ์
๋ ฅ์ผ๋ก ๋ฃ์ด์ค๋ค. t-1๋ฒ์งธ์์ ์ ํํ ์์ธก์ด ์ด๋ฃจ์ด์ง๋ค๋ฉด ์์ฒญ๋ ์ฅ์ ์ ๊ฐ์ง๋ ๊ตฌ์กฐ์ง๋ง, ์๋ชป๋ ์์ธก ์์์๋ ์์ฒญ๋ ๋จ์ ์ด ๋์ด๋ฒ๋ฆฐ๋ค.
๋ค์์ ๋จ์ ์ด ๋์ด๋ฒ๋ฆฐ RNN์ ์๋ชป๋ ์์ธก์ด ์ ํ๋ ๊ฒฝ์ฐ
์ด๋ฌํ ๋จ์ ์ ํ์ต ์ด๊ธฐ์ ํ์ต ์๋ ์ ํ์ ์์ธ์ด ๋๋ฉฐ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋์จ ๊ธฐ๋ฒ์ด ํฐ์ณํฌ์ฑ์ด๋ค.
์์ ๊ฐ์ด ์
๋ ฅ์ Ground Truth๋ก ๋ฃ์ด์ฃผ๊ฒ ๋๋ฉด, ํ์ต์ ๋ ์ ํํ ์์ธก์ด ๊ฐ๋ฅํ๊ฒ ๋์ด ์ด๊ธฐ ํ์ต ์๋๋ฅผ ๋น ๋ฅด๊ฒ ์ฌ๋ฆด ์ ์๋ค.
๊ทธ๋ฌ๋ ๋จ์ ์ผ๋ก๋ ๋
ธ์ถ ํธํฅ ๋ฌธ์ ๊ฐ ์๋ค. ์ถ๋ก ๊ณผ์ ์์๋ Ground Truth๋ฅผ ์ ๊ณตํ ์ ์๊ธฐ ๋๋ฌธ์ ํ์ต๊ณผ ์ถ๋ก ๋จ๊ณ์์์ ์ฐจ์ด๊ฐ ์กด์ฌํ๊ฒ ๋๊ณ ์ด๋ ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ ์์ ์ฑ์ ๋จ์ด๋จ๋ฆด ์ ์๋ค.
๋ค๋ง ๋
ธ์ถ ํธํฅ ๋ฌธ์ ๊ฐ ์๊ฐ๋งํผ ํฐ ์ํฅ์ ๋ฏธ์น์ง ์๋๋ค๋ ์ฐ๊ตฌ๊ฒฐ๊ณผ๊ฐ ์๋ค.
(T. He, J. Zhang, Z. Zhou, and J. Glass. Quantifying Exposure Bias for Neural Language Generation (2019), arXiv.)
for t in range(max_len):
input_emb = embedding(input_id).unsqueeze(0) # (1, B, d_w)
output, hidden = gru(input_emb, hidden) # output: (1, B, d_h), hidden: (1, B, d_h)
# V: vocab size
output = output_layer(output) # (1, B, V)
probs, top_id = torch.max(output, dim=-1) # probs: (1, B), top_id: (1, B)
print("*" * 50)
print(f"Time step: {t}")
print(output.shape)
print(probs.shape)
print(top_id.shape)
input_id = top_id.squeeze(0) # (B)
์๋ฐฉํฅ ๋ฐ ์ฌ๋ฌ layer ์ฌ์ฉ
num_layers = 2
num_dirs = 2
dropout=0.1
gru = nn.GRU(
input_size=embedding_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=True if num_dirs > 1 else False
)
์ฌ๊ธฐ์๋ 2๊ฐ์ ๋ ์ด์ด ๋ฐ ์๋ฐฉํฅ์ ์ฌ์ฉํ๋ค. ๊ทธ๋์ hidden state์ ํฌ๊ธฐ๋ (4, Batchsize, hidden dimension) ์ด ๋๋ค.
# d_w: word embedding size, num_layers: layer์ ๊ฐ์, num_dirs: ๋ฐฉํฅ์ ๊ฐ์
batch_emb = embedding(batch) # (B, L, d_w)
h_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size)) # (num_layers * num_dirs, B, d_h) = (4, B, d_h)
packed_batch = pack_padded_sequence(batch_emb.transpose(0, 1), batch_lens)
packed_outputs, h_n = gru(packed_batch, h_0)
print(packed_outputs)
print(packed_outputs[0].shape)
print(h_n.shape)
PackedSequence(data=tensor([[-0.0214, -0.0892, 0.0404, ..., -0.2017, 0.0148, 0.1133],
[-0.1170, 0.0341, 0.0420, ..., -0.1387, 0.1696, 0.2475],
[-0.1272, -0.1075, 0.0054, ..., -0.0152, -0.0856, -0.0097],
...,
[ 0.2953, 0.1022, -0.0146, ..., 0.0467, -0.0049, -0.1354],
[ 0.1570, -0.1757, -0.1698, ..., 0.0369, -0.0073, 0.0044],
[ 0.0541, 0.1023, -0.1941, ..., 0.0117, 0.0276, 0.0636]],
grad_fn=<CatBackward>), batch_sizes=tensor([10, 10, 10, 10, 10, 9, 7, 7, 6, 6, 5, 5, 5, 5, 5, 4, 4, 3,
1, 1]), sorted_indices=None, unsorted_indices=None)
torch.Size([123, 1024])
torch.Size([4, 10, 512])
์ค์ ๋ก ํ๋ ์คํ
์ดํธ์ ํฌ๊ธฐ๊ฐ 4๋ก ์์ํ๋ ๊ฒ์ ์ ์ ์๋ค. ๋ํ, packed_outputs ์ญ์ 256๊ฐ๊ฐ ์๋๋ผ 1024๊ฐ์ ์ฐจ์์ผ๋ก ์ด๋ฃจ์ด์ง ๊ฒ์ ์ ์ ์๋ค.
outputs, output_lens = pad_packed_sequence(packed_outputs)
print(outputs.shape) # (L, B, num_dirs*d_h)
print(output_lens)
torch.Size([20, 10, 1024])
tensor([20, 18, 18, 17, 15, 10, 8, 6, 6, 5])