From 271069fe420c3441eb9c2f80b3208c7182ed63ca Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 09:52:07 +0100 Subject: [PATCH 01/15] Update lora.py --- lora_diffusion/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 7672b48..5500fcc 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -21,7 +21,7 @@ def __init__(self, in_features, out_features, bias=False, r=4): self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = 1.0 + self.scale = 8.0 nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) From 94e3e814401eed0b3ad7c9a2bfd21042f12703eb Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 10:13:57 +0100 Subject: [PATCH 02/15] Update lora.py --- lora_diffusion/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 5500fcc..3450ebf 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -23,7 +23,8 @@ def __init__(self, in_features, out_features, bias=False, r=4): self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = 8.0 - nn.init.normal_(self.lora_down.weight, std=1 / r) + #nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_up.weight) def forward(self, input): From 0eb98b5e8973eb1e87dc3bf14f33f669ba7b7c09 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 10:56:31 +0100 Subject: [PATCH 03/15] Update lora.py --- lora_diffusion/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 3450ebf..784af35 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -21,7 +21,7 @@ def __init__(self, in_features, out_features, bias=False, r=4): self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = 8.0 + self.scale = 32.0 #nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) From 6d309893f1e6c10c627d7c66e5950c3f46d170f8 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 11:33:16 +0100 Subject: [PATCH 04/15] Revert wrong branch --- lora_diffusion/lora.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 784af35..7672b48 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -21,10 +21,9 @@ def __init__(self, in_features, out_features, bias=False, r=4): self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = 32.0 + self.scale = 1.0 - #nn.init.normal_(self.lora_down.weight, std=1 / r) - nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) def forward(self, input): From f54f16a1e309f0a6bb9eee7d8e87230e5d110cff Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 13:16:31 +0100 Subject: [PATCH 05/15] Expose alpha hyperparameter Note that LoRA paper defines scale = alpha/r --- lora_diffusion/lora.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 7672b48..010209c 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -10,18 +10,25 @@ class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4): + def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0): super().__init__() if r > min(in_features, out_features): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) - + + if alpha <= 0: + raise ValueError( + f"LoRA alpha {r} must be greater than 0" + ) + + self.r = r + self.alpha = alpha self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = 1.0 + self.scale = self.alpha / self.r nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) @@ -116,6 +123,7 @@ def inject_trainable_lora( model: nn.Module, target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, r: int = 4, + alpha: float = 4.0, loras=None, # path to lora .pt ): """ @@ -137,7 +145,8 @@ def inject_trainable_lora( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, - r, + r=r, + alpha=alpha, ) _tmp.linear.weight = weight if bias is not None: @@ -333,7 +342,7 @@ def load_safeloras(path, device="cpu"): def weight_apply_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0 + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, scale=1.0 ): for _m, _n, _child_module in _find_modules( @@ -345,12 +354,12 @@ def weight_apply_lora( down_weight = loras.pop(0).detach().to(weight.device) # W <- W + U * D - weight = weight + alpha * (up_weight @ down_weight).type(weight.dtype) + weight = weight + scale * (up_weight @ down_weight).type(weight.dtype) _child_module.weight = nn.Parameter(weight) def monkeypatch_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, alpha: float = 4.0, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] @@ -362,6 +371,7 @@ def monkeypatch_lora( _child_module.out_features, _child_module.bias is not None, r=r, + alpha=alpha, ) _tmp.linear.weight = weight @@ -385,7 +395,7 @@ def monkeypatch_lora( def monkeypatch_replace_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, alpha: float = 4.0, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[LoraInjectedLinear] @@ -397,6 +407,7 @@ def monkeypatch_replace_lora( _child_module.linear.out_features, _child_module.linear.bias is not None, r=r, + alpha=alpha, ) _tmp.linear.weight = weight @@ -424,6 +435,7 @@ def monkeypatch_or_replace_lora( loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: Union[int, List[int]] = 4, + alpha: Union[float, List[float]] = 4.0, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] @@ -441,6 +453,7 @@ def monkeypatch_or_replace_lora( _source.out_features, _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, + alpha=alpha.pop(0) if isinstance(alpha, list) else alpha, ) _tmp.linear.weight = weight @@ -519,11 +532,11 @@ def monkeypatch_add_lora( _module._modules[name].to(weight.device) -def tune_lora_scale(model, alpha: float = 1.0): +def tune_lora_scale(model, alpha: float = 4.0): for _module in model.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": - _module.scale = alpha - + _module.alpha = alpha + _module.scale = _module.alpha / _module.r def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" From d9531d7bf0a8ed266cf03dc6cd1482a31da250b2 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 13:21:10 +0100 Subject: [PATCH 06/15] Adjust for parameter rename --- lora_diffusion/cli_lora_add.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py index 3a416af..3d11ec2 100644 --- a/lora_diffusion/cli_lora_add.py +++ b/lora_diffusion/cli_lora_add.py @@ -75,13 +75,13 @@ def add( path_1, ).to("cpu") - weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) + weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, torch.load(_text_lora_path(path_2)), - alpha=alpha, + scale=alpha, target_replace_module=["CLIPAttention"], ) @@ -93,12 +93,12 @@ def add( path_1, ).to("cpu") - weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) + weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha) if with_text_lora: weight_apply_lora( loaded_pipeline.text_encoder, torch.load(_text_lora_path(path_2)), - alpha=alpha, + scale=alpha, target_replace_module=["CLIPAttention"], ) From 24c8470f14677c4ca8f27e8f92a68045417626a4 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 13:59:58 +0100 Subject: [PATCH 07/15] Add element-wise nonlinearity between Down and UP --- lora_diffusion/lora.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 010209c..2fc9e94 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -10,7 +10,7 @@ class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0): + def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=None, nonlin: nn.Module = None): super().__init__() if r > min(in_features, out_features): @@ -20,21 +20,29 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0): if alpha <= 0: raise ValueError( - f"LoRA alpha {r} must be greater than 0" + f"LoRA alpha {alpha} must be greater than 0" ) self.r = r self.alpha = alpha self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) + self.nonlin = nonlin if nonlin is not None self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = self.alpha / self.r - nn.init.normal_(self.lora_down.weight, std=1 / r) + if init=="kaiming": + nn.init.kaiming_uniform_(self.lora_down) + else: + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) def forward(self, input): - return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale + if self.nonlin is not None: + return self.linear(input) + self.lora_up(self.nonlin(self.lora_down(input))) * self.scale + else: + return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} From 953b1db04136e610c079b74ed1a3e69eb6e931b1 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 14:19:27 +0100 Subject: [PATCH 08/15] Add nonlin to necessary patching functions --- lora_diffusion/lora.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 2fc9e94..643019f 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -132,6 +132,8 @@ def inject_trainable_lora( target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, r: int = 4, alpha: float = 4.0, + init=None, + nonlin=None, loras=None, # path to lora .pt ): """ @@ -155,6 +157,8 @@ def inject_trainable_lora( _child_module.bias is not None, r=r, alpha=alpha, + init=init, + nonlin=nonlin, ) _tmp.linear.weight = weight if bias is not None: @@ -367,7 +371,12 @@ def weight_apply_lora( def monkeypatch_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, alpha: float = 4.0, + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: int = 4, + alpha: float = 4.0, + nonlin: nn.Module = None, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] @@ -380,6 +389,7 @@ def monkeypatch_lora( _child_module.bias is not None, r=r, alpha=alpha, + nonline=nonlin, ) _tmp.linear.weight = weight @@ -403,7 +413,12 @@ def monkeypatch_lora( def monkeypatch_replace_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, alpha: float = 4.0, + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: int = 4, + alpha: float = 4.0, + nonlin: nn.Module = None, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[LoraInjectedLinear] @@ -416,7 +431,8 @@ def monkeypatch_replace_lora( _child_module.linear.bias is not None, r=r, alpha=alpha, - ) + nonlin=nonlin, + ) _tmp.linear.weight = weight if bias is not None: @@ -444,6 +460,7 @@ def monkeypatch_or_replace_lora( target_replace_module=DEFAULT_TARGET_REPLACE, r: Union[int, List[int]] = 4, alpha: Union[float, List[float]] = 4.0, + nonlin: Union[float, List[nn.Module]] = None, ): for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] @@ -462,6 +479,7 @@ def monkeypatch_or_replace_lora( _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, alpha=alpha.pop(0) if isinstance(alpha, list) else alpha, + nonlin=nonlin.pop(0) if isinstance(nonlin, list) else nonlin, ) _tmp.linear.weight = weight From 9f59cb03bc1bc0f706b27313c88b89499d10c270 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 14:49:12 +0100 Subject: [PATCH 09/15] Missed a few alpha/nonlin --- lora_diffusion/lora.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 643019f..d30123f 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -354,9 +354,16 @@ def load_safeloras(path, device="cpu"): def weight_apply_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, scale=1.0 + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: int = 4, + alpha: float = 4.0, + nonlin: nn.Module = None, + #scale=1.0 ): - + scale = alpha/r + for _m, _n, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] ): @@ -365,11 +372,15 @@ def weight_apply_lora( up_weight = loras.pop(0).detach().to(weight.device) down_weight = loras.pop(0).detach().to(weight.device) - # W <- W + U * D - weight = weight + scale * (up_weight @ down_weight).type(weight.dtype) + if nonlin is None: + # W <- W + U * D + weight = weight + scale * (up_weight @ down_weight).type(weight.dtype) + else: + # W <- W + U * nonlin(D) + weight = weight + scale * (up_weight @ nonlin(down_weight)).type(weight.dtype) + _child_module.weight = nn.Parameter(weight) - def monkeypatch_lora( model, loras, @@ -389,7 +400,7 @@ def monkeypatch_lora( _child_module.bias is not None, r=r, alpha=alpha, - nonline=nonlin, + nonlin=nonlin, ) _tmp.linear.weight = weight @@ -615,6 +626,8 @@ def patch_pipe( unet_path, token: str, r: int = 4, + alpha: float = 4.0, + nonlin: nn.Module = None, patch_unet=True, patch_text=False, patch_ti=False, @@ -635,6 +648,8 @@ def patch_pipe( pipe.unet, torch.load(unet_path), r=r, + alpha=alpha, + nonlin=nonlin, target_replace_module=unet_target_replace_module, ) @@ -645,6 +660,8 @@ def patch_pipe( torch.load(text_path), target_replace_module=text_target_replace_module, r=r, + alpha=alpha, + nonlin=nonlin, ) if patch_ti: print("LoRA : Patching token input") From 2326be52d318ffad76ca56ef56015ae89cb6e3c8 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 15:27:06 +0100 Subject: [PATCH 10/15] Update lora.py --- lora_diffusion/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index d30123f..c091626 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -27,7 +27,7 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=N self.alpha = alpha self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) - self.nonlin = nonlin if nonlin is not None + self.nonlin = nonlin if nonlin else None self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = self.alpha / self.r @@ -39,7 +39,7 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=N nn.init.zeros_(self.lora_up.weight) def forward(self, input): - if self.nonlin is not None: + if self.nonlin: return self.linear(input) + self.lora_up(self.nonlin(self.lora_down(input))) * self.scale else: return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale From b20df37acfca03aacf67f578a9549d296f09914c Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 15:29:25 +0100 Subject: [PATCH 11/15] Make kaiming init same as nn.Linear --- lora_diffusion/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index c091626..ec01b6a 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -32,7 +32,7 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=N self.scale = self.alpha / self.r if init=="kaiming": - nn.init.kaiming_uniform_(self.lora_down) + nn.init.kaiming_uniform_(self.lora_down, a=math.sqrt(5)) else: nn.init.normal_(self.lora_down.weight, std=1 / r) From fc68fcefe21e4e54dd7fd83356582d7c8cddb71b Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 15:57:50 +0100 Subject: [PATCH 12/15] Update lora.py --- lora_diffusion/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index ec01b6a..1c86c9f 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -32,7 +32,7 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=N self.scale = self.alpha / self.r if init=="kaiming": - nn.init.kaiming_uniform_(self.lora_down, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) else: nn.init.normal_(self.lora_down.weight, std=1 / r) From 314ea0a4a4352406a40b54f4c3fcc1ac2b6b72b6 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 19:57:29 +0100 Subject: [PATCH 13/15] Update lora.py --- lora_diffusion/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 1c86c9f..4f3e57b 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -32,7 +32,8 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=N self.scale = self.alpha / self.r if init=="kaiming": - nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + pass + # Kaiming with a=math.sqrt(5) is default else: nn.init.normal_(self.lora_down.weight, std=1 / r) From 5b5dfcb14863773a012e774177b3507e07010325 Mon Sep 17 00:00:00 2001 From: brian6091 Date: Thu, 5 Jan 2023 22:16:21 +0100 Subject: [PATCH 14/15] New function tune_lora_alpha Keep tune_lora_scale as previously to minimize breakage. --- lora_diffusion/lora.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 4f3e57b..379b4dc 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -570,12 +570,25 @@ def monkeypatch_add_lora( _module._modules[name].to(weight.device) -def tune_lora_scale(model, alpha: float = 4.0): +def tune_lora_scale(model, alpha: float = 4.0, scale: float = None): + if scale==None: + # Keep original named parameter alpha (which is really scale), + # Swap here so that we can correctly calculate alpha + scale = alpha + + for _module in model.modules(): + if _module.__class__.__name__ == "LoraInjectedLinear": + _module.scale = scale + _module.alpha = _module.r * _module.scale + + +def tune_lora_alpha(model, alpha: float = 4.0): for _module in model.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": _module.alpha = alpha _module.scale = _module.alpha / _module.r - + + def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) From f0485a373cb671ac79633cf4a7bae8ce1da3759c Mon Sep 17 00:00:00 2001 From: brian6091 Date: Fri, 6 Jan 2023 21:42:00 +0100 Subject: [PATCH 15/15] Revert to scale, remove alpha --- lora_diffusion/lora.py | 55 ++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 379b4dc..a04bee3 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -10,7 +10,7 @@ class LoraInjectedLinear(nn.Module): - def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=None, nonlin: nn.Module = None): + def __init__(self, in_features, out_features, bias=False, r=4, scale=1.0, init=None, nonlin: nn.Module = None): super().__init__() if r > min(in_features, out_features): @@ -18,22 +18,21 @@ def __init__(self, in_features, out_features, bias=False, r=4, alpha=4.0, init=N f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) - if alpha <= 0: + if scale <= 0: raise ValueError( - f"LoRA alpha {alpha} must be greater than 0" + f"LoRA scale {scale} must be greater than 0" ) self.r = r - self.alpha = alpha + self.scale = scale self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.nonlin = nonlin if nonlin else None self.lora_up = nn.Linear(r, out_features, bias=False) - self.scale = self.alpha / self.r if init=="kaiming": pass - # Kaiming with a=math.sqrt(5) is default + # Kaiming with a=math.sqrt(5) is default for nn.Linear else: nn.init.normal_(self.lora_down.weight, std=1 / r) @@ -132,7 +131,7 @@ def inject_trainable_lora( model: nn.Module, target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, r: int = 4, - alpha: float = 4.0, + scale: float = 1.0, init=None, nonlin=None, loras=None, # path to lora .pt @@ -157,7 +156,7 @@ def inject_trainable_lora( _child_module.out_features, _child_module.bias is not None, r=r, - alpha=alpha, + scale=scale, init=init, nonlin=nonlin, ) @@ -359,12 +358,9 @@ def weight_apply_lora( loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, - alpha: float = 4.0, + scale: float = 1.0, nonlin: nn.Module = None, - #scale=1.0 -): - scale = alpha/r - +): for _m, _n, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] ): @@ -387,7 +383,7 @@ def monkeypatch_lora( loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, - alpha: float = 4.0, + scale: float = 1.0, nonlin: nn.Module = None, ): for _module, name, _child_module in _find_modules( @@ -400,7 +396,7 @@ def monkeypatch_lora( _child_module.out_features, _child_module.bias is not None, r=r, - alpha=alpha, + scale=scale, nonlin=nonlin, ) _tmp.linear.weight = weight @@ -429,7 +425,7 @@ def monkeypatch_replace_lora( loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4, - alpha: float = 4.0, + scale: float = 1.0, nonlin: nn.Module = None, ): for _module, name, _child_module in _find_modules( @@ -442,7 +438,7 @@ def monkeypatch_replace_lora( _child_module.linear.out_features, _child_module.linear.bias is not None, r=r, - alpha=alpha, + scale=scale, nonlin=nonlin, ) _tmp.linear.weight = weight @@ -471,7 +467,7 @@ def monkeypatch_or_replace_lora( loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: Union[int, List[int]] = 4, - alpha: Union[float, List[float]] = 4.0, + scale: Union[float, List[float]] = 1.0, nonlin: Union[float, List[nn.Module]] = None, ): for _module, name, _child_module in _find_modules( @@ -490,7 +486,7 @@ def monkeypatch_or_replace_lora( _source.out_features, _source.bias is not None, r=r.pop(0) if isinstance(r, list) else r, - alpha=alpha.pop(0) if isinstance(alpha, list) else alpha, + scale=scale.pop(0) if isinstance(scale, list) else scale, nonlin=nonlin.pop(0) if isinstance(nonlin, list) else nonlin, ) _tmp.linear.weight = weight @@ -547,7 +543,7 @@ def monkeypatch_add_lora( model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, - alpha: float = 1.0, + scale: float = 1.0, beta: float = 1.0, ): for _module, name, _child_module in _find_modules( @@ -570,25 +566,16 @@ def monkeypatch_add_lora( _module._modules[name].to(weight.device) -def tune_lora_scale(model, alpha: float = 4.0, scale: float = None): - if scale==None: +def tune_lora_scale(model, alpha: float = 1.0, scale: float = None): + if alpha: # Keep original named parameter alpha (which is really scale), - # Swap here so that we can correctly calculate alpha scale = alpha for _module in model.modules(): if _module.__class__.__name__ == "LoraInjectedLinear": _module.scale = scale - _module.alpha = _module.r * _module.scale -def tune_lora_alpha(model, alpha: float = 4.0): - for _module in model.modules(): - if _module.__class__.__name__ == "LoraInjectedLinear": - _module.alpha = alpha - _module.scale = _module.alpha / _module.r - - def _text_lora_path(path: str) -> str: assert path.endswith(".pt"), "Only .pt files are supported" return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) @@ -640,7 +627,7 @@ def patch_pipe( unet_path, token: str, r: int = 4, - alpha: float = 4.0, + scale: float = 1.0, nonlin: nn.Module = None, patch_unet=True, patch_text=False, @@ -662,7 +649,7 @@ def patch_pipe( pipe.unet, torch.load(unet_path), r=r, - alpha=alpha, + scale=scale, nonlin=nonlin, target_replace_module=unet_target_replace_module, ) @@ -674,7 +661,7 @@ def patch_pipe( torch.load(text_path), target_replace_module=text_target_replace_module, r=r, - alpha=alpha, + scale=scale, nonlin=nonlin, ) if patch_ti: