From c3593b5384b03bfa77e9bd4cd460f355687a5efc Mon Sep 17 00:00:00 2001 From: Camila Rangel Smith Date: Mon, 4 Mar 2024 11:37:44 +0000 Subject: [PATCH 1/4] Update README.md --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7b1f94cb..805cb8fb 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,8 @@ [![Tests](https://github.com/alan-turing-institute/affinity-vae/actions/workflows/tests.yml/badge.svg)](https://github.com/alan-turing-institute/affinity-vae/actions/workflows/tests.yml) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) -> Note: This a development version of the code. The code in the `main` branch is -> a more stable version of the code. - +> Note: This is a stable version of the code used to produce figures for the paper. The code in `develop` is in constant change. +> # Affinity-VAE **Affinity-VAE for disentanglement, clustering and classification of objects in From 009f2352d9f774aaca8a68619fd65c458686be2b Mon Sep 17 00:00:00 2001 From: crangelsmith Date: Mon, 4 Mar 2024 11:43:25 +0000 Subject: [PATCH 2/4] fixing precommit --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 805cb8fb..d37d8865 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ [![Tests](https://github.com/alan-turing-institute/affinity-vae/actions/workflows/tests.yml/badge.svg)](https://github.com/alan-turing-institute/affinity-vae/actions/workflows/tests.yml) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) -> Note: This is a stable version of the code used to produce figures for the paper. The code in `develop` is in constant change. -> +> Note: This is a stable version of the code used to produce figures for the paper. The code in `develop` is in constant change. +> # Affinity-VAE **Affinity-VAE for disentanglement, clustering and classification of objects in From f8caa056b06acd531fb36338b382422f5160399c Mon Sep 17 00:00:00 2001 From: marjanfamili Date: Fri, 12 Apr 2024 12:18:44 +0100 Subject: [PATCH 3/4] added saving before conv layers and straight through estimator --- avae/decoders/decoders.py | 2 +- avae/decoders/differentiable.py | 48 ++++++++++++++++++++----- avae/evaluate.py | 2 +- avae/models.py | 4 +-- avae/train.py | 28 +++++++++++++++ avae/utils.py | 2 +- avae/utils_learning.py | 4 +-- avae/vis.py | 63 ++++++++++++++++++++++++++++++--- 8 files changed, 132 insertions(+), 21 deletions(-) diff --git a/avae/decoders/decoders.py b/avae/decoders/decoders.py index d87a2e9f..7b3f3bbc 100644 --- a/avae/decoders/decoders.py +++ b/avae/decoders/decoders.py @@ -253,7 +253,7 @@ def __init__( def forward(self, x, x_pose): if self.pose: - return self.decoder(torch.cat([x_pose, x], dim=-1)) + return self.decoder(torch.cat([x_pose, x], dim=-1)), self.decoder(torch.cat([x_pose, x], dim=-1)) else: return self.decoder(x) diff --git a/avae/decoders/differentiable.py b/avae/decoders/differentiable.py index 4c52f032..29e3b059 100644 --- a/avae/decoders/differentiable.py +++ b/avae/decoders/differentiable.py @@ -1,6 +1,13 @@ from typing import Optional, Tuple +import typing +import logging +import torchvision +import numpy as np +from scipy import stats import torch +from avae.utils import save_imshow_png +from avae import settings, vis from avae.decoders.base import AbstractDecoder from avae.decoders.spatial import ( @@ -10,6 +17,24 @@ quaternion_to_rotation_matrix, ) +class STEFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return (input > 0).float() + + @staticmethod + def backward(ctx, grad_output): + return torch.nn.functional.hardtanh(grad_output) + + +class StraightThroughEstimator(torch.nn.Module): + def __init__(self): + super(StraightThroughEstimator, self).__init__() + + def forward(self, x): + x = STEFunction.apply(x) + return x + class GaussianSplatRenderer(torch.nn.Module): """Perform gaussian splatting.""" @@ -185,13 +210,15 @@ def __init__( torch.nn.Linear(latent_dims, n_splats * 3), torch.nn.Tanh(), ) + + # weights are effectively whether a splat is used or not # use a soft step function to make this `binary` (but differentiable) # NOTE(arl): not sure if this really makes any difference self.weights = torch.nn.Sequential( torch.nn.Linear(latent_dims, n_splats), - torch.nn.Tanh(), - SoftStep(k=10.0), + StraightThroughEstimator(), + #SoftStep(k=10.0), ) # sigma ends up being scaled by `splat_sigma_range` self.sigmas = torch.nn.Sequential( @@ -226,11 +253,11 @@ def __init__( else torch.nn.Conv3d ) self._decoder = torch.nn.Sequential( - conv(1, 32, 3, padding="same"), - torch.nn.ReLU(), - conv(32, 32, 3, padding="same"), - torch.nn.ReLU(), - conv(32, output_channels, 3, padding="same"), + conv(1, 1, 9, padding="same"), + #torch.nn.ReLU(), + #conv(32, 32, 3, padding="same"), + #torch.nn.ReLU(), + #conv(32, output_channels, 3, padding="same"), ) def configure_renderer( @@ -332,11 +359,14 @@ def forward( x = self._splatter( splats, weights, sigmas, splat_sigma_range=self._splat_sigma_range ) - # if we're doing a final convolution, do it here + + x_before_conv = x + if ( self._output_channels is not None and self._output_channels != 0 and use_final_convolution ): x = self._decoder(x) - return x + + return x, x_before_conv diff --git a/avae/evaluate.py b/avae/evaluate.py index e9991a34..218cd7e6 100644 --- a/avae/evaluate.py +++ b/avae/evaluate.py @@ -121,7 +121,7 @@ def evaluate( vae.eval() for b, batch in enumerate(tests): - x, x_hat, lat_mu, lat_logvar, lat, lat_pose, _ = pass_batch( + x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, _ = pass_batch( device, vae, batch, b, len(tests) ) x_test.extend(lat_mu.cpu().detach().numpy()) diff --git a/avae/models.py b/avae/models.py index c0571353..9b2711b7 100644 --- a/avae/models.py +++ b/avae/models.py @@ -88,9 +88,9 @@ def forward(self, x): # reparametrise latent = self.reparametrise(latent_mu, latent_logvar) # decode - x_recon = self.decoder(latent, latent_pose) # pose set to None if pd=0 + x_recon, x_before_conv = self.decoder(latent, latent_pose) # pose set to None if pd=0 - return x_recon, latent_mu, latent_logvar, latent, latent_pose + return x_recon, x_before_conv, latent_mu, latent_logvar, latent, latent_pose def reparametrise(self, mu, log_var): if self.training: diff --git a/avae/train.py b/avae/train.py index 7f412d00..fb721046 100644 --- a/avae/train.py +++ b/avae/train.py @@ -380,6 +380,7 @@ def train( ( x, x_hat, + x_before_conv, lat_mu, lat_logvar, lat, @@ -435,6 +436,7 @@ def train( ( v, v_hat, + v_before_conv, v_mu, v_logvar, vlat, @@ -582,6 +584,32 @@ def train( epoch=epoch, writer=writer, ) + + xx = x_before_conv.detach().cpu().numpy() + vis.plot_array_distribution_tool((xx - np.min(xx)) / (np.max(xx) - np.min(xx)), "xx_normalised") + vis.plot_array_distribution_tool(x_hat.detach().cpu().numpy(), "x_hat") + vis.plot_array_distribution_tool(xx, "xx") + + vis.recon_plot( + x, + x_before_conv, + y_train, + data_dim, + mode="trn_before_conv", + epoch=epoch, + writer=writer, + ) + + vis.recon_plot( + x, + (x_before_conv - torch.min(x_before_conv)) / (torch.max(x_before_conv) - torch.min(x_before_conv)), + y_train, + data_dim, + mode="trn_before_conv_normalised", + epoch=epoch, + writer=writer, + ) + vis.recon_plot( v, v_hat, diff --git a/avae/utils.py b/avae/utils.py index 85b9e25d..f67b881f 100644 --- a/avae/utils.py +++ b/avae/utils.py @@ -302,7 +302,7 @@ def pose_interpolation( # Decode interpolated vectors with torch.no_grad(): - decoded_img = vae.decoder(lat, pos) + decoded_img, x_before_conv = vae.decoder(lat, pos) decoded_grid.append(decoded_img.cpu().squeeze().numpy()) diff --git a/avae/utils_learning.py b/avae/utils_learning.py index 767c3c8f..d7d4cbe7 100644 --- a/avae/utils_learning.py +++ b/avae/utils_learning.py @@ -140,7 +140,7 @@ def pass_batch( # forward x = x.to(torch.float32) - x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) + x_hat, x_before_conv,lat_mu, lat_logvar, lat, lat_pose = vae(x) if loss is not None: history_loss = loss(x, x_hat, lat_mu, lat_logvar, e, batch_aff=aff) @@ -164,7 +164,7 @@ def pass_batch( optimizer.step() optimizer.zero_grad() - return x, x_hat, lat_mu, lat_logvar, lat, lat_pose, history + return x, x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose, history def add_meta( diff --git a/avae/vis.py b/avae/vis.py index 4380268c..bca6353a 100644 --- a/avae/vis.py +++ b/avae/vis.py @@ -1370,7 +1370,7 @@ def latent_4enc_interpolate_plot( # Decode the interpolated encoding to generate an image with torch.no_grad(): - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), (torch.zeros(1, poses[0].shape[0]) + pose_mean).to( device=device @@ -1474,11 +1474,11 @@ def latent_disentamglement_plot( current_pos_grid = torch.from_numpy( np.array([pos_means]) ).to(device) - current_recon = vae.decoder( + current_recon, x_before_conv = vae.decoder( current_lat_grid, current_pos_grid ) else: - current_recon = vae.decoder(current_lat_grid, None) + current_recon, x_before_conv = vae.decoder(current_lat_grid, None) recon_images.append(current_recon.cpu().squeeze().numpy()) @@ -1808,7 +1808,7 @@ def interpolations_plot( ) with torch.no_grad(): if poses is not None: - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), ( torch.zeros(1, poses[0].shape[0]) @@ -1816,7 +1816,7 @@ def interpolations_plot( ).to(device=device), ) else: - decoded_images = vae.decoder( + decoded_images, x_before_conv = vae.decoder( interpolated_z.view(-1, latent_dim).to(device=device), None, ) @@ -2073,3 +2073,56 @@ def latent_space_similarity_plot( plt.close() else: plt.show() + + + +def plot_array_distribution_tool(data, array_name, display: bool = False, +): + """ + This is a tool for developers + Plot histogram, boxplot, and violin plot of the data on a single figure. + + Parameters: + data (array-like): The input data. + + Returns: + None + """ + if isinstance(data, torch.Tensor): + # If data is a torch tensor, detach and move it to CPU + data = data.detach().cpu().numpy() + + elif not isinstance(data, np.ndarray): + # If data is not a numpy array or a torch tensor, convert it to numpy array + data = np.array(data) + + # Flatten the array if it has more than two dimensions + if data.ndim > 2: + data = data.flatten() + + fig, axes = plt.subplots(3, 1, figsize=(8, 18)) + + axes[0].hist(data, bins=10, density=True, alpha=0.6, color='b') + axes[0].set_title('Histogram of Data') + axes[0].set_xlabel('Value') + axes[0].set_ylabel('Frequency') + + axes[1].boxplot(data) + axes[1].set_title('Boxplot of Data') + axes[1].set_ylabel('Value') + + axes[2].violinplot(data) + axes[2].set_title('Violin Plot of Data') + axes[2].set_ylabel('Value') + + plt.tight_layout() + + if not display: + if not os.path.exists("plots"): + os.mkdir("plots") + plt.savefig( + f"plots/array_{array_name}_stats.{settings.VIS_FORMAT}" + ) + plt.close() + else: + plt.show() From ef0aa18002adda0a0186543716774ef80e502efa Mon Sep 17 00:00:00 2001 From: jolaem Date: Wed, 7 Aug 2024 10:10:48 +0000 Subject: [PATCH 4/4] Fix failing GSD branch. --- avae/decoders/differentiable.py | 6 +++++- avae/models.py | 9 ++++++--- avae/train.py | 24 +++++++++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/avae/decoders/differentiable.py b/avae/decoders/differentiable.py index bb7911fb..17395906 100644 --- a/avae/decoders/differentiable.py +++ b/avae/decoders/differentiable.py @@ -279,6 +279,7 @@ def configure_renderer( device=device, ) self._splat_sigma_range = splat_sigma_range + self._device = device def decode_splats( self, z: torch.Tensor, pose: torch.Tensor @@ -295,10 +296,13 @@ def decode_splats( # in the case where the encoded pose only has one dimension, we need to # use the pose as a rotation about the z-axis if pose.shape[-1] == 1: + tile = torch.tile(self._default_axis, (batch_size, 1)).to( + self._device + ) pose = torch.concat( [ pose, - torch.tile(self._default_axis, (batch_size, 1)), + tile, ], axis=-1, ) diff --git a/avae/models.py b/avae/models.py index 72035a36..317eed17 100644 --- a/avae/models.py +++ b/avae/models.py @@ -210,9 +210,12 @@ def forward(self, x): # reparametrise latent = self.reparametrise(latent_mu, latent_logvar) # decode - x_recon, x_before_conv = self.decoder( - latent, latent_pose - ) # pose set to None if pd=0 + if self.decoder.__class__.__name__ == 'GaussianSplatDecoder': + x_recon, x_before_conv = self.decoder( + latent, latent_pose + ) # pose set to None if pd=0 + else: + x_recon = self.decoder(latent, latent_pose) return ( x_recon, diff --git a/avae/train.py b/avae/train.py index 9ddc3436..4666d12d 100644 --- a/avae/train.py +++ b/avae/train.py @@ -321,7 +321,12 @@ def train( x = x.to(torch.float32) # forward - x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose = vae(x) + if vae.decoder.__class__.__name__ == "GaussianSplatDecoder": + x_hat, x_before_conv, lat_mu, lat_logvar, lat, lat_pose = vae( + x + ) + else: + x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) history_loss = loss( x, x_hat, lat_mu, lat_logvar, epoch, batch_aff=aff ) @@ -387,7 +392,10 @@ def train( v = v.to(torch.float32) # forward - v_hat, v_mu, v_logvar, vlat, vlat_pos = vae(v) + if vae.decoder.__class__.__name__ == "GaussianSplatDecoder": + v_hat, v_before_conv, v_mu, v_logvar, vlat, vlat_pos = vae(v) + else: + v_hat, v_mu, v_logvar, vlat, vlat_pos = vae(v) v_history_loss = loss( v, v_hat, v_mu, v_logvar, epoch, batch_aff=aff ) @@ -454,7 +462,17 @@ def train( t = t.to(torch.float32) # forward - t_hat, t_mu, t_logvar, tlat, tlat_pose = vae(t) + if vae.decoder.__class__.__name__ == "GaussianSplatDecoder": + ( + t_hat, + t_before_conv, + t_mu, + t_logvar, + tlat, + tlat_pose, + ) = vae(t) + else: + t_hat, t_mu, t_logvar, tlat, tlat_pose = vae(t) x_test.extend(t_mu.cpu().detach().numpy()) # store latents c_test.extend(t_logvar.cpu().detach().numpy())