From e84a039996a65d23ac2a11320f123a4e9ad039a9 Mon Sep 17 00:00:00 2001 From: "Henry Dieckhaus (dieckhau)" Date: Thu, 17 Jul 2025 15:29:40 -0400 Subject: [PATCH] Bugfix for error introduced by RFpeptides update (GitHub issues #272+#273). ScaffoldSampler was using the wrong config variable and didn't define the cyclic residue parameter that RFpeptides expects. --- rfdiffusion/inference/model_runners.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index c176925..469bfda 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -809,7 +809,7 @@ def __init__(self, conf: DictConfig): else: # initialize BlockAdjacency sampling class assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss" - self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs) + self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs) ################################################# @@ -983,6 +983,29 @@ def sample_init(self): xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + return xT, seq_T def _preprocess(self, seq, xyz_t, t):