(8๊ฐ) Training & Inference 2
210826
Last updated
Was this helpful?
210826
Last updated
Was this helpful?
model.train(True)
์ ๋์ผํ๋ค. ๋ชจ๋ธ์ ํ์ตํ๊ฒ ๋ค๋ผ๋ ๋ป. ์ด ๋ถ๋ถ์ ํด์ฃผ์ง ์์ผ๋ฉด ํ์ตํ ๋ ์์ด์ Batch Normalization๊ณผ Dropout์ด ์ ์ฉ๋์ง ์๋๋ค. ์ด๋ฌํ ํ
ํฌ๋์ CNN์ ์์ด์ ๊ฑฐ์ ํ์์ ์ผ๋ก ํฌํจ๋๋ ๊ธฐ์ ์ด๊ธฐ ๋๋ฌธ์ ์ด๊ฒ์ ์ ์ฉํ๋ ค๋ฉด ์ด ๋ถ๋ถ์ ์ ์ธํด์ค์ผ ํ๋ค.
์ด๋ฅผ ํด์ฃผ์ง ์์ผ๋ฉด ์ด์ ๋ฐฐ์น์์ ๊ณ์ฐ๋ grad์ ํ์ฌ ๋ฐฐ์น์์ ๊ณ์ฐ๋ gread๊ฐ ๋ํด์ง๊ฒ๋๋ค.
loss๋ฅผ ์ปจํธ๋กค ํ๋ฉด์ ์ ์ฒด์ ์ธ ํ๋ผ๋ฏธํฐ๋ฅผ ์ปจํธ๋กคํ ์ ์๊ฒ๋๋ค.
๊ณ์์ ์ผ๋ก next_function์ด ์ฐ๊ฒฐ๋์ด ์ฒด์ธ๊ตฌ์ฑ์ผ๋ก ์ด๋ฃจ์ด์ ธ์๋ ๋ชจ์ต. loss.backward()๊ฐ ์ด๋ฃจ์ด์ง๋ฉด ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ grad ๊ฐ์ด ๊ฐฑ์ ์ด ๋๋ค. ๊ทธ๋ฌ๋ ์ค์ง์ ์ผ๋ก ํ๋ผ๋ฏธํฐ๊ฐ ๊ฐฑ์ ๋ ๊ฒ์ ์๋๋ค. ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐฑ์ ํ๋ ค๋ฉด ๋ค์์ ์คํํด์ผํ๋ค.
์ด ํจ์๋ฅผ ๊ฑฐ์น๋ฉด, ํ๋ผ๋ฏธํฐ์ grad๊ฐ์ ๊ฐ์ง๊ณ ํ๋ผ๋ฏธํฐ๊ฐ์ ๊ฐฑ์ ํ๊ฒ๋๋ค.
๋ง์ฝ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ํฌ๊ฒ ํ๊ณ ์ถ์๋ฐ, GPU ๋ฆฌ์์ค๊ฐ ๋ถ์กฑํ๋ค๋ฉด ์ด๋ป๊ฒ ํ ๊น? ์ด์ฉ ์ ์์ด ์์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ก ๋๋ฆด ๊ฒ์ด๋ค. ์ด ๋ ๋งค๋ฒ ์๊ฒ ์ค์ ํ ๋ฐฐ์น๋ง๋ค optimizer.step() ์ด ์ด๋ฃจ์ด์ง๊ฒ ๋๋๋ฐ,
model.train(False)
์ ๋์ผํ๋ค. ๋ชจ๋ธ์ ํ๊ฐํ๊ฒ ๋ค๋ ๋ป. ์ฌ๊ธฐ์๋ ๋๋์์์ด๋ ๋ฐฐ์น์ ๊ทํ ๊ธฐ๋ฅ์ด ๊บผ์ง๊ฒ๋๋ค.
์ด ์์ญ์์๋ถํฐ ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ grad๋ False๊ฐ์ ๊ฐ์ง๋ค๋ ๊ฒ์ด๋ค.
์ถ๋ก ๊ณผ์ ์ Validation ์ ์ด ๋ค์ด๊ฐ๋ฉด ์ด๊ฒ์ด ๊ฒ์ฆ๊ณผ์ ์ด๋ค. Test ์ ๊ณผ ํฐ ์ฐจ์ด์ ์ ์๋ค.
๋ณดํต Validation์ ์ ์ฑ๋ฅ์ ๋ณด๊ณ ๋ชจ๋ธ์ ์ ์ฅํ ์ง ๋ง์ง ๊ฒฐ์ ํ๊ฒ ๋๋ค. ์๋ํ๋ฉด Validation ๋ฐ์ดํฐ๋ Model์ feed๋์ง ์์๊ธฐ ๋๋ฌธ์ด๋ค.