diff --git a/foldtoken/installation.sh b/foldtoken/installation.sh new file mode 100644 index 0000000..13313f7 --- /dev/null +++ b/foldtoken/installation.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Update build basics +python -m pip install -U pip setuptools wheel packaging + +# --------------------------------------------------------------------- +# 1) PyTorch 2.0.1 + CUDA 11.7 (official wheels) +# --------------------------------------------------------------------- +pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2+cu117 \ + --index-url https://download.pytorch.org/whl/cu117 + +# --------------------------------------------------------------------- +# 2) Core numerics that some wheels depend on +# --------------------------------------------------------------------- +pip install "scipy>=1.11,<1.12" + +# --------------------------------------------------------------------- +# 3) PyTorch-Geometric CUDA wheels (Torch 2.0.1 + CU117) +# Hosted on data.pyg.org — disable PyPI so we don’t grab CPU builds. +# --------------------------------------------------------------------- +PYG_INDEX="https://data.pyg.org/whl/torch-2.0.1+cu117.html" + +pip install --no-index --find-links "$PYG_INDEX" torch-scatter==2.1.2+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" torch-sparse==0.6.17+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" torch-cluster==1.6.3+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" torch-spline-conv==1.2.2+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" pyg-lib==0.2.0+pt20cu117 +pip install --find-links https://data.pyg.org/whl/torch-2.0.1+cu117.html torch-geometric==2.3.1 + +# --------------------------------------------------------------------- +# 4) Optional extras and utilities +# --------------------------------------------------------------------- +pip install flash-attn --no-build-isolation # Ampere+ GPUs +pip install pytorch-lightning==1.9.0 # training framework +pip install omegaconf==2.3.0 # config library +pip install tqdm diff --git a/foldtoken/pdb_to_token.py b/foldtoken/pdb_to_token.py new file mode 100644 index 0000000..742938b --- /dev/null +++ b/foldtoken/pdb_to_token.py @@ -0,0 +1,99 @@ +# inference.py: Encode a PDB structure into tokens using FoldToken4 +import os +import torch +import argparse +import tqdm +import csv +from omegaconf import OmegaConf +from model_interface import MInterface +from src.chroma.data import Protein + + +def load_model(config_path, checkpoint_path, device='cpu'): + """Load the FoldToken4 model from checkpoint on given device""" + print(f"Loading config from {config_path}") + config = OmegaConf.load(config_path) + config = OmegaConf.to_container(config, resolve=True) + + print("Initializing model...") + model = MInterface(**config) + + print(f"Loading checkpoint from {checkpoint_path}") + print(f"Using device: {device}") + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + # adjust keys if wrapped in DataParallel + for key in list(checkpoint.keys()): + if '_forward_module.' in key: + checkpoint[key.replace('_forward_module.', '')] = checkpoint[key] + del checkpoint[key] + + model.load_state_dict(checkpoint, strict=False) + model = model.to(device) + print("Model loaded successfully") + return model + + +def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8): + """Encode a PDB file into tokens using FoldToken4 on given device""" + protein = Protein(pdb_path, device=device) + # determine number of residues via XCS representation + X, C, S = protein.to_XCS() + + with torch.no_grad(): + vq_code = model.encode_protein(protein, level=level)[1] + return vq_code + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Encode all PDBs in a directory to tokens CSV.") + parser.add_argument('--pdb_dir', required=True, help="Directory containing PDB files.") + parser.add_argument('--output', default="output.csv", help="Output CSV file path.") + parser.add_argument('--device', choices=['cpu','cuda'], default='cpu', help="Compute device to use.") + args = parser.parse_args() + + # Validate device + device = args.device + if device == 'cuda' and not torch.cuda.is_available(): + print("CUDA not available, falling back to CPU") + device = 'cpu' + + # Config and checkpoint paths + config_path = os.path.join(os.path.dirname(__file__), 'checkpoint', 'FT4', 'config.yaml') + checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoint', 'FT4', 'ckpt.pth') + if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found: {config_path}") + if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + model = load_model(config_path, checkpoint_path, device) + + # Gather PDB files + # Recursively collect .pdb files from the given directory and all subdirectories + pdb_files = [] + for root, _, files in os.walk(args.pdb_dir): + for f in files: + if f.lower().endswith('.pdb'): + pdb_files.append(os.path.join(root, f)) + results = [] + for pdb_path in tqdm.tqdm(pdb_files, desc="Encoding PDBs", unit="file"): + basename = os.path.basename(pdb_path) + try: + protein = Protein(pdb_path, device=device) + X, C, S = protein.to_XCS() + except Exception as e: + print(f"Skipped {basename}: load error {e}") + continue + # filter short or highly masked proteins + if X.shape[1] < 5 or (C != -1).sum() < 5: + print(f"Skipped {basename}: too few residues") + continue + with torch.no_grad(): + vq_code = model.encode_protein(protein, level=8)[1] + token_list = vq_code.cpu().numpy().tolist() if hasattr(vq_code, 'cpu') else list(vq_code) + results.append((basename, token_list)) + + # write results + with open(args.output, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['pdb_file', 'tokens']) + for name, tlist in results: + writer.writerow([name, ' '.join(map(str, tlist))]) + print(f"Processed {len(results)}/{len(pdb_files)} files. Tokens saved to {args.output}") diff --git a/foldtoken/token_to_pdb.py b/foldtoken/token_to_pdb.py new file mode 100644 index 0000000..61bbb45 --- /dev/null +++ b/foldtoken/token_to_pdb.py @@ -0,0 +1,86 @@ +import os +import torch +import argparse +import csv + +import tqdm +from omegaconf import OmegaConf +from model_interface import MInterface + +# Load FoldToken4 model for decoding + +def load_model(device='cpu'): + cwd = os.path.dirname(__file__) + config_path = os.path.join(cwd, 'checkpoint', 'FT4', 'config.yaml') + checkpoint_path = os.path.join(cwd, 'checkpoint', 'FT4', 'ckpt.pth') + config = OmegaConf.load(config_path) + config = OmegaConf.to_container(config, resolve=True) + + model = MInterface(**config) + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + # adjust keys if wrapped in DataParallel + for key in list(checkpoint.keys()): + if '_forward_module.' in key: + checkpoint[key.replace('_forward_module.', '')] = checkpoint[key] + del checkpoint[key] + model.load_state_dict(checkpoint, strict=False) + model = model.to(device) + model.eval() + return model + +# Convert token list to PDB and save + +def tokens_to_pdb(model, tokens, output_path, level=8): + device = next(model.parameters()).device + vq_codes = torch.tensor(tokens, dtype=torch.long, device=device) + # get latent embeddings + h_V = model.model.vq.embed_id(vq_codes, level) + # simple chain encoding + chain_encoding = torch.ones_like(vq_codes, device=device) + # decode to protein object + protein = model.model.decoding(h_V, chain_encoding) + # save PDB + protein.to(output_path) + +# CLI + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Decode tokens CSV to PDB files.") + parser.add_argument('--token_csv', required=True, help="CSV file with columns ['pdb_file','tokens'].") + parser.add_argument('--output_dir', required=True, help="Directory to save reconstructed PDBs.") + parser.add_argument('--device', choices=['cpu','cuda'], default='cpu', help="Compute device.") + parser.add_argument('--level', type=int, default=8, help="Quantization level.") + args = parser.parse_args() + + # Decoding currently requires CUDA-enabled build + if not torch.cuda.is_available(): + print("Error: decoding tokens requires CUDA-enabled PyTorch. Please install the GPU version and run with --device cuda.") + exit(1) + device = 'cuda' + if args.device == 'cuda' and device == 'cpu': + print("CUDA not available, using CPU which may not support decoding.") + elif args.device == 'cpu' and device == 'cuda': + print("CUDA available, overriding to use CUDA for decoding.") + + model = load_model(device) + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.token_csv, newline='') as f: + reader = csv.DictReader(f) + for row in tqdm.tqdm(reader, total=sum(1 for _ in open(args.token_csv))): + name = row.get('pdb_file') or row.get('pdb') + tokens = [int(x) for x in row['tokens'].split()] + out_name = os.path.splitext(name)[0] + '.pdb' + # Ensure unique filename by appending a counter if needed + candidate = out_name + root, ext = os.path.splitext(out_name) + counter = 1 + while os.path.exists(os.path.join(args.output_dir, candidate)): + candidate = f"{root}_{counter}{ext}" + counter += 1 + output_path = os.path.join(args.output_dir, candidate) + try: + tokens_to_pdb(model, tokens, output_path, level=args.level) + # print(f"Saved {output_path}") + except Exception as e: + print(f"Failed {name}: {e}")