(7๊ฐ) Training & Inference 1
210826
Last updated
Was this helpful?
210826
Last updated
Was this helpful?
ํ์ต ํ๋ก์ธ์ค์ ํ์ํ ์์๋ ํฌ๊ฒ ์๋์ ๊ฐ์ด ๋๋ ์ ์๋ค.
Loss๋ Output๊ณผ Target์ ์ฐจ์ด๋ฅผ ์ด๋ป๊ฒ ์ ์ํ๋๋์ ๋ฐ๋ผ ๋ค๋ฅด๋ค.
Loss๋ nn.Module
์ ์์ํ๊ณ ์๊ธฐ ๋๋ฌธ์ forward
ํจ์๊ฐ ์๋ค. ๊ทธ๋ฐ๋ฐ ์ฌ๊ธฐ์ loss.backward
์ ํ ์ค ์ฝ๋๋ก ์ด๋ป๊ฒ ๋ชจ๋ธ์ ์ ์ฒด ํ๋ผ๋ฏธํฐ๊ฐ ์
๋ฐ์ดํธ๊ฐ ๋ ๊น?
์ฌ๊ธฐ์ ์์์ผ ํ ์ ์, nn.Module
์ ์์ํ๊ณ ์๋ ๋ชจ๋๋ค์ ๋ชจ๋ forward
ํจ์๊ฐ ์๊ธฐ ๋๋ฌธ์ input๋ถํฐ output๊น์ง์ ์ฐ๊ฒฐ์ด ์๊ธด๋ค๋ ๊ฒ์ด๋ค. ๋, ์ด๋ค ๋ ์ด์ด์ output์ ๋ค์ ๋ ์ด์ด์ input์ด ๋๊ณ , ๋ชจ๋ ๋ ์ด์ด๊ฐ forward
ํจ์๊ฐ ์๊ธฐ ๋๋ฌธ์ ์ฒซ ์
๋ ฅ๋จ๋ถํฐ loss๊น์ง๋ ์ฐ๊ฒฐ์ด ๋๋ค๊ณ ๋ณผ ์ ์๋ค. ๊ทธ๋์ ๋จ์ํ loss ์์ ์์ํ๋๋ผ๋ ์
๋ ฅ๋จ์ ์ฒ์๊น์ง ์ฌ ์ ์๋ ๊ฒ์ด๋ค.
loss.backward
๊ฐ ์ด๋ฃจ์ด์ง๋ฉด ๊ฐ๊ฐ์ ํ๋ผ๋ฏธํฐ์ grad
๊ฐ์ด ๊ฐฑ์ ๋๋ค. ์ด ๋ ์ด๋ฌํ ๊ฐฑ์ ์ฌ๋ถ๋ฅผ required_grad
๋ก ์ค์ ํด์ค ์ ์๊ณ , False๋ก ์ค์ ํ ๊ฒฝ์ฐ ๊ฐฑ์ ๋์ง ์๋๋ค.
Lossํจ์๋ฅผ Custom์ผ๋ก ์ ์ํ ์๋ ์๋ค.
Focal Loss : Class Imbalance ๋ฌธ์ ๊ฐ ์๋ ๊ฒฝ์ฐ, ๋ง์ถ ํ๋ฅ ์ด ๋์ Class๋ ์กฐ๊ธ์ loss๋ฅผ, ๋ง์ถ ํ๋ฅ ์ด ๋ฎ์ Class๋ loss๋ฅผ ํจ์ฌ ๋๊ฒ ๋ถ์ฌํ๋ค.
Label Smoothing Loss : Class target label์ Onehot์ด ์๋ Soft ํ๊ฒ ํํํด์ ์ผ๋ฐํ ์ฑ๋ฅ์ ๋์ธ๋ค.
0๊ณผ 1์ ๊ฐ๋ง ๊ฐ์ง๊ฒ ๋๋ฉด ๊ทน๋จ์ feature๋ง์ ๊ฐ๊ฒ๋๋๋ฐ, ์ฌ์ค class๋ง๋ค ๋น์ทํ feature๋ ์์ ์ ์๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ ๋ถ๋ถ์ ์ ์ฐํ๊ฒ ์ค์ ํ๋ค
Optimizer๋ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐฑ์ ํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๋ค.
์ผ์ชฝ์ ๊ฒฝ์ฐ ํ์ต๋ฅ ์ด ๊ณ ์ ๋์ด ์๊ธฐ ๋๋ฌธ์ ์๋ ดํ๊ธฐ๊ฐ ์ด๋ ต๋ค. ์ค๋ฅธ์ชฝ์ฒ๋ผ ํ์ต๋ฅ ์ด ์ ์ ์์์ง๋ค๋ฉด ์๋ ดํ๊ธฐ๊ฐ ์ฌ์์ง ๊ฒ์ด๋ค.
ํ์ต๋ฅ ์ ๋์ ์ผ๋ก ์กฐ์ ํ๋ LR scheduler์๋ ๋ค์๊ณผ ๊ฐ์ ๊ฒ๋ค์ด ์๋ค.
step_size
๋ง๋ค ํ์ต๋ฅ ์ gamma
๋งํผ์ ๋น์จ๋ก ์ค์ ํ๋ค.
step_size
๋ batch_size
๋งํผ ํ์ต์ ํ๊ณ ๋ ๋ค ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐฑ์ ํ๋ ํ์์ด๋ค
ํ์ต๋ฅ ์ ๋ณํ๋ฅผ Cosine ํจ์์ฒ๋ผ ๋ง๋๋ ํจ์์ด๋ค. ์ต์ง์ฒ๋ผ ๋ณด์ผ ์ ์์ง๋ง, ๋๋ฆ์ ์ฅ์ ๋ ์๋ค. step์ด ๋ง๋ค๊ณ ๋ฌด์์ ๋ฎ์ถ์ง ์๋ค๋ณด๋๊น local minimum์ ์ ๋น ์ง์ง ์๋๋ค๋ ์ ์ด๋ค.
์ผ๋ฐ์ ์ผ๋ก ๊ฐ์ฅ ๋ง์ด ์ฐ๋ ์ค์ผ์ฅด๋ฌ์ด๋ค. ๋ ์ด์ ์ฑ๋ฅ ํฅ์์ด ์์ ๋ ํ์ต๋ฅ ์ด ๊ฐ์ํ๋ค.
์งํ(=์ธก์ ๋ฒ)๋ ํ์ต์ ์ง์ ์ ์ธ ์ํฅ์ ๋ฏธ์น์ง๋ ์๋๋ค. ๊ทธ๋ฌ๋ traing์ ์ค์ํ ์์๋ก ๋ด์ผ๋๋ ์ด์ ๋ ์งํ๊ฐ ์์ผ๋ฉด ๊ฐ๊ด์ ์ผ๋ก ๋ชจ๋ธ์ ์ ๋ขฐ๋๋ ๋ฒ์ฉ์ฑ์ ํ๋จํ ์ ์๊ธฐ ๋๋ฌธ์ด๋ค. ๋จ์ํ loss์ ์์น๋ก๋ง ๋ด์๋ ์ค์ ๋ก production์์ ์ ์ฉํ๊ธฐ์๋ ๋ถ์กฑํ์ ์ด ๋ง๋ค.
๋ชจ๋ธ์ ํ๊ฐ
Classification
Accuracy : ๋ณดํต ๋ง์ด ์ฐ๋ class๊ฐ imbalance๊ฐ ์์ผ๋ฉด ๋ค๋ฅธ ์งํ๋ ์ฌ์ฉํ๋ค.
F1-score : Class๋ณ ๋ฐธ๋ฐ์ค๊ฐ ์ข์ง ์์ ๋ ๊ฐ ํด๋์ค ๋ณ๋ก ์ฑ๋ฅ์ ์ ๋ผ ์ ์๋์ง์ ๋ํ ์งํ์ด๋ค.
precision
recall
ROC&AUC
Regression
MAE
MSE
Ranking : ์ถ์ฒ์์คํ ์์ ๋ง์ด ์ฐ์ด๋ ์งํ์ด๋ค. ์ถ์ฒ๋๋ ํญ๋ชฉ์ด ๊ฐ์ฅ ์์ ๋ ์ผ ํ๊ธฐ๋๋ฌธ์ ์์๊น์ง ๊ณ ๋ ค๋๋ ์งํ์ด๋ค.
MRR
NDCG
MAP