Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions foldtoken/installation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env bash
set -euo pipefail

# Update build basics
python -m pip install -U pip setuptools wheel packaging

# ---------------------------------------------------------------------
# 1) PyTorch 2.0.1 + CUDA 11.7 (official wheels)
# ---------------------------------------------------------------------
pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2+cu117 \
--index-url https://download.pytorch.org/whl/cu117

# ---------------------------------------------------------------------
# 2) Core numerics that some wheels depend on
# ---------------------------------------------------------------------
pip install "scipy>=1.11,<1.12"

# ---------------------------------------------------------------------
# 3) PyTorch-Geometric CUDA wheels (Torch 2.0.1 + CU117)
# Hosted on data.pyg.org — disable PyPI so we don’t grab CPU builds.
# ---------------------------------------------------------------------
PYG_INDEX="https://data.pyg.org/whl/torch-2.0.1+cu117.html"

pip install --no-index --find-links "$PYG_INDEX" torch-scatter==2.1.2+pt20cu117
pip install --no-index --find-links "$PYG_INDEX" torch-sparse==0.6.17+pt20cu117
pip install --no-index --find-links "$PYG_INDEX" torch-cluster==1.6.3+pt20cu117
pip install --no-index --find-links "$PYG_INDEX" torch-spline-conv==1.2.2+pt20cu117
pip install --no-index --find-links "$PYG_INDEX" pyg-lib==0.2.0+pt20cu117
pip install --find-links https://data.pyg.org/whl/torch-2.0.1+cu117.html torch-geometric==2.3.1

# ---------------------------------------------------------------------
# 4) Optional extras and utilities
# ---------------------------------------------------------------------
pip install flash-attn --no-build-isolation # Ampere+ GPUs
pip install pytorch-lightning==1.9.0 # training framework
pip install omegaconf==2.3.0 # config library
pip install tqdm
99 changes: 99 additions & 0 deletions foldtoken/pdb_to_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# inference.py: Encode a PDB structure into tokens using FoldToken4
import os
import torch
import argparse
import tqdm
import csv
from omegaconf import OmegaConf
from model_interface import MInterface
from src.chroma.data import Protein


def load_model(config_path, checkpoint_path, device='cpu'):
"""Load the FoldToken4 model from checkpoint on given device"""
print(f"Loading config from {config_path}")
config = OmegaConf.load(config_path)
config = OmegaConf.to_container(config, resolve=True)

print("Initializing model...")
model = MInterface(**config)

print(f"Loading checkpoint from {checkpoint_path}")
print(f"Using device: {device}")
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
# adjust keys if wrapped in DataParallel
for key in list(checkpoint.keys()):
if '_forward_module.' in key:
checkpoint[key.replace('_forward_module.', '')] = checkpoint[key]
del checkpoint[key]

model.load_state_dict(checkpoint, strict=False)
model = model.to(device)
print("Model loaded successfully")
return model


def encode_pdb_to_tokens(model, pdb_path, device='cpu', level=8):
"""Encode a PDB file into tokens using FoldToken4 on given device"""
protein = Protein(pdb_path, device=device)
# determine number of residues via XCS representation
X, C, S = protein.to_XCS()

with torch.no_grad():
vq_code = model.encode_protein(protein, level=level)[1]
return vq_code


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Encode all PDBs in a directory to tokens CSV.")
parser.add_argument('--pdb_dir', required=True, help="Directory containing PDB files.")
parser.add_argument('--output', default="output.csv", help="Output CSV file path.")
parser.add_argument('--device', choices=['cpu','cuda'], default='cpu', help="Compute device to use.")
args = parser.parse_args()

# Validate device
device = args.device
if device == 'cuda' and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
device = 'cpu'

# Config and checkpoint paths
config_path = os.path.join(os.path.dirname(__file__), 'checkpoint', 'FT4', 'config.yaml')
checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoint', 'FT4', 'ckpt.pth')
if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found: {config_path}")
if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

model = load_model(config_path, checkpoint_path, device)

# Gather PDB files
# Recursively collect .pdb files from the given directory and all subdirectories
pdb_files = []
for root, _, files in os.walk(args.pdb_dir):
for f in files:
if f.lower().endswith('.pdb'):
pdb_files.append(os.path.join(root, f))
results = []
for pdb_path in tqdm.tqdm(pdb_files, desc="Encoding PDBs", unit="file"):
basename = os.path.basename(pdb_path)
try:
protein = Protein(pdb_path, device=device)
X, C, S = protein.to_XCS()
except Exception as e:
print(f"Skipped {basename}: load error {e}")
continue
# filter short or highly masked proteins
if X.shape[1] < 5 or (C != -1).sum() < 5:
print(f"Skipped {basename}: too few residues")
continue
with torch.no_grad():
vq_code = model.encode_protein(protein, level=8)[1]
token_list = vq_code.cpu().numpy().tolist() if hasattr(vq_code, 'cpu') else list(vq_code)
results.append((basename, token_list))

# write results
with open(args.output, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['pdb_file', 'tokens'])
for name, tlist in results:
writer.writerow([name, ' '.join(map(str, tlist))])
print(f"Processed {len(results)}/{len(pdb_files)} files. Tokens saved to {args.output}")
86 changes: 86 additions & 0 deletions foldtoken/token_to_pdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import torch
import argparse
import csv

import tqdm
from omegaconf import OmegaConf
from model_interface import MInterface

# Load FoldToken4 model for decoding

def load_model(device='cpu'):
cwd = os.path.dirname(__file__)
config_path = os.path.join(cwd, 'checkpoint', 'FT4', 'config.yaml')
checkpoint_path = os.path.join(cwd, 'checkpoint', 'FT4', 'ckpt.pth')
config = OmegaConf.load(config_path)
config = OmegaConf.to_container(config, resolve=True)

model = MInterface(**config)
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
# adjust keys if wrapped in DataParallel
for key in list(checkpoint.keys()):
if '_forward_module.' in key:
checkpoint[key.replace('_forward_module.', '')] = checkpoint[key]
del checkpoint[key]
model.load_state_dict(checkpoint, strict=False)
model = model.to(device)
model.eval()
return model

# Convert token list to PDB and save

def tokens_to_pdb(model, tokens, output_path, level=8):
device = next(model.parameters()).device
vq_codes = torch.tensor(tokens, dtype=torch.long, device=device)
# get latent embeddings
h_V = model.model.vq.embed_id(vq_codes, level)
# simple chain encoding
chain_encoding = torch.ones_like(vq_codes, device=device)
# decode to protein object
protein = model.model.decoding(h_V, chain_encoding)
# save PDB
protein.to(output_path)

# CLI

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Decode tokens CSV to PDB files.")
parser.add_argument('--token_csv', required=True, help="CSV file with columns ['pdb_file','tokens'].")
parser.add_argument('--output_dir', required=True, help="Directory to save reconstructed PDBs.")
parser.add_argument('--device', choices=['cpu','cuda'], default='cpu', help="Compute device.")
parser.add_argument('--level', type=int, default=8, help="Quantization level.")
args = parser.parse_args()

# Decoding currently requires CUDA-enabled build
if not torch.cuda.is_available():
print("Error: decoding tokens requires CUDA-enabled PyTorch. Please install the GPU version and run with --device cuda.")
exit(1)
device = 'cuda'
if args.device == 'cuda' and device == 'cpu':
print("CUDA not available, using CPU which may not support decoding.")
elif args.device == 'cpu' and device == 'cuda':
print("CUDA available, overriding to use CUDA for decoding.")

model = load_model(device)
os.makedirs(args.output_dir, exist_ok=True)

with open(args.token_csv, newline='') as f:
reader = csv.DictReader(f)
for row in tqdm.tqdm(reader, total=sum(1 for _ in open(args.token_csv))):
name = row.get('pdb_file') or row.get('pdb')
tokens = [int(x) for x in row['tokens'].split()]
out_name = os.path.splitext(name)[0] + '.pdb'
# Ensure unique filename by appending a counter if needed
candidate = out_name
root, ext = os.path.splitext(out_name)
counter = 1
while os.path.exists(os.path.join(args.output_dir, candidate)):
candidate = f"{root}_{counter}{ext}"
counter += 1
output_path = os.path.join(args.output_dir, candidate)
try:
tokens_to_pdb(model, tokens, output_path, level=args.level)
# print(f"Saved {output_path}")
except Exception as e:
print(f"Failed {name}: {e}")