Skip to main content


Variational Animal Motion Embedding 0.1 Toolbox © K. Luxem & P. Bauer, Department of Cellular Neuroscience Leibniz Institute for Neurobiology, Magdeburg, Germany Licensed under GNU General Public License v3.0


def reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor,
reduction: str) -> torch.Tensor

Compute the reconstruction loss between input and reconstructed data.


  • x torch.Tensor - Input data tensor.
  • x_tilde torch.Tensor - Reconstructed data tensor.
  • reduction str - Type of reduction for the loss.


  • torch.Tensor - Reconstruction loss.


def future_reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor,
reduction: str) -> torch.Tensor

Compute the future reconstruction loss between input and predicted future data.


  • x torch.Tensor - Input future data tensor.
  • x_tilde torch.Tensor - Reconstructed future data tensor.
  • reduction str - Type of reduction for the loss.


  • torch.Tensor - Future reconstruction loss.


def cluster_loss(H: torch.Tensor, kloss: int, lmbda: float,
batch_size: int) -> torch.Tensor

Compute the cluster loss.


  • H torch.Tensor - Latent representation tensor.
  • kloss int - Number of clusters.
  • lmbda float - Lambda value for the loss.
  • batch_size int - Size of the batch.


  • torch.Tensor - Cluster loss.


def kullback_leibler_loss(mu: torch.Tensor,
logvar: torch.Tensor) -> torch.Tensor

Compute the Kullback-Leibler divergence loss. see Appendix B from VAE paper: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 -

Formula: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)


  • mu torch.Tensor - Mean of the latent distribution.
  • logvar torch.Tensor - Log variance of the latent distribution.


  • torch.Tensor - Kullback-Leibler divergence loss.


def kl_annealing(epoch: int, kl_start: int, annealtime: int,
function: str) -> float

Anneal the Kullback-Leibler loss to let the model learn first the reconstruction of the data before the KL loss term gets introduced.


  • epoch int - Current epoch number.
  • kl_start int - Epoch number to start annealing the loss.
  • annealtime int - Annealing time.
  • function str - Annealing function type.


  • float - Annealed weight value for the loss.


def gaussian(ins: torch.Tensor,
is_training: bool,
seq_len: int,
std_n: float = 0.8) -> torch.Tensor

Add Gaussian noise to the input data.


  • ins torch.Tensor - Input data tensor.
  • is_training bool - Whether it is training mode.
  • seq_len int - Length of the sequence.
  • std_n float - Standard deviation for the Gaussian noise.


  • torch.Tensor - Noisy input data tensor.


def train(train_loader: Data.DataLoader, epoch: int, model: nn.Module,
optimizer: torch.optim.Optimizer, anneal_function: str, BETA: float,
kl_start: int, annealtime: int, seq_len: int, future_decoder: bool,
future_steps: int, scheduler: torch.optim.lr_scheduler._LRScheduler,
mse_red: str, mse_pred: str, kloss: int, klmbda: float, bsize: int,
noise: bool) -> Tuple[float, float, float, float, float, float]

Train the model.


  • train_loader DataLoader - Training data loader.
  • epoch int - Current epoch number.
  • model nn.Module - Model to be trained.
  • optimizer Optimizer - Optimizer for training.
  • anneal_function str - Annealing function type.
  • BETA float - Beta value for the loss.
  • kl_start int - Epoch number to start annealing the loss.
  • annealtime int - Annealing time.
  • seq_len int - Length of the sequence.
  • future_decoder bool - Whether a future decoder is used.
  • epoch0 int - Number of future steps to predict.
  • epoch1 _lr_scheduler.LRScheduler - Learning rate scheduler.
  • epoch2 str - Reduction type for MSE reconstruction loss.
  • epoch3 str - Reduction type for MSE prediction loss.
  • epoch4 int - Number of clusters for cluster loss.
  • epoch5 float - Lambda value for cluster loss.
  • epoch6 int - Size of the batch.
  • epoch7 bool - Whether to add Gaussian noise to the input.


Tuple[float, float, float, float, float, float]: Kullback-Leibler weight, train loss, K-means loss, KL loss, MSE loss, future loss.


def test(test_loader: Data.DataLoader, model: nn.Module, BETA: float,
kl_weight: float, seq_len: int, mse_red: str, kloss: str,
klmbda: float, future_decoder: bool,
bsize: int) -> Tuple[float, float, float]

Evaluate the model on the test dataset.


  • test_loader DataLoader - DataLoader for the test dataset.
  • epoch int, deprecated - Current epoch number.
  • model nn.Module - The trained model.
  • optimizer Optimizer, deprecated - The optimizer used for training.
  • BETA float - Beta value for the VAE loss.
  • kl_weight float - Weighting factor for the KL divergence loss.
  • seq_len int - Length of the sequence.
  • mse_red str - Reduction method for the MSE loss.
  • kloss str - Loss function for K-means clustering.
  • klmbda float - Lambda value for K-means loss.
  • epoch0 bool - Flag indicating whether to use a future decoder.
  • epoch1 int - Batch size.


Tuple[float, float, float]: Tuple containing MSE loss per item, total test loss per item, and K-means loss weighted by the kl_weight.


def train_model(config: str, save_logs: bool = False) -> None

Train Variational Autoencoder using the configuration file values.


  • config str - Path to the configuration file.