From 37b536e31b49e65f736afc91c4d322f5882a9e2b Mon Sep 17 00:00:00 2001 From: Sergiodiaz53 Date: Wed, 3 Nov 2021 10:29:41 +0100 Subject: [PATCH] some bugs fixed --- train.py | 396 +++++++++++++++++++++++++++---------------------------- 1 file changed, 198 insertions(+), 198 deletions(-) diff --git a/train.py b/train.py index 81938aa..05130e2 100644 --- a/train.py +++ b/train.py @@ -57,211 +57,211 @@ parser.add_argument('--lam3', default=10, type=float) parser.add_argument('--final_mlp_style', default='sep', type=str,choices = ['common','sep']) - -opt = parser.parse_args() -modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,str(opt.dset_id),opt.run_name) -if opt.task == 'regression': - opt.dtask = 'reg' -else: - opt.dtask = 'clf' +if __name__ == "__main__": + opt = parser.parse_args() + modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,str(opt.dset_id),opt.run_name) + if opt.task == 'regression': + opt.dtask = 'reg' + else: + opt.dtask = 'clf' -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -print(f"Device is {device}.") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device is {device}.") -torch.manual_seed(opt.set_seed) -os.makedirs(modelsave_path, exist_ok=True) + torch.manual_seed(opt.set_seed) + os.makedirs(modelsave_path, exist_ok=True) -if opt.active_log: - import wandb - if opt.pretrain: - wandb.init(project="saint_v2_all", group =opt.run_name ,name = f'pretrain_{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}') - else: - if opt.task=='multiclass': - wandb.init(project="saint_v2_all_kamal", group =opt.run_name ,name = f'{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}') - else: - wandb.init(project="saint_v2_all", group =opt.run_name ,name = f'{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}') - - - -print('Downloading and processing the dataset, it might take some time.') -cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std = data_prep_openml(opt.dset_id, opt.dset_seed,opt.task, datasplit=[.65, .15, .2]) -continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32) - -##### Setting some hyperparams based on inputs and dataset -_,nfeat = X_train['data'].shape -if nfeat > 100: - opt.embedding_size = min(4,opt.embedding_size) - opt.batchsize = min(64, opt.batchsize) -if opt.attentiontype != 'col': - opt.transformer_depth = 1 - opt.attention_heads = min(4,opt.attention_heads) - opt.attention_dropout = 0.8 - opt.embedding_size = min(32,opt.embedding_size) - opt.ff_dropout = 0.8 - -print(nfeat,opt.batchsize) -print(opt) - -if opt.active_log: - wandb.config.update(opt) -train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std) -trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4) - -valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std) -validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4) - -test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std) -testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4) -if opt.task == 'regression': - y_dim = 1 -else: - y_dim = len(np.unique(y_train['data'][:,0])) - -cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings. - - - -model = SAINT( -categories = tuple(cat_dims), -num_continuous = len(con_idxs), -dim = opt.embedding_size, -dim_out = 1, -depth = opt.transformer_depth, -heads = opt.attention_heads, -attn_dropout = opt.attention_dropout, -ff_dropout = opt.ff_dropout, -mlp_hidden_mults = (4, 2), -cont_embeddings = opt.cont_embeddings, -attentiontype = opt.attentiontype, -final_mlp_style = opt.final_mlp_style, -y_dim = y_dim -) -vision_dset = opt.vision_dset - -if y_dim == 2 and opt.task == 'binary': - # opt.task = 'binary' - criterion = nn.CrossEntropyLoss().to(device) -elif y_dim > 2 and opt.task == 'multiclass': - # opt.task = 'multiclass' - criterion = nn.CrossEntropyLoss().to(device) -elif opt.task == 'regression': - criterion = nn.MSELoss().to(device) -else: - raise'case not written yet' - -model.to(device) - - -if opt.pretrain: - from pretraining import SAINT_pretrain - model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device) - -## Choosing the optimizer - -if opt.optimizer == 'SGD': - optimizer = optim.SGD(model.parameters(), lr=opt.lr, - momentum=0.9, weight_decay=5e-4) - from utils import get_scheduler - scheduler = get_scheduler(opt, optimizer) -elif opt.optimizer == 'Adam': - optimizer = optim.Adam(model.parameters(),lr=opt.lr) -elif opt.optimizer == 'AdamW': - optimizer = optim.AdamW(model.parameters(),lr=opt.lr) -best_valid_auroc = 0 -best_valid_accuracy = 0 -best_test_auroc = 0 -best_test_accuracy = 0 -best_valid_rmse = 100000 -print('Training begins now.') -for epoch in range(opt.epochs): - model.train() - running_loss = 0.0 - for i, data in enumerate(trainloader, 0): - optimizer.zero_grad() - # x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont. - x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device) - - # We are converting the data to embeddings in the next step - _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset) - reps = model.transformer(x_categ_enc, x_cont_enc) - # select only the representations corresponding to y and apply mlp on it in the next step to get the predictions. - y_reps = reps[:,0,:] - - y_outs = model.mlpfory(y_reps) - if opt.task == 'regression': - loss = criterion(y_outs,y_gts) + if opt.active_log: + import wandb + if opt.pretrain: + wandb.init(project="saint_v2_all", group =opt.run_name ,name = f'pretrain_{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}') else: - loss = criterion(y_outs,y_gts.squeeze()) - loss.backward() - optimizer.step() - if opt.optimizer == 'SGD': - scheduler.step() - running_loss += loss.item() - # print(running_loss) + if opt.task=='multiclass': + wandb.init(project="saint_v2_all_kamal", group =opt.run_name ,name = f'{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}') + else: + wandb.init(project="saint_v2_all", group =opt.run_name ,name = f'{opt.task}_{str(opt.attentiontype)}_{str(opt.dset_id)}_{str(opt.set_seed)}') + + + + print('Downloading and processing the dataset, it might take some time.') + cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std = data_prep_openml(opt.dset_id, opt.dset_seed,opt.task, datasplit=[.65, .15, .2]) + continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32) + + ##### Setting some hyperparams based on inputs and dataset + _,nfeat = X_train['data'].shape + if nfeat > 100: + opt.embedding_size = min(4,opt.embedding_size) + opt.batchsize = min(64, opt.batchsize) + if opt.attentiontype != 'col': + opt.transformer_depth = 1 + opt.attention_heads = min(4,opt.attention_heads) + opt.attention_dropout = 0.8 + opt.embedding_size = min(32,opt.embedding_size) + opt.ff_dropout = 0.8 + + print(nfeat,opt.batchsize) + print(opt) + if opt.active_log: - wandb.log({'epoch': epoch ,'train_epoch_loss': running_loss, - 'loss': loss.item() - }) - if epoch%5==0: - model.eval() - with torch.no_grad(): - if opt.task in ['binary','multiclass']: - accuracy, auroc = classification_scores(model, validloader, device, opt.task,vision_dset) - test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task,vision_dset) - - print('[EPOCH %d] VALID ACCURACY: %.3f, VALID AUROC: %.3f' % - (epoch + 1, accuracy,auroc )) - print('[EPOCH %d] TEST ACCURACY: %.3f, TEST AUROC: %.3f' % - (epoch + 1, test_accuracy,test_auroc )) - if opt.active_log: - wandb.log({'valid_accuracy': accuracy ,'valid_auroc': auroc }) - wandb.log({'test_accuracy': test_accuracy ,'test_auroc': test_auroc }) - if opt.task =='multiclass': - if accuracy > best_valid_accuracy: - best_valid_accuracy = accuracy - best_test_auroc = test_auroc - best_test_accuracy = test_accuracy - torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path)) + wandb.config.update(opt) + train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std) + trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4) + + valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std) + validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4) + + test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std) + testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4) + if opt.task == 'regression': + y_dim = 1 + else: + y_dim = len(np.unique(y_train['data'][:,0])) + + cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings. + + + + model = SAINT( + categories = tuple(cat_dims), + num_continuous = len(con_idxs), + dim = opt.embedding_size, + dim_out = 1, + depth = opt.transformer_depth, + heads = opt.attention_heads, + attn_dropout = opt.attention_dropout, + ff_dropout = opt.ff_dropout, + mlp_hidden_mults = (4, 2), + cont_embeddings = opt.cont_embeddings, + attentiontype = opt.attentiontype, + final_mlp_style = opt.final_mlp_style, + y_dim = y_dim + ) + vision_dset = opt.vision_dset + + if y_dim == 2 and opt.task == 'binary': + # opt.task = 'binary' + criterion = nn.CrossEntropyLoss().to(device) + elif y_dim > 2 and opt.task == 'multiclass': + # opt.task = 'multiclass' + criterion = nn.CrossEntropyLoss().to(device) + elif opt.task == 'regression': + criterion = nn.MSELoss().to(device) + else: + raise'case not written yet' + + model.to(device) + + + if opt.pretrain: + from pretraining import SAINT_pretrain + model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device) + + ## Choosing the optimizer + + if opt.optimizer == 'SGD': + optimizer = optim.SGD(model.parameters(), lr=opt.lr, + momentum=0.9, weight_decay=5e-4) + from utils import get_scheduler + scheduler = get_scheduler(opt, optimizer) + elif opt.optimizer == 'Adam': + optimizer = optim.Adam(model.parameters(),lr=opt.lr) + elif opt.optimizer == 'AdamW': + optimizer = optim.AdamW(model.parameters(),lr=opt.lr) + best_valid_auroc = 0 + best_valid_accuracy = 0 + best_test_auroc = 0 + best_test_accuracy = 0 + best_valid_rmse = 100000 + print('Training begins now.') + for epoch in range(opt.epochs): + model.train() + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + optimizer.zero_grad() + # x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont. + x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device) + + # We are converting the data to embeddings in the next step + _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset) + reps = model.transformer(x_categ_enc, x_cont_enc) + # select only the representations corresponding to y and apply mlp on it in the next step to get the predictions. + y_reps = reps[:,0,:] + + y_outs = model.mlpfory(y_reps) + if opt.task == 'regression': + loss = criterion(y_outs,y_gts) + else: + loss = criterion(y_outs,y_gts.squeeze().type(torch.LongTensor)) + loss.backward() + optimizer.step() + if opt.optimizer == 'SGD': + scheduler.step() + running_loss += loss.item() + # print(running_loss) + if opt.active_log: + wandb.log({'epoch': epoch ,'train_epoch_loss': running_loss, + 'loss': loss.item() + }) + if epoch%5==0: + model.eval() + with torch.no_grad(): + if opt.task in ['binary','multiclass']: + accuracy, auroc = classification_scores(model, validloader, device, opt.task,vision_dset) + test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task,vision_dset) + + print('[EPOCH %d] VALID ACCURACY: %.3f, VALID AUROC: %.3f' % + (epoch + 1, accuracy,auroc )) + print('[EPOCH %d] TEST ACCURACY: %.3f, TEST AUROC: %.3f' % + (epoch + 1, test_accuracy,test_auroc )) + if opt.active_log: + wandb.log({'valid_accuracy': accuracy ,'valid_auroc': auroc }) + wandb.log({'test_accuracy': test_accuracy ,'test_auroc': test_auroc }) + if opt.task =='multiclass': + if accuracy > best_valid_accuracy: + best_valid_accuracy = accuracy + best_test_auroc = test_auroc + best_test_accuracy = test_accuracy + torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path)) + else: + if accuracy > best_valid_accuracy: + best_valid_accuracy = accuracy + # if auroc > best_valid_auroc: + # best_valid_auroc = auroc + best_test_auroc = test_auroc + best_test_accuracy = test_accuracy + torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path)) + else: - if accuracy > best_valid_accuracy: - best_valid_accuracy = accuracy - # if auroc > best_valid_auroc: - # best_valid_auroc = auroc - best_test_auroc = test_auroc - best_test_accuracy = test_accuracy + valid_rmse = mean_sq_error(model, validloader, device,vision_dset) + test_rmse = mean_sq_error(model, testloader, device,vision_dset) + print('[EPOCH %d] VALID RMSE: %.3f' % + (epoch + 1, valid_rmse )) + print('[EPOCH %d] TEST RMSE: %.3f' % + (epoch + 1, test_rmse )) + if opt.active_log: + wandb.log({'valid_rmse': valid_rmse ,'test_rmse': test_rmse }) + if valid_rmse < best_valid_rmse: + best_valid_rmse = valid_rmse + best_test_rmse = test_rmse torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path)) + model.train() + - else: - valid_rmse = mean_sq_error(model, validloader, device,vision_dset) - test_rmse = mean_sq_error(model, testloader, device,vision_dset) - print('[EPOCH %d] VALID RMSE: %.3f' % - (epoch + 1, valid_rmse )) - print('[EPOCH %d] TEST RMSE: %.3f' % - (epoch + 1, test_rmse )) - if opt.active_log: - wandb.log({'valid_rmse': valid_rmse ,'test_rmse': test_rmse }) - if valid_rmse < best_valid_rmse: - best_valid_rmse = valid_rmse - best_test_rmse = test_rmse - torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path)) - model.train() - - - -total_parameters = count_parameters(model) -print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters)) -if opt.task =='binary': - print('AUROC on best model: %.3f' %(best_test_auroc)) -elif opt.task =='multiclass': - print('Accuracy on best model: %.3f' %(best_test_accuracy)) -else: - print('RMSE on best model: %.3f' %(best_test_rmse)) - -if opt.active_log: - if opt.task == 'regression': - wandb.log({'total_parameters': total_parameters, 'test_rmse_bestep':best_test_rmse , - 'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) }) + + total_parameters = count_parameters(model) + print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters)) + if opt.task =='binary': + print('AUROC on best model: %.3f' %(best_test_auroc)) + elif opt.task =='multiclass': + print('Accuracy on best model: %.3f' %(best_test_accuracy)) else: - wandb.log({'total_parameters': total_parameters, 'test_auroc_bestep':best_test_auroc , - 'test_accuracy_bestep':best_test_accuracy,'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) }) + print('RMSE on best model: %.3f' %(best_test_rmse)) + + if opt.active_log: + if opt.task == 'regression': + wandb.log({'total_parameters': total_parameters, 'test_rmse_bestep':best_test_rmse , + 'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) }) + else: + wandb.log({'total_parameters': total_parameters, 'test_auroc_bestep':best_test_auroc , + 'test_accuracy_bestep':best_test_accuracy,'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) })