diff --git a/install_dev.sh b/install_dev.sh new file mode 100755 index 0000000..5b83625 --- /dev/null +++ b/install_dev.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright (c) 2019, Lawrence Livermore National Security, LLC and +# GlaxoSmithKline LLC. All rights reserved. LLNL-CODE-784597 +# +# OFFICIAL USE ONLY - EXPORT CONTROLLED INFORMATION +# +# PROTECTED CRADA INFORMATION - 7.31.19 - Authorized by: Jim Brase - +# CRADA TC02264 +# +# This work was produced at the Lawrence Livermore National Laboratory (LLNL) +# under contract no. DE-AC52-07NA27344 (Contract 44) between the U.S. Department +# of Energy (DOE) and Lawrence Livermore National Security, LLC (LLNS) for the +# operation of LLNL. See license for disclaimers, notice of U.S. Government +# Rights and license terms and conditions. + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd $DIR + +pip install -e . --user + diff --git a/moses/char_rnn/model.py b/moses/char_rnn/model.py index abd29a2..f992128 100644 --- a/moses/char_rnn/model.py +++ b/moses/char_rnn/model.py @@ -47,27 +47,28 @@ def tensor2string(self, tensor): return string - def load_lbann_weights(self,weights_dir,epoch_count=None): - - if epoch_count is None: - epoch_count = '*' - + def load_lbann_weights(self, weights_prefix): + with torch.no_grad(): #Load Embedding weights - emb_weights = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"-emb_matrix-Weights.txt")[0]) + + emb_weights = np.loadtxt(weights_prefix+"-emb_matrix-Weights.txt") self.embedding_layer.weight.data.copy_(torch.from_numpy(np.transpose(emb_weights))) #Load LSTM weights/biases param_idx = ['_ih_matrix','_hh_matrix','_ih_bias', '_hh_bias'] for l in range(self.num_layers): for idx, val in enumerate(param_idx): - param_tensor = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"*-gru"+str(l+1)+val+"-Weights.txt")[0]) + + param_tensor = np.loadtxt(weights_prefix+"-gru"+str(l+1)+val+"-Weights.txt") self.lstm_layer.all_weights[l][idx].copy_(torch.from_numpy(param_tensor)) #Load Linear layer weights/biases - linear_layer_weights = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"*-fcmodule"+str(2*self.num_layers+1)+"_matrix-Weights.txt")[0]) + + linear_layer_weights = np.loadtxt(weights_prefix+"-fcmodule"+str(2*self.num_layers+1)+"_matrix-Weights.txt") self.linear_layer.weight.data.copy_(torch.from_numpy(linear_layer_weights)) - linear_layer_bias = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"*-fcmodule"+str(2*self.num_layers+1)+"_bias-Weights.txt")[0]) + + linear_layer_bias = np.loadtxt(weights_prefix+"-fcmodule"+str(2*self.num_layers+1)+"_bias-Weights.txt") self.linear_layer.bias.data.copy_(torch.from_numpy(linear_layer_bias)) print("DONE loading LBANN weights ") diff --git a/moses/metrics/metrics.py b/moses/metrics/metrics.py index 088fd41..1abe7e0 100644 --- a/moses/metrics/metrics.py +++ b/moses/metrics/metrics.py @@ -8,7 +8,7 @@ get_mol, canonic_smiles, mol_passes_filters, \ logP, QED, SA, NP, weight from moses.utils import mapper -from .utils_fcd import get_predictions, calculate_frechet_distance +#from .utils_fcd import get_predictions, calculate_frechet_distance from multiprocessing import Pool from moses.utils import disable_rdkit_log, enable_rdkit_log diff --git a/moses/metrics/utils.py b/moses/metrics/utils.py index f332d7a..f5898d1 100644 --- a/moses/metrics/utils.py +++ b/moses/metrics/utils.py @@ -19,9 +19,9 @@ _base_dir = os.path.split(__file__)[0] _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv')) _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'), - names=['smarts', 'names']) + names=['smarts', 'names'])[['names', 'smarts']] _filters = [Chem.MolFromSmarts(x) for x in - _mcf.append(_pains, sort=True)['smarts'].values] + _mcf.append(_pains)['smarts'].values] def get_mol(smiles_or_mol): diff --git a/moses/metrics/utils_fcd.py b/moses/metrics/utils_fcd.py index 4cded07..181de78 100644 --- a/moses/metrics/utils_fcd.py +++ b/moses/metrics/utils_fcd.py @@ -11,6 +11,7 @@ samples respectivly. ''' +''' import os import keras.backend as K import numpy as np @@ -200,3 +201,5 @@ def get_predictions(smiles, gpu=-1, batch_size=128): else: os.environ.pop("CUDA_DEVICE_ORDER") return smiles_act + +''' diff --git a/moses/models_storage.py b/moses/models_storage.py index 491f3d3..0661986 100644 --- a/moses/models_storage.py +++ b/moses/models_storage.py @@ -1,19 +1,19 @@ from moses.vae import VAE, VAETrainer, vae_parser -from moses.organ import ORGAN, ORGANTrainer, organ_parser -from moses.aae import AAE, AAETrainer, aae_parser -from moses.char_rnn import CharRNN, CharRNNTrainer, char_rnn_parser -from moses.junction_tree import JTNNVAE, JTreeTrainer, junction_tree_parser +#from moses.organ import ORGAN, ORGANTrainer, organ_parser +#from moses.aae import AAE, AAETrainer, aae_parser +#from moses.char_rnn import CharRNN, CharRNNTrainer, char_rnn_parser +#from moses.junction_tree import JTNNVAE, JTreeTrainer, junction_tree_parser class ModelsStorage(): def __init__(self): self._models = {} - self.add_model('aae', AAE, AAETrainer, aae_parser) - self.add_model('char_rnn', CharRNN, CharRNNTrainer, char_rnn_parser) - self.add_model('junction_tree', JTNNVAE, JTreeTrainer, junction_tree_parser) + #self.add_model('aae', AAE, AAETrainer, aae_parser) + #self.add_model('char_rnn', CharRNN, CharRNNTrainer, char_rnn_parser) + #self.add_model('junction_tree', JTNNVAE, JTreeTrainer, junction_tree_parser) self.add_model('vae', VAE, VAETrainer, vae_parser) - self.add_model('organ', ORGAN, ORGANTrainer, organ_parser) + #self.add_model('organ', ORGAN, ORGANTrainer, organ_parser) def add_model(self, name, class_, trainer_, parser_): self._models[name] = { 'class' : class_, diff --git a/moses/script_utils.py b/moses/script_utils.py index 7fd9fdf..5dbecec 100644 --- a/moses/script_utils.py +++ b/moses/script_utils.py @@ -106,10 +106,22 @@ def add_sample_args(parser): return parser -def read_smiles_csv(path): - return pd.read_csv(path, - usecols=['SMILES'], - squeeze=True).astype(str).tolist() +def read_smiles_csv(path, smiles_col='SMILES'): + + # need to check if the specified path even has a SMILES field, if not, just make one + df_first = pd.read_csv(path, nrows=1) + if smiles_col in df_first.columns: + + return pd.read_csv(path, + usecols=[smiles_col], + squeeze=True).astype(str).tolist() + # if the specified smiles_col is not in the columns of the csv file and there are multiple columns, then it is ambigously defined so error out + elif len(df_first.columns) > 1: + raise RuntimeError(f"the provided value for smiles_col, {smiles_col}, is not contained in the header for this csv file, further there are multiple columns to read from, smiles_col is ambiguous.") + # we'll now assume that if the csv has a single column, then that column must be smiles...this might not be true but that's the user responsibility + else: + print(f"{smiles_col} not contained in the csv file, assuming the only column contains the smiles data") + return pd.read_csv(path, header=None, squeeze=True).astype(str).tolist() def set_seed(seed): torch.manual_seed(seed) diff --git a/moses/vae/model.py b/moses/vae/model.py index d2e9440..43fb352 100644 --- a/moses/vae/model.py +++ b/moses/vae/model.py @@ -9,16 +9,15 @@ class VAE(nn.Module): def __init__(self, vocab, config): super().__init__() - + print("loading VAE") self.vocabulary = vocab # Special symbols for ss in ('bos', 'eos', 'unk', 'pad'): setattr(self, ss, getattr(vocab, ss)) # Word embeddings layer - n_vocab, d_emb = len(vocab), vocab.vectors.size(1) + n_vocab, d_emb = len(vocab), len(vocab) self.x_emb = nn.Embedding(n_vocab, d_emb, self.pad) - self.x_emb.weight.data.copy_(vocab.vectors) if config.freeze_embeddings: self.x_emb.weight.requires_grad = False @@ -161,6 +160,7 @@ def forward_decoder(self, x, z): y = self.decoder_fc(output) return y + def compute_loss(x,y): recon_loss = F.cross_entropy( @@ -209,7 +209,7 @@ def sample_z_prior(self, n_batch): return torch.randn(n_batch, self.q_mu.out_features, device=self.x_emb.weight.device) - def sample(self, n_batch, max_len=100, z=None, temp=1.0): + def sample(self, n_batch, max_len=100, z=None, temp=1.0, return_latent=False): """Generating n_batch samples in eval mode (`z` could be not on same device) @@ -217,6 +217,7 @@ def sample(self, n_batch, max_len=100, z=None, temp=1.0): :param max_len: max len of samples :param z: (n_batch, d_z) of floats, latent vector z or None :param temp: temperature of softmax + :param return_latent: whether to return latent vectors as well as SMILES :return: list of tensors of strings, samples sequence x """ with torch.no_grad(): @@ -232,7 +233,9 @@ def sample(self, n_batch, max_len=100, z=None, temp=1.0): x = torch.tensor([self.pad], device=self.device).repeat(n_batch, max_len) x[:, 0] = self.bos end_pads = torch.tensor([max_len], device=self.device).repeat(n_batch) - eos_mask = torch.zeros(n_batch, dtype=torch.bool, device=self.device) + # The changes in this section are only because the version of pytorch in our standard dev + # environment (1.0) doesn't have the torch.bool datatype. + eos_mask = torch.zeros(n_batch, dtype=torch.uint8, device=self.device) # Generating cycle for i in range(1, max_len): @@ -244,17 +247,23 @@ def sample(self, n_batch, max_len=100, z=None, temp=1.0): y = F.softmax(y / temp, dim=-1) w = torch.multinomial(y, 1)[:, 0] - x[~eos_mask, i] = w[~eos_mask] - i_eos_mask = ~eos_mask & (w == self.eos) + x[eos_mask==0, i] = w[eos_mask==0] + i_eos_mask = (eos_mask==0) & (w == self.eos) end_pads[i_eos_mask] = i + 1 - eos_mask = eos_mask | i_eos_mask + eos_mask = (eos_mask==1) | i_eos_mask + + # End of changes for pytorch 1.0 support # Converting `x` to list of tensors new_x = [] for i in range(x.size(0)): new_x.append(x[i, :end_pads[i]]) - - return [self.tensor2string(i_x) for i_x in new_x] + + + if return_latent: + return [self.tensor2string(i_x) for i_x in new_x], z_0.cpu().numpy() + else: + return [self.tensor2string(i_x) for i_x in new_x] def load_lbann_weights(self,weights_dir,epoch_count=-1): print("Loading LBANN Weights ") @@ -299,3 +308,36 @@ def load_lbann_weights(self,weights_dir,epoch_count=-1): self.decoder_fc.bias.data.copy_(torch.from_numpy(decoder_fc_bias)) print("DONE loading LBANN weights ") + + + + def encode_smiles(self, smiles): + """ + Encode the given SMILES strings and return the actual latent vectors as a list + of numpy arrays + """ + from tqdm import tqdm + tensor_list = [] + for smile in tqdm(smiles, desc="converting smiles to tensors"): + tensor_list.append(self.string2tensor(smile).view(1,-1)) + + + latent_list = [] + for i, input_batch in enumerate(tensor_list): + input_batch = tuple(data.to(self.device) for data in input_batch) + with torch.no_grad(): + z, _ = self.forward_encoder(input_batch) + latent_list.append(np.squeeze(np.array(z.cpu()))) + + return latent_list, smiles + + + + def decode_smiles(self, latent_list): + """ + Decode the given list of latent vectors + """ + lat_arr = np.stack(latent_list) + lat_tens = torch.from_numpy(lat_arr) + return self.sample(n_batch=len(latent_list), max_len=100, z=lat_tens, return_latent=True) + diff --git a/scripts/compute_latent_sample_exp.py b/scripts/compute_latent_sample_exp.py new file mode 100644 index 0000000..6fb301d --- /dev/null +++ b/scripts/compute_latent_sample_exp.py @@ -0,0 +1,115 @@ +import os +import torch +from tqdm import tqdm +import argparse +import multiprocessing as mp +import pandas as pd +from moses.models_storage import ModelsStorage +from moses.metrics.utils import average_agg_tanimoto, fingerprints, fingerprint +from rdkit import DataStructs, Chem +from scipy.spatial.distance import jaccard +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument("--model", required=True) +parser.add_argument("--lbann-weights-dir", required=True) +parser.add_argument("--lbann-load-epoch", type=int, required=True) +parser.add_argument("--lbann-load-step", type=int, required=True) +parser.add_argument( + "--vocab-path", type=str, default="", help="path to experiment vocabulary" +) +parser.add_argument("--num-layers", type=int) +parser.add_argument("--dropout", type=float) +parser.add_argument("--weight-prefix") +parser.add_argument("--n-samples", type=int, default=100) +parser.add_argument("--max-len", type=int, default=100) +parser.add_argument("--n-batch", type=int, default=10) +parser.add_argument("--gen-save", required=True) + +parser.add_argument("--test-path", required=True) +parser.add_argument("--test-scaffolds-path") +parser.add_argument("--ptest-path") +parser.add_argument("--ptest-scaffolds-path") + + +parser.add_argument("--ks", type=int, nargs="+", help="list with values for unique@k. Will calculate number of unique molecules in the first k molecules.") +parser.add_argument("--n-jobs", type=int, default=mp.cpu_count()-1) +parser.add_argument("--gpu", type=int, help=" index of GPU for FCD metric and internal diversity, -1 means use CPU") +parser.add_argument("--batch-size", type=int, help="batch size for FCD metric") +parser.add_argument("--hidden", type=int) +parser.add_argument("--metrics", help="output path to store metrics") + +parser.add_argument("--model-config", help="path to model configuration dict") + +###################################### +# These are things specific to the VAE +###################################### + +#parser.add_argument("--freeze-embeddings", action="store_true") # this turns off grad accumulation for embedding layer (see https://github.com/samadejacobs/moses/blob/master/moses/vae/model.py#L22) +#parser.add_argument("--q-cell", default="gru") + + +parser.add_argument("--seed-molecules", help="points to a file with molecules to use as the reference points in the experiment", required=True) +parser.add_argument("--k-neighbor-samples", help="number of neighbors to draw from the gaussian ball", type=int, required=True) +parser.add_argument("--scale-factor", help="scale factor (std) for gaussian", type=float, required=True) +parser.add_argument("--output", help="path to save output results", required=True) +model_config = parser.parse_args() + +moses_config_dict = torch.load(model_config.model_config) + + +def load_model(): + MODELS = ModelsStorage() + model_vocab = torch.load(model_config.vocab_path) + model = MODELS.get_model_class(model_config.model)(model_vocab, moses_config_dict) + # load the model + assert os.path.exists(model_config.lbann_weights_dir) is not None + + weights_prefix = f"{model_config.lbann_weights_dir}/{model_config.weight_prefix}" + model.load_lbann_weights(model_config.lbann_weights_dir, epoch_count=model_config.lbann_load_epoch) + + model.cuda() + model.eval() + + return model + + +def sample_noise_add_to_vec(latent_vec, scale_factor=model_config.scale_factor): + noise = torch.normal(mean=0, std=torch.ones(latent_vec.shape)*scale_factor).numpy() + + return latent_vec + noise + + +def main(k=model_config.k_neighbor_samples): + model = load_model() + + + input_smiles_list = pd.read_csv(model_config.seed_molecules, header=None)[0].to_list() + + + reference_latent_vec_list, reference_smiles_list = model.encode_smiles(input_smiles_list) + + result_list = [] + + + for reference_latent_vec, reference_smiles in tqdm(zip(reference_latent_vec_list, reference_smiles_list), desc="sampling neighbors for reference vec and decoding", total=len(reference_latent_vec_list)): + + neighbor_smiles_list = [model.decode_smiles(sample_noise_add_to_vec(reference_latent_vec).reshape(1,-1))[0][0] for i in range(k)] + + neighbor_fps = [fingerprint(neighbor_smiles, fp_type='morgan') for neighbor_smiles in neighbor_smiles_list] #here is a bug in fingerprints funciton that references first_fp before assignment... + + reference_fp = fingerprint(reference_smiles, fp_type='morgan') + + neighbor_tani_list = [jaccard(reference_fp, neighbor_fp) for neighbor_fp in neighbor_fps] + neighbor_valid_list = [x for x in [Chem.MolFromSmiles(smiles) for smiles in neighbor_smiles_list] if x is not None] + + + + result_list.append({"reference_smiles": reference_smiles, "mean_tani_sim": np.mean(neighbor_tani_list), "min_tani_sim": np.min(neighbor_tani_list), "max_tani_sim": np.max(neighbor_tani_list), "valid_rate": len(neighbor_valid_list)/k }) + + pd.DataFrame(result_list).to_csv(model_config.output) + + +if __name__ == "__main__": + main() + diff --git a/scripts/compute_latent_sample_exp_1BEnamine.sh b/scripts/compute_latent_sample_exp_1BEnamine.sh new file mode 100755 index 0000000..6904161 --- /dev/null +++ b/scripts/compute_latent_sample_exp_1BEnamine.sh @@ -0,0 +1,18 @@ +#!/usr/bin/bash + +#python compute_latent_sample_exp.py --model vae --lbann-weights-dir /usr/workspace/atom/lbann/1BEnamine/weights/ --lbann-load-epoch 76 --lbann-load-step 284772 --gen-save foo_test --test-path /p/lustre1/jones289/lbann/data/newEnamineFrom2020q1-2/newEnamineFrom2020q1-2_test100kSMILES.csv --vocab-path /usr/workspace/atom/lbann/1BEnamine/newEnamineFrom2020q1-2.pt --model-config zinc10Kckpt/vae_config.pt --weight-prefix sgd.training --seed-molecules newEnamineFrom2020q1-2_test100kSMILES_subsample1k.csv --k-neighbor-samples 1000 --scale-factor 0.5 --output 1BEnamine_scale_factor_0.5_results.csv + + +#python compute_latent_sample_exp.py --model vae --lbann-weights-dir /usr/workspace/atom/lbann/1BEnamine/weights/ --lbann-load-epoch 76 --lbann-load-step 284772 --gen-save foo_test --test-path /p/lustre1/jones289/lbann/data/newEnamineFrom2020q1-2/newEnamineFrom2020q1-2_test100kSMILES.csv --vocab-path /usr/workspace/atom/lbann/1BEnamine/newEnamineFrom2020q1-2.pt --model-config zinc10Kckpt/vae_config.pt --weight-prefix sgd.training --seed-molecules newEnamineFrom2020q1-2_test100kSMILES_subsample1k.csv --k-neighbor-samples 1000 --scale-factor 1.0 --output 1BEnamine_scale_factor_1.0_results.csv + + +#python compute_latent_sample_exp.py --model vae --lbann-weights-dir /usr/workspace/atom/lbann/1BEnamine/weights/ --lbann-load-epoch 76 --lbann-load-step 284772 --gen-save foo_test --test-path /p/lustre1/jones289/lbann/data/newEnamineFrom2020q1-2/newEnamineFrom2020q1-2_test100kSMILES.csv --vocab-path /usr/workspace/atom/lbann/1BEnamine/newEnamineFrom2020q1-2.pt --model-config zinc10Kckpt/vae_config.pt --weight-prefix sgd.training --seed-molecules newEnamineFrom2020q1-2_test100kSMILES_subsample1k.csv --k-neighbor-samples 1000 --scale-factor 1.5 --output 1BEnamine_scale_factor_1.5_results.csv + +python compute_latent_sample_exp.py --model vae --lbann-weights-dir /usr/workspace/atom/lbann/1BEnamine/weights/ --lbann-load-epoch 76 --lbann-load-step 284772 --gen-save foo_test --test-path /p/lustre1/jones289/lbann/data/newEnamineFrom2020q1-2/newEnamineFrom2020q1-2_test100kSMILES.csv --vocab-path /usr/workspace/atom/lbann/1BEnamine/newEnamineFrom2020q1-2.pt --model-config zinc10Kckpt/vae_config.pt --weight-prefix sgd.training --seed-molecules newEnamineFrom2020q1-2_test100kSMILES_subsample1k.csv --k-neighbor-samples 1000 --scale-factor 2.0 --output 1BEnamine_scale_factor_2.0_results.csv + +python compute_latent_sample_exp.py --model vae --lbann-weights-dir /usr/workspace/atom/lbann/1BEnamine/weights/ --lbann-load-epoch 76 --lbann-load-step 284772 --gen-save foo_test --test-path /p/lustre1/jones289/lbann/data/newEnamineFrom2020q1-2/newEnamineFrom2020q1-2_test100kSMILES.csv --vocab-path /usr/workspace/atom/lbann/1BEnamine/newEnamineFrom2020q1-2.pt --model-config zinc10Kckpt/vae_config.pt --weight-prefix sgd.training --seed-molecules newEnamineFrom2020q1-2_test100kSMILES_subsample1k.csv --k-neighbor-samples 1000 --scale-factor 2.5 --output 1BEnamine_scale_factor_2.5_results.csv + +python compute_latent_sample_exp.py --model vae --lbann-weights-dir /usr/workspace/atom/lbann/1BEnamine/weights/ --lbann-load-epoch 76 --lbann-load-step 284772 --gen-save foo_test --test-path /p/lustre1/jones289/lbann/data/newEnamineFrom2020q1-2/newEnamineFrom2020q1-2_test100kSMILES.csv --vocab-path /usr/workspace/atom/lbann/1BEnamine/newEnamineFrom2020q1-2.pt --model-config zinc10Kckpt/vae_config.pt --weight-prefix sgd.training --seed-molecules newEnamineFrom2020q1-2_test100kSMILES_subsample1k.csv --k-neighbor-samples 1000 --scale-factor 3.0 --output 1BEnamine_scale_factor_3.0_results.csv + + + diff --git a/scripts/compute_vocab_main.py b/scripts/compute_vocab_main.py index 07e7044..60685d0 100644 --- a/scripts/compute_vocab_main.py +++ b/scripts/compute_vocab_main.py @@ -1,43 +1,53 @@ from char_vocab_utils import compute_vocab + def main(): import os - import torch + import torch from sklearn.model_selection import train_test_split from argparse import ArgumentParser - + parser = ArgumentParser() parser.add_argument("--smiles-path", help="path to csv of smiles strings") - parser.add_argument("--smiles-col", help="column name that contains smiles strings", default=None) - parser.add_argument("--smiles-sep", help="delimiter used to seperate smiles strings, default is set to pandas default for csv", default=",") - parser.add_argument("--n-jobs", type=int, help="number of processes to use for parallel computations") - - parser.add_argument("--output-dir", help="path to output directory to store vocab and numpy arrays") + parser.add_argument( + "--smiles-col", help="column name that contains smiles strings", default=None + ) + parser.add_argument( + "--smiles-sep", + help="delimiter used to seperate smiles strings, default is set to pandas default for csv", + default=",", + ) + parser.add_argument( + "--n-jobs", + type=int, + help="number of processes to use for parallel computations", + ) + + parser.add_argument( + "--output-dir", help="path to output directory to store vocab and numpy arrays" + ) args = parser.parse_args() - # read the smiles strings from the csv path + # read the smiles strings from the csv path import modin.pandas as pd - if args.smiles_col is None: - smiles_df = pd.read_csv(args.smiles_path, header=None, sep=args.smiles_sep) + smiles_df = pd.read_csv(args.smiles_path, header=None, sep=args.smiles_sep) smiles_list = smiles_df[0].values - + else: smiles_df = pd.read_csv(args.smiles_path, sep=args.smiles_sep) smiles_list = smiles_df[args.smiles_col].values - - # if output directory does not exist, create it + # if output directory does not exist, create it if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - - # extract the vocab + + # extract the vocab print("extracting the vocab...") vocab = compute_vocab(smiles_list, n_jobs=args.n_jobs) - torch.save(vocab, args.output_dir+"/vocab.pt") - + torch.save(vocab, args.output_dir + "/vocab.pt") + if __name__ == "__main__": main() - diff --git a/scripts/junction_tree/generate_vocab.py b/scripts/junction_tree/generate_vocab.py index e9ef3fb..f3c9f93 100644 --- a/scripts/junction_tree/generate_vocab.py +++ b/scripts/junction_tree/generate_vocab.py @@ -9,14 +9,23 @@ lg = rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL) + def get_parser(): parser = argparse.ArgumentParser() - parser.add_argument('--train_load', type=str, required=True, help='Input data in csv format to train') - parser.add_argument('--vocab_save', type=str, default='vocab.pt', help='Where to save the vocab') - parser.add_argument('--n_jobs', type=int, default=1, help='Number of jobs') + parser.add_argument( + "--train_load", + type=str, + required=True, + help="Input data in csv format to train", + ) + parser.add_argument( + "--vocab_save", type=str, default="vocab.pt", help="Where to save the vocab" + ) + parser.add_argument("--n_jobs", type=int, default=1, help="Number of jobs") return parser + def main(config): data = read_smiles_csv(config.train_load) @@ -25,7 +34,8 @@ def main(config): torch.save(vocab, config.vocab_save) -if __name__ == '__main__': + +if __name__ == "__main__": parser = get_parser() config = parser.parse_known_args()[0] main(config) diff --git a/scripts/lbann_sample.py b/scripts/lbann_sample.py new file mode 100644 index 0000000..a894512 --- /dev/null +++ b/scripts/lbann_sample.py @@ -0,0 +1,144 @@ +import os +import argparse +import rdkit +import torch +import pandas as pd +import multiprocessing as mp +from tqdm import tqdm +from run import load_module +from moses.models_storage import ModelsStorage +from moses.script_utils import add_sample_args, set_seed, read_smiles_csv +from moses.metrics.metrics import get_all_metrics + +lg = rdkit.RDLogger.logger() +lg.setLevel(rdkit.RDLogger.CRITICAL) + +parser = argparse.ArgumentParser() +parser.add_argument("--model", required=True) +parser.add_argument("--lbann-weights-dir", required=True) +parser.add_argument("--lbann-load-epoch", type=int, required=True) +parser.add_argument("--lbann-load-step", type=int, required=True) +parser.add_argument( + "--vocab-path", type=str, default="", help="path to experiment vocabulary" +) +parser.add_argument("--num-layers", type=int) +parser.add_argument("--dropout", type=float) +parser.add_argument("--weight-prefix") +parser.add_argument("--n-samples", type=int, default=100) +parser.add_argument("--max-len", type=int, default=100) +parser.add_argument("--n-batch", type=int, default=10) +parser.add_argument("--gen-save", required=True) + +parser.add_argument("--test-path", required=True) +parser.add_argument("--test-scaffolds-path") +parser.add_argument("--ptest-path") +parser.add_argument("--ptest-scaffolds-path") + +parser.add_argument("--ks", type=int, nargs="+", help="list with values for unique@k. Will calculate number of unique molecules in the first k molecules.") +parser.add_argument("--n-jobs", type=int, default=mp.cpu_count()-1) +parser.add_argument("--gpu", type=int, help=" index of GPU for FCD metric and internal diversity, -1 means use CPU") +parser.add_argument("--batch-size", type=int, help="batch size for FCD metric") +parser.add_argument("--hidden", type=int) +parser.add_argument("--metrics", help="output path to store metrics") + +model_config = parser.parse_args() + + +def eval_metrics(eval_config, print_metrics=True): + + # need to detect if file has the header or not + test = read_smiles_csv(model_config.test_path) + test_scaffolds = None + ptest = None + ptest_scaffolds = None + if model_config.test_scaffolds_path is not None: + test_scaffolds = read_smiles_csv(model_config.test_scaffolds_path) + if model_config.ptest_path is not None: + if not os.path.exists(model_config.ptest_path): + warnings.warn(f"{model_config.ptest_path} does not exist") + ptest = None + else: + ptest = np.load(model_config.ptest_path)["stats"].item() + if model_config.ptest_scaffolds_path is not None: + if not os.path.exists(model_config.ptest_scaffolds_path): + warnings.warn(f"{model_config.ptest_scaffolds_path} does not exist") + ptest_scaffolds = None + else: + ptest_scaffolds = np.load(model_config.ptest_scaffolds_path)["stats"].item() + gen = read_smiles_csv(model_config.gen_save) + metrics = get_all_metrics( + test, + gen, + k=model_config.ks, + n_jobs=model_config.n_jobs, + gpu=model_config.gpu, + test_scaffolds=test_scaffolds, + ptest=ptest, + ptest_scaffolds=ptest_scaffolds, + ) + + if print_metrics: + print("Metrics:") + for name, value in metrics.items(): + print("\t" + name + " = {}".format(value)) + return metrics + else: + return metrics + + + +def sample(): + MODELS = ModelsStorage() + model_vocab = torch.load(model_config.vocab_path) + model = MODELS.get_model_class(model_config.model)(model_vocab, model_config) + # load the model + assert os.path.exists(model_config.lbann_weights_dir) is not None + + weights_prefix = f"{model_config.lbann_weights_dir}/{model_config.weight_prefix}.epoch.{model_config.lbann_load_epoch}.step.{model_config.lbann_load_step}" + model.load_lbann_weights( + weights_prefix, + ) + + + # here we should try to wrap model in a dataparallel layer or something? + model.cuda() + model.eval() + + samples = [] + n = model_config.n_samples + print("Generating Samples") + with tqdm(total=model_config.n_samples, desc="Generating samples") as T: + while n > 0: + current_samples = model.sample( + min(n, model_config.n_batch), model_config.max_len + ) + samples.extend(current_samples) + + n -= len(current_samples) + T.update(len(current_samples)) + + samples = pd.DataFrame(samples, columns=["SMILES"]) + print("Save generated samples to ", model_config.gen_save) + samples.to_csv(model_config.gen_save, index=False) + return samples + +def compute_metrics(): + metrics = [] + model_metrics = eval_metrics(model_config) + #model_metrics.update({"model": model}) + metrics.append(model_metrics) + + table = pd.DataFrame(metrics) + print("Saving computed metrics to ", model_config.metrics) + table.to_csv(model_config.metrics, index=False) + + +def compute_reconstruction(model, test): + pass + + +if __name__ == "__main__": + sample() + compute_metrics() + + diff --git a/scripts/metrics/eval.py b/scripts/metrics/eval.py index dd3b866..ee92e22 100644 --- a/scripts/metrics/eval.py +++ b/scripts/metrics/eval.py @@ -19,61 +19,75 @@ def main(config, print_metrics=True): test_scaffolds = read_smiles_csv(config.test_scaffolds_path) if config.ptest_path is not None: if not os.path.exists(config.ptest_path): - warnings.warn(f'{config.ptest_path} does not exist') + warnings.warn(f"{config.ptest_path} does not exist") ptest = None else: - ptest = np.load(config.ptest_path)['stats'].item() + ptest = np.load(config.ptest_path)["stats"].item() if config.ptest_scaffolds_path is not None: if not os.path.exists(config.ptest_scaffolds_path): - warnings.warn(f'{config.ptest_scaffolds_path} does not exist') + warnings.warn(f"{config.ptest_scaffolds_path} does not exist") ptest_scaffolds = None else: - ptest_scaffolds = np.load(config.ptest_scaffolds_path)['stats'].item() + ptest_scaffolds = np.load(config.ptest_scaffolds_path)["stats"].item() gen = read_smiles_csv(config.gen_path) - metrics = get_all_metrics(test, gen, k=config.ks, n_jobs=config.n_jobs, - gpu=config.gpu, test_scaffolds=test_scaffolds, - ptest=ptest, ptest_scaffolds=ptest_scaffolds) - + metrics = get_all_metrics( + test, + gen, + k=config.ks, + n_jobs=config.n_jobs, + gpu=config.gpu, + test_scaffolds=test_scaffolds, + ptest=ptest, + ptest_scaffolds=ptest_scaffolds, + ) + if print_metrics: - print('Metrics:') + print("Metrics:") for name, value in metrics.items(): - print('\t' + name + ' = {}'.format(value)) + print("\t" + name + " = {}".format(value)) else: return metrics def get_parser(): parser = argparse.ArgumentParser() - parser.add_argument('--test_path', - type=str, required=True, - help='Path to test molecules csv') - parser.add_argument('--test_scaffolds_path', - type=str, required=False, - help='Path to scaffold test molecules csv') - parser.add_argument('--ptest_path', - type=str, required=False, - help='Path to precalculated test molecules npz') - parser.add_argument('--ptest_scaffolds_path', - type=str, required=False, - help='Path to precalculated scaffold test molecules npz') + parser.add_argument( + "--test_path", type=str, required=True, help="Path to test molecules csv" + ) + parser.add_argument( + "--test_scaffolds_path", + type=str, + required=False, + help="Path to scaffold test molecules csv", + ) + parser.add_argument( + "--ptest_path", + type=str, + required=False, + help="Path to precalculated test molecules npz", + ) + parser.add_argument( + "--ptest_scaffolds_path", + type=str, + required=False, + help="Path to precalculated scaffold test molecules npz", + ) - parser.add_argument('--gen_path', - type=str, required=True, - help='Path to generated molecules csv') - parser.add_argument('--ks', - nargs='+', default=[1000, 10000], - help='Prefixes to calc uniqueness at') - parser.add_argument('--n_jobs', - type=int, default=1, - help='Number of processes to run metrics') - parser.add_argument('--gpu', - type=int, default=-1, - help='GPU index (-1 for cpu)') + parser.add_argument( + "--gen_path", type=str, required=True, help="Path to generated molecules csv" + ) + parser.add_argument( + "--ks", nargs="+", default=[1000, 10000], help="Prefixes to calc uniqueness at" + ) + parser.add_argument( + "--n_jobs", type=int, default=1, help="Number of processes to run metrics" + ) + parser.add_argument("--gpu", type=int, default=-1, help="GPU index (-1 for cpu)") return parser -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() config = parser.parse_known_args()[0] main(config) diff --git a/scripts/metrics/test.py b/scripts/metrics/test.py index b48d42b..0aec909 100755 --- a/scripts/metrics/test.py +++ b/scripts/metrics/test.py @@ -7,38 +7,47 @@ class test_metrics(unittest.TestCase): def setUp(self): - self.test = ['Oc1ccccc1-c1cccc2cnccc12', - 'COc1cccc(NC(=O)Cc2coc3ccc(OC)cc23)c1'] - self.test_sf = ['COCc1nnc(NC(=O)COc2ccc(C(C)(C)C)cc2)s1', - 'O=C(C1CC2C=CC1C2)N1CCOc2ccccc21', - 'Nc1c(Br)cccc1C(=O)Nc1ccncn1'] - self.gen = ['CNC', 'Oc1ccccc1-c1cccc2cnccc12', - 'INVALID', 'CCCP', - 'Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1', - 'Cc1nc(NCc2ccccc2)no1-c1ccccc1'] - self.target = {'valid': 2/3, - 'unique@3': 1.0, - 'FCD/Test': 52.58371754126664, - 'SNN/Test': 0.3152585653588176, - 'Frag/Test': 0.3, - 'Scaf/Test': 0.5, - 'IntDiv': 0.7189187309761661, - 'Filters': 0.75, - 'logP': 4.9581881764518005, - 'SA': 0.5086898026154574, - 'QED': 0.045033731661603064, - 'NP': 0.2902816615644048, - 'weight': 14761.927533455337} + self.test = ["Oc1ccccc1-c1cccc2cnccc12", "COc1cccc(NC(=O)Cc2coc3ccc(OC)cc23)c1"] + self.test_sf = [ + "COCc1nnc(NC(=O)COc2ccc(C(C)(C)C)cc2)s1", + "O=C(C1CC2C=CC1C2)N1CCOc2ccccc21", + "Nc1c(Br)cccc1C(=O)Nc1ccncn1", + ] + self.gen = [ + "CNC", + "Oc1ccccc1-c1cccc2cnccc12", + "INVALID", + "CCCP", + "Cc1noc(C)c1CN(C)C(=O)Nc1cc(F)cc(F)c1", + "Cc1nc(NCc2ccccc2)no1-c1ccccc1", + ] + self.target = { + "valid": 2 / 3, + "unique@3": 1.0, + "FCD/Test": 52.58371754126664, + "SNN/Test": 0.3152585653588176, + "Frag/Test": 0.3, + "Scaf/Test": 0.5, + "IntDiv": 0.7189187309761661, + "Filters": 0.75, + "logP": 4.9581881764518005, + "SA": 0.5086898026154574, + "QED": 0.045033731661603064, + "NP": 0.2902816615644048, + "weight": 14761.927533455337, + } def test_get_all_metrics(self): metrics = get_all_metrics(self.test, self.gen, k=3) fail = set() for metric in self.target: if not np.allclose(metrics[metric], self.target[metric]): - warnings.warn("Metric `{}` value does not match expected " - "value. Got {}, expected {}".format(metric, - metrics[metric], - self.target[metric])) + warnings.warn( + "Metric `{}` value does not match expected " + "value. Got {}, expected {}".format( + metric, metrics[metric], self.target[metric] + ) + ) fail.add(metric) assert len(fail) == 0, f"Some metrics didn't pass tests: {fail}" @@ -47,29 +56,35 @@ def test_get_all_metrics_multiprocess(self): fail = set() for metric in self.target: if not np.allclose(metrics[metric], self.target[metric]): - warnings.warn("Metric `{}` value does not match expected " - "value. Got {}, expected {}".format(metric, - metrics[metric], - self.target[metric])) + warnings.warn( + "Metric `{}` value does not match expected " + "value. Got {}, expected {}".format( + metric, metrics[metric], self.target[metric] + ) + ) fail.add(metric) assert len(fail) == 0, f"Some metrics didn't pass tests: {fail}" - + def test_get_all_metrics_scaffold(self): - metrics = get_all_metrics(self.test, self.gen, test_scaffolds=self.test_sf, k=3, n_jobs=2) + metrics = get_all_metrics( + self.test, self.gen, test_scaffolds=self.test_sf, k=3, n_jobs=2 + ) print(metrics) def test_valid_unique(self): - mols = ['CCNC', 'CCC', 'INVALID', 'CCC'] + mols = ["CCNC", "CCC", "INVALID", "CCC"] assert np.allclose(fraction_valid(mols), 3 / 4), "Failed valid" - assert np.allclose(fraction_unique(mols, check_validity=False), - 3 / 4), "Failed unique" + assert np.allclose( + fraction_unique(mols, check_validity=False), 3 / 4 + ), "Failed unique" assert np.allclose(fraction_unique(mols, k=2), 1), "Failed unique" mols = [Chem.MolFromSmiles(x) for x in mols] assert np.allclose(fraction_valid(mols), 3 / 4), "Failed valid" - assert np.allclose(fraction_unique(mols, check_validity=False), - 3 / 4), "Failed unique" + assert np.allclose( + fraction_unique(mols, check_validity=False), 3 / 4 + ), "Failed unique" assert np.allclose(fraction_unique(mols, k=2), 1), "Failed unique" -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/scripts/prepare_dataset.py b/scripts/prepare_dataset.py index 32985d4..bf91f42 100644 --- a/scripts/prepare_dataset.py +++ b/scripts/prepare_dataset.py @@ -17,34 +17,40 @@ def get_parser(): parser = argparse.ArgumentParser() - parser.add_argument('--output_file', type=str, default='dataset_v1.csv', - help='Path for constructed dataset') - parser.add_argument('--seed', type=int, default=0, - help='Random state') - parser.add_argument('--url', type=str, - default='http://zinc.docking.org/db/bysubset/11/11_p0.smi.gz', - help='url to .smi.gz file with smiles') - parser.add_argument('--n_jobs', type=int, - default=1, - help='number of processes to use') - parser.add_argument('--keep_ids', action='store_true', - help='Keep ZINC ids in the final csv file') + parser.add_argument( + "--output_file", + type=str, + default="dataset_v1.csv", + help="Path for constructed dataset", + ) + parser.add_argument("--seed", type=int, default=0, help="Random state") + parser.add_argument( + "--url", + type=str, + default="http://zinc.docking.org/db/bysubset/11/11_p0.smi.gz", + help="url to .smi.gz file with smiles", + ) + parser.add_argument( + "--n_jobs", type=int, default=1, help="number of processes to use" + ) + parser.add_argument( + "--keep_ids", action="store_true", help="Keep ZINC ids in the final csv file" + ) return parser def process_molecule(mol_row): - mol_row = mol_row.decode('utf-8') + mol_row = mol_row.decode("utf-8") smiles, _id = mol_row.split() if not mol_passes_filters(smiles): return None - smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), - isomericSmiles=False) + smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=False) return _id, smiles def download_dataset(url): - logger.info('Downloading from {}'.format(url)) + logger.info("Downloading from {}".format(url)) req = requests.get(url) with gzip.open(BytesIO(req.content)) as smi: lines = smi.readlines() @@ -52,32 +58,41 @@ def download_dataset(url): def filter_lines(lines, n_jobs): - logger.info('Filtering SMILES') + logger.info("Filtering SMILES") with Pool(n_jobs) as pool: - dataset = [x for x in tqdm.tqdm(pool.imap_unordered(process_molecule, lines), - total=len(lines), - miniters=1000) if x is not None] - dataset = pd.DataFrame(dataset, columns=['ID', 'SMILES']) - dataset = dataset.sort_values(by=['ID', 'SMILES']) - dataset = dataset.drop_duplicates('ID') - dataset = dataset.sort_values(by='ID') - dataset = dataset.drop_duplicates('SMILES') - dataset['scaffold'] = pool.map(compute_scaffold, dataset['SMILES'].values) + dataset = [ + x + for x in tqdm.tqdm( + pool.imap_unordered(process_molecule, lines), + total=len(lines), + miniters=1000, + ) + if x is not None + ] + dataset = pd.DataFrame(dataset, columns=["ID", "SMILES"]) + dataset = dataset.sort_values(by=["ID", "SMILES"]) + dataset = dataset.drop_duplicates("ID") + dataset = dataset.sort_values(by="ID") + dataset = dataset.drop_duplicates("SMILES") + dataset["scaffold"] = pool.map(compute_scaffold, dataset["SMILES"].values) return dataset def split_dataset(dataset, seed): - logger.info('Splitting the dataset') - scaffolds = pd.value_counts(dataset['scaffold']) + logger.info("Splitting the dataset") + scaffolds = pd.value_counts(dataset["scaffold"]) scaffolds = sorted(scaffolds.items(), key=lambda x: (-x[1], x[0])) test_scaffolds = set([x[0] for x in scaffolds[9::10]]) - dataset['SPLIT'] = 'train' - test_scaf_idx = [x in test_scaffolds for x in dataset['scaffold']] - dataset.loc[test_scaf_idx, 'SPLIT'] = 'test_scaffolds' - test_idx = dataset.loc[dataset['SPLIT'] == 'train'].sample(frac=0.1, - random_state=seed).index - dataset.loc[test_idx, 'SPLIT'] = 'test' - dataset.drop('scaffold', axis=1, inplace=True) + dataset["SPLIT"] = "train" + test_scaf_idx = [x in test_scaffolds for x in dataset["scaffold"]] + dataset.loc[test_scaf_idx, "SPLIT"] = "test_scaffolds" + test_idx = ( + dataset.loc[dataset["SPLIT"] == "train"] + .sample(frac=0.1, random_state=seed) + .index + ) + dataset.loc[test_idx, "SPLIT"] = "test" + dataset.drop("scaffold", axis=1, inplace=True) return dataset @@ -86,13 +101,13 @@ def main(config): dataset = filter_lines(lines, config.n_jobs) dataset = split_dataset(dataset, config.seed) if not config.keep_ids: - dataset.drop('ID', 1, inplace=True) + dataset.drop("ID", 1, inplace=True) dataset.to_csv(config.output_file, index=None) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() config, unknown = parser.parse_known_args() if len(unknown) != 0: - raise ValueError("Unknown argument "+unknown[0]) + raise ValueError("Unknown argument " + unknown[0]) main(config) diff --git a/scripts/run.py b/scripts/run.py index 08c9c1e..72b741c 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -257,6 +257,7 @@ def main(config): models = MODELS.get_model_names() if config.model == "all" else [config.model] for model in models: + print(f"lbann weights dir: {config.lbann_weights_dir}") if not os.path.exists(config.lbann_weights_dir): # LBANN is inference only train_model(config, model, train_path) sample_from_model(config, model,test_path) diff --git a/scripts/sample.py b/scripts/sample.py index 2182af6..8fd2a4f 100644 --- a/scripts/sample.py +++ b/scripts/sample.py @@ -14,9 +14,12 @@ MODELS = ModelsStorage() + def get_parser(): parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(title='Models sampler script', description='available models') + subparsers = parser.add_subparsers( + title="Models sampler script", description="available models" + ) for model in MODELS.get_model_names(): add_sample_args(subparsers.add_parser(model)) return parser @@ -26,22 +29,22 @@ def main(model, config): set_seed(config.seed) device = torch.device(config.device) # For CUDNN to work properly: - if device.type.startswith('cuda'): + if device.type.startswith("cuda"): torch.cuda.set_device(device.index or 0) - if(config.lbann_weights_dir): - assert os.path.exists(config.lbann_weights_dir), ("LBANN inference mode is specified but directory " - " to load weights does not exist: '{}'".format(config.lbann_weights_dir)) - - + if config.lbann_weights_dir: + assert os.path.exists(config.lbann_weights_dir), ( + "LBANN inference mode is specified but directory " + " to load weights does not exist: '{}'".format(config.lbann_weights_dir) + ) model_config = torch.load(config.config_load) trainer = MODELS.get_model_trainer(model)(model_config) model_vocab = torch.load(config.vocab_load) model_state = torch.load(config.model_load) - model = MODELS.get_model_class(model)(model_vocab, model_config) - if os.path.exists(config.lbann_weights_dir): - model.load_lbann_weights(config.lbann_weights_dir,config.lbann_epoch_counts) + model = MODELS.get_model_class(model)(model_vocab, model_config) + if os.path.exists(config.lbann_weights_dir): + model.load_lbann_weights(config.lbann_weights_dir, config.lbann_epoch_counts) else: # assume that a non-LBANN model is being loaded model.load_state_dict(model_state) @@ -60,7 +63,7 @@ def main(model, config): samples = [] n = config.n_samples print("Generating Samples") - with tqdm(total=config.n_samples, desc='Generating samples') as T: + with tqdm(total=config.n_samples, desc="Generating samples") as T: while n > 0: current_samples = model.sample(min(n, config.n_batch), config.max_len) samples.extend(current_samples) @@ -68,11 +71,12 @@ def main(model, config): n -= len(current_samples) T.update(len(current_samples)) - samples = pd.DataFrame(samples, columns=['SMILES']) + samples = pd.DataFrame(samples, columns=["SMILES"]) print("Save generated samples to ", config.gen_save) samples.to_csv(config.gen_save, index=False) -if __name__ == '__main__': + +if __name__ == "__main__": parser = get_parser() config = parser.parse_args() model = sys.argv[1] diff --git a/scripts/split_dataset.py b/scripts/split_dataset.py index 0ec29f5..4d43ced 100644 --- a/scripts/split_dataset.py +++ b/scripts/split_dataset.py @@ -6,77 +6,102 @@ def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") + def get_parser(): parser = argparse.ArgumentParser() - parser.add_argument('--dir', type=str, default='./data', - help='Directory for splitted dataset') - parser.add_argument('--no_subset', action='store_true', - help='Do not create subsets for training and testing') - parser.add_argument('--train_size', type=int, default=250000, - help='Size of training dataset') - parser.add_argument('--test_size', type=int, default=10000, - help='Size of testing dataset') - parser.add_argument('--seed', type=int, default=0, - help='Random state') - parser.add_argument('--precompute', type=str2bool, default=True, - help='Precompute intermediate statistics') - parser.add_argument('--n_jobs', type=int, default=1, - help='Number of workers') - parser.add_argument('--gpu', type=int, default=-1, - help='GPU id') - parser.add_argument('--batch_size', type=int, default=512, - help='Batch size for FCD calculation') + parser.add_argument( + "--dir", type=str, default="./data", help="Directory for splitted dataset" + ) + parser.add_argument( + "--no_subset", + action="store_true", + help="Do not create subsets for training and testing", + ) + parser.add_argument( + "--train_size", type=int, default=250000, help="Size of training dataset" + ) + parser.add_argument( + "--test_size", type=int, default=10000, help="Size of testing dataset" + ) + parser.add_argument("--seed", type=int, default=0, help="Random state") + parser.add_argument( + "--precompute", + type=str2bool, + default=True, + help="Precompute intermediate statistics", + ) + parser.add_argument("--n_jobs", type=int, default=1, help="Number of workers") + parser.add_argument("--gpu", type=int, default=-1, help="GPU id") + parser.add_argument( + "--batch_size", type=int, default=512, help="Batch size for FCD calculation" + ) return parser def main(config): - dataset_path = os.path.join(config.dir, 'dataset_v1.csv') - download_url = 'https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv' + dataset_path = os.path.join(config.dir, "dataset_v1.csv") + download_url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" if not os.path.exists(dataset_path): - raise ValueError(f"Missing dataset_v1.csv in {config.dir}; " - f"Please, use 'git lfs pull' or download it manually from {download_url}") + raise ValueError( + f"Missing dataset_v1.csv in {config.dir}; " + f"Please, use 'git lfs pull' or download it manually from {download_url}" + ) if config.no_subset: return data = pd.read_csv(dataset_path) - train_data = data[data['SPLIT'] == 'train'] - test_data = data[data['SPLIT'] == 'test'] - test_scaffolds_data = data[data['SPLIT'] == 'test_scaffolds'] + train_data = data[data["SPLIT"] == "train"] + test_data = data[data["SPLIT"] == "test"] + test_scaffolds_data = data[data["SPLIT"] == "test_scaffolds"] if config.train_size is not None: train_data = train_data.sample(config.train_size, random_state=config.seed) if config.test_size is not None: test_data = test_data.sample(config.test_size, random_state=config.seed) - test_scaffolds_data = test_scaffolds_data.sample(config.test_size, random_state=config.seed) + test_scaffolds_data = test_scaffolds_data.sample( + config.test_size, random_state=config.seed + ) - train_data.to_csv(os.path.join(config.dir, 'train.csv'), index=False) - test_data.to_csv(os.path.join(config.dir, 'test.csv'), index=False) - test_scaffolds_data.to_csv(os.path.join(config.dir, 'test_scaffolds.csv'), index=False) + train_data.to_csv(os.path.join(config.dir, "train.csv"), index=False) + test_data.to_csv(os.path.join(config.dir, "test.csv"), index=False) + test_scaffolds_data.to_csv( + os.path.join(config.dir, "test_scaffolds.csv"), index=False + ) if config.precompute: test_stats = compute_intermediate_statistics( - test_data['SMILES'].values, n_jobs=config.n_jobs, gpu=config.gpu, batch_size=config.batch_size) + test_data["SMILES"].values, + n_jobs=config.n_jobs, + gpu=config.gpu, + batch_size=config.batch_size, + ) test_sf_stats = compute_intermediate_statistics( - test_scaffolds_data['SMILES'].values, n_jobs=config.n_jobs, gpu=config.gpu, batch_size=config.batch_size) + test_scaffolds_data["SMILES"].values, + n_jobs=config.n_jobs, + gpu=config.gpu, + batch_size=config.batch_size, + ) - np.savez(os.path.join(config.dir, 'test_stats.npz'), stats=test_stats) - np.savez(os.path.join(config.dir, 'test_scaffolds_stats.npz'), stats=test_sf_stats) + np.savez(os.path.join(config.dir, "test_stats.npz"), stats=test_stats) + np.savez( + os.path.join(config.dir, "test_scaffolds_stats.npz"), stats=test_sf_stats + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() config, unknown = parser.parse_known_args() if len(unknown) != 0: - raise ValueError("Unknown argument "+unknown[0]) + raise ValueError("Unknown argument " + unknown[0]) main(config) diff --git a/scripts/train.py b/scripts/train.py index f0d1b9f..2263da3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -12,11 +12,16 @@ MODELS = ModelsStorage() + def get_parser(): parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(title='Models trainer script', description='available models') + subparsers = parser.add_subparsers( + title="Models trainer script", description="available models" + ) for model in MODELS.get_model_names(): - add_train_args(MODELS.get_model_train_parser(model)(subparsers.add_parser(model))) + add_train_args( + MODELS.get_model_train_parser(model)(subparsers.add_parser(model)) + ) return parser @@ -25,7 +30,7 @@ def main(model, config): device = torch.device(config.device) # For CUDNN to work properly - if device.type.startswith('cuda'): + if device.type.startswith("cuda"): torch.cuda.set_device(device.index or 0) train_data = read_smiles_csv(config.train_load) @@ -33,7 +38,7 @@ def main(model, config): trainer = MODELS.get_model_trainer(model)(config) if config.vocab_load is not None: - assert os.path.exists(config.vocab_load), 'vocab_load path doesn\'t exist!' + assert os.path.exists(config.vocab_load), "vocab_load path doesn't exist!" vocab = torch.load(config.vocab_load) else: vocab = trainer.get_vocabulary(train_data) @@ -41,14 +46,15 @@ def main(model, config): model = MODELS.get_model_class(model)(vocab, config).to(device) trainer.fit(model, train_data, val_data) - model = model.to('cpu') + model = model.to("cpu") torch.save(model.state_dict(), config.model_save) if config.config_save is not None: torch.save(config, config.config_save) if config.vocab_save is not None: torch.save(vocab, config.vocab_save) -if __name__ == '__main__': + +if __name__ == "__main__": parser = get_parser() config = parser.parse_args() model = sys.argv[1] diff --git a/setup.py b/setup.py index 93fe796..9384833 100644 --- a/setup.py +++ b/setup.py @@ -12,10 +12,9 @@ 'keras>=2.2', 'matplotlib>=3.0.0', 'numpy>=1.15', - 'pandas>=0.23', + 'pandas>=0.22', 'scipy>=1.1.0', - 'tensorflow>=1.0', - 'torch>=0.4.1', + 'torch==1.7.0', ], description='MOSES: A benchmarking platform for molecular generation models', author='Neuromation & Insilico Medicine Teams',