Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmf/trainers/core/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def prediction_loop(self, dataset_type: str) -> None:

for batch in tqdm.tqdm(dataloader):
prepared_batch = reporter.prepare_batch(batch)
prepared_batch = to_device(prepared_batch, torch.device("cuda"))
prepared_batch = to_device(prepared_batch, self.device)
with torch.cuda.amp.autocast(enabled=self.training_config.fp16):
model_output = self.model(prepared_batch)
report = Report(prepared_batch, model_output)
Expand Down
16 changes: 8 additions & 8 deletions mmf/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from torch import distributed as dist
try:
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
Expand Down Expand Up @@ -82,10 +82,10 @@ def broadcast_tensor(tensor, src=0):
if is_xla():
tensor = xm.all_to_all(
tensor.repeat([world_size,1]),
split_dimension=0,
concat_dimension=0,
split_count=world_size)[0]
else:
split_dimension=0,
concat_dimension=0,
split_count=world_size)[0]
else:
dist.broadcast(tensor, src=0)

return tensor
Expand Down Expand Up @@ -128,9 +128,9 @@ def gather_tensor(tensor):
if is_xla():
tensor_list = xm.all_gather(tensor)
tensor_list = tensor_list.view(world_size, *tensor.size())
else:
else:
dist.all_gather(tensor_list, tensor)
tensor_list = torch.stack(tensor_list, dim=0)
tensor_list = torch.stack(tuple(tensor_list), dim=0)
return tensor_list


Expand All @@ -151,7 +151,7 @@ def reduce_dict(dictionary):
[values],
scale=1.0/world_size
)[0]
else:
else:
dist.reduce(values, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
Expand Down