🍕
AI Paper Study
  • AI Paper Study
  • Computer Vision
    • SRCNN(2015)
      • Introduction
      • CNN for SR
      • Experiment
      • 구현해보기
    • DnCNN(2016)
      • Introduction
      • Related Work
      • DnCNN Model
      • Experiment
      • 구현해보기
    • CycleGAN(2017)
      • Introduction
      • Formulation
      • Results
      • 구현해보기
  • Language Computation
    • Attention is All You Need(2017)
      • Introduction & Background
      • Model Architecture
      • Appendix - Positional Encoding 거리 증명
  • ML Statistics
    • VAE(2013)
      • Introduction
      • Problem Setting
      • Method
      • Variational Auto-Encoder
      • 구현해보기
      • Appendix - KL Divergence 적분
  • 직관적 이해
    • Seq2Seq
      • Ko-En Translation
Powered by GitBook
On this page
  • Model
  • Dataset
  • Result

Was this helpful?

  1. Computer Vision
  2. SRCNN(2015)

구현해보기

PyTorch

전체 코드와 결과는 여기에 저장되어 있다.

Model

class SRCNN(nn.Module):
    def __init__(self,f1=9,f2=5,f3=5,n1=64,n2=32):
        super(SRCNN,self).__init__()
        self.conv1 = nn.Conv2d(3,n1,f1,padding=f1//2)
        self.conv2 = nn.Conv2d(n1,n2,f2,padding=f2//2)
        self.conv3 = nn.Conv2d(n2,3,f3,bias=False,padding=f3//2)

    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

간단한 Model이므로 forward propagation을 쉽게 구현할 수 있다.

Dataset

#학습 데이터셋
class TrainDataset(Dataset):
    def __init__(self,path):
        self.paths = glob.glob(path)
        self.trans = transforms.Compose([transforms.Resize((11,11)),
                            transforms.Resize((33,33), interpolation=InterpolationMode.BICUBIC),
                            transforms.ToTensor()])

    def __getitem__(self, index):
        x = Image.open(self.paths[index])
        y = Image.open(self.paths[index])
        x = self.trans(x)
        y = transforms.ToTensor()(y)
        return x,y

    def __len__(self):
        return len(self.paths)

#테스트 데이터셋
class TestDataset(Dataset):
    def __init__(self,path):
        self.paths = glob.glob(path)

    def __getitem__(self, index):
        x = Image.open(self.paths[index])
        y = Image.open(self.paths[index])
        w = x.width//3 *3
        h = x.height//3 *3

        x = x.resize([w//3,h//3],resample=Image.BICUBIC)
        x = x.resize([w,h],resample=Image.BICUBIC)
        x = transforms.ToTensor()(x)
        y = transforms.ToTensor()(y.resize([w,h],resample=Image.BICUBIC))
        return x,y

    def __len__(self):
        return len(self.paths)

Train Dataset은 미리 crop해둔 이미지를 불러와, Resize한다.

Test Dataset은 Set5 Dataset을 한 장씩 불러와, Scale Factor=3으로 Resize 가능하도록 이미지를 조정한다.

Result

9-5-5, n1=64, n2=32의 SRCNN (3배율)으로 91 Images에서 batch size=16으로 100회 학습했다.

RGB 3 channel에서 모두 학습해 parameter의 수는 약 20000개이다.

Set5의 이미지 중 하나를 실행해 보았다. Bicubic에 비해 다른 부분 사이의 경계가 더욱 sharp해진 것을 확인할 수 있다. 원본 논문에서도, Edge detection에 해당하는 filter가 나타났다고 한다.

MSE loss와 PSNR은 구현 결과다음과 같이 보여진다. Bicubic(29.21 dB)보다 높은 PSNR을 보인다. 최고 PSNR은 30.25 dB이다.

PreviousExperimentNextDnCNN(2016)

Last updated 3 years ago

Was this helpful?

첫 두 layer는 10−410^{-4}10−4, 마지막 layer는 10−510^{-5}10−5의 learning rate로 학습하며, 논문과는 다르게 Adam Optimizer를 사용한다.

Original Image
Bicubic x3
SRCNN(9-5-5) x3
loss
PSNR