From 9e0fb75ff232c766af8b141b063259ba473233e4 Mon Sep 17 00:00:00 2001 From: Julius Simonelli Date: Sun, 9 Aug 2020 11:37:32 -0400 Subject: [PATCH] Update train.py --- PT-BOSS/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PT-BOSS/train.py b/PT-BOSS/train.py index 3d81754..c7d0eb7 100644 --- a/PT-BOSS/train.py +++ b/PT-BOSS/train.py @@ -104,7 +104,7 @@ def train_one_epoch( ims_x_strong = ims_x_strong.cuda() ims_x_weak = ims_x_weak.cuda() - lbs_x = lbs_x.cuda() + lbs_x = lbs_x.type(torch.LongTensor).cuda() ims_u_weak = ims_u_weak.cuda() ims_u_strong = ims_u_strong.cuda() @@ -127,7 +127,7 @@ def train_one_epoch( with torch.no_grad(): lbs_u_real = lbs_u_real[valid_u].cuda() corr_lb = lbs_u_real == lbs_u - loss_u_real = F.cross_entropy(logits_u, lbs_u_real) + loss_u_real = F.cross_entropy(logits_u, lbs_u_real.type(torch.LongTensor).cuda()) else: logits_x = model(ims_x_weak) loss_x = criteria_x(logits_x, lbs_x)