๐VAE
์ด๋ฒ ํฌ์คํธ๋ ์์ฑ๋ชจ๋ธ์์ ์ ๋ช
ํ Variational Auto-Encoder(VAE)๋ฅผ ๋ค๋ฃจ๊ณ ์๋ Auto-Encoding Variational Bayes
๋ผ๋ ๋
ผ๋ฌธ ๋ฆฌ๋ทฐ์
๋๋ค. ์ด๋ฒ ํฌ์คํธ๋ฅผ ์ ๋ฆฌํ๋ฉด์ ๊ฐ์ฅ ๋ง์ด ์ธ์ฉํ๊ณ ๋์์ ๋ฐ์ ์คํ ์ธ์ฝ๋์ ๋ชจ๋ ๊ฒ๋ฅผ ๋ณด์๋ฉด ํจ์ฌ ๋ ์์ธํ๊ณ ๊น์ ์ดํด๋ฅผ ํ์ค ์ ์์ต๋๋ค. ํฌ์คํธ์ ์์๋ ์๋์ ๊ฐ์ด ์งํ๋ฉ๋๋ค.
1 Introduction
VAE๋ ์์ฑ๋ชจ๋ธ(Generative Model)์์ ์ ๋ช
ํ ๋ชจ๋ธ์
๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์์ฑ ๋ชจ๋ธ์ด๋ ๋ฌด์์ ๋งํ๋ ๊ฑธ๊น์? ์๋ฅผ ๋ค์ด ์ฐ๋ฆฌ๊ฐ ์ฐ์ ์ ์ด ์๋ ๊ฐ์์ง ์ฌ์ง์ ๋ง๋ค์ด๋ด๊ณ ์ถ๋ค
๊ณ ํด๋ด
์๋ค. ๊ทธ๋ ์ง๋ง ๊ฐ์์ง ์ฌ์ง์ด ์ค์ ๊ฐ์์ง๋ค์ ์ฐ์ ์ฌ์ง๋ค๊ณผ ๋๋ฌด ๋๋จ์ด์ ธ์ ์ด์ง๊ฐ์ ๋๋ผ์ง ์์์ผ ํฉ๋๋ค. ์ด๋ฐ ๋งฅ๋ฝ์์ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๊ฒ์ train database์ ์๋ ์ฌ์ง๋ค, ์ฆ ์ค์ ๋ก ๊ฐ์์ง๋ค์ ์ฐ์ ์ฌ์ง๋ค์ ๋ถํฌ๋ฅผ ์๊ณ ์ถ์ต๋๋ค. ์ฌ๊ธฐ์ ๋ถํฌ๋ฅผ ์๊ณ ์ถ์ ์ด์ ๋ ์ฐ๋ฆฌ๊ฐ ๋ถํฌ(distribution)์ ์์์ผ ๋ถํฌ์์ data๋ฅผ ์ํ๋งํด์ ์์ฑํ ์ ์๊ธฐ ๋๋ฌธ์
๋๋ค. ๋ค์ ์ ๋ฆฌํ์๋ฉด, ํ์ฌ ๋ฐ์ดํฐ๋ค๊ณผ ๋น์ทํ ์๋ก์ด ๋ฐ์ดํฐ๋ฅผ ์์ฑ ํ๊ธฐ ์ํด ํ์ฌ train DB์ ๋ฐ์ดํฐ๋ค์ ๋ถํฌ p(x) ๋ฅผ ์๊ณ ์ถ์ต๋๋ค.
๋ฐ์ดํฐ x๋ฅผ ์์ฑํ๋ Generator๋ฅผ ์๋์ํฌ controller๊ฐ ํ์ํฉ๋๋ค. ์ด๋ค ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋๋ก Generator๋ฅผ trigger
ํด์ฃผ๋ ๋ถ๋ถ์ด๊ธฐ ๋๋ฌธ์ ์ฐ๋ฆฌ๊ฐ ๋ค๋ฃจ๊ธฐ ์ฝ๊ฒ
๋ง๋ค์ด ์ค์ผ ์ดํ ์์ฑ๋ชจ๋ธ์ ์ฌ์ฉํ ๋ ํธ๋ฆฌํ ๊ฒ ์
๋๋ค. controller ์ญํ ์ ํ๋ latent variable z๋ p(z)์์ ์ํ๋ง๋๋ฉฐ ๋ฐ์ดํฐ x๋ณด๋ค ์ฐจ์์ด ์๊ฒ ๋ง๋ญ๋๋ค.
๋ค์ ๋ชฉํ์๋ p(x)๋ฅผ ์๊ฐํด๋ณด๋ฉด, prior probability
p(z)์ conditional probabability์ ๊ณฑ ์ ๋ถ์ผ๋ก ์๊ฐํด๋ณผ ์ ์์ต๋๋ค. ์ ๋ถ ์ด๋ ์ ๋ถ์ ๋จ์ํ samplingํ ์ฌ๋ฌ ๋ฐ์ดํฐ๋ค์ summationํด์ maximum likelihood estimation์ ๋ฐ๋ก ํ ์ ์์ง ์์๊น? ์๊ฐํ ์๋ ์์ง๋ง ์ด๋ ์ํ๋งํ๋ ๊ณผ์ ์์ ์ฐ๋ฆฌ๊ฐ ์ํ์ง ์๋ ์ํ๋ค์ด ๋ ๋ง์ด ๋ฝํ ์ ์๊ธฐ ๋๋ฌธ์ ์ด ๋ฐฉ๋ฒ์ ์ธ ์ ์์ต๋๋ค.
์ฐ๋ฆฌ๊ฐ ์ํ์ง ์๋ ์ํ๋ค์ด ๋ ๋ง์ด ๋ฝํ๋ ํ์์ ์์๋ฅผ ๋ค์ด ์ดํด๋ณด๊ฒ ์ต๋๋ค. MINST ๋ฐ์ดํฐ ์ค 2๋ฅผ ๋ํ๋ด๋ ์ด๋ฏธ์ง (a)๊ฐ ์๊ณ , (a)๋ฅผ ์ผ๋ถ ์ง์ด (b)์ (a)๋ฅผ ์ค๋ฅธ์ชฝ์ผ๋ก 1 pixel ๋งํผ ์ฎ๊ธด (c)๊ฐ ์์ต๋๋ค. ์ด๋ ์ฐ๋ฆฌ๋ (a)์ ์ ์ฌํ ๋ฐ์ดํฐ๋ฅผ ๋ ๋ง์ด ๋ฝ๊ณ ์ถ๊ณ , (b)๋ณด๋ค๋ (c)๊ฐ (a)์ ๋ ๊ฐ๊น๋ค๊ณ ์๊ฐํ๋ฏ๋ก (c)์ ๊ฐ์ ๋ฐ์ดํฐ๋ค์ด ๋ ๋ง์ด ๋ฝํ๊ธฐ๋ฆฌ ์ํฉ๋๋ค. ํ์ง๋ง ๋ณดํต Generator๊ฐ Normal distribution์ผ๋ก ๋์์ธ ๋๊ณ MSE ๊ฑฐ๋ฆฌ ๊ณ์ฐ์ ํตํด (a)์ ๋ ๊ฐ๊น์ด ๋ฐ์ดํฐ ์ํ๋ก Normal distribution์ ํ๊ท ์ ์ฎ๊ฒจ๊ฐ๋ค๊ณ ํ์ ๋, (c)๋ณด๋ค (b)๊ฐ (a)์์ MSE๊ฐ ์ ๊ธฐ ๋๋ฌธ์ (b)์ ๋น์ทํ ๊ฐ์ด ์ ๊ท๋ถํฌ์ ํ๊ท ์ด ๋๊ณ (b)์ ๋น์ทํ ์ํ๋ค์ด ๋ ๋ง์ด ๋์ค๊ฒ ๋ฉ๋๋ค. ํ์ง๋ง ๊ฒฐ๊ณผ์ ์ผ๋ก (c)๊ฐ (a)์ ๋น์ทํ ๊ฒ์ด ๋ ์ข์ ์ํ๋ง์ด ๋๋ ๊ฒ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค.
์ข์ sampling function์ด๋, ์์ ์์์์ ๋ณผ ์ ์์๋ฏ์ด train DB์ ์๋ data x์ ์ ์ฌํ ์ํ์ด ๋์ฌ ์ ์๋ ํ๋ฅ ๋ถํฌ๋ผ๊ณ ํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ๊ทธ๋ฅ ์ํ๋ง ํจ์๋ฅผ ๋ง๋ค๊ธฐ ๋ณด๋ค evidence๋ก x๋ฅผ given(์กฐ๊ฑด)์ผ๋ก ํ์ฌ z๋ฅผ ๋ฝ์๋ด๋ ํ๋ฅ ๋ถํฌ p(z\|x)๋ฅผ ๋ง๋ค์ด๋ด๋ ๊ฒ์ด ๋ชฉ์ ์ ๋๋ค. ํ์ง๋ง ์ฌ๊ธฐ์ ๋ ๋ฌธ์ ์ธ ์ ์ true distridution์ธ ํด๋น ๋ถํฌ๋ฅผ ๋ง๋ค์ด๋ด๊ธฐ์ํด Variational Inference ๋ฐฉ๋ฒ์ ์ด์ฉํฉ๋๋ค. ๋ถํฌ ์ถ์ ์ ์ํ family, ์๋ฅผ ๋ค๋ฉด guassian ๋ถํฌ๋ค์ Approximation Class๋ก ๋๊ณ true distribution์ ์ถ์ ํฉ๋๋ค. ์ด๋ gaussian ๋ถํฌ์ ํ๋ผ๋ฏธํฐ์ธ ฯ๋ mean๊ณผ std ๊ฐ์ด ๋ ๊ฒ ์ด๊ณ ์ด๋ฐ ์ฌ๋ฌ gaussian ๋ถํฌ๋ค๊ณผ true posterior ๊ฐ์ KL divergence๋ฅผ ๊ตฌํ์ฌ ์ถ์ ํด๊ฐ๋๋ค.
๋ฐ๋ผ์ ์ ๋ฆฌํด๋ณด๋ฉด ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ์์ฑ๋ชจ๋ธ์ธ Generator๋ฅผ ํ์ตํ๊ธฐ ์ํด Variational Inference ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๊ฒ ๋์๊ณ ๊ทธ๋ฌ๋ค ๋ณด๋ AutoEncoder์ ๋น์ทํ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ ๋์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ฐ์ดํฐ ์์ถ
์ด ๋ชฉํ์ธ AutoEncoder์ ๋ฐ์ดํฐ ์์ฑ
์ด ๋ชฉํ์ธ VAE๋ ๊ฐ์์ ๋ชฉํ์ ๋ง์ถฐ ํ์ํ ๋ฐฉ๋ฒ๋ก ์ ๋ํ๊ฒ ๋๋ฉด์ ๊ทธ ๋ชจ๋ธ ๊ตฌ์กฐ๊ฐ ๋น์ทํด๋ณด์ด๊ฒ ๋ ๊ฒ์ด์ง ๊ฐ์ง์์ต๋๋ค.
VAE์ ์ ์ฒด ๊ตฌ์กฐ๋ [1] Decoder, Generator, Generation Network ๋ผ๊ณ ๋ถ๋ฅด๋ ๋ถ๋ถ๊ณผ [2] Encoder, Posterior, Inference Network๋ผ๊ณ ๋ถ๋ฅด๋ ๋ถ๋ถ, ํฌ๊ฒ 2๊ฐ์ง ํํธ๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
2 Variational Bound
์์ ํ๋ฆ์ ์ด์ด๊ฐ๋ณด๋ฉด, ์ฒ์์ ์๊ณ ์ถ์๋ ๊ฒ์ (1) p(x)์์ผ๋ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๋ฐ์ดํฐ๋ค๋ก ์ํ๋ง(์ปจํธ๋กค)ํ๊ธฐ ์ํด (2) p(z\|x) (true posterior)๊ฐ ํ์ํด์ก๊ณ , true posterior๋ฅผ ์ ์ ์์ผ๋ ์ด๋ฅผ ์ถ์ (Variational Inference)ํ๊ธฐ ์ํด์ (3) qฯ(z\|x)๊ฐ ํ์ํ์ต๋๋ค. ๋ฐ๋ผ์ ์ฐ๋ฆฌ๋ ์ด 3๊ฐ์ ๋ถํฌ๋ค์ ๊ด๊ณ๋ฅผ ์ข ๋ ์ดํด๋ณด๊ณ ์ด๋ป๊ฒ ์์ฑ๋ชจ๋ธ์ ํ์ตํด๋๊ฐ ๊ฒ์ธ์ง ๊ณ ๋ฏผํด๋ด์ผ ํฉ๋๋ค.
์ฒ์์ ๋ชฉํ์๋ p(x) ์ log๋ฅผ ์์์ ์๋์ ๊ฐ์ ์ ๋ณํ์ ์งํํ๋ฉด 2๊ฐ์ term์ผ๋ก ๋๋ ์ง๋๋ค. ์ฒซ๋ฒ์งธ term์ ์ด๋ฒ ์ฅ์ ์ฃผ์ธ๊ณต์ธ Evidence LowerBOund๋ผ๋ ELBO์ด๊ณ ๋๋ฒ์งธ term์ Variational Inference์์ ๋ดค์๋ true posterior์ approximator ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๋ํ๋ด๋ KL ๊ฐ์ ๋๋ค. ์ฌ๊ธฐ์ log(p(x)) ๊ฐ ์ผ์ ํ ๋ KL ๊ฐ์ ์ค์ด๋ ๊ฒ์ด ๋ชฉํ(=true posterior๋ฅผ ์ approximationํ๋ ๊ฒ)์ด๊ณ KL์ ํญ์ ์์์ด๊ธฐ ๋๋ฌธ์, ์ญ์ผ๋ก ์๊ฐํด๋ณด๋ฉด ์ฒซ๋ฒ์งธ term์ด์๋ ELBO ๊ฐ์ ์ต๋ํํ๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ ์ ์์ต๋๋ค. ์ด๋ฅผ ๊ฐ๋จํ ๊ทธ๋ํ๋ก ๋ํ๋ด๋ณด๋ฉด ์ค๋ฅธ์ชฝ ๊ทธ๋ฆผ์์ ๋ณผ ์ ์๋ฏ์ด ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๊ฒ์ ELBO๊ฐ์ด ์ปค์ง ์ ์๋ ฯ๋ฅผ ์ฐพ์๊ฐ๋ ๊ณผ์ ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค.
๋ฐ๋ผ์ ELBO๊ฐ์ด ์ปค์ง ์ ์๋ ฯ๋ฅผ ์ฐพ์๊ฐ๋ ์ต์ ํ๋ฅผ ์์์ ๋ณํํ์ฌ ๋ ๋ค์ 2๊ฐ์ term ์ฆ, (1) Reconstructino Error
์ (2) Regularization
์ผ๋ก ๋ณผ ์ ์์ต๋๋ค.
๋จผ์ Reconstructino Error์ ํ์ต ๋ฐ์ดํฐ์ธ x๊ฐ ๋ค์ด๊ฐ์ ๋ x๊ฐ ๋์ค๋๋ก ํ๋ ๋ณต์(Reconstruction)
์ ํ๋๋ก ํ๋ term์ด๋ฉฐ, Regularization์ prior distribution์ธ q์ ํํ๋ฅผ ์ ํํด์ฃผ๋ ์ญํ ์ ํ๋ term์
๋๋ค.
2.1 Regularization term
ELBO term์ ๋๋์์ ๋ ๋์๋ ์ฒซ๋ฒ์งธ Regularization term์ ๋ํด ๋ณด๊ฒ ์ต๋๋ค. True posterior๋ฅผ ์ถ์ ํ๊ธฐ ์ํ q_\phi(\mathrm{z} \mid \mathrm{x})์ KL ๊ฐ์ ๊ณ์ฐํ๊ธฐ ์ฝ๋๋ก ํ๊ธฐ ์ํด Multivariate gaussian distribution์ผ๋ก ์ค๊ณํฉ๋๋ค. ๋ํ ์์ ์ด์ผ๊ธฐํ๋ ๊ฒ ์ฒ๋ผ controller ๋ถ๋ถ์ธ p(z)๋ ๋ค๋ฃจ๊ธฐ ์ฌ์ด ๋ถํฌ์ด์ด์ผ ํ๊ธฐ ๋๋ฌธ์ ์ ๊ท๋ถํฌ๋ก ๋ง๋ค์ด ์ค๋๋ค. ๊ทธ๋ฌ๋ฉด ๋
ผ๋ฌธ์ Appendix F.1
์์ ๋ณผ ์ ์๋ฏ์ด ๊ฐ์ฐ์์ ๋ถํฌ๋ค ์ฌ์ด์ KL ๊ฐ์ mean๊ณผ std๋ฅผ ์ฌ์ฉํ์ฌ ๋ค์๊ณผ ๊ฐ์ด ์ฝ๊ฒ ๊ณ์ฐ๋ ์ ์์ต๋๋ค.
2.2 Reconstruction error term
ELBO์ ๋๋ฒ์งธ term์ธ Reconstruction error์ ๋ํด ์ดํด๋ณด๊ฒ ์ต๋๋ค. Reconstruction error์ expectation ํํ์ integral๋ก ํํํ๋ฉด ๋ค์๊ณผ ๊ฐ๊ณ ์ด๋ ๋ชฌํ ์นด๋ฅผ๋ก ์ํ๋ง์ ํตํด L๊ฐ์ z_{i,โl}๋ฅผ ๊ฐ์ง๊ณ ํ๊ท ์ ๋ด์ ๊ตฌํ ์ ์์ต๋๋ค. ์ฌ๊ธฐ์์ index i๋ ๋ฐ์ดํฐ x์ ๋๋ฒ๋ง์ด๊ณ index l์ generator์ distribution์์ ์ํ๋งํ๋ ํ์์ ๋ํ ๋๋ฒ๋ง์ ๋๋ค. VAE๋ ํ์ ๋ ๋ชฌํ ์นด๋ฅผ๋ก ์ํ๋ง์ ํตํด ํจ๊ณผ์ ์ผ๋ก optimization์ ์ํํฉ๋๋ค.
2.2.1 Reparametrization Trick
์์์ Reconstruction error๋ฅผ ๊ตฌํ๊ธฐ ์ํด ์ํ๋งํ๋ ๊ณผ์ ์์ backpropation์ ํ๊ธฐ ์ํด Reparametrization trick์ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. ๋จ์ํ ์ ๊ท๋ถํฌ์์ ์ํ๋ง ํ๋ฉด random node์ธ z์ ๋ํด์ gradient๋ฅผ ๊ณ์ฐํ ์ ์๊ธฐ ๋๋ฌธ์ random์ฑ์ ์ ๊ท๋ถํฌ์์ ์ํ๋ง ๋๋ ฯต์ผ๋ก ๋ง๋ค์ด์ฃผ๊ณ ์ด๋ฅผ reparametrization์ ํด์ฃผ์ด์ deterministic node๊ฐ ๋ z๋ฅผ backpropagation ํ ์ ์๊ฒ ๋ฉ๋๋ค.
# sampling by re-parameterization technique
z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32)
z๋ฅผ ์ํ๋งํ๋ generator์ distribution์ Bernoulli๋ก ๋์์ธํ ๊ฒฝ์ฐ NLL์ด Cross Entropy๊ฐ ๋๋ฉฐ Gaussian ๋ถํฌ๋ก ๋์์ธํ ๊ฒฝ์ฐ MSE๊ฐ ๋์ด์ ๋ณดํต ๊ณ์ฐํ๊ธฐ ์ฉ์ดํ 2๊ฐ์ ๋ถํฌ ์ค ํ๋๋ฅผ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. ๋ชจ๋ธ ๋์์ธ์ ์กฐ๊ฑด์ ๋ฐ์ดํฐ์ ๋ถํฌ์ ๋ฐ๋ผ ๊ฒฐ์ ๋๋๋ฐ ๋ฐ์ดํฐ์ ๋ถํฌ๊ฐ continuous ํ๋ค๋ฉด Gaussian ๋ถํฌ์ ๊ฐ๊น๊ธฐ ๋๋ฌธ์ Gaussian์ผ๋ก ๋์์ธํ๊ณ , ๋ฐ์ดํฐ์ ๋ถํฌ๊ฐ discrete ํ๋ค๋ฉด Bernoulli๋ถํฌ์ ๊ฐ๊น๊ธฐ ๋๋ฌธ์ Bernoulli๋ก ๋์์ธํฉ๋๋ค.
3 VAE Structure
์ง๊ธ๊น์ง ์ดํด๋ณธ VAE ๊ตฌ์กฐ๋ Encoder์ Decoder๋ฅผ ๊ฐ๊ฐ ์ด๋ค ๋ถํฌ๋ก ๋์์ธํด์ฃผ๋ ๋์ ๋ฐ๋ผ Reconstruction error์ Regularization์ ๊ณ์ฐํ๋ ์๋ง ์กฐ๊ธ์ฉ ๋ฌ๋ผ์ง๊ฒ ๋ฉ๋๋ค. Encoder ๋ถ๋ถ์ Reconstruction error์ ๊ณ์ฐ์ ์ฉ์ด์ฑ ๋๋ฌธ์ ๋ชจ๋ ์ ํ์์ ๊ฐ์ฐ์์ ๋ถํฌ๋ฅผ ์ฌ์ฉํ๊ฒ ๋๊ณ Decoder ๋ถ๋ถ๋ง ๋ณํํ์ฌ ์๋์ ์ฌ๋ฌ ์ ํ๋ค์ด ๋ํ๋๊ฒ ๋ฉ๋๋ค. ์ฐ์ ๋ชจ๋ VAE์์ ๊ณตํต์ ์ผ๋ก ์ฌ์ฉํ๊ณ ์๋ Encoder๋ฅผ ์ฝ๋๋ก ๊ตฌํํ๋ฉด ์๋์ ๊ฐ์ต๋๋ค.
# Gateway
def autoencoder(x_hat, x, dim_img, dim_z, n_hidden, keep_prob):
# encoding
mu, sigma = gaussian_MLP_encoder(x_hat, n_hidden, dim_z, keep_prob)
# sampling by re-parameterization technique
z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32)
# decoding
y = bernoulli_MLP_decoder(z, n_hidden, dim_img, keep_prob)
y = tf.clip_by_value(y, 1e-8, 1 - 1e-8)
# loss
marginal_likelihood = tf.reduce_sum(x * tf.log(y) + (1 - x) * tf.log(1 - y), 1)
KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, 1)
marginal_likelihood = tf.reduce_mean(marginal_likelihood)
KL_divergence = tf.reduce_mean(KL_divergence)
ELBO = marginal_likelihood - KL_divergence
loss = -ELBO
return y, z, loss, -marginal_likelihood, KL_divergence
def decoder(z, dim_img, n_hidden):
y = bernoulli_MLP_decoder(z, n_hidden, dim_img, 1.0, reuse=True)
return y
(1) Encoder: Gaussian / Decoder: Bernoulli
์ ๋ชจ๋ธ์ ์ฝ๋๋ก ๊ตฌํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
# Bernoulli MLP as decoder
def bernoulli_MLP_decoder(z, n_hidden, n_output, keep_prob, reuse=False):
with tf.variable_scope("bernoulli_MLP_decoder", reuse=reuse):
# initializers
w_init = tf.contrib.layers.variance_scaling_initializer()
b_init = tf.constant_initializer(0.)
# 1st hidden layer
w0 = tf.get_variable('w0', [z.get_shape()[1], n_hidden], initializer=w_init)
b0 = tf.get_variable('b0', [n_hidden], initializer=b_init)
h0 = tf.matmul(z, w0) + b0
h0 = tf.nn.tanh(h0)
h0 = tf.nn.dropout(h0, keep_prob)
# 2nd hidden layer
w1 = tf.get_variable('w1', [h0.get_shape()[1], n_hidden], initializer=w_init)
b1 = tf.get_variable('b1', [n_hidden], initializer=b_init)
h1 = tf.matmul(h0, w1) + b1
h1 = tf.nn.elu(h1)
h1 = tf.nn.dropout(h1, keep_prob)
# output layer-mean
wo = tf.get_variable('wo', [h1.get_shape()[1], n_output], initializer=w_init)
bo = tf.get_variable('bo', [n_output], initializer=b_init)
y = tf.sigmoid(tf.matmul(h1, wo) + bo)
return
(2) Encoder: Gaussian / Decoder: Gaussian
์ ๋ชจ๋ธ์ ์ฝ๋๋ก ๊ตฌํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
# Gaussian MLP as encoder
def gaussian_MLP_encoder(x, n_hidden, n_output, keep_prob):
with tf.variable_scope("gaussian_MLP_encoder"):
# initializers
w_init = tf.contrib.layers.variance_scaling_initializer()
b_init = tf.constant_initializer(0.)
# 1st hidden layer
w0 = tf.get_variable('w0', [x.get_shape()[1], n_hidden], initializer=w_init)
b0 = tf.get_variable('b0', [n_hidden], initializer=b_init)
h0 = tf.matmul(x, w0) + b0
h0 = tf.nn.elu(h0)
h0 = tf.nn.dropout(h0, keep_prob)
# 2nd hidden layer
w1 = tf.get_variable('w1', [h0.get_shape()[1], n_hidden], initializer=w_init)
b1 = tf.get_variable('b1', [n_hidden], initializer=b_init)
h1 = tf.matmul(h0, w1) + b1
h1 = tf.nn.tanh(h1)
h1 = tf.nn.dropout(h1, keep_prob)
# output layer
# borrowed from https: // github.com / altosaar / vae / blob / master / vae.py
wo = tf.get_variable('wo', [h1.get_shape()[1], n_output * 2], initializer=w_init)
bo = tf.get_variable('bo', [n_output * 2], initializer=b_init)
gaussian_params = tf.matmul(h1, wo) + bo
# The mean parameter is unconstrained
mean = gaussian_params[:, :n_output]
# The standard deviation must be positive. Parametrize with a softplus and
# add a small epsilon for numerical stability
stddev = 1e-6 + tf.nn.softplus(gaussian_params[:, n_output:])
return mean, stddev
(3) Encoder: Gaussian / Decoder: Gaussian w/ Identity Covariance
MNIST data์ ์์๋ก ๋ค์ด์ VAE ๊ตฌ์กฐ๋ฅผ ๋ํ๋ด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
4 Experiment
ํด๋น ๋
ผ๋ฌธ์์ ์คํ์ ์ด 2๊ฐ์ง๋ฅผ ์งํํ๋๋ฐ ์์์๋ ๊ณ์ VAE๋ก ๋ํ๋์ง๋ง ๋
ผ๋ฌธ์์๋ ํด๋น ์๊ณ ๋ฆฌ์ฆ์ AEVB
๋ก ์ง์นญํ๊ธฐ ๋๋ฌธ์ ์ด๋ฅผ VAE ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ์๊ฐํ๊ณ ์คํ ๊ฒฐ๊ณผ๋ค์ ๋ณด๋ฉด ๋ฉ๋๋ค. ์ฐ์ ์ฒซ๋ฒ์งธ๋ก MNIST ๋ฐ์ดํฐ์
๊ณผ Frey Face ๋ฐ์ดํฐ์
์ ์ฌ์ฉํ์ฌ ๋ฒ ์ด์ค๋ผ์ธ์ผ๋ก wake-sleep ์๊ณ ๋ฆฌ์ฆ๊ณผ ์ฑ๋ฅ์ ๋น๊ตํ์ต๋๋ค.
ELBO๊ฐ์ ์ต๋ํํ๋ ๊ฒ์ด ๋ชฉํ์ด๋ฏ๋ก y์ถ์ ๊ฐ์ด ํด์๋ก ์ข์ ๊ฒ์ผ๋ก ํด์ํ ์ ์์ต๋๋ค. ์๋์ ๊ทธ๋ํ๋ค์์ ์ค์ ๊ณผ ์ ์ ์ ๊ฐ๊ฐ train๊ณผ test ๋ฐ์ดํฐ์ ์ ๋ํด ELBO ๊ฐ์ plottingํ ๊ฒ์ผ๋ก latent variable์ธ z์ ์ฐจ์์ ํฌ๊ธฐ์ ๋ฐ๋ผ ELBO ๊ฐ์ด ์ด๋ค ์์์ ๋ํ๋ด๋์ง ๋ณด์ฌ์ค๋๋ค. Experiment I์ ๋ณด๋ฉด ํธ๋ ์ด๋ ํฌ์ธํธ๊ฐ ๋ง์ ๋ ์ฆ x์ถ ๊ฐ์ด ํด๋ test์ training์ ELBO๊ฐ์ด ์ ์ ๋ฒ์ด์ง๋๊ฒ์ด ๊ด์ฐฐ์ด ๋๋๋ฐ ์ด๋ ์ค๋ฒํผํ ์ผ๋ก ๋ณด์ ๋๋ค. ์ ์๋ ์ค๋ฒํผํ ์ ๋ฐฉ์งํ๊ธฐ์ํด ๋ฐ์ดํฐ์ ์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์์ ํ๋ ์์ ์ ํ๋ค๊ณ ํฉ๋๋ค.
๋๋ฒ์งธ๋ก๋ MNIST ๋ฐ์ดํฐ์ ์ ๋ํด์ z์ ์ฐจ์์ด 1000, 50000์ผ๋์ ๊ฐ ์๊ณ ๋ฆฌ์ฆ๋ค์ ์ฑ๋ฅ์ ํ์ต ์ํ์์ ๋ฐ๋ผ Marginal log-likelihood๋ฅผ plottingํ์ฌ ๋ํ๋์ต๋๋ค. ์ด ์คํ์์๋ ๋ฒ ์ด์ค๋ผ์ธ์ผ๋ก Wake-Sleep๊ณผ MCEM์ ์ฌ์ฉํ์ผ๋ฉฐ ์ฌ๊ธฐ์๋ AEVB(=VAE)๊ฐ convergence speed ์ธก๋ฉด์์ ๋ฒ ์ด์ค๋ผ์ธ๋ณด๋ค ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค.
5 Conclusion
๋ณธ ๋ ผ๋ฌธ์์๋ ์ฐ์์ ์ธ latent variable์ ํจ์จ์ ์ผ๋ก inference ํ๊ธฐ์ํด Stochastic Gradient VB๋ก variational lower bound์ estimation ํ๋ ๋ฐฉ๋ฒ๋ก ์ ์ ์ํ์์ต๋๋ค. ๊ฐ์ฐ์์ ๋ถํฌ์์ ๋๋คํ๊ฒ ์ํ๋งํ๋ ๋ฐฉ๋ฒ์ back propgataion์ด ๋ถ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ VAE๋ Reparametrization trick์ ์ด์ฉํ์ฌ ๊ฐ์ฐ์์ ๋ถํฌ๋ก๋ถํฐ ์ํ๋งํ estimator๋ ๋ฏธ๋ถ ๊ฐ๋ฅํ๊ณ SGD๋ก ์ต์ ํ ๋ฉ๋๋ค. ๋ํ VAE๋ i.i.d์ ๋ฐ์ดํฐ์ ๊ณผ ๊ฐ์ด ๊ฐ datapoint๊ฐ ์ฐ์์ ์ธ latent variable๋ฅผ ๊ฐ์ง๋ high dimensional ๋ฐ์ดํฐ์ ๋๋นํด Auto-Encoding VB ์๊ณ ๋ฆฌ์ฆ์ผ๋ก SGVB task๋ฅผ ํด๊ฒฐํ์์ต๋๋ค. VAE๋ ์ด๋ฏธ์ง๋ก๋ถํฐ ์ ์ฐจ์์ ๊ฐ์ฐ์์ ๋ถํฌ๋ฅผ ํ์ตํ ๋ค์ ์๋ณธ ์ด๋ฏธ์ง๋ก ๋ณต์ํ๋ ์์ฑ๋ชจ๋ธ์ ๋๋ค. ํ์ง๋ง ๋ณต์กํ ๋ถํฌ๋ก ๊ตฌ์ฑ๋ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ฐ์์ ๋ถํฌ๋ก ์ฐจ์์ถ์ํ ํ์ต๋ฐฉ๋ฒ์ ๋ฌธ์ ์ ์ด ์์ต๋๋ค. ํต์ฌ์ ์ธ ์ ๋ณด๋ฅผ ๊ฐ์ฐ์์ ๋ถํฌ๋ชจ์์ผ๋ก ์์ถํ๋ ๋ฐฉ๋ฒ์ ๋งค์ฐ ํ๋ค๋ฉฐ function loss๋ ์กด์ฌํ ์ ๋ฐ์ ์์ต๋๋ค. ์ด๋ฅผ posterior collapse๋ผ๊ณ ํฉ๋๋ค. VAE ์ดํ Posterior collapse๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํ ๋ค์ํ ์ฐ๊ตฌ๋ค์ด ์งํ๋์๋๋ฐ์, VQ-VAE์ ๊ฐ์ด VAE๋ฅผ ์ ๊ทธ๋ ์ด๋ํ ๋ค์ํ latent variable ๋ชจ๋ธ๋ค์ ๋ง๋๋ณด์ธ์!
6 Improved Work
VAE์ ์์ฑ์ ์ ๋ ๋คํธ์ํฌ(GAN)๋ฅผ ์ฌ์ฉํ์ฌ Auto-Encoder์ posterior๋ฅผ ์์์ prior ๋ถํฌ์ matchingํจ์ผ๋ก์จ variational inference๋ฅผ ํ๋ Advarsarial AutoEncoder(AAE)๋ผ๋ ์ฐ๊ตฌ๊ฐ ์์์ต๋๋ค. GAN์ Discriminator๊ฐ ์์ฑ๋ ์ด๋ฏธ์ง์ ๋ถํฌ์ ์ง์ง ๋ฐ์ดํฐ์ ์์ ์จ ์ด๋ฏธ์ง์ ๋ถํฌ๋ฅผ ํ๋ณํ๋ฉด์ prior distribution์ ๋งค์นญํ๊ฒ ๋ฉ๋๋ค. Posterior distribution์ prior distribution๊ณผ matching ํจ์ผ๋ก์จ latent space(prior distribution)์์ ์ํ๋ค์ด ์ด๋ค ์๋ฏธ๋ฅผ ๊ฐ์ง๊ณ ๋ถํฌ๋์ด ์๋์ง ์ ์ ์์ต๋๋ค. AAE์ encoder๋ ์ํ๋ prior distribution์ data ๋ถํฌ๋ฅผ ๋ง๋ค๊ฒ ๋๊ณ decoder๋ ํด๋น prior์์ ์๋ฏธ๋ฅผ ์ฐพ์ ์ ์๋ ์ํ๋ค์ ์์ฑํ ์ ์๊ฒ ๋ฉ๋๋ค.
Reference
[1] original paper: https://arxiv.org/abs/1312.6114
[2] https://di-bigdata-study.tistory.com/5
[3] https://di-bigdata-study.tistory.com/4?category=848869
[4] https://ratsgo.github.io/generative%20model/2017/12/19/vi/
[5] https://taeu.github.io/paper/deeplearning-paper-vae/
[6] https://medium.com/humanscape-tech/paper-review-vae-ac918509a9ba
[7] https://www.youtube.com/watch?v=o_peo6U7IRM
[8] https://youtu.be/SAfJz_uzaa8
[9] https://youtu.be/GbCAwVVKaHY
[10] https://youtu.be/7t_3dNs4QK4
[11] https://arxiv.org/abs/1606.05908
[12] https://cs.stanford.edu/~sunfanyun/talks/vi_discrete.pdf