Chapter 4 / Deep Neural Net
Multi-Class Classification
Fasion MNIST ๋ฐ์ดํฐ์ , Deep Neural Net์ ์ด์ฉํด ํจ์ ์์ดํ (class=10)์ ๋ถ๋ฅํ๋ ๋ชจ๋ธ์ ๋ง๋ญ๋๋ค.
Dataset
Fasion MNIST ๋ฐ์ดํฐ์ ์ฌ์ฉ
28x28 ํฝ์ ์ 70000๊ฐ์ ํ๋ฐฑ ์ด๋ฏธ์ง๋ก ํจ์ ์์ดํ ์ 10๊ฐ์ง ์นดํ ๊ณ ๋ฆฌ๋ก ๋๋ ๋ฐ์ดํฐ์
MNIST๊ฐ ์ซ์ ๋ฐ์ดํฐ๋ฅผ 28x28, 10๊ฐ๋ก ๋๋ ๊ฒ์ฒ๋ผ ๋น์ทํ ํํ, torchvision์์ ๋ฐ์ดํฐ์ ์ ๋ก๋
torchvision, utils
transform = transforms.Compose([transform๋ฆฌ์คํธ])
๋ก ๋ณํ๊ธฐ๋ฅผ ๋ง๋ค๊ณ , torchvision.datasets
์์ ๋ฐ์ดํฐ์
์ ๋ก๋ํ ๋ transform ํค์๋์ transform์ ํ ๋นํด์ฃผ๋ฉด ๋ฐ์ดํฐ์
์ transform ๋ฆฌ์คํธ ์์๋๋ก ๋ณํ
์ฌ์ฉ๋ ๊ตฌ๋ฌธ๋ค
๊ตฌ๋ฌธ
์ค๋ช
transform.Compose([])
transform ์ค๋ธ์ ํธ๋ฅผ ๋ง๋ค์ด ์ผ๊ด์ ์ผ๋ก ๋ณํํ๋ ๋ณํ๊ธฐ ๋ง๋ค๊ธฐ
datasets.FasionMNIST()
fasion MNIST ๋ถ๋ฌ์ค๊ธฐ
data.DataLoader()
dataset ๊ฐ์ฒด๋ฅผ ๋ค์ํ ์ต์ ์ผ๋ก ๋ถ๋ฌ์ ๋ฐ๋ณต ๊ฐ๋ฅํ ๊ฐ์ฒด๋ก ๋ง๋ค๊ธฐ
Model
๋ค์ ์ฝ๋๋ก CUDA ํ์ธ
๋ชจ๋ธ์ ๊ตฌ์กฐ๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์
Layer
dimension
view
28*28*1 -> 784
Linear
784 -> 256
ReLU
Linear
256 -> 128
ReLU
Linear
128 -.> 10
model.to(DEVICE)
๋ฅผ ํตํด CUDA๋ฅผ ์ฌ์ฉํ ์ ์์
torch.nn.Module
๋ก๋ถํฐ ์๋ธํด๋์ฑํด ๋ง๋ ๋ชจ๋ธ์ด๋ค. nn.Module์ ์ฌ์ฉํ๋ฉด ์ง๊ด์ ์ผ๋ก ์ปค์คํ
๋ชจ๋ธ์ ๋ง๋ค ์ ์๋ค.
Optimization
ํญ๋ชฉ
๊ฐ
Epochs
30
Batch Size
64
Loss Function
Cross Entropy
Optimizer
SGD
Learning Rate
0.01
๋ค์ ์ฝ๋๋ train, evaluate ๋ถ๋ถ์๋ค.
ํ์ต ๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ๋ค.
data์ target์ GPU(CPU)๋ก ์ด๋
optimizer์ gradient๋ฅผ ์ด๊ธฐํ
data์ ๋ํ model์ ์ถ๋ ฅ ์ฐ์ฐ
output์ ๋ํ target๊ณผ์ loss ๊ณ์ฐ
loss๋ฅผ ์ญ๋ฏธ๋ถํด gradient ํ ๋น
optimizer์ ํ ๋น๋ parameter๋ค์ ๊ณ์ฐ๋ gradient๋ฅผ ์ ์ฉ
ํ๊ฐ ๋ฐฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ๋ค.
data์ target์ GPU(CPU)๋ก ์ด๋
data์ ๋ํ model์ ์ถ๋ ฅ ์ฐ์ฐ
์ถ๋ ฅ๊ณผ target์ loss ๊ณ์ฐ, ๋ชจ๋ batch์ ๋ํด ํฉํจ
์ถ๋ ฅ์ argmax๋ฅผ ๊ตฌํด ์ผ๋ง๋ ๋ง์ด ์ ๋ต์ ๋งํ๋์ง ์ฐ์ฐ
Data Augmentation
์ด๋ฏธ์ง๋ฅผ ๋ฌด์์๋ก ๋ค์ง์ด ๋ฐ์ดํฐ์ ์ ํฌ๊ธฐ๋ฅผ ๋๋ ค ํ์ต์ ๋ ์ ํ ์ ์๊ฒ ๋ง๋ ๋ค. ๋ณธ ์ฑํฐ์์ RandomHorizontalFlip์ transform์ ์ถ๊ฐํด ๋ฐ์ดํฐ๋ฅผ ๋๋ฆฐ๋ค.
Dropout
๊ณผ์ ํฉ์ ์ ์ ๋ฐ์ดํฐ์ ๋ํด ์์ธก์ด ํ์ต ์ค์ฐจ๋ฅผ ์ค์ด๋ ๋ฐ ๊ณผํ๊ฒ ํ์ตํ๊ณ ์ค์ ์ผ๋ฐ์ ๋ฐ์ดํฐ์ ๋ํด์ ์ผ๋ฐํํ์ง ๋ชปํ๋ ๊ฒฝ์ฐ๋ฅผ ๋งํ๋ค. train loss๋ ๊ณ์ ์ค์ด๋ค์ง๋ง, validation loss๋ ์ฆ๊ฐํ๋ ์์ ์ด ์๋๋ฐ, ์ด ์์ ์์ ํ์ต์ ์ข ๋ฃํด์ผ ์ ์ ํ ์์ธก์ ์ป์ ์ ์๋ค.
dropout์ ์ ๊ฒฝ๋ง์์ ๋ค์ layer๋ก ์ด๋ํ ๋ ์ผ์ ํ๋ฅ ๋ก node๊ฐ ์๋ ๊ฒ์ฒ๋ผ ์ด๋ํ๋ค. ๊ณ์ฐํ ๊ฒฐ๊ณผ์์ ์ผ๋ถ node์ ๊ฒฐ๊ณผ๋ฅผ ์ผ์ ํ๋ฅ ๋ก ์ ๊ฑฐํด ๊ณผ์ ํฉ์ ๋ฐฉ์งํ๋ค.
๊ฐ๋จํ dropout ํจ์๋ฅผ ๊ฑฐ์นจ์ผ๋ก์จ ์ฌ์ฉํ ์ ์๋ค.
Last updated
Was this helpful?