Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions experiment_params/train_config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ networks:
n_filters: [64, 64, 64]
kernel_sizes: [3, 3, 3, 3]

# Define HGN Integrator
integrator:
method: "Leapfrog"

# Define optimization
optimization:
epochs: 5
Expand Down
19 changes: 7 additions & 12 deletions hamiltonian_generative_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ def __init__(self,
transformer,
hnn,
decoder,
integrator,
device,
dtype,
seq_len,
channels):
channels,
delta_t=0.125):
"""Instantiate a Hamiltonian Generative Network.

Args:
encoder (networks.encoder_net.EncoderNet): Encoder neural network.
transformer (networks.transformer_net.TransformerNet): Transformer neural network.
hnn (networks.hamiltonian_net.HamiltonianNet): Hamiltonian neural network.
decoder (networks.decoder_net.DecoderNet): Decoder neural network.
integrator (Integrator): HGN integrator.
device (str): String with the device to use. E.g. 'cuda:0', 'cpu'.
dtype (torch.dtype): Data type used for the networks.
seq_len (int): Number of frames in each rollout.
Expand All @@ -52,8 +51,8 @@ def __init__(self,
self.transformer = transformer
self.hnn = hnn
self.decoder = decoder
self.integrator = integrator
self.delta_t = delta_t

def forward(self, rollout_batch, n_steps=None, variational=True):
"""Get the prediction of the HGN for a given rollout_batch of n_steps.

Expand Down Expand Up @@ -95,18 +94,14 @@ def forward(self, rollout_batch, n_steps=None, variational=True):
# Estimate predictions
for _ in range(n_steps - 1):
# Compute next state
q, p = self.integrator.step(q=q, p=p, hnn=self.hnn)
delta_q, delta_p = self.hnn(q=q, p=p)
q = q + delta_q * self.delta_t
p = p + delta_p * self.delta_t
prediction.append_state(q=q, p=p)
prediction.append_energy(self.integrator.energy) # This is the energy of previous timestep

# Compute state reconstruction
x_reconstructed = self.decoder(q)
prediction.append_reconstruction(x_reconstructed)

# We need to add the energy of the system at the last time-step
with torch.no_grad():
last_energy = self.hnn(q=q, p=p).detach().cpu().numpy()
prediction.append_energy(last_energy) # This is the energy of previous timestep
return prediction

def load(self, directory):
Expand Down
19 changes: 5 additions & 14 deletions networks/hamiltonian_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self,
raise ValueError(
'Args hidden_conv_layers, n_filters, kernel_sizes, and strides'
'can only be either all None, or all defined by the user.')
self.in_shape = in_shape
in_channels = in_shape[0] * 2
paddings = [int(k / 2) for k in kernel_sizes]
self.in_conv = nn.Conv2d(in_channels=in_channels,
Expand All @@ -83,13 +84,9 @@ def __init__(self,
out_size = int((out_size - kernel_sizes[i] + 2 * paddings[i]) /
strides[i]) + 1
self.out_conv = nn.Conv2d(in_channels=n_filters[-1],
out_channels=n_filters[-1],
out_channels=in_channels,
kernel_size=kernel_sizes[-1],
padding=paddings[-1])
out_size = int(
(out_size - kernel_sizes[-1] + 2 * paddings[-1]) / strides[-1]) + 1
self.n_flat = (out_size**2) * n_filters[-1]
self.linear = nn.Linear(in_features=self.n_flat, out_features=1)
self.activation = act_func
self.type(dtype)

Expand All @@ -111,13 +108,7 @@ def forward(self, q, p):
x = self.activation(self.in_conv(x))
for layer in self.hidden_layers:
x = self.activation(layer(x))
x = self.activation(self.out_conv(x))
x = x.view(-1, self.n_flat)
x = self.linear(x)
return x

x = self.out_conv(x)
q, p = x[:, :self.in_shape[0]], x[:, self.in_shape[0]:]
return q, p

if __name__ == '__main__':
hamiltonian_net = HamiltonianNet(in_shape=(16, 4, 4))
q, p = torch.randn((2, 128, 16, 4, 4))
h = hamiltonian_net(q, p)
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, params, resume=False):
# Load hgn from parameters to deice
self.hgn = load_hgn(params=self.params,
device=self.device,
dtype=self.dtype)
dtype=self.dtype,)

# Either generate data on-the-fly or load the data from disk
if "train_data" in self.params["dataset"]:
Expand Down Expand Up @@ -162,7 +162,6 @@ def training_step(self, rollouts):

train_loss.backward()
self.optimizer.step()

return losses, hgn_output

def fit(self):
Expand Down
8 changes: 2 additions & 6 deletions utilities/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,16 @@ def load_hgn(params, device, dtype):
**params["networks"]["decoder"],
dtype=dtype).to(device)

# Define HGN integrator
integrator = Integrator(delta_t=params["dataset"]["rollout"]["delta_time"],
method=params["integrator"]["method"])

# Instantiate Hamiltonian Generative Network
hgn = HGN(encoder=encoder,
transformer=transformer,
hnn=hnn,
decoder=decoder,
integrator=integrator,
device=device,
dtype=dtype,
seq_len=params["dataset"]["rollout"]["seq_length"],
channels=params["dataset"]["rollout"]["n_channels"])
channels=params["dataset"]["rollout"]["n_channels"],
delta_t=params["dataset"]["rollout"]["delta_time"])
return hgn


Expand Down
3 changes: 0 additions & 3 deletions utilities/training_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def step(self, losses, rollout_batch, prediction, model):
for loss_name, loss_value in losses.items():
if loss_value is not None:
self.writer.add_scalar(f'{loss_name}', loss_value, self.iteration)
enery_mean, energy_std = prediction.get_energy()
self.writer.add_scalar(f'energy/mean', enery_mean, self.iteration)
self.writer.add_scalar(f'energy/std', energy_std, self.iteration)

if self.iteration % self.rollout_freq == 0:
self.writer.add_video('data/input',
Expand Down