diff --git a/DeepSequence/model.py b/DeepSequence/model.py index 7fa30b2..93d1d08 100644 --- a/DeepSequence/model.py +++ b/DeepSequence/model.py @@ -1,4 +1,7 @@ from __future__ import print_function +from collections import OrderedDict +import cPickle +import os import numpy as np import theano @@ -7,10 +10,6 @@ #import theano.sandbox.linalg as T_linalg from scipy.special import erfinv -import cPickle - -from collections import OrderedDict - if theano.config.floatX == "float16": print ("using epsilon=1e-6") epsilon = 1e-6 @@ -716,6 +715,7 @@ def create_gradientfunctions(self): def save_parameters(self, file_prefix): """Saves all the parameters in a way they can be retrieved later""" + # TODO(Lood): Can also add a flag to save parameters in custom directory (not self.working_dir/params/) cPickle.dump({name: p.get_value() for name, p in self.params.items()},\ open(self.working_dir+"/params/"+file_prefix + "_params.pkl", "wb")) cPickle.dump({name: m.get_value() for name, m in self.m.items()}, \ @@ -723,13 +723,30 @@ def save_parameters(self, file_prefix): cPickle.dump({name: v.get_value() for name, v in self.v.items()}, \ open(self.working_dir+"/params/"+file_prefix +"_v.pkl", "wb")) - def load_parameters(self, file_prefix=""): + def load_parameters(self, file_prefix="", seed=None, override_params_dir=None): """Load the variables in a shared variable safe way""" - p_list = cPickle.load(open(self.working_dir+"/params/"+file_prefix \ + if override_params_dir is not None: + params_dir = override_params_dir + else: + params_dir = os.path.join(self.working_dir, 'params') + assert os.path.isdir(params_dir), "{} is not a directory".format(params_dir) + # Check _params.pkl exists, then assume the others exist too + file_matches = [file for file in os.listdir(params_dir) if file.startswith(file_prefix) + and file.endswith("_params.pkl") and "epoch" not in file] # Ignoring the intermediate "_theta__params" checkpoints + + if seed is not None: + print("Searching for seed {}, files matched so far: {}".format(seed, file_matches)) + file_matches = [file for file in file_matches if "seed-" + str(seed) in file] + + assert len(file_matches) == 1, "Could not find unique params file for prefix {} in {}, found {} files".format(file_prefix, params_dir, file_matches) + print("Matched file: {}".format(file_matches[0])) + full_prefix = os.path.join(params_dir, file_matches[0].replace("_params.pkl", "")) + + p_list = cPickle.load(open(full_prefix \ + "_params.pkl", "rb")) - m_list = cPickle.load(open(self.working_dir+"/params/"+file_prefix \ + m_list = cPickle.load(open(full_prefix \ + "_m.pkl", "rb")) - v_list = cPickle.load(open(self.working_dir+"/params/"+file_prefix \ + v_list = cPickle.load(open(full_prefix \ + "_v.pkl", "rb")) for name in p_list.keys(): diff --git a/DeepSequence/train.py b/DeepSequence/train.py index 7779a35..0ec4074 100644 --- a/DeepSequence/train.py +++ b/DeepSequence/train.py @@ -44,6 +44,7 @@ def train(data, batch_order = np.arange(data.x_train.shape[0]) seq_sample_probs = data.weights / np.sum(data.weights) + assert len(seq_sample_probs) == data.x_train.shape[0], "Length of sequence weights {} does not match number of sequences {}".format(len(seq_sample_probs), data.x_train.shape[0]) update_num = 0