-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
0batch [00:00, ?batch/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[21], line 5
3 for i in range(runs):
4 start_time = time.time()
----> 5 result, stats, model = train_run(1e-3)
6 results.append(result)
7 print("result:", result)
Cell In[20], line 20, in train_run(lr)
17 stats = defaultdict(list)
19 start_time = time.time()
---> 20 train_loop(autoencoder, train_dataloader, optimizer, 10, stats)
21 result = test_loop(autoencoder, test_dataloader)
22 return result, stats, autoencoder
Cell In[17], line 24, in train_loop(model, dataloader, optimizer, epochs, stats)
19 image = data.cuda()
20 # print(image.shape)
21 # break
22
23 # loss = train_batch(image, model, optimizer, autoencoder.spectral_loss)
---> 24 loss, losses = train_batch(model, image, image, optimizer, F.mse_loss)
26 if isinstance(losses, tuple):
27 stats['raw_losses'].append(loss)
Cell In[16], line 19, in train_batch(model, inputs, targets, optimizer, loss_func)
16 loss = F.mse_loss(inputs, outputs)
17 loss.backward()
---> 19 optimizer.step()
21 return loss.detach().item(), None
File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
482 else:
483 raise RuntimeError(
484 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
485 )
--> 487 out = func(*args, **kwargs)
488 self._optimizer_step_code()
490 # call optimizer step post hooks
File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:601, in StatefulOptimizer.step(self, closure)
599 for top_group in self.param_groups:
600 for group in self.get_groups(top_group):
--> 601 self._step(group)
602 if self.use_ema:
603 self.ema_update(group)
File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/precond_schedule_palm_foreach_soap.py:64, in PrecondSchedulePaLMForeachSOAP._step(self, group)
62 state["exp_avg_sq"] = torch.zeros_like(g, dtype=torch.float32)
63 init_preconditioner(g, state, max_precond_dim, precondition_1d)
---> 64 update_preconditioner(g, state, max_precond_dim, precondition_1d, 0, True)
65 continue # first step is skipped so that we never use the current gradients in the projection.
67 # Projecting gradients to the eigenbases of Shampoo's preconditioner
68 # i.e. projecting to the eigenbases of matrices in state['GG']
File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:423, in update_preconditioner(grad, state, max_precond_dim, precondition_1d, beta, update_precond)
421 compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
422 if state['Q'] is None:
--> 423 state['Q'] = get_orthogonal_matrix(state['GG'])
424 if update_precond:
425 get_orthogonal_matrix_QR(state['GG'], state['Q'], state['exp_avg_sq'])
File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:320, in get_orthogonal_matrix(mat)
318 for modifier in (None, torch.double, 'cpu'):
319 if modifier is not None:
--> 320 m = m.to(modifier)
321 try:
322 Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))[1].to(device=device,
323 dtype=dtype)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
This is using:
optimizer = heavyball.PrecondSchedulePaLMForeachSOAP(autoencoder.parameters(), lr=lr)
It runs fine with default torch sgd, so I assume it's not a problem with my model.
Let me know if you need more information.
Metadata
Metadata
Assignees
Labels
No labels