diff --git a/vde/vde.py b/vde/vde.py index a36ebbb..11490f9 100644 --- a/vde/vde.py +++ b/vde/vde.py @@ -7,6 +7,8 @@ import numpy as np +from collections import defaultdict + import torch import torch.nn as nn import torch.optim as optim @@ -111,8 +113,38 @@ def forward(self, x): out = self.output_layer(out) return out +class Model(nn.Module): + """Full VDE model""" + def __init__(self, input_size, encoder_size=1, + hidden_layer_depth=3, hidden_size=2048, scale=1E-3, + dropout_rate=0., activation='Swish', cuda=False): -class VDE(BaseEstimator, nn.Module): + super(Model, self).__init__() + + self.encoder = Encoder(input_size, output_size=encoder_size, + hidden_layer_depth=hidden_layer_depth, + hidden_size=hidden_size, activation=activation, + dropout_rate=dropout_rate) + self.lmbd = Lambda(encoder_size, encoder_size, scale=scale) + self.decoder = Decoder(input_size, input_size=encoder_size, + hidden_layer_depth=hidden_layer_depth, + hidden_size=hidden_size, activation=activation, + dropout_rate=dropout_rate) + + self.dtype = torch.FloatTensor + if cuda: + self.cuda() + self.dtype = torch.cuda.FloatTensor + + self.apply(initialize_weights) + + def forward(self, x): + u = self.encoder(x) + u_p = self.lmbd(u) + out = self.decoder(u_p) + return out, u + +class VDE(BaseEstimator): """Variational Dynamical Encoder (VDE) Non-linear dimensionality reduction using a time-lagged variational @@ -159,38 +191,39 @@ def __init__(self, input_size, lag_time=1, encoder_size=1, batch_size=100, verbose=True): super(VDE, self).__init__() - self.encoder = Encoder(input_size, output_size=encoder_size, - hidden_layer_depth=hidden_layer_depth, - hidden_size=hidden_size, activation=activation, - dropout_rate=dropout_rate) - self.lmbd = Lambda(encoder_size, encoder_size, scale=scale) - self.decoder = Decoder(input_size, input_size=encoder_size, - hidden_layer_depth=hidden_layer_depth, - hidden_size=hidden_size, activation=activation, - dropout_rate=dropout_rate) - - self.verbose = verbose self.input_size = input_size + self.lag_time = lag_time self.encoder_size = encoder_size - self.n_epochs = n_epochs self.batch_size = batch_size - self.lag_time = lag_time + self.hidden_layer_depth = hidden_layer_depth + self.hidden_size = hidden_size + self.scale = scale + self.dropout_rate = dropout_rate + self.learning_rate = learning_rate + self.n_epochs = n_epochs + self.optimizer = optimizer + self.activation = activation + self.loss = loss self.sliding_window = sliding_window self.autocorr = autocorr - - self.use_cuda = cuda - self.dtype = torch.FloatTensor - if self.use_cuda: - self.cuda() - self.dtype = torch.cuda.FloatTensor - self.apply(initialize_weights) - - self.learning_rate = learning_rate + self.cuda = cuda + self.verbose = verbose + + self._init_model() + + + def _init_model(self): + + self._model = Model(self.input_size, encoder_size=self.encoder_size, + hidden_layer_depth=self.hidden_layer_depth, hidden_size=self.hidden_size, + scale=self.scale, dropout_rate=self.dropout_rate, activation=self.activation, + cuda=self.cuda) + if optimizer == 'Adam': - self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) + self._optimizer = optim.Adam(self._model.parameters(), lr=learning_rate) elif optimizer == 'SGD': - self.optimizer = optim.SGD(self.parameters(), lr=learning_rate) + self._optimizer = optim.SGD(self._model.parameters(), lr=learning_rate) else: raise ValueError('Not a recognized optimizer') @@ -202,6 +235,42 @@ def __init__(self, input_size, lag_time=1, encoder_size=1, batch_size=100, raise ValueError('Not a recognized loss function') self.is_fitted = False + + return self + + def set_params(self, **params): + """Set the parameters of this estimator. + The method works on simple estimators as well as on nested objects + (such as pipelines). The latter have parameters of the form + ``__`` so that it's possible to update each + component of a nested object. + Returns + ------- + self + """ + if not params: + # Simple optimization to gain speed (inspect is slow) + return self + valid_params = self.get_params(deep=True) + + nested_params = defaultdict(dict) # grouped by prefix + for key, value in params.items(): + key, delim, sub_key = key.partition('__') + if key not in valid_params: + raise ValueError('Invalid parameter %s for estimator %s. ' + 'Check the list of available parameters ' + 'with `estimator.get_params().keys()`.' % + (key, self)) + + if delim: + nested_params[key][sub_key] = value + else: + setattr(self, key, value) + + for key, sub_params in nested_params.items(): + valid_params[key].set_params(**sub_params) + + return self._init_model() def __repr__(self): return """VDE(input_size={input_size}, encoder_size={encoder_size}, n_epochs={n_epochs}, @@ -214,17 +283,11 @@ def __repr__(self): lag_time=self.lag_time, sliding_window=self.sliding_window, autocorr=self.autocorr, - cuda=self.use_cuda + cuda=self.cuda ) - def forward(self, x): - u = self.encoder(x) - u_p = self.lmbd(u) - out = self.decoder(u_p) - return out, u - def _rec(self, x_decoded_mean, x, loss_fn): - z_mean, z_log_var = self.lmbd.mu, self.lmbd.log_v + z_mean, z_log_var = self._model.lmbd.mu, self._model.lmbd.log_v loss = loss_fn(x_decoded_mean, x) kl_loss = -0.5 * torch.mean(1. + z_log_var - z_mean ** 2. - torch.exp(z_log_var)) @@ -243,26 +306,26 @@ def _corr(self, x, y): return r_val def compute_loss(self, X): - x = Variable(X[:, :, 0].type(self.dtype), requires_grad=True) - y = Variable(X[:, :, 1].type(self.dtype), requires_grad=True) + x = Variable(X[:, :, 0].type(self._model.dtype), requires_grad=True) + y = Variable(X[:, :, 1].type(self._model.dtype), requires_grad=True) - o, u = self(x) + o, u = self._model(x) autocorr_loss = 0. rec_loss = self._rec(o, y.detach(), self.loss_fn) loss = rec_loss if self.autocorr: - v = self.encoder(y) + v = self._model.encoder(y) autocorr_loss = (1 - self._corr(u, v)) loss = rec_loss + autocorr_loss - self.optimizer.zero_grad() + self._optimizer.zero_grad() loss.backward() return loss, rec_loss, autocorr_loss, x def _train(self, data, print_every=100): - self.train() + self._model.train() for t, X in enumerate(data): loss, rec_loss, autocorr_loss, _ = self.compute_loss(X) @@ -274,7 +337,7 @@ def _train(self, data, print_every=100): print('rec_loss = %.4f, ' 'autocorr_loss = %.4f' % (rec_loss.data[0], autocorr_loss.data[0])) - self.optimizer.step() + self._optimizer.step() def _create_dataset(self, data): slide = self.lag_time if self.sliding_window else 1 @@ -302,29 +365,29 @@ def fit(self, X): def _batch_transform(self, x): y = [] for arr in np.array_split(x, x.shape[0] // self.batch_size): - out = self.encoder(Variable( - torch.from_numpy(arr).type(self.dtype)) + out = self._model.encoder(Variable( + torch.from_numpy(arr).type(self._model.dtype)) ).cpu().data.numpy() y.append(out.reshape(-1, self.encoder_size)) return np.concatenate(y, axis=0) def propagate(self, X, scale=None): - self.eval() + self._model.eval() if self.is_fitted: - out = self.encoder(Variable( + out = self._model.encoder(Variable( torch.from_numpy(X.reshape(-1, self.input_size) - ).type(self.dtype))) + ).type(self._model.dtype))) if scale is not None: - old_scale = self.lmbd.scale - self.lmbd.scale = scale - out = self.lmbd(out) - self.lmbd.scale = old_scale - return self.decoder(out).cpu().data.numpy() + old_scale = self._model.lmbd.scale + self._model.lmbd.scale = scale + out = self._model.lmbd(out) + self._model.lmbd.scale = old_scale + return self._model.decoder(out).cpu().data.numpy() raise RuntimeError('Model needs to be fit.') def transform(self, X): - self.eval() + self._model.eval() if self.is_fitted: out = [self._batch_transform(x) for x in X] return out @@ -335,10 +398,10 @@ def fit_transform(self, X): return self.transform(X) def compute_saliency(self, data, add_n_lag_zeros=True): - self.eval() + self._model.eval() saliency_list = [] - scale = self.lmbd.scale - self.lmbd.scale = 0. + scale = self._model.lmbd.scale + self._model.lmbd.scale = 0. for t, X in enumerate(data): _, _, _, x0 = self.compute_loss(X) @@ -346,7 +409,7 @@ def compute_saliency(self, data, add_n_lag_zeros=True): saliency = saliency.squeeze() saliency_list.append(saliency) - self.lmbd.scale = scale + self._model.lmbd.scale = scale if not add_n_lag_zeros: return np.vstack([i.numpy() for i in saliency_list]) else: