From 3ea134f861f719efce2e314400eb581373c1783c Mon Sep 17 00:00:00 2001 From: Mahdi Date: Sun, 25 May 2025 20:38:40 -0500 Subject: [PATCH 1/8] feat: pdb to token code --- foldtoken/pdb_to_token.py | 81 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 foldtoken/pdb_to_token.py diff --git a/foldtoken/pdb_to_token.py b/foldtoken/pdb_to_token.py new file mode 100644 index 0000000..4989302 --- /dev/null +++ b/foldtoken/pdb_to_token.py @@ -0,0 +1,81 @@ +# inference.py: Encode a PDB structure into tokens using FoldToken4 +import os +import torch +import argparse +import tqdm +from omegaconf import OmegaConf +from model_interface import MInterface +from src.chroma.data import Protein + + +def load_model(config_path, checkpoint_path, device='cpu'): + """Load the FoldToken4 model from checkpoint on given device""" + print(f"Loading config from {config_path}") + config = OmegaConf.load(config_path) + config = OmegaConf.to_container(config, resolve=True) + + print("Initializing model...") + model = MInterface(**config) + + print(f"Loading checkpoint from {checkpoint_path}") + print(f"Using device: {device}") + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + + model.load_state_dict(checkpoint, strict=False) + model = model.to(device) + print("Model loaded successfully") + return model + + +def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8): + """Encode a PDB file into tokens using FoldToken4 on given device""" + protein = Protein(pdb_path, device=device) + # determine number of residues via XCS representation + X, C, S = protein.sys.to_XCS() + + with torch.no_grad(): + vq_code = model.encode_protein(protein, level=level)[1] + return vq_code + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Encode all PDBs in a directory to tokens CSV.") + parser.add_argument('--pdb_dir', required=True, help="Directory containing PDB files.") + parser.add_argument('--output', default="output.csv", help="Output CSV file path.") + parser.add_argument('--device', choices=['cpu','cuda'], default='cpu', help="Compute device to use.") + args = parser.parse_args() + + # Validate device + device = args.device + if device == 'cuda' and not torch.cuda.is_available(): + print("CUDA not available, falling back to CPU") + device = 'cpu' + + # Config and checkpoint paths + config_path = os.path.join(os.path.dirname(__file__), 'checkpoint', 'FT4', 'config.yaml') + checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoint', 'FT4', 'ckpt.pth') + if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found: {config_path}") + if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + model = load_model(config_path, checkpoint_path, device) + + # Gather PDB files + pdb_files = [os.path.join(args.pdb_dir, f) for f in os.listdir(args.pdb_dir) if f.lower().endswith('.pdb')] + results = [] + for pdb_path in tqdm.tqdm(pdb_files, desc="Encoding PDBs", unit="file"): + try: + codes = encode_pdb_to_tokens(model, pdb_path, device) + token_list = codes.cpu().numpy().tolist() if hasattr(codes, 'cpu') else list(codes) + results.append((os.path.basename(pdb_path), token_list)) + except Exception as e: + print(f"Skipped {os.path.basename(pdb_path)}: {e}") + + # Write CSV + import csv + with open(args.output, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['pdb_file', 'tokens']) + for name, tlist in results: + writer.writerow([name, ' '.join(map(str, tlist))]) + print(f"Processed {len(results)}/{len(pdb_files)} files. Tokens saved to {args.output}") + From f20db9acbc131f18c55f30723343736ef0612aac Mon Sep 17 00:00:00 2001 From: Mahdi Date: Sun, 25 May 2025 21:56:28 -0500 Subject: [PATCH 2/8] feat: token to pdb code --- foldtoken/token_to_pdb.py | 81 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 foldtoken/token_to_pdb.py diff --git a/foldtoken/token_to_pdb.py b/foldtoken/token_to_pdb.py new file mode 100644 index 0000000..f108944 --- /dev/null +++ b/foldtoken/token_to_pdb.py @@ -0,0 +1,81 @@ +import os +import torch +import argparse +import csv +from omegaconf import OmegaConf +from model_interface import MInterface + +# Load FoldToken4 model for decoding + +def load_model(device='cpu'): + cwd = os.path.dirname(__file__) + config_path = os.path.join(cwd, 'checkpoint', 'FT4', 'config.yaml') + checkpoint_path = os.path.join(cwd, 'checkpoint', 'FT4', 'ckpt.pth') + config = OmegaConf.load(config_path) + config = OmegaConf.to_container(config, resolve=True) + + model = MInterface(**config) + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + # adjust keys if wrapped in DataParallel + for key in list(checkpoint.keys()): + if '_forward_module.' in key: + checkpoint[key.replace('_forward_module.', '')] = checkpoint[key] + del checkpoint[key] + model.load_state_dict(checkpoint, strict=False) + model = model.to(device) + model.eval() + return model + +# Convert token list to PDB and save + +def tokens_to_pdb(model, tokens, output_path, level=8): + device = next(model.parameters()).device + vq_codes = torch.tensor(tokens, dtype=torch.long, device=device) + # get latent embeddings + h_V = model.model.vq.embed_id(vq_codes, level) + # simple chain encoding + chain_encoding = torch.ones_like(vq_codes, device=device) + # decode to protein object + protein = model.model.decoding(h_V, chain_encoding) + # save PDB + protein.to(output_path) + +# CLI + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Decode tokens CSV to PDB files.") + parser.add_argument('--token_csv', required=True, help="CSV file with columns ['pdb_file','tokens'].") + parser.add_argument('--output_dir', required=True, help="Directory to save reconstructed PDBs.") + parser.add_argument('--device', choices=['cpu','cuda'], default='cpu', help="Compute device.") + parser.add_argument('--level', type=int, default=8, help="Quantization level.") + args = parser.parse_args() + + # Decoding currently requires CUDA-enabled build + if not torch.cuda.is_available(): + print("Error: decoding tokens requires CUDA-enabled PyTorch. Please install the GPU version and run with --device cuda.") + exit(1) + device = 'cuda' + if args.device == 'cuda' and device == 'cpu': + print("CUDA not available, using CPU which may not support decoding.") + elif args.device == 'cpu' and device == 'cuda': + print("CUDA available, overriding to use CUDA for decoding.") + + model = load_model(device) + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.token_csv, newline='') as f: + reader = csv.DictReader(f) + for row in reader: + name = row.get('pdb_file') or row.get('pdb') + tokens = [int(x) for x in row['tokens'].split()] + out_name = os.path.splitext(name)[0] + '_pred.pdb' + output_path = os.path.join(args.output_dir, out_name) + if os.path.exists(output_path): + print(f"{output_path} exists, skip") + continue + try: + tokens_to_pdb(model, tokens, output_path, level=args.level) + print(f"Saved {output_path}") + except Exception as e: + print(f"Failed {name}: {e}") + From 69edd764f4629b3c0f2b83c7c1e82277cbd0835f Mon Sep 17 00:00:00 2001 From: mahdi Date: Mon, 26 May 2025 01:32:12 -0500 Subject: [PATCH 3/8] fix: issues --- foldtoken/pdb_to_token.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/foldtoken/pdb_to_token.py b/foldtoken/pdb_to_token.py index 4989302..32c22b0 100644 --- a/foldtoken/pdb_to_token.py +++ b/foldtoken/pdb_to_token.py @@ -3,6 +3,7 @@ import torch import argparse import tqdm +import csv from omegaconf import OmegaConf from model_interface import MInterface from src.chroma.data import Protein @@ -20,6 +21,11 @@ def load_model(config_path, checkpoint_path, device='cpu'): print(f"Loading checkpoint from {checkpoint_path}") print(f"Using device: {device}") checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + # adjust keys if wrapped in DataParallel + for key in list(checkpoint.keys()): + if '_forward_module.' in key: + checkpoint[key.replace('_forward_module.', '')] = checkpoint[key] + del checkpoint[key] model.load_state_dict(checkpoint, strict=False) model = model.to(device) @@ -31,7 +37,7 @@ def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8): """Encode a PDB file into tokens using FoldToken4 on given device""" protein = Protein(pdb_path, device=device) # determine number of residues via XCS representation - X, C, S = protein.sys.to_XCS() + X, C, S = protein.to_XCS() with torch.no_grad(): vq_code = model.encode_protein(protein, level=level)[1] @@ -63,15 +69,23 @@ def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8): pdb_files = [os.path.join(args.pdb_dir, f) for f in os.listdir(args.pdb_dir) if f.lower().endswith('.pdb')] results = [] for pdb_path in tqdm.tqdm(pdb_files, desc="Encoding PDBs", unit="file"): + basename = os.path.basename(pdb_path) try: - codes = encode_pdb_to_tokens(model, pdb_path, device) - token_list = codes.cpu().numpy().tolist() if hasattr(codes, 'cpu') else list(codes) - results.append((os.path.basename(pdb_path), token_list)) + protein = Protein(pdb_path, device=device) + X, C, S = protein.to_XCS() except Exception as e: - print(f"Skipped {os.path.basename(pdb_path)}: {e}") - - # Write CSV - import csv + print(f"Skipped {basename}: load error {e}") + continue + # filter short or highly masked proteins + if X.shape[1] < 5 or (C != -1).sum() < 5: + print(f"Skipped {basename}: too few residues") + continue + with torch.no_grad(): + vq_code = model.encode_protein(protein, level=8)[1] + token_list = vq_code.cpu().numpy().tolist() if hasattr(vq_code, 'cpu') else list(vq_code) + results.append((basename, token_list)) + + # write results with open(args.output, 'w', newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(['pdb_file', 'tokens']) From 3c4ca505cf8a5fc674f69f1d4e259dcd74a60ce6 Mon Sep 17 00:00:00 2001 From: mahdi Date: Mon, 26 May 2025 01:48:47 -0500 Subject: [PATCH 4/8] doc: installation.sh on python env --- foldtoken/installation.sh | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 foldtoken/installation.sh diff --git a/foldtoken/installation.sh b/foldtoken/installation.sh new file mode 100644 index 0000000..13313f7 --- /dev/null +++ b/foldtoken/installation.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Update build basics +python -m pip install -U pip setuptools wheel packaging + +# --------------------------------------------------------------------- +# 1) PyTorch 2.0.1 + CUDA 11.7 (official wheels) +# --------------------------------------------------------------------- +pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2+cu117 \ + --index-url https://download.pytorch.org/whl/cu117 + +# --------------------------------------------------------------------- +# 2) Core numerics that some wheels depend on +# --------------------------------------------------------------------- +pip install "scipy>=1.11,<1.12" + +# --------------------------------------------------------------------- +# 3) PyTorch-Geometric CUDA wheels (Torch 2.0.1 + CU117) +# Hosted on data.pyg.org — disable PyPI so we don’t grab CPU builds. +# --------------------------------------------------------------------- +PYG_INDEX="https://data.pyg.org/whl/torch-2.0.1+cu117.html" + +pip install --no-index --find-links "$PYG_INDEX" torch-scatter==2.1.2+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" torch-sparse==0.6.17+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" torch-cluster==1.6.3+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" torch-spline-conv==1.2.2+pt20cu117 +pip install --no-index --find-links "$PYG_INDEX" pyg-lib==0.2.0+pt20cu117 +pip install --find-links https://data.pyg.org/whl/torch-2.0.1+cu117.html torch-geometric==2.3.1 + +# --------------------------------------------------------------------- +# 4) Optional extras and utilities +# --------------------------------------------------------------------- +pip install flash-attn --no-build-isolation # Ampere+ GPUs +pip install pytorch-lightning==1.9.0 # training framework +pip install omegaconf==2.3.0 # config library +pip install tqdm From d059f744daa4ded7ef98abd82703d79ac0962335 Mon Sep 17 00:00:00 2001 From: mahdi Date: Wed, 28 May 2025 14:19:31 -0500 Subject: [PATCH 5/8] feat: progress bar --- foldtoken/token_to_pdb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/foldtoken/token_to_pdb.py b/foldtoken/token_to_pdb.py index f108944..0ab95fb 100644 --- a/foldtoken/token_to_pdb.py +++ b/foldtoken/token_to_pdb.py @@ -2,6 +2,8 @@ import torch import argparse import csv + +import tqdm from omegaconf import OmegaConf from model_interface import MInterface @@ -65,7 +67,7 @@ def tokens_to_pdb(model, tokens, output_path, level=8): with open(args.token_csv, newline='') as f: reader = csv.DictReader(f) - for row in reader: + for row in tqdm.tqdm(reader, total=sum(1 for _ in open(args.token_csv))): name = row.get('pdb_file') or row.get('pdb') tokens = [int(x) for x in row['tokens'].split()] out_name = os.path.splitext(name)[0] + '_pred.pdb' @@ -75,7 +77,7 @@ def tokens_to_pdb(model, tokens, output_path, level=8): continue try: tokens_to_pdb(model, tokens, output_path, level=args.level) - print(f"Saved {output_path}") + # print(f"Saved {output_path}") except Exception as e: print(f"Failed {name}: {e}") From f7bb64880a711e6448282c2694c7ea7b6f3bf841 Mon Sep 17 00:00:00 2001 From: mahdi Date: Wed, 28 May 2025 14:23:27 -0500 Subject: [PATCH 6/8] ref: make the pdb names identical --- foldtoken/token_to_pdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/foldtoken/token_to_pdb.py b/foldtoken/token_to_pdb.py index 0ab95fb..0704e5f 100644 --- a/foldtoken/token_to_pdb.py +++ b/foldtoken/token_to_pdb.py @@ -70,7 +70,7 @@ def tokens_to_pdb(model, tokens, output_path, level=8): for row in tqdm.tqdm(reader, total=sum(1 for _ in open(args.token_csv))): name = row.get('pdb_file') or row.get('pdb') tokens = [int(x) for x in row['tokens'].split()] - out_name = os.path.splitext(name)[0] + '_pred.pdb' + out_name = os.path.splitext(name)[0] + '.pdb' output_path = os.path.join(args.output_dir, out_name) if os.path.exists(output_path): print(f"{output_path} exists, skip") From d85f125681dca2b11cc432ae1e0dd7de9bff4ec8 Mon Sep 17 00:00:00 2001 From: mahdi Date: Sun, 24 Aug 2025 01:22:14 -0500 Subject: [PATCH 7/8] feat: recursively collect PDB files from subdirectories --- foldtoken/pdb_to_token.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/foldtoken/pdb_to_token.py b/foldtoken/pdb_to_token.py index 32c22b0..742938b 100644 --- a/foldtoken/pdb_to_token.py +++ b/foldtoken/pdb_to_token.py @@ -66,7 +66,12 @@ def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8): model = load_model(config_path, checkpoint_path, device) # Gather PDB files - pdb_files = [os.path.join(args.pdb_dir, f) for f in os.listdir(args.pdb_dir) if f.lower().endswith('.pdb')] + # Recursively collect .pdb files from the given directory and all subdirectories + pdb_files = [] + for root, _, files in os.walk(args.pdb_dir): + for f in files: + if f.lower().endswith('.pdb'): + pdb_files.append(os.path.join(root, f)) results = [] for pdb_path in tqdm.tqdm(pdb_files, desc="Encoding PDBs", unit="file"): basename = os.path.basename(pdb_path) @@ -92,4 +97,3 @@ def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8): for name, tlist in results: writer.writerow([name, ' '.join(map(str, tlist))]) print(f"Processed {len(results)}/{len(pdb_files)} files. Tokens saved to {args.output}") - From c185da555e7771c09567faf61f75de5a4d19b50f Mon Sep 17 00:00:00 2001 From: mahdi Date: Sun, 24 Aug 2025 01:28:16 -0500 Subject: [PATCH 8/8] feat: ensure unique output filenames for PDB files if find already identical names --- foldtoken/token_to_pdb.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/foldtoken/token_to_pdb.py b/foldtoken/token_to_pdb.py index 0704e5f..61bbb45 100644 --- a/foldtoken/token_to_pdb.py +++ b/foldtoken/token_to_pdb.py @@ -71,13 +71,16 @@ def tokens_to_pdb(model, tokens, output_path, level=8): name = row.get('pdb_file') or row.get('pdb') tokens = [int(x) for x in row['tokens'].split()] out_name = os.path.splitext(name)[0] + '.pdb' - output_path = os.path.join(args.output_dir, out_name) - if os.path.exists(output_path): - print(f"{output_path} exists, skip") - continue + # Ensure unique filename by appending a counter if needed + candidate = out_name + root, ext = os.path.splitext(out_name) + counter = 1 + while os.path.exists(os.path.join(args.output_dir, candidate)): + candidate = f"{root}_{counter}{ext}" + counter += 1 + output_path = os.path.join(args.output_dir, candidate) try: tokens_to_pdb(model, tokens, output_path, level=args.level) # print(f"Saved {output_path}") except Exception as e: print(f"Failed {name}: {e}") -