diff --git a/README.md b/README.md index 6112de36..9ceed6e7 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,8 @@ [![Tests](https://github.com/alan-turing-institute/affinity-vae/actions/workflows/tests.yml/badge.svg)](https://github.com/alan-turing-institute/affinity-vae/actions/workflows/tests.yml) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) -> Note: This a development version of the code. The code in the `main` branch is -> a more stable version of the code. - +> Note: This is a stable version of the code used to produce figures for the paper. The code in `develop` is in constant change. +> # Affinity-VAE **Affinity-VAE for disentanglement, clustering and classification of objects in diff --git a/avae/decoders/decoders.py b/avae/decoders/decoders.py index ecbcd409..dc2de4ec 100644 --- a/avae/decoders/decoders.py +++ b/avae/decoders/decoders.py @@ -253,7 +253,9 @@ def __init__( def forward(self, x, x_pose): if self.pose: - return self.decoder(torch.cat([x_pose, x], dim=-1)) + return self.decoder(torch.cat([x_pose, x], dim=-1)), self.decoder( + torch.cat([x_pose, x], dim=-1) + ) else: return self.decoder(x) diff --git a/avae/decoders/differentiable.py b/avae/decoders/differentiable.py index 4c52f032..17395906 100644 --- a/avae/decoders/differentiable.py +++ b/avae/decoders/differentiable.py @@ -1,7 +1,13 @@ +import logging +import typing from typing import Optional, Tuple +import numpy as np import torch +import torchvision +from scipy import stats +from avae import settings, vis from avae.decoders.base import AbstractDecoder from avae.decoders.spatial import ( CartesianAxes, @@ -9,6 +15,26 @@ axis_angle_to_quaternion, quaternion_to_rotation_matrix, ) +from avae.utils import save_imshow_png + + +class STEFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return (input > 0).float() + + @staticmethod + def backward(ctx, grad_output): + return torch.nn.functional.hardtanh(grad_output) + + +class StraightThroughEstimator(torch.nn.Module): + def __init__(self): + super(StraightThroughEstimator, self).__init__() + + def forward(self, x): + x = STEFunction.apply(x) + return x class GaussianSplatRenderer(torch.nn.Module): @@ -185,13 +211,14 @@ def __init__( torch.nn.Linear(latent_dims, n_splats * 3), torch.nn.Tanh(), ) + # weights are effectively whether a splat is used or not # use a soft step function to make this `binary` (but differentiable) # NOTE(arl): not sure if this really makes any difference self.weights = torch.nn.Sequential( torch.nn.Linear(latent_dims, n_splats), - torch.nn.Tanh(), - SoftStep(k=10.0), + StraightThroughEstimator(), + # SoftStep(k=10.0), ) # sigma ends up being scaled by `splat_sigma_range` self.sigmas = torch.nn.Sequential( @@ -226,11 +253,7 @@ def __init__( else torch.nn.Conv3d ) self._decoder = torch.nn.Sequential( - conv(1, 32, 3, padding="same"), - torch.nn.ReLU(), - conv(32, 32, 3, padding="same"), - torch.nn.ReLU(), - conv(32, output_channels, 3, padding="same"), + conv(1, 1, 9, padding="same"), ) def configure_renderer( @@ -256,6 +279,7 @@ def configure_renderer( device=device, ) self._splat_sigma_range = splat_sigma_range + self._device = device def decode_splats( self, z: torch.Tensor, pose: torch.Tensor @@ -272,10 +296,13 @@ def decode_splats( # in the case where the encoded pose only has one dimension, we need to # use the pose as a rotation about the z-axis if pose.shape[-1] == 1: + tile = torch.tile(self._default_axis, (batch_size, 1)).to( + self._device + ) pose = torch.concat( [ pose, - torch.tile(self._default_axis, (batch_size, 1)), + tile, ], axis=-1, ) @@ -332,11 +359,14 @@ def forward( x = self._splatter( splats, weights, sigmas, splat_sigma_range=self._splat_sigma_range ) - # if we're doing a final convolution, do it here + + x_before_conv = x + if ( self._output_channels is not None and self._output_channels != 0 and use_final_convolution ): x = self._decoder(x) - return x + + return x, x_before_conv diff --git a/avae/evaluate.py b/avae/evaluate.py index 297671e1..85062b05 100644 --- a/avae/evaluate.py +++ b/avae/evaluate.py @@ -129,7 +129,7 @@ def evaluate( t = t.to(torch.float32) # forward - t_hat, t_mu, t_logvar, tlat, tlat_pose = vae(t) + t_hat, t_before_conv, t_mu, t_logvar, tlat, tlat_pose = vae(t) x_test.extend(t_mu.cpu().detach().numpy()) # store latents c_test.extend(t_logvar.cpu().detach().numpy()) diff --git a/avae/models.py b/avae/models.py index 0c7a2277..317eed17 100644 --- a/avae/models.py +++ b/avae/models.py @@ -210,9 +210,21 @@ def forward(self, x): # reparametrise latent = self.reparametrise(latent_mu, latent_logvar) # decode - x_recon = self.decoder(latent, latent_pose) # pose set to None if pd=0 - - return x_recon, latent_mu, latent_logvar, latent, latent_pose + if self.decoder.__class__.__name__ == 'GaussianSplatDecoder': + x_recon, x_before_conv = self.decoder( + latent, latent_pose + ) # pose set to None if pd=0 + else: + x_recon = self.decoder(latent, latent_pose) + + return ( + x_recon, + x_before_conv, + latent_mu, + latent_logvar, + latent, + latent_pose, + ) def reparametrise(self, mu, log_var): if self.training: diff --git a/avae/train.py b/avae/train.py index 88766eb4..4666d12d 100644 --- a/avae/train.py +++ b/avae/train.py @@ -321,7 +321,12 @@ def train( x = x.to(torch.float32) # forward - x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) + if vae.decoder.__class__.__name__ == "GaussianSplatDecoder": + x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose = vae( + x + ) + else: + x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) history_loss = loss( x, x_hat, lat_mu, lat_logvar, epoch, batch_aff=aff ) @@ -387,7 +392,10 @@ def train( v = v.to(torch.float32) # forward - v_hat, v_mu, v_logvar, vlat, vlat_pos = vae(v) + if vae.decoder.__class__.__name__ == "GaussianSplatDecoder": + v_hat, v_before_conv, v_mu, v_logvar, vlat, vlat_pos = vae(v) + else: + v_hat, v_mu, v_logvar, vlat, vlat_pos = vae(v) v_history_loss = loss( v, v_hat, v_mu, v_logvar, epoch, batch_aff=aff ) @@ -454,7 +462,17 @@ def train( t = t.to(torch.float32) # forward - t_hat, t_mu, t_logvar, tlat, tlat_pose = vae(t) + if vae.decoder.__class__.__name__ == "GaussianSplatDecoder": + ( + t_hat, + t_before_conv, + t_mu, + t_logvar, + tlat, + tlat_pose, + ) = vae(t) + else: + t_hat, t_mu, t_logvar, tlat, tlat_pose = vae(t) x_test.extend(t_mu.cpu().detach().numpy()) # store latents c_test.extend(t_logvar.cpu().detach().numpy()) @@ -546,6 +564,37 @@ def train( epoch=epoch, writer=writer, ) + + xx = x_before_conv.detach().cpu().numpy() + vis.plot_array_distribution_tool( + (xx - np.min(xx)) / (np.max(xx) - np.min(xx)), "xx_normalised" + ) + vis.plot_array_distribution_tool( + x_hat.detach().cpu().numpy(), "x_hat" + ) + vis.plot_array_distribution_tool(xx, "xx") + + vis.recon_plot( + x, + x_before_conv, + y_train, + data_dim, + mode="trn_before_conv", + epoch=epoch, + writer=writer, + ) + + vis.recon_plot( + x, + (x_before_conv - torch.min(x_before_conv)) + / (torch.max(x_before_conv) - torch.min(x_before_conv)), + y_train, + data_dim, + mode="trn_before_conv_normalised", + epoch=epoch, + writer=writer, + ) + vis.recon_plot( v, v_hat, diff --git a/avae/utils.py b/avae/utils.py index 85b9e25d..f67b881f 100644 --- a/avae/utils.py +++ b/avae/utils.py @@ -302,7 +302,7 @@ def pose_interpolation( # Decode interpolated vectors with torch.no_grad(): - decoded_img = vae.decoder(lat, pos) + decoded_img, x_before_conv = vae.decoder(lat, pos) decoded_grid.append(decoded_img.cpu().squeeze().numpy()) diff --git a/avae/utils_learning.py b/avae/utils_learning.py index 929bcb82..0f9cd6c7 100644 --- a/avae/utils_learning.py +++ b/avae/utils_learning.py @@ -72,6 +72,7 @@ def pass_batch( torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, list, ]: """Passes a batch through the affinity VAE model epoch and computes the loss. @@ -141,7 +142,7 @@ def pass_batch( # forward x = x.to(torch.float32) - x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) + x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose = vae(x) if loss is not None: history_loss = loss(x, x_hat, lat_mu, lat_logvar, e, batch_aff=aff) @@ -165,7 +166,7 @@ def pass_batch( optimizer.step() optimizer.zero_grad() - return x, x_hat, lat_mu, lat_logvar, lat, lat_pose, history + return x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, history def add_meta( diff --git a/avae/vis.py b/avae/vis.py index 2d975f5d..aee3d41d 100644 --- a/avae/vis.py +++ b/avae/vis.py @@ -1370,7 +1370,7 @@ def latent_4enc_interpolate_plot( # Decode the interpolated encoding to generate an image with torch.no_grad(): - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), (torch.zeros(1, poses[0].shape[0]) + pose_mean).to( device=device @@ -1474,11 +1474,13 @@ def latent_disentamglement_plot( current_pos_grid = torch.from_numpy( np.array([pos_means]) ).to(device) - current_recon = vae.decoder( + current_recon, x_before_conv = vae.decoder( current_lat_grid, current_pos_grid ) else: - current_recon = vae.decoder(current_lat_grid, None) + current_recon, x_before_conv = vae.decoder( + current_lat_grid, None + ) recon_images.append(current_recon.cpu().squeeze().numpy()) @@ -1808,7 +1810,7 @@ def interpolations_plot( ) with torch.no_grad(): if poses is not None: - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), ( torch.zeros(1, poses[0].shape[0]) @@ -1816,7 +1818,7 @@ def interpolations_plot( ).to(device=device), ) else: - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), None, ) @@ -2073,3 +2075,56 @@ def latent_space_similarity_plot( plt.close() else: plt.show() + + +def plot_array_distribution_tool( + data, + array_name, + display: bool = False, +): + """ + This is a tool for developers + Plot histogram, boxplot, and violin plot of the data on a single figure. + + Parameters: + data (array-like): The input data. + + Returns: + None + """ + if isinstance(data, torch.Tensor): + # If data is a torch tensor, detach and move it to CPU + data = data.detach().cpu().numpy() + + elif not isinstance(data, np.ndarray): + # If data is not a numpy array or a torch tensor, convert it to numpy array + data = np.array(data) + + # Flatten the array if it has more than two dimensions + if data.ndim > 2: + data = data.flatten() + + fig, axes = plt.subplots(3, 1, figsize=(8, 18)) + + axes[0].hist(data, bins=10, density=True, alpha=0.6, color='b') + axes[0].set_title('Histogram of Data') + axes[0].set_xlabel('Value') + axes[0].set_ylabel('Frequency') + + axes[1].boxplot(data) + axes[1].set_title('Boxplot of Data') + axes[1].set_ylabel('Value') + + axes[2].violinplot(data) + axes[2].set_title('Violin Plot of Data') + axes[2].set_ylabel('Value') + + plt.tight_layout() + + if not display: + if not os.path.exists("plots"): + os.mkdir("plots") + plt.savefig(f"plots/array_{array_name}_stats.{settings.VIS_FORMAT}") + plt.close() + else: + plt.show()