From 7b2f38955f86341538c4ab9f6f7f7997ec8ae08f Mon Sep 17 00:00:00 2001 From: OleguerCanal Date: Sun, 24 Jan 2021 23:19:26 +0100 Subject: [PATCH] Fixed working with doubles bug --- experiment_params/train_config_default.yaml | 2 +- hamiltonian_generative_network.py | 6 +++--- utilities/hgn_result.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/experiment_params/train_config_default.yaml b/experiment_params/train_config_default.yaml index 365367b..326b08a 100644 --- a/experiment_params/train_config_default.yaml +++ b/experiment_params/train_config_default.yaml @@ -6,7 +6,7 @@ device: 'cuda:0' # Will use this device if available, otherwise wil use cpu # Define networks architectures networks: variational: True - dtype : "float" + dtype : "double" encoder: hidden_conv_layers: 6 n_filters: [32, 64, 64, 64, 64, 64, 64] # first + hidden diff --git a/hamiltonian_generative_network.py b/hamiltonian_generative_network.py index e2037b2..acf734d 100644 --- a/hamiltonian_generative_network.py +++ b/hamiltonian_generative_network.py @@ -74,7 +74,7 @@ def forward(self, rollout_batch, n_steps=None, variational=True): prediction_shape = list(rollout_batch.shape) prediction_shape[1] = n_steps + 1 # Count the first one prediction = HgnResult(batch_shape=torch.Size(prediction_shape), - device=self.device) + device=self.device, dtype=self.dtype) prediction.set_input(rollout_batch) # Concat along channel dimension @@ -165,13 +165,13 @@ def get_random_sample(self, n_steps, img_shape=(32, 32)): # Sample from a normal distribution the latent representation of the rollout latent_shape = (1, self.encoder.out_mean.out_channels, img_shape[0], img_shape[1]) - latent_representation = torch.randn(latent_shape).to(self.device) + latent_representation = torch.randn(latent_shape).to(self.device).type(self.dtype) # Instantiate prediction object prediction_shape = (1, n_steps, self.channels, img_shape[0], img_shape[1]) prediction = HgnResult(batch_shape=torch.Size(prediction_shape), - device=self.device) + device=self.device, dtype=self.dtype) prediction.set_z(z_sample=latent_representation) diff --git a/utilities/hgn_result.py b/utilities/hgn_result.py index 00edec9..9eba295 100644 --- a/utilities/hgn_result.py +++ b/utilities/hgn_result.py @@ -14,7 +14,7 @@ class HgnResult(): """Class to bundle HGN guessed output information. """ - def __init__(self, batch_shape, device): + def __init__(self, batch_shape, device, dtype): """Instantiate the HgnResult that will contain all the information of the forward pass over a single batch of rollouts. @@ -22,6 +22,7 @@ def __init__(self, batch_shape, device): batch_shape (torch.Size): Shape of a batch of reconstructed rollouts, returned by batch.shape. device (str): String with the device to use. E.g. 'cuda:0', 'cpu'. + dtype (torch.dtype): Data type """ self.input = None self.z_mean = None @@ -30,7 +31,7 @@ def __init__(self, batch_shape, device): self.q_s = [] self.p_s = [] self.energies = [] # Estimated energy of the system by the Hamiltonian network - self.reconstructed_rollout = torch.empty(batch_shape).to(device) + self.reconstructed_rollout = torch.empty(batch_shape, dtype=dtype).to(device) self.reconstruction_ptr = 0 def set_input(self, rollout):