Unable to call trained model with an arbitrary number of samples since there is a fixed batch size tf.Variable input to recurrent decoder of VRAE

Solution for Unable to call trained model with an arbitrary number of samples since there is a fixed batch size tf.Variable input to recurrent decoder of VRAE
is Given Below:

I am converting a Variational Autoencoder (VRAE) from this PyTorch implementation to Keras. For the VRAE, the inputs to the recurrent decoder must be of the shape (batch_size, timesteps, features). In PyTorch this is achieved via

self.decoder_inputs = torch.zeros(self.sequence_length, self.batch_size, 1, requires_grad=True).type(self.dtype)

Please note, the PyTorch GRU by default expects batch size to be the second dimension of its input, this is not the case for the TensorFlow GRU.
When I implement this model in TensorFlow, if I attempt to pass a single sample (1, timesteps, features) model, then I get this error

InvalidArgumentError: Invalid input_h shape: [1,1,300] [1,128,300] [Op:CudnnRNN]

However, if I pass a sample where with the first dimension matches the batch size that the model was originally trained on, I get no error. Such a sample in this case would be of shape (batch_size=128, timesteps, features). When I trained the model, I also used the drop_remainder=True kwarg for the tf.data.Dataset.batch(batch_size, drop_remainder=True) to avoid this problem that I mention during model.fit. Is there a way to pass in an arbitrary number of samples to such a network that was trained with a fixed size tf.Variable? It seems really inefficient to always make my inputs be the first element of a zeros tensor such as this

dummy_zeros = np.zeros(batch_size, timesteps, features)
a_single_sample = np.random.randn(timesteps, features)
dummy_zeros[0, :, :] = a_single_sample

and then pass that zeros tensor to the model call like so

# I index the first element of the results of the forward pass through the model because
# the other elements are not samples.
a_single_sample_reconstruction = model(dummy_zeros)[0]

Below is the model that I have made.

class VRAE(tf.keras.Model):
    """Variational Recurrent Autoencoder"""

    def __init__(self, batch_size, rnn_units, latent_size, features, **kwargs):
        """Defines architecture layers.

        :param batch_size: <class 'int'> Self-explanatory.
        :param rnn_units: <class 'int'> The number of hidden units in the RNN layers.
        :param latent_size: <class 'int'> Dimensions of latent space.
        :param features: <class 'int'> The number of features (aka the last dimension) of the input.
        """

        super().__init__(**kwargs)

        # The first step of the encoding process
        self.recurrent_encoder_part = tf.keras.layers.GRU(units=rnn_units, return_sequences=False, return_state=True)

        # Variational part of the encoding process
        self.dense_mus = tf.keras.layers.Dense(units=latent_size)
        self.dense_log_vars = tf.keras.layers.Dense(units=latent_size)
        self.sampling = tf.keras.layers.Lamba(function=self.reparametrize)

        # Converts outputs from the sampling layer to initial hidden state for decoder rnn
        self.z_to_init_hidden_state = tf.keras.layers.Dense(units=recurrent_units)

        ########################################################################################################
        # BEGIN: THE PROBLEM IS THAT THIS VARIABLE HAS A FIXED NUMBER OF PARAMETERS DUE TO BATCH SIZE CONSTRAINT
        ########################################################################################################
    
        # Define the trainable zeros tensor that is input (not hidden state) of recurrent decoder
        self.decoder_input = tf.Variable(tf.zeros(shape=(batch_size, timesteps, 1)), trainable=True)

        ########################################################################################################
        # END: THE PROBLEM IS THAT THIS VARIABLE HAS A FIXED NUMBER OF PARAMETERS DUE TO BATCH SIZE CONSTRAINT
        ########################################################################################################

        # Decoding layer
        self.recurrent_decoder = tf.keras.layers.GRU(units=recurrent_units, return_sequences=True, return_state=False)

        # Reconstruct original inputs (softmax used because inputs belong to multiple classes)
        self.reconstruction = tf.keras.layers.Dense(units=features, activation='softmax')

    def call(self, inputs):
        """Forward computation that reconstructs original inputs.
        
        :param inputs: Rank-3 tensor of inputs (batch_size, timesteps, features).

        :return: Rank-3 tensor of probabilities (batch_size, timesteps, features).
        """

        # Encode the input
        h_t = self.recurrent_encoder_part(inputs)

        # Get latent variables
        mus = self.dense_mus(h_t)
        log_vars = self.dense_log_vars(h_t)
        z = self.sampling([mus, log_vars]) 

        # Project z to correct shape for initial hidden state of decoder
        decoder_h_0 = self.z_to_init_hidden(z)

        # Decode -- note how the 'decoder' input is always (batch_size, timesteps, 1) shape
        outputs = self.recurrent_decoder(self.decoder_input, initial_state=decoder_h_0)

        # Unnormalized log probabilities -> normalized probabilities
        reconstructions = self.reconstructions(outputs)

        # Add KL loss term
        self.add_loss(self.compute_kl_loss(mus, log_vars))

        # Return the reconstructions of the original input for the main loss function (CategoricalCrossentropy) to handle
        return reconstructions

    def reparametrize(self, inputs):
       """Reparametrization trick that returns latent space z given means and log variances."""

        mus, log_vars = inputs

        # Extract the shapes from one of the input tensors.
        # Note, dimensions(mus) = dimensions(log_vars)
        batch_size = tf.shape(mus)[0]
        dims = tf.shape(mus)[1]

        # Compute epsilon
        epsilon = tf.keras.random_normal(shape=(batch_size, dims))

        # Sample the latent space
        return mus + tf.exp(log_vars/2) * epsilon
 
    def compute_kl_loss(self, mus, log_vars):
        """Computes the Kullback-Leibler loss using means and log variances of prior distribution."""

        return -0.5 * tf.reduce_mean(1. + log_vars - tf.exp(log_vars) - tf.pow(mus, 2))