SYDLAB_inha/Machine Learning

ML - Generative Model(2) / GAN(adversarial networks) / generator / discriminator

빈그레 2023. 9. 18. 02:05

 

 

 


Generative Adversarial Networks

 

 

 

 

 

GAN

 

: generative model중에 크게 VAE와 GAN이 있는데 GAN이 압도적으로 더 많다. GAN 중 일부는 VAE의 좋은 특성을 갖고 있는 것들도 있다. 

 

GAN이란, generator(생성자)와 discriminator(판별자) 네트워크 간의 경쟁적인 학습을 기반으로 하며, 실제 데이터와 비슷한 가짜 데이터를 생성하는 모델이다. 

 

generator는 random latent vector(잠재 벡터) 에서 데이터를 생성하려 계속해서 업그레이드 되고,

discriminator는 생성된 데이터와 실제 데이터를 구별하기 위해 계속해서 업그레이드 됨으로써

이 둘이 경쟁 구도를 가지며 생성자가 점차적으로 실제와 더욱 유사한 데이터를 생성하도록 이끈다.

 

 

GAN에는 크게 두가지의 모듈이 있다. 

 

 

Generator Network 

: 말그대로 무언가를 생성하는 것

 

Discriminator Network

: 분별하는 것 (classifier)

 

 

 

Training GANs : Two-player game (생성자 vs 판별자)

 

 

Random한 latent vector인 z를 generator network를 통과시키면 이미지랑 똑같은 차원의 가짜 이미지를 만들 수 있다. generator network는 일종의 decoder역할을 한다.

 

초기 generator network의 초기화값은 loss function이 적용되기 전이기에 이미지와 크기는 똑같지만 값은 이상할 것이다.

 

discriminator는 두개의 real 혹은 fake 이 label만 존재하는 classfier라 볼 수 있다.  

 

 

 

Training GANs : Step by Stpe 
[ 세타d 학습 과정 // maximize term ]

Step 1
: real image가 들어오면 discriminator가 1이 되도록 학습을 시킨다.




Step 2



generator가 만든 fake image에 대해서는 discriminator가 0이라는 값으로 label을 예측할 수 있게끔 loss function을 걸어주는 것이다. 여기서 핵심은 back propagation은 discriminator까지만 가고 generator까지는 가지 않는다는 것이다.  따라서 step1,2는 discriminator에 대해서만 update 하고있는 과정이라고 할 수 있다.


[ 세타g 학습 과정 // minimize term ]

Step 3
:
generator의 parameter를 학습시키기 위해서, discriminator의 parmameter는 학습이 안되게끔 freeze시키고, back propagation으로 g가 update 되도록 한다. 즉!! generator가 더 real같은 data를 만들 수 있도록 하되, 그 과정이 discriminator 모르게 일어나서 update된 데이터에 대해서도 판별자가 다시 제대로 판별해 낼 수 있는지를 보며 경쟁구도로 두 module를 업데이트 시킨다.

 

 

 

 

 

Object function

 

loss function은 objective function의 한 형태이다. loss function은 값이 작을수록 좋은 상황인 것이다.

 

괄호 안 수식에 대해서 generator 입장에서는 그 값을 작게 하려고 하고, discriminator는 그 값을 커지게끔 하고싶어한다. 

minmax game이란, 한쪽에서는 maximise되는 것을 원하고 한쪽에서는 minimise되는 것을 원하는 게임이다.