check point file
# save checkpoint
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}, "./checkpoint.tar")
# load checkpoint
checkpoint = torch.load("./checkpoint.tar")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
Plain Text
๋ณต์ฌ
optimizer
์ตํฐ๋ง์ด์ ๋ ํ์ต ๋ฐ์ดํฐ(Train data)์
์ ์ด์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ต ํ ๋ ๋ฐ์ดํฐ์ ์ค์ ๊ฒฐ๊ณผ์ ๋ชจ๋ธ์ด ์์ธกํ ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ์ค์ผ ์ ์๊ฒ ๋ง๋ค์ด์ฃผ๋ ์ญํ ์ ํ๋ค.
๊ฒฐ๋ก ๋ถํฐ ์ค๋ช
ํ์๋ฉด ํ์ฌ ๊ฐ์ฅ ๋ง์ด ์ฌ์ฉํ๋ ์ตํฐ๋ง์ด์ ๋ Adam์ด๋ค. ํ์ฑํ ํจ์(Activation Function)๋ก Relu๋ฅผ ๊ฐ์ฅ ๋ง์ด ํ์ฉํ๋ ๊ฒ์ฒ๋ผ, Adam์ด ๋น ๋ฅด๊ธฐ๋ ํ๋ฉด์ ์ฑ๋ฅ๋ ์ข๊ณ ๋ฌด๋ํ๋ค. ํ์ง๋ง ๋ค๋ฅธ ์ตํฐ๋ง์ด์ ๋ฅผ ๋ฌด์กฐ๊ฑด ์ฐ์ง ๋ง๋ผ๋ ๋ฒ์ ์๋ค. ํ๋ก์ ํธ์ ๋ฐ์ดํฐ ๋ณ๋ก ๋ฏธ๋ฌํ๊ฒ ๋ค๋ฅธ ์ตํฐ๋ง์ด์ ๊ฐ ๋ ์ข์์๋ ์๊ณ , ํน์ ํ์ต ์๋์ ๋ฌธ์ ๋ฑ์ผ๋ก ์ฝ๊ฐ์ ์ฑ๋ฅ์ ํฌ๊ธฐํ ์ ์๋ค.