From 274a1a6f9981d21fb8fe4888d215158d0c759bef Mon Sep 17 00:00:00 2001 From: Sagar Gupta Date: Thu, 21 Oct 2021 11:12:43 -0400 Subject: [PATCH] Fix errors in model_training.py --- pathflowai/model_training.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pathflowai/model_training.py b/pathflowai/model_training.py index 9b833ad..2ceeddf 100644 --- a/pathflowai/model_training.py +++ b/pathflowai/model_training.py @@ -81,7 +81,7 @@ def return_transformer(training_opts): transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms']) transformers = get_data_transforms(**transform_opts) - return dataset_df,dataset_opts,transformers + return dataset_df,dataset_opts,transformers,transform_opts def return_datasets(training_opts,dataset_df,transformers): datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']} @@ -124,7 +124,7 @@ def return_datasets(training_opts,dataset_df,transformers): if training_opts['external_test_db'] and training_opts['external_test_dir']: datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename']) - return datasets,training_opts,transform_opts + return datasets,training_opts #@pysnooper.snoop('train_model.log') def train_model_(training_opts): @@ -141,14 +141,14 @@ def train_model_(training_opts): model = return_model(training_opts) - dataset_df,dataset_opts,transformers=return_transformer(training_opts) + dataset_df,dataset_opts,transformers,transform_opts=return_transformer(training_opts) if training_opts['extract_embedding'] and training_opts['npy_file']: dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"]) dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir']) exit() - datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers) + datasets,training_opts=return_datasets(training_opts,dataset_df,transformers) if training_opts['num_training_images_epoch']>0: num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size'] @@ -162,7 +162,7 @@ def train_model_(training_opts): if training_opts['run_test']: run_test(dataloaders['train']) - model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches) + model_trainer_opts=return_trainer_opts(model,training_opts,dataloaders,num_train_batches) if not training_opts['predict']: