via Target Propagation instead of reparameterization trick[WIP]

Posted by one_meets_seven on 金 10 2月 2017

via Target Propagation instead of reparameterization trick[WIP]

潜在変数が離散的(カテゴリカル分布)なVAEをTarget Propagationで学習させた。

Variational AutoEncoder for Discrete Random Variables

Prepare

dataset: MNIST(binary)

In [1]:
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
import gzip
f = gzip.open('mnist.pkl.gz')
In [3]:
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = np.load(f)
In [4]:
train_x[train_x > 0.5] = 1.0
train_x[train_x <= 0.5] = 0.0
valid_x[valid_x > 0.5] = 1.0
valid_x[valid_x <= 0.5] = 0.0
test_x[test_x > 0.5] = 1.0
test_x[test_x <= 0.5] = 0.0

Target Propagation

via Difference Target Propagation (instead of Back Propagation)
using approximate inverse (instead of reparameterization trick)
(https://arxiv.org/abs/1412.7525)

In [5]:
from gumbel_dtp import VAE_DTP
model = VAE_DTP()
model.fit(x_train=train_x, x_valid=valid_x, n_epochs=50)
q_y_x1/weights:0
q_y_x1/biases:0
q_y_x2/weights:0
q_y_x2/biases:0
q_y_x3/weights:0
q_y_x3/biases:0
p_x_y1/weights:0
p_x_y1/biases:0
p_x_y2/weights:0
p_x_y2/biases:0
p_x_y3/weights:0
p_x_y3/biases:0
Epoch 1, loss: -181.528, enc_loss: 1.633
val_loss: -176.631, time: 20.475
Epoch 2, loss: -175.079, enc_loss: 1.330
val_loss: -170.894, time: 20.484
Epoch 3, loss: -179.198, enc_loss: 10.287
val_loss: -172.048, time: 21.530
Epoch 4, loss: -155.014, enc_loss: 5.160
val_loss: -156.495, time: 21.882
Epoch 5, loss: -155.618, enc_loss: 3.194
val_loss: -152.740, time: 22.102
Epoch 6, loss: -153.468, enc_loss: 2.618
val_loss: -151.374, time: 22.731
Epoch 7, loss: -151.451, enc_loss: 1.718
val_loss: -150.879, time: 23.232
Epoch 8, loss: -158.401, enc_loss: 1.484
val_loss: -150.903, time: 22.670
Epoch 9, loss: -150.197, enc_loss: 3.062
val_loss: -150.773, time: 21.988
Epoch 10, loss: -151.462, enc_loss: 1.891
val_loss: -150.780, time: 22.088
Epoch 11, loss: -167.011, enc_loss: 2.591
val_loss: -150.096, time: 21.908
Epoch 12, loss: -152.649, enc_loss: 6.321
val_loss: -148.808, time: 21.946
Epoch 13, loss: -147.282, enc_loss: 5.881
val_loss: -146.008, time: 22.020
Epoch 14, loss: -153.295, enc_loss: 6.640
val_loss: -144.781, time: 22.920
Epoch 15, loss: -144.526, enc_loss: 9.148
val_loss: -144.214, time: 22.264
Epoch 16, loss: -148.974, enc_loss: 7.499
val_loss: -144.254, time: 22.335
Epoch 17, loss: -135.863, enc_loss: 5.920
val_loss: -143.738, time: 22.453
Epoch 18, loss: -141.195, enc_loss: 4.995
val_loss: -144.296, time: 22.267
Epoch 19, loss: -145.764, enc_loss: 7.427
val_loss: -143.850, time: 22.650
Epoch 20, loss: -156.815, enc_loss: 4.429
val_loss: -144.182, time: 22.276
Epoch 21, loss: -158.898, enc_loss: 5.451
val_loss: -144.194, time: 22.423
Epoch 22, loss: -147.974, enc_loss: 6.557
val_loss: -143.945, time: 22.221
Epoch 23, loss: -151.237, enc_loss: 4.112
val_loss: -144.710, time: 22.324
Epoch 24, loss: -143.208, enc_loss: 7.075
val_loss: -144.454, time: 22.564
Epoch 25, loss: -142.938, enc_loss: 5.326
val_loss: -144.323, time: 22.246
Epoch 26, loss: -148.902, enc_loss: 6.524
val_loss: -144.296, time: 22.292
Epoch 27, loss: -148.302, enc_loss: 4.372
val_loss: -143.879, time: 22.553
Epoch 28, loss: -150.249, enc_loss: 3.592
val_loss: -143.972, time: 22.856
Epoch 29, loss: -147.510, enc_loss: 6.287
val_loss: -144.053, time: 22.972
Epoch 30, loss: -143.519, enc_loss: 5.758
val_loss: -144.133, time: 22.825
Epoch 31, loss: -141.336, enc_loss: 5.589
val_loss: -143.727, time: 22.774
Epoch 32, loss: -148.377, enc_loss: 4.134
val_loss: -143.957, time: 22.924
Epoch 33, loss: -132.986, enc_loss: 6.796
val_loss: -143.880, time: 22.824
Epoch 34, loss: -143.904, enc_loss: 6.006
val_loss: -143.851, time: 22.914
Epoch 35, loss: -138.371, enc_loss: 3.220
val_loss: -143.742, time: 23.176
Epoch 36, loss: -141.178, enc_loss: 4.968
val_loss: -143.789, time: 22.725
Epoch 37, loss: -134.987, enc_loss: 6.365
val_loss: -143.776, time: 22.925
Epoch 38, loss: -145.562, enc_loss: 6.158
val_loss: -143.343, time: 22.823
Epoch 39, loss: -144.444, enc_loss: 6.671
val_loss: -143.688, time: 22.949
Epoch 40, loss: -137.429, enc_loss: 5.739
val_loss: -143.283, time: 23.107
Epoch 41, loss: -137.841, enc_loss: 4.334
val_loss: -143.159, time: 23.036
Epoch 42, loss: -150.200, enc_loss: 4.650
val_loss: -143.334, time: 22.880
Epoch 43, loss: -144.802, enc_loss: 9.214
val_loss: -143.240, time: 23.392
Epoch 44, loss: -138.497, enc_loss: 5.076
val_loss: -143.079, time: 23.363
Epoch 45, loss: -126.026, enc_loss: 4.992
val_loss: -143.476, time: 23.278
Epoch 46, loss: -144.195, enc_loss: 4.194
val_loss: -143.219, time: 23.425
Epoch 47, loss: -143.523, enc_loss: 4.875
val_loss: -143.492, time: 23.332
Epoch 48, loss: -146.495, enc_loss: 2.853
val_loss: -143.280, time: 23.529
Epoch 49, loss: -138.208, enc_loss: 6.174
val_loss: -143.406, time: 24.947
Epoch 50, loss: -130.100, enc_loss: 6.287
val_loss: -143.535, time: 25.024
In [6]:
targets = test_x[0:5]

latents = model.predict(x_data=targets)

fig = plt.figure(figsize=(8, 6))
for i, target in enumerate(targets):
    ax = fig.add_subplot(2, 5, i+1, xticks=[], yticks=[])
    ax.imshow(target.reshape(28, 28), 'gray')
for i, latent in enumerate(latents):
    ax = fig.add_subplot(2, 5, 6+i, xticks=[], yticks=[])
    ax.imshow(latent.reshape(28, 28), 'gray')
In [7]:
encode = model.encoder(x_data=test_x[0:1])
import seaborn as sns
sns.heatmap(encode)
Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x12935fd50>

Back Propagation

using Gumbel-Softmax trick (https://arxiv.org/abs/1611.01144) (https://arxiv.org/abs/1611.00712)


Comments !