diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index c176925..469bfda 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -809,7 +809,7 @@ def __init__(self, conf: DictConfig): else: # initialize BlockAdjacency sampling class assert all(x is None for x in (conf.contigmap.inpaint_str_helix, conf.contigmap.inpaint_str_strand, conf.contigmap.inpaint_str_loop)), "can't provide scaffold_dir if you're also specifying per-residue ss" - self.blockadjacency = iu.BlockAdjacency(conf.scaffoldguided, conf.inference.num_designs) + self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs) ################################################# @@ -983,6 +983,29 @@ def sample_init(self): xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + return xT, seq_T def _preprocess(self, seq, xyz_t, t):