diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
new file mode 100644
index 00000000..b4694cd3
--- /dev/null
+++ b/.github/workflows/main.yml
@@ -0,0 +1,95 @@
+# This workflow will install Python dependencies, run tests and lint with a single version of Python
+# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
+
+name: Python application
+
+on:
+ push:
+ branches: [ "main", "github-ci" , "fix_tests2"]
+ pull_request:
+ branches: [ "main", "github-ci" ]
+
+permissions:
+ contents: read
+
+jobs:
+ ppi-scaffolds-test:
+
+ #runs-on: ubuntu-latest
+ runs-on: ubuntu-24.04
+ #container: nvcr.io/nvidia/cuda:12.9.1-cudnn-runtime-ubuntu20.04
+ container: nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v3
+ with:
+ python-version: "3.9"
+ # - name: Install dependencies
+ # run: |
+ # python -m pip install --upgrade pip
+ # #pip install flake8 pytest
+ # pip install pytest
+ # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
+ # - name: Lint with flake8
+ # run: |
+ # # stop the build if there are Python syntax errors or undefined names
+ # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+ # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
+
+ - name: Install dependencies
+ run: |
+ apt-get update -qq && apt-get install -y --no-install-recommends \
+ build-essential \
+ git \
+ curl \
+ wget \
+ ca-certificates
+
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ . $HOME/.local/bin/env bash
+
+ uv python install 3.9
+ uv venv
+ uv pip install --no-cache-dir -q \
+ dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html \
+ torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
+ e3nn==0.3.3 \
+ wandb==0.12.0 \
+ pynvml==11.0.0 \
+ git+https://github.com/NVIDIA/dllogger#egg=dllogger \
+ decorator==5.1.0 \
+ hydra-core==1.3.2 \
+ pyrsistent==0.19.3 \
+ pytest
+
+ uv pip install --no-cache-dir env/SE3Transformer
+ uv pip install --no-cache-dir -e . --no-deps
+ rm -rf ~/.cache # /app/RFdiffusion/tests
+
+ - name: Download weights
+ run: |
+ mkdir models
+ wget -q -O models/Base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
+ wget -q -O models/Complex_base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
+ wget -q -O models/Complex_Fold_base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
+ wget -q -O models/InpaintSeq_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
+ wget -q -O models/InpaintSeq_Fold_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
+ wget -q -O models/ActiveSite_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
+ wget -q -O models/Base_epoch8_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt
+ #optional
+ wget -q -O models/Complex_beta_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt
+ #original structure prediction weights
+ wget -q -O models/RF_structure_prediction_weights.pt http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt
+
+ - name: Setup and Run ppi_scaffolds tests
+ run: |
+ tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples
+ cd tests && uv run python test_diffusion.py
+
+
+ # - name: Test with pytest
+ # run: |
+ # pytest
diff --git a/README.md b/README.md
index 8fb92c51..4186b8b3 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
# RF*diffusion*
-
+.
@@ -45,6 +45,7 @@ RFdiffusion is an open source method for structure generation, with or without c
- [Generation of Symmetric Oligomers](#generation-of-symmetric-oligomers)
- [Using Auxiliary Potentials](#using-auxiliary-potentials)
- [Symmetric Motif Scaffolding.](#symmetric-motif-scaffolding)
+ - [RFpeptides macrocycle design](#macrocyclic-peptide-design-with-rfpeptides)
- [A Note on Model Weights](#a-note-on-model-weights)
- [Things you might want to play with at inference time](#things-you-might-want-to-play-with-at-inference-time)
- [Understanding the output files](#understanding-the-output-files)
@@ -466,6 +467,51 @@ Note that the contigs should specify something that is precisely symmetric. Thin
---
+### Macrocyclic peptide design with RFpeptides Add commentMore actions
+
+We have recently published the RFpeptides protocol for using RFdiffusion to design macrocyclic peptides that bind target proteins with atomic accuracy (Rettie, Juergens, Adebomi, et al., 2025). In this section we briefly outline how to run this inference protocol. We have added two examples for running macrocycle design with the RFpeptides protocol. One for monomeric design, and one for binder design.
+
+NOTE: Until the pull request is merged, you can find this code in the branch `rfpeptides`.
+
+```
+examples/design_macrocyclic_monomer.sh
+examples/design_macrocyclic_binder.sh
+```
+#### RFpeptides binder design
+
+
+To design a macrocyclic peptide to bind a target, the flags needed are very similar to classic binder design, but with two additional flags:
+```
+#!/bin/bash
+
+prefix=./outputs/diffused_binder_cyclic2
+
+# Note that the indices in this pdb file have been
+# shifted by +2 in chain A relative to pdbID 7zkr.
+pdb='./input_pdbs/7zkr_GABARAP.pdb'
+
+num_designs=10
+script="../scripts/run_inference.py"
+$script --config-name base \
+inference.output_prefix=$prefix \
+inference.num_designs=$num_designs \
+'contigmap.contigs=[12-18 A3-117/0]' \
+inference.input_pdb=$pdb \
+inference.cyclic=True \
+diffuser.T=50 \
+inference.cyc_chains='a' \
+ppi.hotspot_res=[\'A51\',\'A52\',\'A50\',\'A48\',\'A62\',\'A65\'] \
+```
+
+The new flags are `inference.cyclic=True` and `inference.cyc_chains`. Yes, they are somewhat redundant.
+
+`inference.cyclic` simply notifies the program that the user would like to design at least one macrocycle, and `inference.cyc_chains` is just a string containing the letter of every chain you would like to design as a cyclic peptide. In the example above, only chain `A` (`inference.cyc_chains='a'`) is cyclized, but one could do `inference.cyc_chains='abcd'` if they so desired (and the contigs was compatible with this, which the above one is not).
+
+#### RFpeptides monomer design
+For monomer design, you can simply adjust the contigs to only contain a single generated chain e.g., `contigmap.contigs=[12-18]`, keep the `inference.cyclic=True` and `inference.cyc_chains='a'`, and you're off to the races making monomers.
+
+---
+
### A Note on Model Weights
Because of everything we want diffusion to be able to do, there is not *One Model To Rule Them All*. E.g., if you want to run with secondary structure conditioning, this requires a different model than if you don't. Under the hood, we take care of most of this by default - we parse your input and work out the most appropriate checkpoint.
diff --git a/config/inference/base.yaml b/config/inference/base.yaml
index e3798f32..3bb0a5c1 100644
--- a/config/inference/base.yaml
+++ b/config/inference/base.yaml
@@ -21,6 +21,8 @@ inference:
trb_save_ckpt_path: null
schedule_directory_path: null
model_directory_path: null
+ cyclic: False
+ cyc_chains: 'a'
contigmap:
contigs: null
diff --git a/rfdiffusion/Embeddings.py b/rfdiffusion/Embeddings.py
index a052c9db..7b5114ca 100644
--- a/rfdiffusion/Embeddings.py
+++ b/rfdiffusion/Embeddings.py
@@ -4,10 +4,11 @@
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from rfdiffusion.util import get_tips
-from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
+from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal, find_breaks
from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
from rfdiffusion.Track_module import PairStr2Pair
import math
+import numpy as np
# Module contains classes and functions to generate initial embeddings
@@ -21,10 +22,34 @@ def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
self.emb = nn.Embedding(self.nbin, d_model)
self.drop = nn.Dropout(p_drop)
- def forward(self, x, idx):
+ def forward(self, x, idx, cyclize=None):
bins = torch.arange(self.minpos, self.maxpos, device=x.device)
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
#
+
+
+ # adding support for multi-chain cyclic
+ # find chain breaks and label chain ids
+ breaks = find_breaks(idx.squeeze().cpu().numpy(), thresh=35) # NOTE: Hard coded threshold for defining chain breaks here
+ # Typical jump for chainbreaks is +200
+ # Assumes monotonically increasing absolute IDX
+
+ chainids = np.zeros_like(idx.squeeze().cpu().numpy())
+ for i, b in enumerate(breaks):
+ chainids[b:] = i+1
+ chainids = torch.from_numpy(chainids).to(device=idx.device)
+
+ # cyclic peptide
+ if cyclize is not None:
+ for chid in torch.unique(chainids):
+ is_chid = chainids==chid
+ cur_cyclize = cyclize*is_chid
+ cur_mask = cur_cyclize[:,None]*cur_cyclize[None,:] # (L,L)
+ cur_ncyc = torch.sum(cur_cyclize)
+
+ seqsep[:,cur_mask*(seqsep[0]>cur_ncyc//2)] -= cur_ncyc
+ seqsep[:,cur_mask*(seqsep[0]<-cur_ncyc//2)] += cur_ncyc
+
ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
emb = self.emb(ib) #(B, L, L, d_model)
x = x + emb # add relative positional encoding
@@ -56,7 +81,7 @@ def reset_parameter(self):
nn.init.zeros_(self.emb.bias)
- def forward(self, msa, seq, idx):
+ def forward(self, msa, seq, idx, cyclize):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
@@ -82,7 +107,7 @@ def forward(self, msa, seq, idx):
right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
pair = left + right # (B, L, L, d_pair)
- pair = self.pos(pair, idx) # add relative position
+ pair = self.pos(pair, idx, cyclize) # add relative position
# state embedding
# Sergey's one hot trick
diff --git a/rfdiffusion/RoseTTAFoldModel.py b/rfdiffusion/RoseTTAFoldModel.py
index 84fbac43..dcf1a106 100644
--- a/rfdiffusion/RoseTTAFoldModel.py
+++ b/rfdiffusion/RoseTTAFoldModel.py
@@ -69,11 +69,12 @@ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
t1d=None, t2d=None, xyz_t=None, alpha_t=None,
msa_prev=None, pair_prev=None, state_prev=None,
return_raw=False, return_full=False, return_infer=False,
- use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None):
+ use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None,
+ cyclic_reses=None):
B, N, L = msa_latent.shape[:3]
# Get embeddings
- msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx)
+ msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx, cyclic_reses)
msa_full = self.full_emb(msa_full, seq, idx)
# Do recycling
@@ -101,7 +102,7 @@ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool()
msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3],
state, idx, use_checkpoint=use_checkpoint,
- motif_mask=is_frozen_residue)
+ motif_mask=is_frozen_residue, cyclic_reses=cyclic_reses)
if return_raw:
# get last structure
diff --git a/rfdiffusion/Track_module.py b/rfdiffusion/Track_module.py
index 12c0863d..27511e5d 100644
--- a/rfdiffusion/Track_module.py
+++ b/rfdiffusion/Track_module.py
@@ -234,7 +234,7 @@ def reset_parameter(self):
nn.init.zeros_(self.embed_e2.bias)
@torch.cuda.amp.autocast(enabled=False)
- def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5):
+ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses=None, top_k=64, eps=1e-5):
B, N, L = msa.shape[:3]
if motif_mask is None:
@@ -249,7 +249,7 @@ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64,
node = self.norm_node(self.embed_x(node))
pair = self.norm_edge1(self.embed_e1(pair))
- neighbor = get_seqsep(idx)
+ neighbor = get_seqsep(idx, cyclic_reses)
rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1]))
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
pair = self.norm_edge2(self.embed_e2(pair))
@@ -318,18 +318,18 @@ def __init__(self, d_msa=256, d_pair=128,
SE3_param=SE3_param,
p_drop=p_drop)
- def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False):
+ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False, cyclic_reses=None):
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
if use_checkpoint:
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
- R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask)
+ R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses)
else:
msa = self.msa2msa(msa, pair, rbf_feat, state)
pair = self.msa2pair(msa, pair)
pair = self.pair2pair(pair, rbf_feat)
- R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, top_k=0)
+ R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, cyclic_reses=cyclic_reses, top_k=0)
return msa, pair, R, T, state, alpha
@@ -384,7 +384,7 @@ def reset_parameter(self):
self.proj_state2 = init_lecun_normal(self.proj_state2)
nn.init.zeros_(self.proj_state2.bias)
- def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False, motif_mask=None):
+ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, cyclic_reses=None, use_checkpoint=False, motif_mask=None):
"""
input:
seq: query sequence (B, L)
@@ -425,7 +425,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
state,
idx,
motif_mask=motif_mask,
- use_checkpoint=use_checkpoint)
+ use_checkpoint=use_checkpoint,
+ cyclic_reses=cyclic_reses)
R_s.append(R_in)
T_s.append(T_in)
alpha_s.append(alpha)
@@ -444,7 +445,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
state,
idx,
motif_mask=motif_mask,
- use_checkpoint=use_checkpoint)
+ use_checkpoint=use_checkpoint,
+ cyclic_reses=cyclic_reses)
R_s.append(R_in)
T_s.append(T_in)
alpha_s.append(alpha)
@@ -462,7 +464,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
state,
idx,
top_k=64,
- motif_mask=motif_mask)
+ motif_mask=motif_mask,
+ cyclic_reses=cyclic_reses)
R_s.append(R_in)
T_s.append(T_in)
alpha_s.append(alpha)
diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py
index 3e6505f4..1ad3c7f1 100644
--- a/rfdiffusion/inference/model_runners.py
+++ b/rfdiffusion/inference/model_runners.py
@@ -278,7 +278,29 @@ def sample_init(self, return_forward_trajectory=False):
self.mappings = self.contig_map.get_mappings()
self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:]
self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:]
- self.binderlen = len(self.contig_map.inpaint)
+ self.binderlen = len(self.contig_map.inpaint)
+
+ #######################################
+ ### Resolve cyclic peptide indicies ###
+ #######################################
+ if self._conf.inference.cyclic:
+ if self._conf.inference.cyc_chains is None:
+ # default to all residues being cyclized
+ self.cyclic_reses = ~self.mask_str.to(self.device).squeeze()
+ else:
+ # use cyc_chains arg to determine cyclic_reses mask
+ assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string'
+ cyc_chains = self._conf.inference.cyc_chains
+ cyc_chains = [i.upper() for i in cyc_chains]
+ hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains
+ is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty
+
+ for ch in cyc_chains:
+ ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool()
+ is_cyclized[ch_mask] = True # set this whole chain to be cyclic
+ self.cyclic_reses = is_cyclized
+ else:
+ self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze()
####################
### Get Hotspots ###
@@ -675,7 +697,8 @@ def sample_step(self, *, t, x_t, seq_init, final_step):
state_prev = None,
t=torch.tensor(t),
return_infer=True,
- motif_mask=self.diffusion_mask.squeeze().to(self.device))
+ motif_mask=self.diffusion_mask.squeeze().to(self.device),
+ cyclic_reses=self.cyclic_reses)
if self.symmetry is not None and self.inf_conf.symmetric_self_cond:
px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3]
diff --git a/rfdiffusion/util_module.py b/rfdiffusion/util_module.py
index 20ba2dc4..839b7614 100644
--- a/rfdiffusion/util_module.py
+++ b/rfdiffusion/util_module.py
@@ -7,6 +7,13 @@
import dgl
from rfdiffusion.util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points
+
+def find_breaks(ix, thresh=35):
+ # finds positions in ix where the jump is greater than 100
+ breaks = np.where(np.diff(ix) > thresh)[0]
+ return np.array(breaks)+1
+
+
def init_lecun_normal(module):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
@@ -91,7 +98,7 @@ def rbf(D):
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
-def get_seqsep(idx):
+def get_seqsep(idx, cyclic=None):
'''
Input:
- idx: residue indices of given sequence (B,L)
@@ -104,6 +111,25 @@ def get_seqsep(idx):
neigh = torch.abs(seqsep)
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
neigh = sign * neigh
+
+ # add cyclic edges
+ breaks = find_breaks(idx.squeeze().cpu().numpy())
+ chainids = np.zeros_like(idx.squeeze().cpu().numpy())
+ for i, b in enumerate(breaks):
+ chainids[b:] = i+1
+ chainids = torch.from_numpy(chainids).to(device=idx.device)
+
+ # add cyclic edges with multiple chains
+ if (cyclic is not None):
+ for chid in torch.unique(chainids):
+ is_chid = chainids==chid
+ cur_cyclic = cyclic*is_chid
+ cur_cres = cur_cyclic.nonzero()
+
+ if cur_cyclic.sum()>=2:
+ neigh[:,cur_cres[-1],cur_cres[0]] = 1
+ neigh[:,cur_cres[0],cur_cres[-1]] = -1
+
return neigh.unsqueeze(-1)
def make_full_graph(xyz, pair, idx, top_k=64, kmin=9):