GAN (Generative Adversarial Networks)

- 저자: 이안 굿펠로우 (이름이 특이.. 좋은친구)

  • 천재임.
  • 지도교수가 요수아 벤지오.

- ref

- 최근 10년간 머신러닝 분야에서 가장 혁신적인 아이디어이다. (얀 르쿤)

- 무슨내용? 생성모형

생성모형이란? (쉬운설명)

만들 수 없다면 이해하지 못한 것이다, 리처드 파인만 (천재물리학자)

- 사진속에 들어 있는 동물이 개인지 고양이 인지 맞추는 기계와, 개와 고양이 그림을 그려주는 기계중 어떤것이 더 시각정보에 대한 이해가 높다고 볼 수 있을까?

- 진정으로 인공지능이 이미지에 대한 이해를 했다면 이미지를 만들수도 있어야 한다 $\to$ 이미지를 생성하는 모형을 고려하자 $\to$ 성공

GAN의 응용분야

- 실제사진 + 반고흐의 화풍

- 1920년대 서울의 모습이 칼라로 복원된다면? 퀸의 라이브에이드가 4k로 복원된다면?

- 딥페이크: 유명배우들의 가짜 포르노, 가짜뉴스, 협박(거짓기소)

- 게임영상 (파이널판타지)

- 추억의 거북이(가수) 소환

- 너무 많아요..

생성모형?

제한된 정보만으로 어떤 문제를 풀 때, 그 과정에서 원래의 문제보다 일반적인 문제를 풀지말고, 가능한 원래의 문제를 직접 풀어야 한다. 배프닉 (SVM 창시자)

- 이미지를 ${\boldsymbol x}$라고 하고 라벨을 $y$라고 하자.

- 이미지를 보고 라벨을 맞추는 일은 $p(y|\boldsymbol{x})$에 관심이 있다.

- 이미지를 생성하는 일은 $p(\boldsymbol{x}, y)$에 관심이 있다.

- 데이터의 생성확률은 $p(\boldsymbol{x}, y)$을 알면 클래스의 사후확률 $p(y|\boldsymbol{x})$를 알 수 있음. 하지만 역은 불가능.

$$p(y|\boldsymbol{x}) = \frac{p(\boldsymbol{x},y)}{p(\boldsymbol{x})} = \frac{p(\boldsymbol{x},y)}{\sum_y p(\boldsymbol{x},y)}$$

  • 즉 이미지를 생성하는 일은 이미지를 분류하는 문제보다 더 어려운 일이라 해석가능

- 따라서 배프닉의 원리에 따르면 식별을 하고 싶다면 생성모형이 그렇게 매력적이지 않음.

- 하지만 다양한 현실문제에서 생성모형이 유용할 때가 많이 있음.

GAN의 원리

- GAN은 생성모형중 하나임

- GAN의 원리: 경찰과 위조지폐범이 서로 선의의(?) 경쟁을 통하여 서로 발전

  • 아래는 위에서 언급한 이안굿펠로우의 GAN 논문의 일부

The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles

- 서로 적대적인(adversarial) 네트워크들(networks)를 동시에 학습시켜 가짜이미지를 만든다(generate)

- 무식한 상황극..

위조범:가짜돈을 만들어서 부자가 되어야지!! (가짜돈을 그림)
경찰: (위조범이 만든 돈을 보고) 이건 가짜다!!
위조범: 걸렸군.. 더 정교하게 만들어서 속여야겠다.
경찰: 이게 진짜가?... --> 상사한테 혼남 (그것도 구분못하나?!) --> (실력이 업그레이드되어서) 이건 가짜다!!
위조범: 더 더.. 정교하게 만들자.
경찰: 판별능력 더 업그레이드..
반복..

- 굉장히 우수한 경찰조차도 진짜와 가짜를 구분하지 못할때 (= 진짜 이미지를 0.5의 확률로만 진짜라고 말할때 = 가짜이미지를 0.5의 확률로만 가짜라고 말할때) 학습을 멈춘다.

GAN의 구현

import

import torch 
from fastai.vision.all import *

data

path = untar_data(URLs.MNIST_SAMPLE)
(path/'train').ls()
(#2) [Path('/home/csy/.fastai/data/mnist_sample/train/7'),Path('/home/csy/.fastai/data/mnist_sample/train/3')]
threes = (path/'train'/'3').ls()
X = torch.stack([tensor(Image.open(i)) for i in threes]).float()/255
X.shape
torch.Size([6131, 28, 28])
plt.imshow(X[0])
<matplotlib.image.AxesImage at 0x7fd53199b790>

- MLP를 이용해 학습하기 위해 X의 차원을 변경

X=X.reshape(6131,28*28)
X.shape
torch.Size([6131, 784])

위조지폐범의 설계: noise $\to$ 가짜이미지를 만들어 내는 네트워크를 만들자.

- 네트워크의 입력? 적당한 벡터, 혹은 매트릭스에 노이즈(랜덤으로 채운 어떠한 숫자들)를 채운것

- 네트워크의 출력? (28,28)의 텐서 혹은 784의 벡터

net1 = torch.nn.Sequential(torch.nn.Linear(in_features=28, out_features=64),
                           torch.nn.ReLU(),
                           torch.nn.Linear(in_features=64, out_features=64), 
                           torch.nn.ReLU(),
                           torch.nn.Linear(in_features=64, out_features=784),
                           torch.nn.Sigmoid()) ## 마지막의 시그모이드는 출력이 0~1사이로 나오게 하기 위함 
counterfeiter = net1 

경찰의 설계: 진짜이미지는 1, 가짜이미지는 0으로 판별하는 DNN을 만들자.

- 네트워크의 입력? (28,28)의 텐서, 혹은 784의 벡터

- 네트워크의 출력? yhat (y는 0 or 1)

net2 = torch.nn.Sequential(torch.nn.Linear(in_features=784,out_features=64),
                           torch.nn.ReLU(),
                           torch.nn.Linear(in_features=64,out_features=28),
                           torch.nn.ReLU(),
                           torch.nn.Linear(in_features=28,out_features=1),
                           torch.nn.Sigmoid())
police = net2 

스토리전개

- 아래는 진짜이미지

realimage=X[0].reshape(28,28)
plt.imshow(realimage)
<matplotlib.image.AxesImage at 0x7fd53197f130>

- 위와 같은 진짜 이미지를 경찰이 봤음 $\to$ yhat이 나오겠죠?

policehat_from_realimage = police(realimage.reshape(-1))
policehat_from_realimage
tensor([0.4959], grad_fn=<SigmoidBackward0>)
  • 진짜 이미지일수록 policehat_from_realimage $\approx$ 1 이어야 함
  • 하지만 그렇지 못함 (배운것이 없는 무능한 경찰)

- 이번에는 가짜이미지를 경찰이 봤다고 생각해보자.

(step1) 랜덤으로 아무숫자나 28개를 생성한다.

err= torch.randn(28)
err
tensor([ 1.4210, -0.8841, -2.0476, -1.3590,  0.1983,  1.8354,  0.0223, -0.0560,
        -1.0690, -0.4255,  2.1154, -0.9462,  0.1073, -0.6490,  1.0822,  1.5894,
        -1.0199,  1.2513,  1.5294, -1.4069,  0.7065,  0.9657,  0.6484,  1.8417,
         1.2033,  0.7760, -2.2023, -0.4770])

(step2) 위조범은 err를 입력으로 받고 가짜이미지를 만든다.

couterfeiter_output는 network의 출력이다보니 미분에 대한 정보가 있어 fakeimage에도 미분 정보가 포함되어 있다. detach를 통해 제거해주자

couterfeiter_output = counterfeiter(err)
fakeimage=couterfeiter_output.reshape(28,28)
plt.imshow(fakeimage.detach())
<matplotlib.image.AxesImage at 0x7fd5318e5a00>
  • 누가봐도 가짜자료임
  • 위조범의 실력이 형편없음

(step3) 위조범이 생성한 이미지를 경찰한테 넘긴다.

아래는 shape을 바꾸냐 안 바꾸냐의 차이일 뿐

policehat_from_fakeimage = police(couterfeiter_output)
#policehat_from_fakeimage = police(fakeimage.detach().reshape(-1))
policehat_from_fakeimage
tensor([0.5085], grad_fn=<SigmoidBackward0>)

- 경찰의 실력도 형편없고 위조범의 실력도 형편없다.

경찰네트워크의 실력을 향상시키자.

- 데이터 정리

  • 원래 $n=6131$개의 이미지 자료가 있음. 이를 ${\bf X}$라고 하자. 따라서 ${\bf X}$의 차원은 (6131,784).
  • 위조범이 만든 가짜자료를 원래 자료와 같은 숫자인 6131개 만듬. 이 가짜자료를 $\tilde{\bf X}$라고 하자. 따라서 $\tilde{\bf X}$의 차원은 (6131,784).
  • 진짜자료는 1, 가짜자료는 0으로 라벨링.
X.shape
torch.Size([6131, 784])
err= torch.randn(6131,28)
counterfeiter_output = counterfeiter(err) # counterfeiter_output를 Xtilde로 생각하면 된다. 
real_label=torch.tensor([[1.0]]*6131) ## y=1 
fake_label=torch.tensor([[0.0]]*6131) ## y=0

- step1: yhat, 경찰의 예측

policehat_from_realimage = police(X) 
policehat_from_fakeimage = police(counterfeiter_output) 

- step2: 손실함수? 경찰의 미덕은 (1) 가짜를 가짜라고 하고 (2) 진짜를 진짜라 해야한다.

sigmoid는 앞에서 추가해준 단계니까 빼고

loss_fn = torch.nn.BCELoss() 
loss_of_police =\
loss_fn(policehat_from_fakeimage,fake_label)+\
loss_fn(policehat_from_realimage,real_label)

loss_of_police
tensor(1.4011, grad_fn=<AddBackward0>)

- step3~4는 미분이후 업데이트

- 옵티마이저를 설계하자.

optimizer_of_police = torch.optim.Adam(police.parameters())

- for 문을 돌리자.

for i in range(50): 
    ## 1 yhat 
    policehat_from_realimage = police(X) 
    
    #policehat_from_fakeimage = police(Xitlde)
    err= torch.randn(6131,28)
    counterfeiter_output = counterfeiter(err) # counterfeiter_output를 Xtilde로 생각하면 된다. 
    policehat_from_fakeimage= police(counterfeiter_output)
    
    ## 2 loss 
    loss_of_police =\
    loss_fn(policehat_from_fakeimage,fake_label)+\
    loss_fn(policehat_from_realimage,real_label)
    
    ## 3 back propagation 
    loss_of_police.backward()
    
    ## 4 update
    optimizer_of_police.step()
    optimizer_of_police.zero_grad()

- 훈련된 경찰의 성능을 살펴보자.

police(counterfeiter_output)
tensor([[0.0009],
        [0.0009],
        [0.0009],
        ...,
        [0.0009],
        [0.0009],
        [0.0008]], grad_fn=<SigmoidBackward0>)
police(X)
tensor([[0.9998],
        [0.9995],
        [0.9997],
        ...,
        [0.9996],
        [0.9932],
        [0.9946]], grad_fn=<SigmoidBackward0>)

- 우수한 경찰 (비록 위조범의 수준이 낮긴하지만)

위조범네트워크의 성능을 향상시키자.

- 자료구조: X는 임의의 에러이미지, net(X)는 fakeimage

err=torch.randn(6131,28) 
counterfeiter_output= counterfeiter(err) 

- 손실함수: 잘 훈련된 경찰조차도 잘못된 판단을 내릴만큼 가짜지폐를 잘 만들면 위조범의 실력이 우수하다 볼 수 있음

policehat_from_fakeimage = police(counterfeiter_output) 
loss_of_counterfeiter = loss_fn(policehat_from_fakeimage,real_label) ## 가짜이미지를 보고 경찰이 진짜라고 믿으면 위조범의 실력이 좋은것임  

- 옵티마이저

optimizer_of_counterfeiter = torch.optim.Adam(counterfeiter.parameters())

- 학습

for i in range(50): 
    ## 1 
    err=torch.randn(6131,28) 
    counterfeiter_output= counterfeiter(err)  
    policehat_from_fakeimage = police(counterfeiter_output) 
    ## 2 
    loss_of_counterfeiter = loss_fn(policehat_from_fakeimage,real_label)
    ## 3 
    loss_of_counterfeiter.backward()
    ## 4 
    optimizer_of_counterfeiter.step()
    optimizer_of_counterfeiter.zero_grad()

- 위조범의 실력향상을 감상해보자.

plt.imshow(counterfeiter_output[0].reshape(28,28).data)
<matplotlib.image.AxesImage at 0x7fd531027e50>

두 적대적 네트워크를 경쟁시키자.

for k in range(100): 
    for i in range(50): 
        ## 1 yhat 
        policehat_from_realimage = police(X) 
    
        #policehat_from_fakeimage = police(Xitlde)
        err= torch.randn(6131,28)
        counterfeiter_output = counterfeiter(err) # counterfeiter_output를 Xtilde로 생각하면 된다. 
        policehat_from_fakeimage= police(counterfeiter_output)
    
        ## 2 loss 
        loss_of_police =\
        loss_fn(policehat_from_fakeimage,fake_label)+\
        loss_fn(policehat_from_realimage,real_label)
    
        ## 3 back propagation 
        loss_of_police.backward()
    
        ## 4 update
        optimizer_of_police.step()
        optimizer_of_police.zero_grad()
        
    for i in range(50): 
        ## 1 
        err=torch.randn(6131,28) 
        counterfeiter_output= counterfeiter(err)  
        policehat_from_fakeimage = police(counterfeiter_output) 
        ## 2 
        loss_of_counterfeiter = loss_fn(policehat_from_fakeimage,real_label)
        ## 3 
        loss_of_counterfeiter.backward()
        ## 4 
        optimizer_of_counterfeiter.step()
        optimizer_of_counterfeiter.zero_grad()        

- 위조범의 최종적 실력향상감상

plt.imshow(counterfeiter_output[0].reshape(28,28).data)
<matplotlib.image.AxesImage at 0x7fd530fa1ca0>
police(counterfeiter_output[0])
tensor([0.4974], grad_fn=<SigmoidBackward0>)