Dataset/Loader
์ค๋์ ๋ฏธ์
์ ๋ฐ์ดํฐ์
๊ณผ ๋ฐ์ดํฐ๋ก๋๋ฅผ ๊ตฌํํ๋ ๊ฒ. ์ฌ์ค ํ๊ธฐ๋ฅผ ์์ฑํ๋ ์์ ์์๋ ์ด๋ ค์ด ๋ถ๋ถ์ด ์์ง๋ง, ์ ๋ง๋ก ๋ง์ ์ฒ์ ๊ตฌํํ ๋๋ ๋๋ฌด ๋ง๋งํ๋ค. ๊ทธ๋งํผ ๋ด๊ฐ ์ ๋๋ก ์ดํดํ์ง ๋ชปํ๊ฑฐ๊ฒ ์ง.
์ผ๋จ ์ฒ์์๋ ์บ๊ธ์ ์ฐธ๊ณ ํ๋ค. ์บ๊ธ ์ฝ๋๋ฅผ ๊ทธ๋๋ก ์ด ๊ฒ์ ์๋๊ณ ์ด๋ ํ ํ๋ฆ์ผ๋ก ์จ์ง๋ ๊ตฌ๋๋ฅผ ์ฐธ๊ณ ํ๋ค. ๊ทธ๋ฌ๋ฉด์ ์๊ฒ๋ ๋ถ๋ถ์ CFG ๋ผ๋ ๋์
๋๋ฆฌ ๋ณ์์ ํ์ดํผ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ชจ๋ ์ ์ธํด๋๋ ๋ฐฉ๋ฒ์ด์๋ค. ๋ ์์๋ณด๋ ์ด๋ฅผ ํด๋์ค๋ก ์ ์ธํ๋ ์ฌ๋๋ ์์๋ค.
Config
DATA_DIR = './input/data/train/images/'
CFG = {
'fold_num': 5,
'seed': 719,
'epochs': 30,
'train_bs': 30,
'valid_bs': 60,
'T_0': 10,
'lr': 1e-5,
'max_lr': 1e-3,
'weight_decay':1e-6,
'num_workers': 8,
'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
'verbose_step': 1,
'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
}
print(f'{CFG["device"]} is using!')
Train_test_split
์ดํ, Train data์ Valid data๋ก ๋๋์๋ค. ์ด ๋๋ sklearn
์ train_test_split
๋ฅผ ์ฌ์ฉํ๋ค.
data = pd.read_csv('./train_face.csv')
train_df, valid_df = train_test_split(data, test_size=0.35, shuffle=True, stratify=data['label'], random_state=2021)
train_df.shape, valid_df.shape
์ธ์๋ก data ํ๋๋ง ์ฃผ์ด์ง๋ฉด train_data ์ valid_data๋ก ๋๋์ด ์ฃผ๋ฉฐ ์ธ์๋ก data์ label์ด ์ฃผ์ด์ง๋ฉด trainX, trainY, validX, validY๋ก ๋๋์ด์ก๋ค. ๋๋ ์ ์๊ฐ ํ์ํด์ data ์ธ์ ํ๋๋ง ์ฃผ์๋ค.
transform = transforms.Compose([
transforms.CenterCrop((380, 380)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomChoice([transforms.ColorJitter(brightness=(0.2, 3)),
transforms.ColorJitter(contrast=(0.2, 3)),
transforms.ColorJitter(saturation=(0.2, 3)),
transforms.ColorJitter(hue=(-0.3, 0.3))
]),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
์ด๋ฏธ์ง์ ํฌ๊ธฐ๊ฐ 384 * 512์ด๋ค. ์ผ๊ตด์ด ๋๋ถ๋ถ ์ค์์ ์์นํ๋ฏ๋ก CenterCrop์ ์ฌ์ฉํ๋ฉด ๋๋ค๊ณ ์๊ฐ์ด ๋ค์๋ค. CenterCrop์ ์ฌ์ฉํ๋ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ๋ค.
์ด๋ฏธ์ง ์ฌ์ด์ฆ๊ฐ ์์์๋ก GPU์ฌ์ฉ์ด ์ค์ด๋ ๋ค. ์ค์ ๋ก ์๋ณธ ์ด๋ฏธ์ง๋ฅผ ์
๋ ฅํ์ ๋๋ batch size๋ฅผ ๋งค์ฐ ์๊ฒ ํด์ผ๋ง ๋์๊ฐ๋ค.
์ฌ๋์ ์ผ๊ตด ์ ๋ณด๋ง ํ์ํ๋ค๊ณ ์๊ฐํ๋ค. ๊ทธ ์ธ์๋ ๋ฒฝ์ด๋ ์ท๋ฑ์ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง๊ฐ ํ์ต์ ์คํ๋ ค ๋ฐฉํด๊ฐ ๋๋ค๊ณ ์๊ฐํ๋ค.
๋ค์ ์ด๋ฏธ์ง๋ฅผ ์ฐธ๊ณ ํ๋ฉด ์ ์ ์๋ฏ์ด b4๊ฐ ํ์ตํ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋ 380์ด๋ค.
๊ทธ ์ธ์ ๋ง์ Transform์ ํด์ฃผ์ง๋ ์์๋ค. RandomChoice๋ฅผ ์ฌ์ฉํด์ 4๊ฐ์ trsf ๋ฅผ ์์๋ก ์ ์ฉ๋๋๋ก ํด์ฃผ์๊ณ ์ด ์์ ์๋ ๋ณํ์ ๋ฐ๊ธฐ, ์ฑ๋ ๋ฑ์ ํฝ์
๊ฐ ๋ณํ์ด๋ค.
1/2 ํ๋ฅ ๋ก ์ข์ฐ๋ฐ์ ์ด ์ผ์ด๋๋๋ก ํ๋ค.
ToTensor
๋ฅผ ์ด์ฉํด ํ
์๋ก ๋ณํ๋๊ฒ ํ๊ณ ์ ๊ทํ๋ฅผ ํ๋ค. ์ด ๋ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๊ฐ์ train image์ ํ๊ท ๊ณผ ํ์คํธ์ฐจ ๊ฐ์ ์๋์ ๊ฐ์ด ๊ตฌํ๊ณ ์ด๋ฅผ ์์๋ก ๊ณ์ ์ฐ๋๋ก ํ๋ค.
seed๊ฐ ๊ณ ์ ๋์ด ์์ด์ ๊ณ์ ๋๊ฐ์ train image๋ก ๊ณ ์ ๋๋ค.
def get_img_stats(img_paths):
img_info = dict(means=[], stds=[])
for img_path in tqdm(img_paths):
img = np.array(Image.open(glob(img_path)[0]))
img_info['means'].append(img.mean(axis=(0,1)))
img_info['stds'].append(img.std(axis=(0,1)))
return img_info
img_stats = get_img_stats(train_df.path.values)
mean = np.mean(img_stats["means"], axis=0) / 255.
std = np.mean(img_stats["stds"], axis=0) / 255.
print(f'RGB Mean: {mean}')
print(f'RGB Standard Deviation: {std}')
๋จ์ํ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์์ ๋ชจ๋ ํฝ์
๊ฐ์ ํฉํ๊ณ ์ด์๋ํ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๊ตฌํ๋ ๊ณผ์ ์ด๋ค.
Dataset
class MaskDataset(Dataset):
def __init__(self, df, transform=None):
self.path = df['path']
self.transform = transform
self.label = df['label']
def __len__(self):
return len(self.path)
def __getitem__(self, index):
image = Image.open(self.path.iloc[index])
if self.transform:
image = self.transform(image)
label = self.label.iloc[index]
return image, torch.tensor(label)
class TestMaskDataset(Dataset):
def __init__(self, df, transform=None):
self.path = df['path']
self.transform = transform
def __len__(self):
return len(self.path)
def __getitem__(self, index):
image = Image.open(self.path.iloc[index])
if self.transform:
image = self.transform(image)
return image
init
dataframe์ ์ธ์๋ก ๋ฐ๊ณ ๊ทธ ์์ ์๋ ํน์ ์ปฌ๋ผ์ X์ y๋ก ์ ํ๋ค. ์ฌ๊ธฐ์๋ path
์ label
์ด๋ค.
getitem
PIL
ํจํค์ง์ Image
๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํด์ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์๋ค. cv2
๋ฅผ ์ฌ์ฉํ ๊น ํ์ง๋ง BGR๋ก ์ฝ์ด์ง๊ณ ์ด๋ฅผ ๋งค๋ฒ convert ํด์ค์ผ ํด์ ์ฌ์ฉํ์ง ์์๋ค.
dataframe์์ index๋ก ์ ๊ทผํ๋ ค๋ฉด iloc
๋ฅผ ์ฌ์ฉํด์ผํ๋ค.
image๋ transform์์ ToTensor
๋ฅผ ๊ฑฐ์น๋ฉด์ tensor ํํ๊ฐ ๋๋๊น ๊ทธ๋๋ก ๋ฐํํด์ฃผ๊ณ , label์ tensor๋ก ์บ์คํ
ํด์ค๋ค.
train_dataset = MaskDataset(df=train_df, transform=transform)
valid_dataset = MaskDataset(df=valid_df, transform=transform)
๋ฐ์ดํฐ์
์ ์์ฑํ๋ค.
DataLoader
train_loader = DataLoader(dataset = train_dataset,
batch_size=CFG['train_bs'],
shuffle=True,
num_workers=CFG['num_workers'],
)
valid_loader = DataLoader(dataset = valid_dataset,
batch_size=CFG['valid_bs'],
shuffle=False,
num_workers=CFG['num_workers'],
)
ํ๋ จ ๋ฐ์ดํฐ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ ์๊ฒ ํ๊ณ ๊ฒ์ฆ ๋ฐ์ดํฐ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ 2๋ฐฐ๋ก ์ค์ ํ๋ค,
ํ๋ จ ๋ฐ์ดํฐ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ 30 ๋๋ 60์ผ๋ก ๊ฒฐ์ ํ๋ค.
EfficientNet
์ฌ์ค, ์ฌ๋ฌ ๋ชจ๋ธ์ ์ฐพ์๋ณด๊ณ ์คํ์ ํตํด ๊ฒฐ์ ํ๋ ๊ฒ์ด ๋ง์ง๋ง, ์ฌ๋ฌ ์ด์ ๋ฅผ ํตํด EfficientNet์ ์ ์ผ ๋จผ์ ์ฌ์ฉํ๊ฒ ๋์๋ค.
์ด์ 1๊ธฐ ๋ฉค๋ฒ์ ํฌ์คํ
์ ์ฐธ๊ณ ํ๋ EfficientNet ์ฌ์ฉ
์ด๋ฏธ์ง๋ท ๋ฆฌ๋๋ณด๋์์ EfficientNet์ด 3๋ฑ์ด๋ค.
๊ทธ๋์ ๋ด์ผ 1, 2๋ฑ ๋ชจ๋ธ์ธ ViT๋ ์ฌ์ฉํด๋ณผ ์์
์๊ฐ๋ณด๋ค ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๋ ๊ฒ์ ์ฌ์ ๋ค.
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b4')
๋, ๋ชจ๋ธ๋ค์ ๋ชจ์๋์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ timm
์ ์ฌ์ฉํด์ ๋ถ๋ฌ์ฌ ์๋ ์์๋ค. efficientnet ๋ชจ๋ธ์ ์ข
๋ฅ๋ ๊ต์ฅํ ๋ง๋ค.
import timm
timm.list_models('*eff*')
['eca_efficientnet_b0',
'efficientnet_b0',
'efficientnet_b1',
'efficientnet_b1_pruned',
'efficientnet_b2',
'efficientnet_b2_pruned',
'efficientnet_b2a',
'efficientnet_b3',
'efficientnet_b3_pruned',
'efficientnet_b3a',
'efficientnet_b4',
'efficientnet_b5',
'efficientnet_b6',
'efficientnet_b7',
'efficientnet_b8',
'efficientnet_cc_b0_4e',
'efficientnet_cc_b0_8e',
'efficientnet_cc_b1_8e',
'efficientnet_el',
'efficientnet_el_pruned',
'efficientnet_em',
'efficientnet_es',
'efficientnet_es_pruned',
'efficientnet_l2',
'efficientnet_lite0',
'efficientnet_lite1',
'efficientnet_lite2',
'efficientnet_lite3',
'efficientnet_lite4',
'efficientnetv2_l',
'efficientnetv2_m',
'efficientnetv2_rw_m',
'efficientnetv2_rw_s',
'efficientnetv2_s',
'gc_efficientnet_b0',
'tf_efficientnet_b0',
'tf_efficientnet_b0_ap',
'tf_efficientnet_b0_ns',
'tf_efficientnet_b1',
'tf_efficientnet_b1_ap',
'tf_efficientnet_b1_ns',
'tf_efficientnet_b2',
'tf_efficientnet_b2_ap',
'tf_efficientnet_b2_ns',
'tf_efficientnet_b3',
'tf_efficientnet_b3_ap',
'tf_efficientnet_b3_ns',
'tf_efficientnet_b4',
'tf_efficientnet_b4_ap',
'tf_efficientnet_b4_ns',
'tf_efficientnet_b5',
'tf_efficientnet_b5_ap',
'tf_efficientnet_b5_ns',
'tf_efficientnet_b6',
'tf_efficientnet_b6_ap',
'tf_efficientnet_b6_ns',
'tf_efficientnet_b7',
'tf_efficientnet_b7_ap',
'tf_efficientnet_b7_ns',
'tf_efficientnet_b8',
'tf_efficientnet_b8_ap',
'tf_efficientnet_cc_b0_4e',
'tf_efficientnet_cc_b0_8e',
'tf_efficientnet_cc_b1_8e',
'tf_efficientnet_el',
'tf_efficientnet_em',
'tf_efficientnet_es',
'tf_efficientnet_l2_ns',
'tf_efficientnet_l2_ns_475',
'tf_efficientnet_lite0',
'tf_efficientnet_lite1',
'tf_efficientnet_lite2',
'tf_efficientnet_lite3',
'tf_efficientnet_lite4',
'tf_efficientnetv2_b0',
'tf_efficientnetv2_b1',
'tf_efficientnetv2_b2',
'tf_efficientnetv2_b3',
'tf_efficientnetv2_l',
'tf_efficientnetv2_l_in21ft1k',
'tf_efficientnetv2_l_in21k',
'tf_efficientnetv2_m',
'tf_efficientnetv2_m_in21ft1k',
'tf_efficientnetv2_m_in21k',
'tf_efficientnetv2_s',
'tf_efficientnetv2_s_in21ft1k',
'tf_efficientnetv2_s_in21k']
๋๋ ์ด ์ค์์ efficientnet_b4
๋ชจ๋ธ์ ์ ํํ๋ค. ์ด์ ๋ ๋ค์๊ณผ ๊ฐ๋ค.
b7
์๋ฆฌ์ฆ ๋ถํฐ๋ V100 ์ผ๋ก ๋๋ฆด ์๊ฐ ์์๋ค.
์ด ๋น์ batch_size๋ 30์ด๋ค. ๋ ์๊ฒํ์ผ๋ฉด ๊ฐ๋ฅํ ์๋ ์์๊ฒ ์ง๋ง, 15,000 ์ฅ์ ๋ฐ์ดํฐ์
์ ์ฌ๋ฌ epoch๋ฅผ ๋๋ ค๊ฐ๋ฉฐ ํ์ธํ๋ค๊ณ ํ๋ฉด ์์ฒญ๋ ์๊ฐ์ด ์์๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ์ ์ด๋ ์ ์ผ ๋ง์ง๋ง์ ์ฌ์ฉํ ์ ์๋ ๋ฐฉ๋ฒ์ด๋ค.
b5
์๋ฆฌ์ฆ ๋ถํฐ๋ pretrained=True
์ธ ๋ชจ๋ธ์ด ์๋ค. ์ฆ b4
๊น์ง๋ง pretrained
๋ ๋ชจ๋ธ๋ก ์ฌ์ฉ ๊ฐ๋ฅํ๋ค.
์ค์ ๋ก pretrained์ ํ์ ์์ฒญ๋ฌ๋๋ฐ, ์ฅ์ ์ ๋ค์๊ณผ ๊ฐ๋ค.
7๋ฒ ์ดํ์ ์ ์ epoch ์๋ก๋ train data๊ฐ ์์ ํ ํ์ต๋์๋ค. ๊ทธ๋งํผ ์ ์ ์๊ฐ์ด ์์๋๋ค.
๋ฐ๋ผ์, ์ ์ epoch ์์์ ์ฌ๋ฌ๊ฐ์ง ์คํ์ด ๊ฐ๋ฅํ๋ค. Loss ํจ์๋ ์ด๋ค๊ฒ์ด ์ข๊ณ , Lr Scheduler๋ ์ด๋ค ๊ฒ์ด ์ข๊ณ ๋ฑ์ ์คํ.
๋ฐ๋๋ก ๋งํ๋ฉด, pretrained ๋ ๋ชจ๋ธ์ด ์๋๋ผ๋ฉด ์ค๋ฒํผํ
๋๋ ์์ ์ epoch๋ฅผ ์๊ธฐ๊ฐ ์ด๋ ค์ ๊ณ , ๋งค๋ฒ ๋ฐ๋ ๊ฐ๋ฅ์ฑ๋ ์๋ค. ๋, ์๊ฐ๋ ๋ง์ด ์์๋์ด ์ด๋ค ์คํ์ ํ๊ธฐ๊ฐ ์ด๋ ต๊ณ , early stopping ๊ฐ์ด ๋ถ์์ ์ธ ๋ถ๋ถ์ ์ถ๊ฐํ์ด์ผ ํ๋ค.
์ฌ์ค, ํ๋ฒ 50~100 epoch ์ฉ ๋๋ฆฌ๋ฉด์ ํ์ต ํด๋ณด๊ณ ์ถ์์ง๋ง, ์ ์ถ์ ๋ํ ์๋ฐ๋ ์์๊ณ , ์ฌ๋ฌ ์คํ๋ ํด๋ณด๊ณ ์ถ์๋ค. ๋, ์ฑ๋ฅ์ ๋์ผ ์ ์๋ ๋ค์ํ ํ
ํฌ๋์ ํด๋ณผ๊ฒ ๋ง๋ค๊ณ ์๊ฐํ๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก๋ ์์ฝ์ง๋ง ๊ทธ๋๋ pretrained ๋ ๋ชจ๋ธ์ ์ฌ์ฉํด์ ๋ ๋ค์ํ ์คํ๊ณผ ํ
ํฌ๋์ ์ ์ฉํ๊ธฐ์ ํํ๋ ์๋ค.
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b4', num_classes=18)
model = model.to(CFG['device'])
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CFG['lr'])
torch.cuda.empty_cache()
lrs = []
# for epoch in range(7):
for epoch in range(CFG['epochs']):
model.train()
train_batch_f1 = 0
train_batch_accuracy = []
train_batch_loss = []
train_pbar = tqdm(train_loader)
for n, (X, y) in enumerate(train_pbar):
X, y = X.to(CFG['device']), y.to(CFG['device'])
y_hat = model(X)
loss = criterion(y_hat, y)
pred = torch.argmax(y_hat, axis=1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lrs.append(optimizer.param_groups[0]["lr"])
train_batch_accuracy.append(
torch.sum(pred == y).cpu().numpy() / CFG['train_bs']
)
train_batch_loss.append(
loss.item()
)
f1 = f1_score(y.cpu().numpy(), pred.cpu().numpy(), average='macro')
train_batch_f1 += f1
train_pbar.set_description(f'train : {n} / {len(train_loader)} | f1 : {f1:.5f} | accuracy : {train_batch_accuracy[-1]:.5f} | loss : {train_batch_loss[-1]:.5f}')
model.eval()
valid_batch_f1 = 0
valid_batch_accuracy = []
valid_batch_loss = []
valid_pbar = tqdm(valid_loader)
with torch.no_grad():
for n, (X, y) in enumerate(valid_pbar):
X, y = X.to(CFG['device']), y.to(CFG['device'])
y_hat = model(X)
loss = criterion(y_hat, y)
pred = torch.argmax(y_hat, axis=1)
valid_batch_accuracy.append(
torch.sum(pred == y).cpu().numpy() / CFG['valid_bs']
)
valid_batch_loss.append(
loss.item()
)
f1 = f1_score(y.cpu().numpy(), pred.cpu().numpy(), average='macro')
valid_batch_f1 += f1
valid_pbar.set_description(f'valid : {n} / {len(valid_loader)} | f1 : {f1:.5f} | accuracy : {valid_batch_accuracy[-1]:.5f} | loss : {valid_batch_loss[-1]:.5f}')
print(f"""
epoch : {epoch+1:02d}
[train] f1 : {train_batch_f1/len(train_loader):.5f} | accuracy : {np.sum(train_batch_accuracy) / len(train_loader):.5f} | loss : {np.sum(train_batch_loss) / len(train_loader):.5f}
[valid] f1 : {valid_batch_f1/len(valid_loader):.5f} | accuracy : {np.sum(valid_batch_accuracy) / len(valid_loader):.5f} | loss : {np.sum(valid_batch_loss) / len(valid_loader):.5f}
""")
if valid_batch_f1/len(valid_loader) >= 0.9:
torch.save(model.state_dict(), f'v:f1_{valid_batch_f1/len(valid_loader):.3f}_t:f1_{train_batch_f1/len(train_loader):.5f}_efficientnet_b4_state_dict.pt') # ๋ชจ๋ธ ๊ฐ์ฒด์ state_dict ์ ์ฅ
์๋ฌด๋ฐ ํ
ํฌ๋์ ์ ์ฉํ์ง ์๊ณ ๋๋ ธ์ ๋์ f1 ์ ์๋ 60์ ์ค๋ฐ ์ ๋๊ฐ ๋์๋ค. ์๊ฐ๋ณด๋ค ์ ์๊ฐ ๋ฎ๋ค ์ถ์์ง๋ง, ์ฌ๋ฌ ํ
ํฌ๋์ ๊ณ ๋ฏผํด๋ณด๊ณ ์๋ค.