DAY 3 : DataSet/Lodaer | EfficientNet
210825
Dataset/Loader
์ค๋์ ๋ฏธ์ ์ ๋ฐ์ดํฐ์ ๊ณผ ๋ฐ์ดํฐ๋ก๋๋ฅผ ๊ตฌํํ๋ ๊ฒ. ์ฌ์ค ํ๊ธฐ๋ฅผ ์์ฑํ๋ ์์ ์์๋ ์ด๋ ค์ด ๋ถ๋ถ์ด ์์ง๋ง, ์ ๋ง๋ก ๋ง์ ์ฒ์ ๊ตฌํํ ๋๋ ๋๋ฌด ๋ง๋งํ๋ค. ๊ทธ๋งํผ ๋ด๊ฐ ์ ๋๋ก ์ดํดํ์ง ๋ชปํ๊ฑฐ๊ฒ ์ง.
์ผ๋จ ์ฒ์์๋ ์บ๊ธ์ ์ฐธ๊ณ ํ๋ค. ์บ๊ธ ์ฝ๋๋ฅผ ๊ทธ๋๋ก ์ด ๊ฒ์ ์๋๊ณ ์ด๋ ํ ํ๋ฆ์ผ๋ก ์จ์ง๋ ๊ตฌ๋๋ฅผ ์ฐธ๊ณ ํ๋ค. ๊ทธ๋ฌ๋ฉด์ ์๊ฒ๋ ๋ถ๋ถ์ CFG ๋ผ๋ ๋์ ๋๋ฆฌ ๋ณ์์ ํ์ดํผ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ชจ๋ ์ ์ธํด๋๋ ๋ฐฉ๋ฒ์ด์๋ค. ๋ ์์๋ณด๋ ์ด๋ฅผ ํด๋์ค๋ก ์ ์ธํ๋ ์ฌ๋๋ ์์๋ค.
Config
DATA_DIR = './input/data/train/images/'
CFG = {
'fold_num': 5,
'seed': 719,
'epochs': 30,
'train_bs': 30,
'valid_bs': 60,
'T_0': 10,
'lr': 1e-5,
'max_lr': 1e-3,
'weight_decay':1e-6,
'num_workers': 8,
'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
'verbose_step': 1,
'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
}
print(f'{CFG["device"]} is using!')Train_test_split
์ดํ, Train data์ Valid data๋ก ๋๋์๋ค. ์ด ๋๋ sklearn ์ train_test_split ๋ฅผ ์ฌ์ฉํ๋ค.
์ธ์๋ก data ํ๋๋ง ์ฃผ์ด์ง๋ฉด train_data ์ valid_data๋ก ๋๋์ด ์ฃผ๋ฉฐ ์ธ์๋ก data์ label์ด ์ฃผ์ด์ง๋ฉด trainX, trainY, validX, validY๋ก ๋๋์ด์ก๋ค. ๋๋ ์ ์๊ฐ ํ์ํด์ data ์ธ์ ํ๋๋ง ์ฃผ์๋ค.
Transform
์ด๋ฏธ์ง์ ํฌ๊ธฐ๊ฐ 384 * 512์ด๋ค. ์ผ๊ตด์ด ๋๋ถ๋ถ ์ค์์ ์์นํ๋ฏ๋ก CenterCrop์ ์ฌ์ฉํ๋ฉด ๋๋ค๊ณ ์๊ฐ์ด ๋ค์๋ค. CenterCrop์ ์ฌ์ฉํ๋ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ๋ค.
์ด๋ฏธ์ง ์ฌ์ด์ฆ๊ฐ ์์์๋ก GPU์ฌ์ฉ์ด ์ค์ด๋ ๋ค. ์ค์ ๋ก ์๋ณธ ์ด๋ฏธ์ง๋ฅผ ์ ๋ ฅํ์ ๋๋ batch size๋ฅผ ๋งค์ฐ ์๊ฒ ํด์ผ๋ง ๋์๊ฐ๋ค.
์ฌ๋์ ์ผ๊ตด ์ ๋ณด๋ง ํ์ํ๋ค๊ณ ์๊ฐํ๋ค. ๊ทธ ์ธ์๋ ๋ฒฝ์ด๋ ์ท๋ฑ์ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง๊ฐ ํ์ต์ ์คํ๋ ค ๋ฐฉํด๊ฐ ๋๋ค๊ณ ์๊ฐํ๋ค.
๋ค์ ์ด๋ฏธ์ง๋ฅผ ์ฐธ๊ณ ํ๋ฉด ์ ์ ์๋ฏ์ด b4๊ฐ ํ์ตํ ์ด๋ฏธ์ง์ ํฌ๊ธฐ๋ 380์ด๋ค.

๊ทธ ์ธ์ ๋ง์ Transform์ ํด์ฃผ์ง๋ ์์๋ค. RandomChoice๋ฅผ ์ฌ์ฉํด์ 4๊ฐ์ trsf ๋ฅผ ์์๋ก ์ ์ฉ๋๋๋ก ํด์ฃผ์๊ณ ์ด ์์ ์๋ ๋ณํ์ ๋ฐ๊ธฐ, ์ฑ๋ ๋ฑ์ ํฝ์ ๊ฐ ๋ณํ์ด๋ค.
1/2 ํ๋ฅ ๋ก ์ข์ฐ๋ฐ์ ์ด ์ผ์ด๋๋๋ก ํ๋ค.
ToTensor๋ฅผ ์ด์ฉํด ํ ์๋ก ๋ณํ๋๊ฒ ํ๊ณ ์ ๊ทํ๋ฅผ ํ๋ค. ์ด ๋ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๊ฐ์ train image์ ํ๊ท ๊ณผ ํ์คํธ์ฐจ ๊ฐ์ ์๋์ ๊ฐ์ด ๊ตฌํ๊ณ ์ด๋ฅผ ์์๋ก ๊ณ์ ์ฐ๋๋ก ํ๋ค.seed๊ฐ ๊ณ ์ ๋์ด ์์ด์ ๊ณ์ ๋๊ฐ์ train image๋ก ๊ณ ์ ๋๋ค.
๋จ์ํ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์์ ๋ชจ๋ ํฝ์ ๊ฐ์ ํฉํ๊ณ ์ด์๋ํ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๊ตฌํ๋ ๊ณผ์ ์ด๋ค.
Dataset
initdataframe์ ์ธ์๋ก ๋ฐ๊ณ ๊ทธ ์์ ์๋ ํน์ ์ปฌ๋ผ์ X์ y๋ก ์ ํ๋ค. ์ฌ๊ธฐ์๋
path์label์ด๋ค.
getitemPILํจํค์ง์Image๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํด์ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์๋ค.cv2๋ฅผ ์ฌ์ฉํ ๊น ํ์ง๋ง BGR๋ก ์ฝ์ด์ง๊ณ ์ด๋ฅผ ๋งค๋ฒ convert ํด์ค์ผ ํด์ ์ฌ์ฉํ์ง ์์๋ค.dataframe์์ index๋ก ์ ๊ทผํ๋ ค๋ฉด
iloc๋ฅผ ์ฌ์ฉํด์ผํ๋ค.image๋ transform์์
ToTensor๋ฅผ ๊ฑฐ์น๋ฉด์ tensor ํํ๊ฐ ๋๋๊น ๊ทธ๋๋ก ๋ฐํํด์ฃผ๊ณ , label์ tensor๋ก ์บ์คํ ํด์ค๋ค.
๋ฐ์ดํฐ์ ์ ์์ฑํ๋ค.
DataLoader
ํ๋ จ ๋ฐ์ดํฐ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ ์๊ฒ ํ๊ณ ๊ฒ์ฆ ๋ฐ์ดํฐ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ 2๋ฐฐ๋ก ์ค์ ํ๋ค,
ํ๋ จ ๋ฐ์ดํฐ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ 30 ๋๋ 60์ผ๋ก ๊ฒฐ์ ํ๋ค.
EfficientNet
์ฌ์ค, ์ฌ๋ฌ ๋ชจ๋ธ์ ์ฐพ์๋ณด๊ณ ์คํ์ ํตํด ๊ฒฐ์ ํ๋ ๊ฒ์ด ๋ง์ง๋ง, ์ฌ๋ฌ ์ด์ ๋ฅผ ํตํด EfficientNet์ ์ ์ผ ๋จผ์ ์ฌ์ฉํ๊ฒ ๋์๋ค.
์ด์ 1๊ธฐ ๋ฉค๋ฒ์ ํฌ์คํ ์ ์ฐธ๊ณ ํ๋ EfficientNet ์ฌ์ฉ
์ด๋ฏธ์ง๋ท ๋ฆฌ๋๋ณด๋์์ EfficientNet์ด 3๋ฑ์ด๋ค.
๊ทธ๋์ ๋ด์ผ 1, 2๋ฑ ๋ชจ๋ธ์ธ ViT๋ ์ฌ์ฉํด๋ณผ ์์
๋ฉํ ๋์ ์ถ์ฒ

์๊ฐ๋ณด๋ค ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๋ ๊ฒ์ ์ฌ์ ๋ค.
๋, ๋ชจ๋ธ๋ค์ ๋ชจ์๋์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ timm ์ ์ฌ์ฉํด์ ๋ถ๋ฌ์ฌ ์๋ ์์๋ค. efficientnet ๋ชจ๋ธ์ ์ข
๋ฅ๋ ๊ต์ฅํ ๋ง๋ค.
๋๋ ์ด ์ค์์ efficientnet_b4 ๋ชจ๋ธ์ ์ ํํ๋ค. ์ด์ ๋ ๋ค์๊ณผ ๊ฐ๋ค.
b7์๋ฆฌ์ฆ ๋ถํฐ๋ V100 ์ผ๋ก ๋๋ฆด ์๊ฐ ์์๋ค.์ด ๋น์ batch_size๋ 30์ด๋ค. ๋ ์๊ฒํ์ผ๋ฉด ๊ฐ๋ฅํ ์๋ ์์๊ฒ ์ง๋ง, 15,000 ์ฅ์ ๋ฐ์ดํฐ์ ์ ์ฌ๋ฌ epoch๋ฅผ ๋๋ ค๊ฐ๋ฉฐ ํ์ธํ๋ค๊ณ ํ๋ฉด ์์ฒญ๋ ์๊ฐ์ด ์์๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ์ ์ด๋ ์ ์ผ ๋ง์ง๋ง์ ์ฌ์ฉํ ์ ์๋ ๋ฐฉ๋ฒ์ด๋ค.
b5์๋ฆฌ์ฆ ๋ถํฐ๋pretrained=True์ธ ๋ชจ๋ธ์ด ์๋ค. ์ฆb4๊น์ง๋งpretrained๋ ๋ชจ๋ธ๋ก ์ฌ์ฉ ๊ฐ๋ฅํ๋ค.
์ค์ ๋ก pretrained์ ํ์ ์์ฒญ๋ฌ๋๋ฐ, ์ฅ์ ์ ๋ค์๊ณผ ๊ฐ๋ค.
7๋ฒ ์ดํ์ ์ ์ epoch ์๋ก๋ train data๊ฐ ์์ ํ ํ์ต๋์๋ค. ๊ทธ๋งํผ ์ ์ ์๊ฐ์ด ์์๋๋ค.
๋ฐ๋ผ์, ์ ์ epoch ์์์ ์ฌ๋ฌ๊ฐ์ง ์คํ์ด ๊ฐ๋ฅํ๋ค. Loss ํจ์๋ ์ด๋ค๊ฒ์ด ์ข๊ณ , Lr Scheduler๋ ์ด๋ค ๊ฒ์ด ์ข๊ณ ๋ฑ์ ์คํ.
๋ฐ๋๋ก ๋งํ๋ฉด, pretrained ๋ ๋ชจ๋ธ์ด ์๋๋ผ๋ฉด ์ค๋ฒํผํ ๋๋ ์์ ์ epoch๋ฅผ ์๊ธฐ๊ฐ ์ด๋ ค์ ๊ณ , ๋งค๋ฒ ๋ฐ๋ ๊ฐ๋ฅ์ฑ๋ ์๋ค. ๋, ์๊ฐ๋ ๋ง์ด ์์๋์ด ์ด๋ค ์คํ์ ํ๊ธฐ๊ฐ ์ด๋ ต๊ณ , early stopping ๊ฐ์ด ๋ถ์์ ์ธ ๋ถ๋ถ์ ์ถ๊ฐํ์ด์ผ ํ๋ค.
์ฌ์ค, ํ๋ฒ 50~100 epoch ์ฉ ๋๋ฆฌ๋ฉด์ ํ์ต ํด๋ณด๊ณ ์ถ์์ง๋ง, ์ ์ถ์ ๋ํ ์๋ฐ๋ ์์๊ณ , ์ฌ๋ฌ ์คํ๋ ํด๋ณด๊ณ ์ถ์๋ค. ๋, ์ฑ๋ฅ์ ๋์ผ ์ ์๋ ๋ค์ํ ํ ํฌ๋์ ํด๋ณผ๊ฒ ๋ง๋ค๊ณ ์๊ฐํ๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก๋ ์์ฝ์ง๋ง ๊ทธ๋๋ pretrained ๋ ๋ชจ๋ธ์ ์ฌ์ฉํด์ ๋ ๋ค์ํ ์คํ๊ณผ ํ ํฌ๋์ ์ ์ฉํ๊ธฐ์ ํํ๋ ์๋ค.
์๋ฌด๋ฐ ํ ํฌ๋์ ์ ์ฉํ์ง ์๊ณ ๋๋ ธ์ ๋์ f1 ์ ์๋ 60์ ์ค๋ฐ ์ ๋๊ฐ ๋์๋ค. ์๊ฐ๋ณด๋ค ์ ์๊ฐ ๋ฎ๋ค ์ถ์์ง๋ง, ์ฌ๋ฌ ํ ํฌ๋์ ๊ณ ๋ฏผํด๋ณด๊ณ ์๋ค.
Last updated
Was this helpful?