(8๊ฐ•) Training & Inference 2

210826

Training Process

model.train()

model.train(True) ์™€ ๋™์ผํ•˜๋‹ค. ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ฒ ๋‹ค๋ผ๋Š” ๋œป. ์ด ๋ถ€๋ถ„์„ ํ•ด์ฃผ์ง€ ์•Š์œผ๋ฉด ํ•™์Šตํ•  ๋•Œ ์žˆ์–ด์„œ Batch Normalization๊ณผ Dropout์ด ์ ์šฉ๋˜์ง€ ์•Š๋Š”๋‹ค. ์ด๋Ÿฌํ•œ ํ…Œํฌ๋‹‰์€ CNN์— ์žˆ์–ด์„œ ๊ฑฐ์˜ ํ•„์ˆ˜์ ์œผ๋กœ ํฌํ•จ๋˜๋Š” ๊ธฐ์ˆ ์ด๊ธฐ ๋•Œ๋ฌธ์— ์ด๊ฒƒ์„ ์ ์šฉํ•˜๋ ค๋ฉด ์ด ๋ถ€๋ถ„์„ ์„ ์–ธํ•ด์ค˜์•ผ ํ•œ๋‹ค.

optimizer.zero_grad()

์ด๋ฅผ ํ•ด์ฃผ์ง€ ์•Š์œผ๋ฉด ์ด์ „ ๋ฐฐ์น˜์—์„œ ๊ณ„์‚ฐ๋œ grad์— ํ˜„์žฌ ๋ฐฐ์น˜์—์„œ ๊ณ„์‚ฐ๋œ gread๊ฐ€ ๋”ํ•ด์ง€๊ฒŒ๋œ๋‹ค.

loss = criterion(outpus, labels)

loss๋ฅผ ์ปจํŠธ๋กค ํ•˜๋ฉด์„œ ์ „์ฒด์ ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ปจํŠธ๋กคํ•  ์ˆ˜ ์žˆ๊ฒŒ๋œ๋‹ค.

๊ณ„์†์ ์œผ๋กœ next_function์ด ์—ฐ๊ฒฐ๋˜์–ด ์ฒด์ธ๊ตฌ์„ฑ์œผ๋กœ ์ด๋ฃจ์–ด์ ธ์žˆ๋Š” ๋ชจ์Šต. loss.backward()๊ฐ€ ์ด๋ฃจ์–ด์ง€๋ฉด ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ grad ๊ฐ’์ด ๊ฐฑ์‹ ์ด ๋œ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์‹ค์งˆ์ ์œผ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๊ฐฑ์‹ ๋œ ๊ฒƒ์€ ์•„๋‹ˆ๋‹ค. ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐฑ์‹ ํ•˜๋ ค๋ฉด ๋‹ค์Œ์„ ์‹คํ–‰ํ•ด์•ผํ•œ๋‹ค.

optimizer.step()

์ด ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์น˜๋ฉด, ํŒŒ๋ผ๋ฏธํ„ฐ์˜ grad๊ฐ’์„ ๊ฐ€์ง€๊ณ  ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ’์„ ๊ฐฑ์‹ ํ•˜๊ฒŒ๋œ๋‹ค.

Gradient Accumulation

๋งŒ์•ฝ์— ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ ํฌ๊ฒŒ ํ•˜๊ณ  ์‹ถ์€๋ฐ, GPU ๋ฆฌ์†Œ์Šค๊ฐ€ ๋ถ€์กฑํ•˜๋‹ค๋ฉด ์–ด๋–ป๊ฒŒ ํ• ๊นŒ? ์–ด์ฉ” ์ˆ˜ ์—†์ด ์ž‘์€ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋กœ ๋Œ๋ฆด ๊ฒƒ์ด๋‹ค. ์ด ๋•Œ ๋งค๋ฒˆ ์ž‘๊ฒŒ ์„ค์ •ํ•œ ๋ฐฐ์น˜๋งˆ๋‹ค optimizer.step() ์ด ์ด๋ฃจ์–ด์ง€๊ฒŒ ๋˜๋Š”๋ฐ,

Inference Process

model.eval()

model.train(False) ์™€ ๋™์ผํ•˜๋‹ค. ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๊ฒ ๋‹ค๋Š” ๋œป. ์—ฌ๊ธฐ์„œ๋Š” ๋“œ๋ž์•„์›ƒ์ด๋‚˜ ๋ฐฐ์น˜์ •๊ทœํ™” ๊ธฐ๋Šฅ์ด ๊บผ์ง€๊ฒŒ๋œ๋‹ค.

with torch.no_grad()

์ด ์˜์—ญ์—์„œ๋ถ€ํ„ฐ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ grad๋Š” False๊ฐ’์„ ๊ฐ€์ง„๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

Validation

์ถ”๋ก  ๊ณผ์ •์— Validation ์…‹์ด ๋“ค์–ด๊ฐ€๋ฉด ์ด๊ฒƒ์ด ๊ฒ€์ฆ๊ณผ์ •์ด๋‹ค. Test ์…‹๊ณผ ํฐ ์ฐจ์ด์ ์€ ์—†๋‹ค.

Checkpoint

๋ณดํ†ต Validation์…‹์˜ ์„ฑ๋Šฅ์„ ๋ณด๊ณ  ๋ชจ๋ธ์„ ์ €์žฅํ• ์ง€ ๋ง์ง€ ๊ฒฐ์ •ํ•˜๊ฒŒ ๋œ๋‹ค. ์™œ๋ƒํ•˜๋ฉด Validation ๋ฐ์ดํ„ฐ๋Š” Model์— feed๋˜์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

Last updated

Was this helpful?