(7๊ฐ•) Training & Inference 1

210826

ํ•™์Šต ํ”„๋กœ์„ธ์Šค์— ํ•„์š”ํ•œ ์š”์†Œ๋Š” ํฌ๊ฒŒ ์•„๋ž˜์™€ ๊ฐ™์ด ๋‚˜๋ˆŒ ์ˆ˜ ์žˆ๋‹ค.

Loss

Loss๋Š” Output๊ณผ Target์˜ ์ฐจ์ด๋ฅผ ์–ด๋–ป๊ฒŒ ์ •์˜ํ•˜๋А๋ƒ์— ๋”ฐ๋ผ ๋‹ค๋ฅด๋‹ค.

Loss๋Š” nn.Module ์„ ์ƒ์†ํ•˜๊ณ  ์žˆ๊ธฐ ๋–„๋ฌธ์— forward ํ•จ์ˆ˜๊ฐ€ ์žˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ์—ฌ๊ธฐ์„œ loss.backward ์˜ ํ•œ ์ค„ ์ฝ”๋“œ๋กœ ์–ด๋–ป๊ฒŒ ๋ชจ๋ธ์˜ ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์—…๋ฐ์ดํŠธ๊ฐ€ ๋ ๊นŒ?

์—ฌ๊ธฐ์„œ ์•Œ์•„์•ผ ํ•  ์ ์€, nn.Module ์„ ์ƒ์†ํ•˜๊ณ  ์žˆ๋Š” ๋ชจ๋“ˆ๋“ค์€ ๋ชจ๋‘ forward ํ•จ์ˆ˜๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— input๋ถ€ํ„ฐ output๊นŒ์ง€์˜ ์—ฐ๊ฒฐ์ด ์ƒ๊ธด๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ๋˜, ์–ด๋–ค ๋ ˆ์ด์–ด์˜ output์€ ๋‹ค์Œ ๋ ˆ์ด์–ด์˜ input์ด ๋˜๊ณ , ๋ชจ๋“  ๋ ˆ์ด์–ด๊ฐ€ forward ํ•จ์ˆ˜๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ฒซ ์ž…๋ ฅ๋‹จ๋ถ€ํ„ฐ loss๊นŒ์ง€๋Š” ์—ฐ๊ฒฐ์ด ๋œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋‹จ์ˆœํžˆ loss ์—์„œ ์‹œ์ž‘ํ•˜๋”๋ผ๋„ ์ž…๋ ฅ๋‹จ์˜ ์ฒ˜์Œ๊นŒ์ง€ ์˜ฌ ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ด๋‹ค.

loss.backward ๊ฐ€ ์ด๋ฃจ์–ด์ง€๋ฉด ๊ฐ๊ฐ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ์˜ grad ๊ฐ’์ด ๊ฐฑ์‹ ๋œ๋‹ค. ์ด ๋•Œ ์ด๋Ÿฌํ•œ ๊ฐฑ์‹ ์—ฌ๋ถ€๋ฅผ required_grad ๋กœ ์„ค์ •ํ•ด์ค„ ์ˆ˜ ์žˆ๊ณ , False๋กœ ์„ค์ •ํ•  ๊ฒฝ์šฐ ๊ฐฑ์‹ ๋˜์ง€ ์•Š๋Š”๋‹ค.

Lossํ•จ์ˆ˜๋ฅผ Custom์œผ๋กœ ์ •์˜ํ•  ์ˆ˜๋„ ์žˆ๋‹ค.

  • Focal Loss : Class Imbalance ๋ฌธ์ œ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ, ๋งž์ถ˜ ํ™•๋ฅ ์ด ๋†’์€ Class๋Š” ์กฐ๊ธˆ์˜ loss๋ฅผ, ๋งž์ถ˜ ํ™•๋ฅ ์ด ๋‚ฎ์€ Class๋Š” loss๋ฅผ ํ›จ์”ฌ ๋†’๊ฒŒ ๋ถ€์—ฌํ•œ๋‹ค.

  • Label Smoothing Loss : Class target label์„ Onehot์ด ์•„๋‹Œ Soft ํ•˜๊ฒŒ ํ‘œํ˜„ํ•ด์„œ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์„ ๋†’์ธ๋‹ค.

    • 0๊ณผ 1์˜ ๊ฐ’๋งŒ ๊ฐ€์ง€๊ฒŒ ๋˜๋ฉด ๊ทน๋‹จ์˜ feature๋งŒ์„ ๊ฐ–๊ฒŒ๋˜๋Š”๋ฐ, ์‚ฌ์‹ค class๋งˆ๋‹ค ๋น„์Šทํ•œ feature๋„ ์žˆ์„ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ด๋Ÿฌํ•œ ๋ถ€๋ถ„์„ ์œ ์—ฐํ•˜๊ฒŒ ์„ค์ •ํ•œ๋‹ค

Optimizer

Optimizer๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐฑ์‹ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ •์˜ํ•œ๋‹ค.

์™ผ์ชฝ์˜ ๊ฒฝ์šฐ ํ•™์Šต๋ฅ ์ด ๊ณ ์ •๋˜์–ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ˆ˜๋ ดํ•˜๊ธฐ๊ฐ€ ์–ด๋ ต๋‹ค. ์˜ค๋ฅธ์ชฝ์ฒ˜๋Ÿผ ํ•™์Šต๋ฅ ์ด ์ ์  ์ž‘์•„์ง„๋‹ค๋ฉด ์ˆ˜๋ ดํ•˜๊ธฐ๊ฐ€ ์‰ฌ์›Œ์งˆ ๊ฒƒ์ด๋‹ค.

ํ•™์Šต๋ฅ ์„ ๋™์ ์œผ๋กœ ์กฐ์ ˆํ•˜๋Š” LR scheduler์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒƒ๋“ค์ด ์žˆ๋‹ค.

StepLR

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

step_size ๋งˆ๋‹ค ํ•™์Šต๋ฅ ์„ gamma ๋งŒํผ์˜ ๋น„์œจ๋กœ ์„ค์ •ํ•œ๋‹ค.

  • step_size ๋Š” batch_size ๋งŒํผ ํ•™์Šต์„ ํ•˜๊ณ  ๋‚œ ๋’ค ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐฑ์‹ ํ•˜๋Š” ํšŸ์ˆ˜์ด๋‹ค

CosineAnnealingLR

scheduler = torch.optim.lr_scheduler.CossineAnnealingLR(optimizer, T_MAX=10, eta_min=0))

ํ•™์Šต๋ฅ ์˜ ๋ณ€ํ™”๋ฅผ Cosine ํ•จ์ˆ˜์ฒ˜๋Ÿผ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜์ด๋‹ค. ์–ต์ง€์ฒ˜๋Ÿผ ๋ณด์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, ๋‚˜๋ฆ„์˜ ์žฅ์ ๋„ ์žˆ๋‹ค. step์ด ๋งŽ๋‹ค๊ณ  ๋ฌด์ž‘์ • ๋‚ฎ์ถ”์ง€ ์•Š๋‹ค๋ณด๋‹ˆ๊นŒ local minimum์— ์ž˜ ๋น ์ง€์ง€ ์•Š๋Š”๋‹ค๋Š” ์ ์ด๋‹ค.

ReduceLROnPlateau

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, T_MAX=10, eta_min=0))

์ผ๋ฐ˜์ ์œผ๋กœ ๊ฐ€์žฅ ๋งŽ์ด ์“ฐ๋Š” ์Šค์ผ€์ฅด๋Ÿฌ์ด๋‹ค. ๋” ์ด์ƒ ์„ฑ๋Šฅ ํ–ฅ์ƒ์ด ์—†์„ ๋•Œ ํ•™์Šต๋ฅ ์ด ๊ฐ์†Œํ•œ๋‹ค.

Metric

์ง€ํ‘œ(=์ธก์ •๋ฒ•)๋Š” ํ•™์Šต์— ์ง์ ‘์ ์ธ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€๋Š” ์•Š๋Š”๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ traing์˜ ์ค‘์š”ํ•œ ์š”์†Œ๋กœ ๋ด์•ผ๋˜๋Š” ์ด์œ ๋Š” ์ง€ํ‘œ๊ฐ€ ์—†์œผ๋ฉด ๊ฐ๊ด€์ ์œผ๋กœ ๋ชจ๋ธ์˜ ์‹ ๋ขฐ๋„๋‚˜ ๋ฒ”์šฉ์„ฑ์„ ํŒ๋‹จํ•  ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๋‹จ์ˆœํžˆ loss์˜ ์ˆ˜์น˜๋กœ๋งŒ ๋ด์„œ๋Š” ์‹ค์ œ๋กœ production์—์„œ ์ ์šฉํ•˜๊ธฐ์—๋Š” ๋ถ€์กฑํ•œ์ ์ด ๋งŽ๋‹ค.

๋ชจ๋ธ์˜ ํ‰๊ฐ€

  • Classification

    • Accuracy : ๋ณดํ†ต ๋งŽ์ด ์“ฐ๋‚˜ class๊ฐ„ imbalance๊ฐ€ ์žˆ์œผ๋ฉด ๋‹ค๋ฅธ ์ง€ํ‘œ๋„ ์‚ฌ์šฉํ•œ๋‹ค.

    • F1-score : Class๋ณ„ ๋ฐธ๋Ÿฐ์Šค๊ฐ€ ์ข‹์ง€ ์•Š์„ ๋•Œ ๊ฐ ํด๋ž˜์Šค ๋ณ„๋กœ ์„ฑ๋Šฅ์„ ์ž˜ ๋‚ผ ์ˆ˜ ์žˆ๋Š”์ง€์— ๋Œ€ํ•œ ์ง€ํ‘œ์ด๋‹ค.

    • precision

    • recall

    • ROC&AUC

  • Regression

    • MAE

    • MSE

  • Ranking : ์ถ”์ฒœ์‹œ์Šคํ…œ์—์„œ ๋งŽ์ด ์“ฐ์ด๋Š” ์ง€ํ‘œ์ด๋‹ค. ์ถ”์ฒœ๋˜๋Š” ํ•ญ๋ชฉ์ด ๊ฐ€์žฅ ์œ„์— ๋– ์•ผ ํ•˜๊ธฐ๋•Œ๋ฌธ์— ์ˆœ์„œ๊นŒ์ง€ ๊ณ ๋ ค๋˜๋Š” ์ง€ํ‘œ์ด๋‹ค.

    • MRR

    • NDCG

    • MAP

Last updated

Was this helpful?