Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pathflowai/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def return_transformer(training_opts):
transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms'])

transformers = get_data_transforms(**transform_opts)
return dataset_df,dataset_opts,transformers
return dataset_df,dataset_opts,transformers,transform_opts

def return_datasets(training_opts,dataset_df,transformers):
datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']}
Expand Down Expand Up @@ -124,7 +124,7 @@ def return_datasets(training_opts,dataset_df,transformers):

if training_opts['external_test_db'] and training_opts['external_test_dir']:
datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename'])
return datasets,training_opts,transform_opts
return datasets,training_opts

#@pysnooper.snoop('train_model.log')
def train_model_(training_opts):
Expand All @@ -141,14 +141,14 @@ def train_model_(training_opts):

model = return_model(training_opts)

dataset_df,dataset_opts,transformers=return_transformer(training_opts)
dataset_df,dataset_opts,transformers,transform_opts=return_transformer(training_opts)

if training_opts['extract_embedding'] and training_opts['npy_file']:
dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"])
dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir'])
exit()

datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers)
datasets,training_opts=return_datasets(training_opts,dataset_df,transformers)

if training_opts['num_training_images_epoch']>0:
num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size']
Expand All @@ -162,7 +162,7 @@ def train_model_(training_opts):

if training_opts['run_test']: run_test(dataloaders['train'])

model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches)
model_trainer_opts=return_trainer_opts(model,training_opts,dataloaders,num_train_batches)

if not training_opts['predict']:

Expand Down