-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
42 lines (33 loc) · 1.3 KB
/
utils.py
File metadata and controls
42 lines (33 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from cyclical_learning_rate import CyclicLR
from config import NUM_EPOCHS, EARLY_STOPPING_PATIENCE
class WorkingModelCheckpoint(Callback):
def __init__(self, model=None, save_file=None, **kwargs):
super(WorkingModelCheckpoint, self).__init__(**kwargs)
self.model = model
self.save_file = "models/model.hdf5" if save_file is None else save_file
def on_train_begin(self, logs={}):
self.minloss = float('inf')
return
# Save model if the validation loss is at its lowest point
def on_epoch_end(self, epoch, logs={}):
if self.minloss > logs.get('val_loss'):
self.minloss = logs.get('val_loss')
self.model.save(self.save_file)
#print('\t\tModel saved!')
return
def get_cb_early_stopping(patience):
return EarlyStopping(monitor='val_loss', min_delta=0, patience=patience, verbose=0)
def get_cb_checkpoint(model, save_file):
return WorkingModelCheckpoint(model, save_file=save_file)
def get_cb_cyclic_lr():
return CyclicLR(
base_lr=0.001,
max_lr=0.006,
step_size=2000.,
mode='triangular',
gamma=1.,
scale_fn=None,
scale_mode='cycle'
)