diff --git a/.gitignore b/.gitignore index 0aaad68..76428c4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,5 @@ __pycache__ __test* merged_lora* wandb -exps +exps* .vscode \ No newline at end of file diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 7cc8d92..ba58ae6 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -36,11 +36,13 @@ PivotalTuningDatasetCapation, extract_lora_ups_down, inject_trainable_lora, + inject_trainable_lora_extended, inspect_lora, save_lora_weight, save_all, prepare_clip_model_sets, evaluate_pipe, + UNET_EXTENDED_TARGET_REPLACE, ) @@ -418,6 +420,8 @@ def perform_tuning( placeholder_tokens, save_path, lr_scheduler_lora, + lora_unet_target_modules, + lora_clip_target_modules, ): progress_bar = tqdm(range(num_steps)) @@ -467,6 +471,8 @@ def perform_tuning( save_path=os.path.join( save_path, f"step_{global_step}.safetensors" ), + target_replace_module_text=lora_clip_target_modules, + target_replace_module_unet=lora_unet_target_modules, ) moved = ( torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) @@ -521,11 +527,12 @@ def train( lora_rank: int = 4, lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, lora_clip_target_modules={"CLIPAttention"}, + use_extended_lora: bool = False, clip_ti_decay: bool = True, learning_rate_unet: float = 1e-4, learning_rate_text: float = 1e-5, learning_rate_ti: float = 5e-4, - continue_inversion: bool = True, + continue_inversion: bool = False, continue_inversion_lr: Optional[float] = None, use_face_segmentation_condition: bool = False, scale_lr: bool = False, @@ -690,9 +697,21 @@ def train( del ti_optimizer # Next perform Tuning with LoRA: - unet_lora_params, _ = inject_trainable_lora( - unet, r=lora_rank, target_replace_module=lora_unet_target_modules - ) + if not use_extended_lora: + unet_lora_params, _ = inject_trainable_lora( + unet, r=lora_rank, target_replace_module=lora_unet_target_modules + ) + else: + print("USING EXTENDED UNET!!!") + lora_unet_target_modules = ( + lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE + ) + print("Will replace modules: ", lora_unet_target_modules) + + unet_lora_params, _ = inject_trainable_lora_extended( + unet, r=lora_rank, target_replace_module=lora_unet_target_modules + ) + print(f"PTI : has {len(unet_lora_params)} lora") print("Before training:") inspect_lora(unet) @@ -720,7 +739,8 @@ def train( ) for param in params_to_freeze: param.requires_grad = False - + else: + text_encoder.requires_grad_(False) if train_text_encoder: text_encoder_lora_params, _ = inject_trainable_lora( text_encoder, @@ -763,6 +783,8 @@ def train( placeholder_token_ids=placeholder_token_ids, save_path=output_dir, lr_scheduler_lora=lr_scheduler_lora, + lora_unet_target_modules=lora_unet_target_modules, + lora_clip_target_modules=lora_clip_target_modules, ) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 6e02423..a4e1b81 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -30,7 +30,7 @@ def safe_save( class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4): + def __init__(self, in_features, out_features, bias=False, r=4, dropout_p=0.1): super().__init__() if r > min(in_features, out_features): @@ -40,6 +40,7 @@ def __init__(self, in_features, out_features, bias=False, r=4): self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = 1.0 @@ -47,12 +48,82 @@ def __init__(self, in_features, out_features, bias=False, r=4): nn.init.zeros_(self.lora_up.weight) def forward(self, input): - return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale + return ( + self.linear(input) + + self.lora_up(self.dropout(self.lora_down(input))) * self.scale + ) + + +class LoraInjectedConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + r: int = 4, + dropout_p: float = 0.1, + ): + super().__init__() + if r > min(in_channels, out_channels): + raise ValueError( + f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}" + ) + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.lora_down = nn.Conv2d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv2d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.scale = 1.0 + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.lora_up(self.dropout(self.lora_down(input))) * self.scale + ) UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} + TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} +TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} + DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE EMBED_FLAG = "" @@ -79,7 +150,10 @@ def _find_modules_v2( model, ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, search_class: List[Type[nn.Module]] = [nn.Linear], - exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoraInjectedLinear, + LoraInjectedConv2d, + ], ): """ Find all modules of a certain class (or union of classes) that are direct or @@ -183,12 +257,85 @@ def inject_trainable_lora( return require_grad_params, names +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): loras = [] for _m, _n, _child_module in _find_modules( - model, target_replace_module, search_class=[LoraInjectedLinear] + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], ): loras.append((_child_module.lora_up, _child_module.lora_down)) @@ -246,7 +393,12 @@ def save_safeloras_with_embeds( for i, (_up, _down) in enumerate( extract_lora_ups_down(model, target_replace_module) ): - metadata[f"{name}:{i}:rank"] = str(_down.out_features) + try: + rank = getattr(_down, "out_features") + except: + rank = getattr(_down, "out_channels") + + metadata[f"{name}:{i}:rank"] = str(rank) weights[f"{name}:{i}:up"] = _up.weight weights[f"{name}:{i}:down"] = _down.weight @@ -424,54 +576,28 @@ def weight_apply_lora( _child_module.weight = nn.Parameter(weight) -def monkeypatch_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 +def monkeypatch_or_replace_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=[nn.Linear] + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] ): - weight = _child_module.weight - bias = _child_module.bias - _tmp = LoraInjectedLinear( - _child_module.in_features, - _child_module.out_features, - _child_module.bias is not None, - r=r, - ) - _tmp.linear.weight = weight - - if bias is not None: - _tmp.linear.bias = bias - - # switch the module - _module._modules[name] = _tmp - - up_weight = loras.pop(0) - down_weight = loras.pop(0) - - _module._modules[name].lora_up.weight = nn.Parameter( - up_weight.type(weight.dtype) - ) - _module._modules[name].lora_down.weight = nn.Parameter( - down_weight.type(weight.dtype) + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module ) - _module._modules[name].to(weight.device) - - -def monkeypatch_replace_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 -): - for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=[LoraInjectedLinear] - ): - weight = _child_module.linear.weight - bias = _child_module.linear.bias + weight = _source.weight + bias = _source.bias _tmp = LoraInjectedLinear( - _child_module.linear.in_features, - _child_module.linear.out_features, - _child_module.linear.bias is not None, - r=r, + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, ) _tmp.linear.weight = weight @@ -494,33 +620,72 @@ def monkeypatch_replace_lora( _module._modules[name].to(weight.device) -def monkeypatch_or_replace_lora( +def monkeypatch_or_replace_lora_extended( model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: Union[int, List[int]] = 4, ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + model, + target_replace_module, + search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d], ): - _source = ( - _child_module.linear - if isinstance(_child_module, LoraInjectedLinear) - else _child_module - ) - weight = _source.weight - bias = _source.bias - _tmp = LoraInjectedLinear( - _source.in_features, - _source.out_features, - _source.bias is not None, - r=r.pop(0) if isinstance(r, list) else r, - ) - _tmp.linear.weight = weight + if (_child_module.__class__ == nn.Linear) or ( + _child_module.__class__ == LoraInjectedLinear + ): + if len(loras[0].shape) != 2: + continue - if bias is not None: - _tmp.linear.bias = bias + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + elif (_child_module.__class__ == nn.Conv2d) or ( + _child_module.__class__ == LoraInjectedConv2d + ): + if len(loras[0].shape) != 4: + continue + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv2d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv2d( + _source.in_channels, + _source.out_channels, + _source.kernel_size, + _source.stride, + _source.padding, + _source.dilation, + _source.groups, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias # switch the module _module._modules[name] = _tmp @@ -548,7 +713,7 @@ def monkeypatch_or_replace_safeloras(models, safeloras): print(f"No model provided for {name}, contained in Lora") continue - monkeypatch_or_replace_lora(model, lora, target, ranks) + monkeypatch_or_replace_lora_extended(model, lora, target, ranks) def monkeypatch_remove_lora(model): @@ -596,7 +761,7 @@ def monkeypatch_add_lora( def tune_lora_scale(model, alpha: float = 1.0): for _module in model.modules(): - if _module.__class__.__name__ == "LoraInjectedLinear": + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: _module.scale = alpha @@ -755,9 +920,9 @@ def inspect_lora(model): def save_all( unet, text_encoder, - placeholder_token_ids, - placeholder_tokens, save_path, + placeholder_token_ids=None, + placeholder_tokens=None, save_lora=True, save_ti=True, target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, @@ -801,7 +966,7 @@ def save_all( ), f"Save path : {save_path} should end with .safetensors" loras = {} - embeds = None + embeds = {} if save_lora: @@ -809,7 +974,6 @@ def save_all( loras["text_encoder"] = (text_encoder, target_replace_module_text) if save_ti: - embeds = {} for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] print( diff --git a/setup.py b/setup.py index 1280523..1181d78 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="lora_diffusion", py_modules=["lora_diffusion"], - version="0.1.0", + version="0.1.2", description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(),