Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python application

on:
push:
branches: [ "main", "github-ci" , "fix_tests2"]
pull_request:
branches: [ "main", "github-ci" ]

permissions:
contents: read

jobs:
ppi-scaffolds-test:

#runs-on: ubuntu-latest
runs-on: ubuntu-24.04
#container: nvcr.io/nvidia/cuda:12.9.1-cudnn-runtime-ubuntu20.04
container: nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v3
with:
python-version: "3.9"
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# #pip install flake8 pytest
# pip install pytest
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# - name: Lint with flake8
# run: |
# # stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

- name: Install dependencies
run: |
apt-get update -qq && apt-get install -y --no-install-recommends \
build-essential \
git \
curl \
wget \
ca-certificates

curl -LsSf https://astral.sh/uv/install.sh | sh
. $HOME/.local/bin/env bash

uv python install 3.9
uv venv
uv pip install --no-cache-dir -q \
dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html \
torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
e3nn==0.3.3 \
wandb==0.12.0 \
pynvml==11.0.0 \
git+https://github.com/NVIDIA/dllogger#egg=dllogger \
decorator==5.1.0 \
hydra-core==1.3.2 \
pyrsistent==0.19.3 \
pytest

uv pip install --no-cache-dir env/SE3Transformer
uv pip install --no-cache-dir -e . --no-deps
rm -rf ~/.cache # /app/RFdiffusion/tests

- name: Download weights
run: |
mkdir models
wget -q -O models/Base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
wget -q -O models/Complex_base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
wget -q -O models/Complex_Fold_base_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
wget -q -O models/InpaintSeq_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
wget -q -O models/InpaintSeq_Fold_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
wget -q -O models/ActiveSite_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
wget -q -O models/Base_epoch8_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt
#optional
wget -q -O models/Complex_beta_ckpt.pt http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt
#original structure prediction weights
wget -q -O models/RF_structure_prediction_weights.pt http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt

- name: Setup and Run ppi_scaffolds tests
run: |
tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples
cd tests && uv run python test_diffusion.py


# - name: Test with pytest
# run: |
# pytest
48 changes: 47 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# RF*diffusion*

.
<!--
<img width="1115" alt="Screen Shot 2023-01-19 at 5 56 33 PM" src="https://user-images.githubusercontent.com/56419265/213588200-f8f44dba-276e-4dd2-b844-15acc441458d.png">
-->
Expand Down Expand Up @@ -45,6 +45,7 @@ RFdiffusion is an open source method for structure generation, with or without c
- [Generation of Symmetric Oligomers](#generation-of-symmetric-oligomers)
- [Using Auxiliary Potentials](#using-auxiliary-potentials)
- [Symmetric Motif Scaffolding.](#symmetric-motif-scaffolding)
- [RFpeptides macrocycle design](#macrocyclic-peptide-design-with-rfpeptides)
- [A Note on Model Weights](#a-note-on-model-weights)
- [Things you might want to play with at inference time](#things-you-might-want-to-play-with-at-inference-time)
- [Understanding the output files](#understanding-the-output-files)
Expand Down Expand Up @@ -466,6 +467,51 @@ Note that the contigs should specify something that is precisely symmetric. Thin

---

### Macrocyclic peptide design with RFpeptides Add commentMore actions
<img src="./img/rfpeptides_fig1.png" alt="alt text" width="400px" align="right"/>
We have recently published the RFpeptides protocol for using RFdiffusion to design macrocyclic peptides that bind target proteins with atomic accuracy (Rettie, Juergens, Adebomi, et al., 2025). In this section we briefly outline how to run this inference protocol. We have added two examples for running macrocycle design with the RFpeptides protocol. One for monomeric design, and one for binder design.

NOTE: Until the pull request is merged, you can find this code in the branch `rfpeptides`.

```
examples/design_macrocyclic_monomer.sh
examples/design_macrocyclic_binder.sh
```
#### RFpeptides binder design
<img src="./img/rfpeptides_binder.png" alt="alt text" width="1100" align="center"/>

To design a macrocyclic peptide to bind a target, the flags needed are very similar to classic binder design, but with two additional flags:
```
#!/bin/bash

prefix=./outputs/diffused_binder_cyclic2

# Note that the indices in this pdb file have been
# shifted by +2 in chain A relative to pdbID 7zkr.
pdb='./input_pdbs/7zkr_GABARAP.pdb'

num_designs=10
script="../scripts/run_inference.py"
$script --config-name base \
inference.output_prefix=$prefix \
inference.num_designs=$num_designs \
'contigmap.contigs=[12-18 A3-117/0]' \
inference.input_pdb=$pdb \
inference.cyclic=True \
diffuser.T=50 \
inference.cyc_chains='a' \
ppi.hotspot_res=[\'A51\',\'A52\',\'A50\',\'A48\',\'A62\',\'A65\'] \
```

The new flags are `inference.cyclic=True` and `inference.cyc_chains`. Yes, they are somewhat redundant.

`inference.cyclic` simply notifies the program that the user would like to design at least one macrocycle, and `inference.cyc_chains` is just a string containing the letter of every chain you would like to design as a cyclic peptide. In the example above, only chain `A` (`inference.cyc_chains='a'`) is cyclized, but one could do `inference.cyc_chains='abcd'` if they so desired (and the contigs was compatible with this, which the above one is not).

#### RFpeptides monomer design
For monomer design, you can simply adjust the contigs to only contain a single generated chain e.g., `contigmap.contigs=[12-18]`, keep the `inference.cyclic=True` and `inference.cyc_chains='a'`, and you're off to the races making monomers.

---

### A Note on Model Weights

Because of everything we want diffusion to be able to do, there is not *One Model To Rule Them All*. E.g., if you want to run with secondary structure conditioning, this requires a different model than if you don't. Under the hood, we take care of most of this by default - we parse your input and work out the most appropriate checkpoint.
Expand Down
2 changes: 2 additions & 0 deletions config/inference/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ inference:
trb_save_ckpt_path: null
schedule_directory_path: null
model_directory_path: null
cyclic: False
cyc_chains: 'a'

contigmap:
contigs: null
Expand Down
33 changes: 29 additions & 4 deletions rfdiffusion/Embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from rfdiffusion.util import get_tips
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal, find_breaks
from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
from rfdiffusion.Track_module import PairStr2Pair
import math
import numpy as np

# Module contains classes and functions to generate initial embeddings

Expand All @@ -21,10 +22,34 @@ def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
self.emb = nn.Embedding(self.nbin, d_model)
self.drop = nn.Dropout(p_drop)

def forward(self, x, idx):
def forward(self, x, idx, cyclize=None):
bins = torch.arange(self.minpos, self.maxpos, device=x.device)
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
#


# adding support for multi-chain cyclic
# find chain breaks and label chain ids
breaks = find_breaks(idx.squeeze().cpu().numpy(), thresh=35) # NOTE: Hard coded threshold for defining chain breaks here
# Typical jump for chainbreaks is +200
# Assumes monotonically increasing absolute IDX

chainids = np.zeros_like(idx.squeeze().cpu().numpy())
for i, b in enumerate(breaks):
chainids[b:] = i+1
chainids = torch.from_numpy(chainids).to(device=idx.device)

# cyclic peptide
if cyclize is not None:
for chid in torch.unique(chainids):
is_chid = chainids==chid
cur_cyclize = cyclize*is_chid
cur_mask = cur_cyclize[:,None]*cur_cyclize[None,:] # (L,L)
cur_ncyc = torch.sum(cur_cyclize)

seqsep[:,cur_mask*(seqsep[0]>cur_ncyc//2)] -= cur_ncyc
seqsep[:,cur_mask*(seqsep[0]<-cur_ncyc//2)] += cur_ncyc

ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
emb = self.emb(ib) #(B, L, L, d_model)
x = x + emb # add relative positional encoding
Expand Down Expand Up @@ -56,7 +81,7 @@ def reset_parameter(self):

nn.init.zeros_(self.emb.bias)

def forward(self, msa, seq, idx):
def forward(self, msa, seq, idx, cyclize):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
Expand All @@ -82,7 +107,7 @@ def forward(self, msa, seq, idx):
right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)

pair = left + right # (B, L, L, d_pair)
pair = self.pos(pair, idx) # add relative position
pair = self.pos(pair, idx, cyclize) # add relative position

# state embedding
# Sergey's one hot trick
Expand Down
7 changes: 4 additions & 3 deletions rfdiffusion/RoseTTAFoldModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
t1d=None, t2d=None, xyz_t=None, alpha_t=None,
msa_prev=None, pair_prev=None, state_prev=None,
return_raw=False, return_full=False, return_infer=False,
use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None):
use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None,
cyclic_reses=None):

B, N, L = msa_latent.shape[:3]
# Get embeddings
msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx)
msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx, cyclic_reses)
msa_full = self.full_emb(msa_full, seq, idx)

# Do recycling
Expand Down Expand Up @@ -101,7 +102,7 @@ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool()
msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3],
state, idx, use_checkpoint=use_checkpoint,
motif_mask=is_frozen_residue)
motif_mask=is_frozen_residue, cyclic_reses=cyclic_reses)

if return_raw:
# get last structure
Expand Down
21 changes: 12 additions & 9 deletions rfdiffusion/Track_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def reset_parameter(self):
nn.init.zeros_(self.embed_e2.bias)

@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5):
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses=None, top_k=64, eps=1e-5):
B, N, L = msa.shape[:3]

if motif_mask is None:
Expand All @@ -249,7 +249,7 @@ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64,
node = self.norm_node(self.embed_x(node))
pair = self.norm_edge1(self.embed_e1(pair))

neighbor = get_seqsep(idx)
neighbor = get_seqsep(idx, cyclic_reses)
rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1]))
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
pair = self.norm_edge2(self.embed_e2(pair))
Expand Down Expand Up @@ -318,18 +318,18 @@ def __init__(self, d_msa=256, d_pair=128,
SE3_param=SE3_param,
p_drop=p_drop)

def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False):
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False, cyclic_reses=None):
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
if use_checkpoint:
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask)
R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses)
else:
msa = self.msa2msa(msa, pair, rbf_feat, state)
pair = self.msa2pair(msa, pair)
pair = self.pair2pair(pair, rbf_feat)
R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, top_k=0)
R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, cyclic_reses=cyclic_reses, top_k=0)

return msa, pair, R, T, state, alpha

Expand Down Expand Up @@ -384,7 +384,7 @@ def reset_parameter(self):
self.proj_state2 = init_lecun_normal(self.proj_state2)
nn.init.zeros_(self.proj_state2.bias)

def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False, motif_mask=None):
def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, cyclic_reses=None, use_checkpoint=False, motif_mask=None):
"""
input:
seq: query sequence (B, L)
Expand Down Expand Up @@ -425,7 +425,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
state,
idx,
motif_mask=motif_mask,
use_checkpoint=use_checkpoint)
use_checkpoint=use_checkpoint,
cyclic_reses=cyclic_reses)
R_s.append(R_in)
T_s.append(T_in)
alpha_s.append(alpha)
Expand All @@ -444,7 +445,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
state,
idx,
motif_mask=motif_mask,
use_checkpoint=use_checkpoint)
use_checkpoint=use_checkpoint,
cyclic_reses=cyclic_reses)
R_s.append(R_in)
T_s.append(T_in)
alpha_s.append(alpha)
Expand All @@ -462,7 +464,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
state,
idx,
top_k=64,
motif_mask=motif_mask)
motif_mask=motif_mask,
cyclic_reses=cyclic_reses)
R_s.append(R_in)
T_s.append(T_in)
alpha_s.append(alpha)
Expand Down
27 changes: 25 additions & 2 deletions rfdiffusion/inference/model_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,29 @@ def sample_init(self, return_forward_trajectory=False):
self.mappings = self.contig_map.get_mappings()
self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:]
self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:]
self.binderlen = len(self.contig_map.inpaint)
self.binderlen = len(self.contig_map.inpaint)

#######################################
### Resolve cyclic peptide indicies ###
#######################################
if self._conf.inference.cyclic:
if self._conf.inference.cyc_chains is None:
# default to all residues being cyclized
self.cyclic_reses = ~self.mask_str.to(self.device).squeeze()
else:
# use cyc_chains arg to determine cyclic_reses mask
assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string'
cyc_chains = self._conf.inference.cyc_chains
cyc_chains = [i.upper() for i in cyc_chains]
hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains
is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty

for ch in cyc_chains:
ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool()
is_cyclized[ch_mask] = True # set this whole chain to be cyclic
self.cyclic_reses = is_cyclized
else:
self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze()

####################
### Get Hotspots ###
Expand Down Expand Up @@ -675,7 +697,8 @@ def sample_step(self, *, t, x_t, seq_init, final_step):
state_prev = None,
t=torch.tensor(t),
return_infer=True,
motif_mask=self.diffusion_mask.squeeze().to(self.device))
motif_mask=self.diffusion_mask.squeeze().to(self.device),
cyclic_reses=self.cyclic_reses)

if self.symmetry is not None and self.inf_conf.symmetric_self_cond:
px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3]
Expand Down
Loading