Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experiment_params/train_config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ device: 'cuda:0' # Will use this device if available, otherwise wil use cpu
# Define networks architectures
networks:
variational: True
dtype : "float"
dtype : "double"
encoder:
hidden_conv_layers: 6
n_filters: [32, 64, 64, 64, 64, 64, 64] # first + hidden
Expand Down
6 changes: 3 additions & 3 deletions hamiltonian_generative_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, rollout_batch, n_steps=None, variational=True):
prediction_shape = list(rollout_batch.shape)
prediction_shape[1] = n_steps + 1 # Count the first one
prediction = HgnResult(batch_shape=torch.Size(prediction_shape),
device=self.device)
device=self.device, dtype=self.dtype)
prediction.set_input(rollout_batch)

# Concat along channel dimension
Expand Down Expand Up @@ -165,13 +165,13 @@ def get_random_sample(self, n_steps, img_shape=(32, 32)):
# Sample from a normal distribution the latent representation of the rollout
latent_shape = (1, self.encoder.out_mean.out_channels, img_shape[0],
img_shape[1])
latent_representation = torch.randn(latent_shape).to(self.device)
latent_representation = torch.randn(latent_shape).to(self.device).type(self.dtype)

# Instantiate prediction object
prediction_shape = (1, n_steps, self.channels, img_shape[0],
img_shape[1])
prediction = HgnResult(batch_shape=torch.Size(prediction_shape),
device=self.device)
device=self.device, dtype=self.dtype)

prediction.set_z(z_sample=latent_representation)

Expand Down
5 changes: 3 additions & 2 deletions utilities/hgn_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ class HgnResult():
"""Class to bundle HGN guessed output information.
"""

def __init__(self, batch_shape, device):
def __init__(self, batch_shape, device, dtype):
"""Instantiate the HgnResult that will contain all the information of the forward pass
over a single batch of rollouts.

Args:
batch_shape (torch.Size): Shape of a batch of reconstructed rollouts, returned by
batch.shape.
device (str): String with the device to use. E.g. 'cuda:0', 'cpu'.
dtype (torch.dtype): Data type
"""
self.input = None
self.z_mean = None
Expand All @@ -30,7 +31,7 @@ def __init__(self, batch_shape, device):
self.q_s = []
self.p_s = []
self.energies = [] # Estimated energy of the system by the Hamiltonian network
self.reconstructed_rollout = torch.empty(batch_shape).to(device)
self.reconstructed_rollout = torch.empty(batch_shape, dtype=dtype).to(device)
self.reconstruction_ptr = 0

def set_input(self, rollout):
Expand Down