Variational Auto-Encoder

์•ž์˜ Method ๋ถ€๋ถ„์—์„œ ์ฆ๋ช…ํ•˜๊ณ  ์ •๋‹นํ™”ํ•œ AEVB(Auto-Encoding Variational Bayes)๋ฅผ ์‚ฌ์šฉํ•œ Variational Auto-Encoder์˜ ์˜ˆ์‹œ๋ฅผ ๋“ค์–ด๋ณธ๋‹ค. ์ด ์™ธ์—๋„ ๋‹ค์–‘ํ•œ ํ˜•ํƒœ์˜ AEVB๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ๋‹ค.

๋ถ„ํฌ ์„ค์ •ํ•˜๊ธฐ

prior์— ๋Œ€ํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด isotropic multivariate gaussian์œผ๋กœ ๊ฐ€์ •ํ•˜์ž.

pฮธ(z)=N(z;0,I)p_{\bm \theta}(\bold z)=\mathcal N(\bold z;\bold 0, \bold I)

์ด๋ ‡๊ฒŒ ๋‹จ์ˆœํ•œ ํ˜•ํƒœ์˜ ๋‹จ์ผ ๋ถ„ํฌ๋กœ ๊ฐ€์ •ํ•  ๊ฒฝ์šฐ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋„ˆ๋ฌด ์ ์œผ๋‚˜, ๋””์ฝ”๋”๊ฐ€ ์ถฉ๋ถ„ํžˆ ๊นŠ์€ MLP์ด๋ฏ€๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ถ€์กฑ์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค. ๋””์ฝ”๋” ํ™•๋ฅ ๋ถ„ํฌ์ธ pฮธ(xโˆฃz)p_{\bm \theta}(\bold x|\bold z)๋Š” z\bold z ๋กœ๋ถ€ํ„ฐ x\bold x ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๊ณ , ์ถœ๋ ฅ ํ˜•ํƒœ์— ๋”ฐ๋ผ multivariate gaussian, bernoulli ๋“ฑ์˜ ๋ถ„ํฌ๋ฅผ ๊ฐ€์ง„๋‹ค.

Multivariate Gaussian์€ ์ •๊ทœ๋ถ„ํฌ์ด์ง€๋งŒ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์„ ๋ฒกํ„ฐ๋กœ ๊ฐ€์ง€๋Š” ๋‹ค๋ณ€๋Ÿ‰ ์ •๊ทœ๋ถ„ํฌ์ด๋ฉฐ ์‹ค์ˆ˜๊ฐ’์„ ๊ฐ€์ง€๋Š” ์ถœ๋ ฅ ๋ถ„ํฌ๋ฅผ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.

Bernoulli ๋ถ„ํฌ๋Š” n=1์ธ (๋‹ค๋ณ€๋Ÿ‰)์ดํ•ญ ๋ถ„ํฌ๋กœ, ์ด์ง„๊ฐ’์„ ๊ฐ€์ง€๋Š” ์ถœ๋ ฅ ๋ถ„ํฌ๋ฅผ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.

Intractableํ•œ pฮธ(zโˆฃx)p_{\bm \theta}(\bold z|\bold x)์™€ ๋‹ฌ๋ฆฌ, ๊ทผ์‚ฌ์‹œํ‚จ ์ธ์ฝ”๋” qฯ•(zโˆฃx)q_{\bm \phi}(\bold z|\bold x)๋Š” ์ž์œ ๋กญ๊ฒŒ ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค. qฯ•(zโˆฃx)q_{\bm \phi}(\bold z|\bold x) ๋˜ํ•œ ์ •๊ทœ๋ถ„ํฌ๋กœ ๊ฐ€์ •ํ•œ๋‹ค. x๊ฐ€ ์กฐ๊ฑด๋ถ€์ธ ์ƒํƒœ์—์„œ ๊ณ„์‚ฐํ•œ ๋ถ„ํฌ์—ฌ์•ผ ํ•˜๋ฏ€๋กœ ์ถœ๋ ฅ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์€ ์ž…๋ ฅ x์— ๋Œ€ํ•œ ์ธ์ฝ”๋”์˜ ์ถœ๋ ฅ ํ˜•ํƒœ์—ฌ์•ผ ํ•œ๋‹ค.

๊ณ„์‚ฐ์˜ ๊ฐ„์†Œํ™”๋ฅผ ์œ„ํ•ด True posterior์˜ ๊ฐ ๋ณ€์ˆ˜๊ฐ„์˜ ์ƒ๊ด€๊ด€๊ณ„๊ฐ€ ์—†์„ ๊ฒƒ์œผ๋กœ ๊ฐ€์ •ํ•˜๊ณ , ๊ณต๋ถ„์‚ฐ์ด ๋Œ€๊ฐํ–‰๋ ฌ์ธ ์ •๊ทœ๋ถ„ํฌ๋ฅผ ์ถœ๋ ฅํ•˜๋„๋ก ์„ค์ •ํ•œ๋‹ค. ์ฆ‰,

qฯ•(zโˆฃx(i))=N(z;ฮผ(x(i)),ฯƒ2(x(i)))q_{\bm \phi}(\bold z|\bold x^{(i)})=\mathcal N(\bold z;\bm \mu(\bold x^{(i)}),\bm \sigma^2(\bold x^{(i)}))

๋‹ค๋ณ€๋Ÿ‰ ์ •๊ทœ๋ถ„ํฌ N(ฮผ,ฯƒ2)\mathcal N(\bm \mu,\bm \sigma^2)์˜ ๊ณต๋ถ„์‚ฐ์ด ๋Œ€๊ฐํ–‰๋ ฌ์ด๋ผ๋Š” ๊ฒƒ์€ ๊ฐ ๋ฒกํ„ฐ ์„ฑ๋ถ„ ์‚ฌ์ด์˜ ์ƒ๊ด€๊ด€๊ณ„๊ฐ€ ์—†๋‹ค๋Š” ๋œป์ด๋‹ค.

๋ถ„ํฌ์—์„œ z ์ƒ˜ํ”Œ๋งํ•˜๊ธฐ

์ƒ˜ํ”Œ๋ง์ด๋ผ๊ณ  ํ•˜๋ฉด ์–ด๋ ต์ง€๋งŒ, NN์—์„œ๋Š” ๊ฐ„๋‹จํžˆ ํ•ด๋‹น ํ™•๋ฅ ๋ถ„ํฌ๋ฅผ ํ•™์Šตํ•œ NN์˜ feed-forward์ด๋‹ค.

z(i,l)โˆผqฯ•(zโˆฃx(i))\bold z^{(i,l)} \sim q_{\bm \phi}(\bold z|\bold x^{(i)}) ์ฒ˜๋Ÿผ ์ƒ˜ํ”Œ๋งํ•˜๋ฉด stochastic gradient๋กœ ์ตœ์ ํ™” ๋ถˆ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ, reparameterization trick์„ ์‚ฌ์šฉํ•œ๋‹ค. ฯต(l)โˆผN(0,I)\bm \epsilon^{(l)}\sim \mathcal N(\bold 0,\bold I)๋กœ ๋ฏธ๋ฆฌ ์ƒ˜ํ”Œ๋งํ•˜๊ณ , z(i,l)=gฯ•(x(i),ฯต(l))\bold z^{(i,l)} = g_{\bm \phi}(\bold x^{(i)},\bm \epsilon^{(l)})์™€ ๊ฐ™์ด deterministicํ•˜๊ฒŒ ๊ณ„์‚ฐํ•œ๋‹ค.

์ •๊ทœ๋ถ„ํฌ qฯ•(zโˆฃx(i))q_{\bm \phi}(\bold z|\bold x^{(i)})์—์„œ ์ƒ˜ํ”Œ๋งํ•˜๋ฏ€๋กœ, gฯ•(โ‹…)g_{\bm \phi}(\cdot)๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ณ„์‚ฐํ•œ๋‹ค. โŠ™\odot์€ element-wise product์ด๋‹ค( xi=aiโ‹…bix_{i}=a_i\cdot b_i ).

z(i,l)=gฯ•(x(i),ฯต(l))=ฮผ(x(i))+ฯƒ(x(i))โŠ™ฯต(l)\bold z^{(i,l)}=g_{\bm \phi}(\bold x^{(i)},\bm \epsilon^{(l)})= \bm \mu(\bold x^{(i)})+\bm \sigma(\bold x^{(i)})\odot \bm \epsilon^{(l)}

z\bold z ๋ฅผ ํ•˜๋‚˜ ์—ญ์ „ํŒŒ ๊ฐ€๋Šฅํ•˜๋„๋ก ์ƒ˜ํ”Œ๋งํ•˜๋Š” ๋ฐ ์„ฑ๊ณตํ–ˆ๋‹ค.

SGVB๋กœ ์ตœ์ ํ™”ํ•˜๊ธฐ

๊ตฌํ•˜๋ ค๋Š” lower bound๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

์ง€๊ธˆ๊นŒ์ง€ qฯ•(zโˆฃx(i))q_{\bm \phi}(\bold z|\bold x^{(i)})์™€ pฮธ(z)p_{\bm \theta}(\bold z) ๋ชจ๋‘ ์ •๊ทœ๋ถ„ํฌ๋กœ ๊ฐ€์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ์œ„ ์‹์˜ KL Divergence ๋ถ€๋ถ„์€ ์ˆ˜์‹ ์ •๋ฆฌ๋ฅผ ํ†ตํ•ด ๊ฐ„๋‹จํžˆ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. ์ด ๊ณผ์ •์€ ์ˆ˜์‹ ์ „๊ฐœ๊ฐ€ ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ณผ์ •์„ Appendix์— ์ •๋ฆฌํ•ด ๋†“์•˜๋‹ค.ฮผ(x(i)),ฯƒ(x(i))\bm \mu(\bold x^{(i)}), \bm \sigma(\bold x^{(i)})๋ฅผ ๊ฐ๊ฐ ฮผ(i),ฯƒ(i)\bm \mu^{(i)},\bm \sigma^{(i)} ๋กœ ๊ฐ„๋‹จํ•˜๊ฒŒ ํ‘œ์‹œํ•˜์ž.

๊ฒฐ๋ก ์ ์œผ๋กœ z\bold z์˜_ _์ฐจ์› JJ ์— ๋Œ€ํ•ด, lower bound๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œํ˜„๋œ๋‹ค.

L~(ฮธ,ฯ•;x(i))=โˆ’DKL(qฯ•(zโˆฃx(i))โˆฃโˆฃpฮธ(z))+1Lโˆ‘l=1Llogโกpฮธ(x(i)โˆฃz(i,l))\tilde{ \mathcal L}(\theta,\phi;\bold x^{(i)})=-D_{KL}(q_{\bm \phi}(\bold z|\bold x^{(i)})||p_{\bm \theta}(\bold z))+\frac 1 L\sum^L_{l=1}\log p_{\bm \theta}(\bold x^{(i)}|\bold z^{(i,l)})
L~(ฮธ,ฯ•;x(i))=12โˆ‘j=1J(1+logโก((ฯƒj(i))2)โˆ’(ฮผj(i))2โˆ’(ฯƒj(i))2)+1Lโˆ‘l=1Llogโกpฮธ(x(i)โˆฃz(i,l))\tilde {\mathcal L}(\theta,\phi;\bold x^{(i)})=\frac 1 2\sum^J_{j=1}\left(1 +\log((\sigma_j^{(i)})^2) -(\mu_j^{(i)})^2-(\sigma_j^{(i)})^2 \right)+\frac 1 L\sum^L_{l=1}\log p_{\bm \theta}(\bold x^{(i)}|\bold z^{(i,l)})

์ธ์ฝ”๋”์—์„œ ์ถœ๋ ฅ๋œ z\bold z์˜ ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ ๋ฒกํ„ฐ, ์ถœ๋ ฅ๋œ z\bold z์— ๋Œ€ํ•ด ๋””์ฝ”๋”์—์„œ ์˜ฌ๋ฐ”๋ฅธ ๋ฐ์ดํ„ฐ x(i)x^{(i)}๊ฐ€ ๋‚˜ํƒ€๋‚  ๊ฐ€๋Šฅ๋„(cross entropy loss)๋ฅผ ์ด์šฉํ•˜๋ฉด estimator๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋ฅผ ๋ฏธ๋ถ„/์—ญ์ „ํŒŒํ•ด ์ตœ์ ํ™”ํ•˜๋Š” ๊ฒƒ์ด Variational Auto-Encoder์ด๋‹ค.

ํŒŒ๋ผ๋ฏธํ„ฐ ํ•™์Šต ํ›„์—๋Š” z\bold z ๋ฒกํ„ฐ๋ฅผ ์ž…๋ ฅํ•ด ์ƒ˜ํ”Œ์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋กœ์จ ๋””์ฝ”๋” pฮธ(xโˆฃz)p_{\bm \theta}(\bold x|\bold z)๋ฅผ ๊ตฌํ•˜๋Š” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•œ๋‹ค. ๋‹ค์Œ ์ฑ•ํ„ฐ์˜ ์ฝ”๋“œ๋ฅผ ๋ณด๋ฉด ๋”์šฑ ์ง๊ด€์ ์œผ๋กœ VAE๋ฅผ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.

Last updated

Was this helpful?