diff --git a/loralib/layers.py b/loralib/layers.py index 0e54a64b..72b40501 100644 --- a/loralib/layers.py +++ b/loralib/layers.py @@ -255,7 +255,7 @@ def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lor self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) ) self.lora_B = nn.Parameter( - self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) + self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size**(self.conv.weight.dim()-3), r*kernel_size)) ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix