Solution for Unable to call trained model with an arbitrary number of samples since there is a fixed batch size tf.Variable input to recurrent decoder of VRAE
is Given Below:
I am converting a Variational Autoencoder (VRAE) from this PyTorch implementation to Keras. For the VRAE, the inputs to the recurrent decoder must be of the shape (batch_size, timesteps, features). In PyTorch this is achieved via
self.decoder_inputs = torch.zeros(self.sequence_length, self.batch_size, 1, requires_grad=True).type(self.dtype)
Please note, the PyTorch GRU by default expects batch size to be the second dimension of its input, this is not the case for the TensorFlow GRU.
When I implement this model in TensorFlow, if I attempt to pass a single sample (1, timesteps, features) model, then I get this error
InvalidArgumentError: Invalid input_h shape: [1,1,300] [1,128,300] [Op:CudnnRNN]
However, if I pass a sample where with the first dimension matches the batch size that the model was originally trained on, I get no error. Such a sample in this case would be of shape (batch_size=128, timesteps, features). When I trained the model, I also used the drop_remainder=True
kwarg for the tf.data.Dataset.batch(batch_size, drop_remainder=True)
to avoid this problem that I mention during model.fit
. Is there a way to pass in an arbitrary number of samples to such a network that was trained with a fixed size tf.Variable
? It seems really inefficient to always make my inputs be the first element of a zeros tensor such as this
dummy_zeros = np.zeros(batch_size, timesteps, features)
a_single_sample = np.random.randn(timesteps, features)
dummy_zeros[0, :, :] = a_single_sample
and then pass that zeros tensor to the model call like so
# I index the first element of the results of the forward pass through the model because
# the other elements are not samples.
a_single_sample_reconstruction = model(dummy_zeros)[0]
Below is the model that I have made.
class VRAE(tf.keras.Model):
"""Variational Recurrent Autoencoder"""
def __init__(self, batch_size, rnn_units, latent_size, features, **kwargs):
"""Defines architecture layers.
:param batch_size: <class 'int'> Self-explanatory.
:param rnn_units: <class 'int'> The number of hidden units in the RNN layers.
:param latent_size: <class 'int'> Dimensions of latent space.
:param features: <class 'int'> The number of features (aka the last dimension) of the input.
"""
super().__init__(**kwargs)
# The first step of the encoding process
self.recurrent_encoder_part = tf.keras.layers.GRU(units=rnn_units, return_sequences=False, return_state=True)
# Variational part of the encoding process
self.dense_mus = tf.keras.layers.Dense(units=latent_size)
self.dense_log_vars = tf.keras.layers.Dense(units=latent_size)
self.sampling = tf.keras.layers.Lamba(function=self.reparametrize)
# Converts outputs from the sampling layer to initial hidden state for decoder rnn
self.z_to_init_hidden_state = tf.keras.layers.Dense(units=recurrent_units)
########################################################################################################
# BEGIN: THE PROBLEM IS THAT THIS VARIABLE HAS A FIXED NUMBER OF PARAMETERS DUE TO BATCH SIZE CONSTRAINT
########################################################################################################
# Define the trainable zeros tensor that is input (not hidden state) of recurrent decoder
self.decoder_input = tf.Variable(tf.zeros(shape=(batch_size, timesteps, 1)), trainable=True)
########################################################################################################
# END: THE PROBLEM IS THAT THIS VARIABLE HAS A FIXED NUMBER OF PARAMETERS DUE TO BATCH SIZE CONSTRAINT
########################################################################################################
# Decoding layer
self.recurrent_decoder = tf.keras.layers.GRU(units=recurrent_units, return_sequences=True, return_state=False)
# Reconstruct original inputs (softmax used because inputs belong to multiple classes)
self.reconstruction = tf.keras.layers.Dense(units=features, activation='softmax')
def call(self, inputs):
"""Forward computation that reconstructs original inputs.
:param inputs: Rank-3 tensor of inputs (batch_size, timesteps, features).
:return: Rank-3 tensor of probabilities (batch_size, timesteps, features).
"""
# Encode the input
h_t = self.recurrent_encoder_part(inputs)
# Get latent variables
mus = self.dense_mus(h_t)
log_vars = self.dense_log_vars(h_t)
z = self.sampling([mus, log_vars])
# Project z to correct shape for initial hidden state of decoder
decoder_h_0 = self.z_to_init_hidden(z)
# Decode -- note how the 'decoder' input is always (batch_size, timesteps, 1) shape
outputs = self.recurrent_decoder(self.decoder_input, initial_state=decoder_h_0)
# Unnormalized log probabilities -> normalized probabilities
reconstructions = self.reconstructions(outputs)
# Add KL loss term
self.add_loss(self.compute_kl_loss(mus, log_vars))
# Return the reconstructions of the original input for the main loss function (CategoricalCrossentropy) to handle
return reconstructions
def reparametrize(self, inputs):
"""Reparametrization trick that returns latent space z given means and log variances."""
mus, log_vars = inputs
# Extract the shapes from one of the input tensors.
# Note, dimensions(mus) = dimensions(log_vars)
batch_size = tf.shape(mus)[0]
dims = tf.shape(mus)[1]
# Compute epsilon
epsilon = tf.keras.random_normal(shape=(batch_size, dims))
# Sample the latent space
return mus + tf.exp(log_vars/2) * epsilon
def compute_kl_loss(self, mus, log_vars):
"""Computes the Kullback-Leibler loss using means and log variances of prior distribution."""
return -0.5 * tf.reduce_mean(1. + log_vars - tf.exp(log_vars) - tf.pow(mus, 2))