diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 20d30dd..eb02efc 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -46,6 +46,8 @@ prepare_clip_model_sets, evaluate_pipe, UNET_EXTENDED_TARGET_REPLACE, + parse_safeloras_embeds, + apply_learned_embed_in_clip, ) def preview_training_batch(train_dataloader, mode, n_imgs = 40): @@ -67,6 +69,52 @@ def preview_training_batch(train_dataloader, mode, n_imgs = 40): print(f"\nSaved {imgs_saved} preview training imgs to {outdir}") return +def sim_matrix(a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + + +def compute_pairwise_distances(x,y): + # compute the L2 distance of each row in x to each row in y (both are torch tensors) + # x is a torch tensor of shape (m, d) + # y is a torch tensor of shape (n, d) + # returns a torch tensor of shape (m, n) + + n = y.shape[0] + m = x.shape[0] + d = x.shape[1] + + x = x.unsqueeze(1).expand(m, n, d) + y = y.unsqueeze(0).expand(m, n, d) + + return torch.pow(x - y, 2).sum(2) + + +def print_most_similar_tokens(tokenizer, optimized_token, text_encoder, n=10): + with torch.no_grad(): + # get all the token embeddings: + token_embeds = text_encoder.get_input_embeddings().weight.data + + # Compute the cosine-similarity between the optimized tokens and all the other tokens + similarity = sim_matrix(optimized_token.unsqueeze(0), token_embeds).squeeze() + similarity = similarity.detach().cpu().numpy() + + distances = compute_pairwise_distances(optimized_token.unsqueeze(0), token_embeds).squeeze() + distances = distances.detach().cpu().numpy() + + # print similarity for the most similar tokens: + most_similar_tokens = np.argsort(similarity)[::-1] + + print(f"{tokenizer.decode(most_similar_tokens[0])} --> mean: {optimized_token.mean().item():.3f}, std: {optimized_token.std().item():.3f}, norm: {optimized_token.norm():.4f}") + for token_id in most_similar_tokens[1:n+1]: + print(f"sim of {similarity[token_id]:.3f} & L2 of {distances[token_id]:.3f} with \"{tokenizer.decode(token_id)}\"") + def get_models( pretrained_model_name_or_path, @@ -139,11 +187,13 @@ def get_models( pretrained_vae_name_or_path or pretrained_model_name_or_path, subfolder=None if pretrained_vae_name_or_path else "vae", revision=None if pretrained_vae_name_or_path else revision, + local_files_only = True, ) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", revision=revision, + local_files_only = True, ) return ( @@ -151,7 +201,7 @@ def get_models( vae.to(device), unet.to(device), tokenizer, - placeholder_token_ids, + placeholder_token_ids ) @@ -477,12 +527,13 @@ def train_inversion( if global_step % accum_iter == 0: # print gradient of text encoder embedding - print( - text_encoder.get_input_embeddings() - .weight.grad[index_updates, :] - .norm(dim=-1) - .mean() - ) + if 0: + print( + text_encoder.get_input_embeddings() + .weight.grad[index_updates, :] + .norm(dim=-1) + .mean() + ) optimizer.step() optimizer.zero_grad() @@ -517,8 +568,10 @@ def train_inversion( index_no_updates ] = orig_embeds_params[index_no_updates] - for i, t in enumerate(optimizing_embeds): - print(f"token {i} --> mean: {t.mean().item():.3f}, std: {t.std().item():.3f}, norm: {t.norm():.4f}") + if global_step % 50 == 0: + print("------------------------------") + for i, t in enumerate(optimizing_embeds): + print_most_similar_tokens(tokenizer, t, text_encoder) global_step += 1 progress_bar.update(1) @@ -537,7 +590,7 @@ def train_inversion( placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, save_path=os.path.join( - save_path, f"step_inv_{global_step}.safetensors" + save_path, f"step_inv_{global_step:04d}.safetensors" ), save_lora=False, ) @@ -583,7 +636,7 @@ def train_inversion( return import matplotlib.pyplot as plt -def plot_loss_curve(losses, name, moving_avg=20): +def plot_loss_curve(losses, name, moving_avg=5): losses = np.array(losses) losses = np.convolve(losses, np.ones(moving_avg)/moving_avg, mode='valid') plt.plot(losses) @@ -654,7 +707,7 @@ def perform_tuning( vae, text_encoder, scheduler, - optimized_embeddings = text_encoder.get_input_embeddings().weight[:, :], + optimized_embeddings = text_encoder.get_input_embeddings().weight[~index_no_updates, :], train_inpainting=train_inpainting, t_mutliplier=0.8, mixed_precision=True, @@ -683,6 +736,12 @@ def perform_tuning( index_no_updates ] = orig_embeds_params[index_no_updates] + if global_step % 100 == 0: + optimizing_embeds = text_encoder.get_input_embeddings().weight[~index_no_updates] + print("------------------------------") + for i, t in enumerate(optimizing_embeds): + print_most_similar_tokens(tokenizer, t, text_encoder) + global_step += 1 @@ -696,7 +755,7 @@ def perform_tuning( placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens, save_path=os.path.join( - save_path, f"step_{global_step}.safetensors" + save_path, f"step_{global_step:04d}.safetensors" ), target_replace_module_text=lora_clip_target_modules, target_replace_module_unet=lora_unet_target_modules, @@ -706,8 +765,8 @@ def perform_tuning( .mean() .item() ) - print("LORA Unet Moved", moved) + moved = ( torch.tensor( list(itertools.chain(*inspect_lora(text_encoder).values())) @@ -715,7 +774,6 @@ def perform_tuning( .mean() .item() ) - print("LORA CLIP Moved", moved) if log_wandb: @@ -778,6 +836,7 @@ def train( placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: Optional[str] = None, + load_pretrained_inversion_embeddings_path: Optional[str] = None, seed: int = 42, resolution: int = 512, color_jitter: bool = True, @@ -788,7 +847,8 @@ def train( save_steps: int = 100, gradient_accumulation_steps: int = 4, gradient_checkpointing: bool = False, - lora_rank: int = 4, + lora_rank_unet: int = 4, + lora_rank_text_encoder: int = 4, lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, lora_clip_target_modules={"CLIPAttention"}, lora_dropout_p: float = 0.0, @@ -825,6 +885,10 @@ def train( script_start_time = time.time() torch.manual_seed(seed) + if use_template == "person" and not use_face_segmentation_condition: + print("### WARNING ### : Using person template without face segmentation condition") + print("When training people, it is highly recommended to use face segmentation condition!!") + # Get a dict with all the arguments: args_dict = locals() @@ -841,7 +905,7 @@ def train( if output_dir is not None: os.makedirs(output_dir, exist_ok=True) - # print(placeholder_tokens, initializer_tokens) + if len(placeholder_tokens) == 0: placeholder_tokens = [] print("PTI : Placeholder Tokens not given, using null token") @@ -874,6 +938,7 @@ def train( print("PTI : Placeholder Tokens", placeholder_tokens) print("PTI : Initializer Tokens", initializer_tokens) + print("PTI : Token Map: ", token_map) # get the models text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( @@ -886,7 +951,8 @@ def train( ) noise_scheduler = DDPMScheduler.from_config( - pretrained_model_name_or_path, subfolder="scheduler" + pretrained_model_name_or_path, subfolder="scheduler", + local_files_only = True, ) if gradient_checkpointing: @@ -925,8 +991,6 @@ def train( train_inpainting=train_inpainting, ) - train_dataset.blur_amount = 200 - if train_inpainting: assert not cached_latents, "Cached latents not supported for inpainting" @@ -963,7 +1027,7 @@ def train( vae = None # STEP 1 : Perform Inversion - if perform_inversion and not cached_latents: + if perform_inversion and not cached_latents and (load_pretrained_inversion_embeddings_path is None): preview_training_batch(train_dataloader, "inversion") print("PTI : Performing Inversion") @@ -1014,16 +1078,32 @@ def train( del ti_optimizer print("############### Inversion Done ###############") + elif load_pretrained_inversion_embeddings_path is not None: + + print("PTI : Loading pretrained inversion embeddings..") + from safetensors.torch import safe_open + # Load the pretrained embeddings from the lora file: + safeloras = safe_open(load_pretrained_inversion_embeddings_path, framework="pt", device="cpu") + #monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + apply_learned_embed_in_clip( + tok_dict, + text_encoder, + tokenizer, + idempotent=True, + ) + # Next perform Tuning with LoRA: if not use_extended_lora: unet_lora_params, _ = inject_trainable_lora( unet, - r=lora_rank, + r=lora_rank_unet, target_replace_module=lora_unet_target_modules, dropout_p=lora_dropout_p, scale=lora_scale, ) print("PTI : not use_extended_lora...") + print("PTI : Will replace modules: ", lora_unet_target_modules) else: print("PTI : USING EXTENDED UNET!!!") lora_unet_target_modules = ( @@ -1031,17 +1111,11 @@ def train( ) print("PTI : Will replace modules: ", lora_unet_target_modules) unet_lora_params, _ = inject_trainable_lora_extended( - unet, r=lora_rank, target_replace_module=lora_unet_target_modules + unet, r=lora_rank_unet, target_replace_module=lora_unet_target_modules ) - n_optimizable_unet_params = sum( - [el.numel() for el in itertools.chain(*unet_lora_params)] - ) - print("PTI : n_optimizable_unet_params: ", n_optimizable_unet_params) - - print(f"PTI : has {len(unet_lora_params)} lora") - print("PTI : Before training:") - inspect_lora(unet) + #n_optimizable_unet_params = sum([el.numel() for el in itertools.chain(*unet_lora_params)]) + #print("PTI : Number of optimizable UNET parameters: ", n_optimizable_unet_params) params_to_optimize = [ {"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, @@ -1073,15 +1147,15 @@ def train( text_encoder_lora_params, _ = inject_trainable_lora( text_encoder, target_replace_module=lora_clip_target_modules, - r=lora_rank, + r=lora_rank_text_encoder, ) params_to_optimize += [ - { - "params": itertools.chain(*text_encoder_lora_params), - "lr": text_encoder_lr, - } + {"params": itertools.chain(*text_encoder_lora_params), + "lr": text_encoder_lr} ] - inspect_lora(text_encoder) + + #n_optimizable_text_Encoder_params = sum( [el.numel() for el in itertools.chain(*text_encoder_lora_params)]) + #print("PTI : Number of optimizable text-encoder parameters: ", n_optimizable_text_Encoder_params) lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) @@ -1090,8 +1164,6 @@ def train( print("Training text encoder!") text_encoder.train() - train_dataset.blur_amount = 70 - lr_scheduler_lora = get_scheduler( lr_scheduler_lora, optimizer=lora_optimizers, @@ -1101,6 +1173,22 @@ def train( if not cached_latents: preview_training_batch(train_dataloader, "tuning") + #print("PTI : n_optimizable_unet_params: ", n_optimizable_unet_params) + print(f"PTI : has {len(unet_lora_params)} lora") + print("PTI : Before training:") + + moved = ( + torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) + .mean().item()) + print(f"LORA Unet Moved {moved:.6f}") + + + moved = ( + torch.tensor( + list(itertools.chain(*inspect_lora(text_encoder).values())) + ).mean().item()) + print(f"LORA CLIP Moved {moved:.6f}") + perform_tuning( unet, vae, @@ -1132,6 +1220,8 @@ def train( training_time = time.time() - script_start_time print(f"Training time: {training_time/60:.1f} minutes") args_dict["training_time_s"] = int(training_time) + args_dict["n_epochs"] = math.ceil(max_train_steps_tuning / len(train_dataloader.dataset)) + args_dict["n_training_imgs"] = len(train_dataloader.dataset) # Save the args_dict to the output directory as a json file: with open(os.path.join(output_dir, "lora_training_args.json"), "w") as f: diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index e51c301..f191c12 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -44,6 +44,7 @@ "{}", "a picture of {}", "a closeup of {}", + "a closeup of {}'s face", "a closeup photo of {}", "a close-up picture of {}", "a photo of {}", @@ -60,6 +61,7 @@ "{} is having fun, 4k photograph", "{} wearing a plaidered shirt standing next to another person", "smiling {} in a hoodie and sweater", + "{} smiling at the camera", "a photo of the cool {}", "a close-up photo of {}", "a bright photo of {}", @@ -205,8 +207,7 @@ def __init__( resize=True, use_mask_captioned_data=False, use_face_segmentation_condition=False, - train_inpainting=False, - blur_amount: int = 70, + train_inpainting=False ): self.size = size self.tokenizer = tokenizer @@ -312,6 +313,11 @@ def __init__( for idx in range(len(self.instance_images_path)): self.mask_path.append(f"{instance_data_root}/{idx}.mask.png") + # Final important variables for this dataset: + # self.instance_images_path + # self.mask_path + # self.captions + self.num_instance_images = len(self.instance_images_path) self.token_map = token_map @@ -339,7 +345,12 @@ def __init__( ] ) - self.blur_amount = blur_amount + self.instance_images = [] + + if len(self.instance_images_path) < 20: + # Load all the images into memory: + for f in self.instance_images_path: + self.instance_images.append(Image.open(f).convert("RGB")) print("Captions:") print(self.captions) @@ -348,18 +359,20 @@ def tune_h_flip_prob(self, training_progress): if self.h_flip: # Tune the h_flip probability to be 0.5 training_progress is 0 and end_prob when training_progress is 1 self.h_flip_prob = 0.5 + (self.final_flip_prob - 0.5) * training_progress - print(f"h_flip_prob: {self.h_flip_prob:.3f}") def __len__(self): return self._length def __getitem__(self, index): example = {} - instance_image = Image.open( - self.instance_images_path[index % self.num_instance_images] - ) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") + + if len(self.instance_images) > 0: + instance_image = self.instance_images[index % self.num_instance_images] + else: + instance_image = Image.open( + self.instance_images_path[index % self.num_instance_images] + ).convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) if self.train_inpainting: diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index bc3c5d1..5919178 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union import numpy as np +import random import PIL import torch import torch.nn as nn @@ -534,7 +535,6 @@ def convert_loras_to_safeloras( ): convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) - def parse_safeloras( safeloras, ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: @@ -596,6 +596,55 @@ def parse_safeloras( return loras +def dict_to_lora(tensor_dict, metadata): + """ + Converts a dictionary of tensors + metadata into a Lora + """ + loras = {} + + get_name = lambda k: k.split(":")[0] + + keys = list(tensor_dict.keys()) + keys.sort(key=get_name) + + for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras + # Extract the targets + target = json.loads(info) + + # Build the result lists - Python needs us to preallocate lists to insert into them + module_keys = list(module_keys) + ranks = [4] * (len(module_keys) // 2) + weights = [None] * len(module_keys) + + for key in module_keys: + # Split the model name and index out of the key + _, idx, direction = key.split(":") + idx = int(idx) + + # Add the rank + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) + + # Insert the weight into the list + idx = idx * 2 + (1 if direction == "down" else 0) + weights[idx] = nn.parameter.Parameter(tensor_dict[key]) + + loras[name] = (weights, ranks, target) + + return loras + + def parse_safeloras_embeds( safeloras, ) -> Dict[str, torch.Tensor]: @@ -801,7 +850,7 @@ def monkeypatch_or_replace_safeloras(models, safeloras): for name, (lora, ranks, target) in loras.items(): model = getattr(models, name, None) - + if not model: print(f"No model provided for {name}, contained in Lora") continue @@ -1028,17 +1077,19 @@ def inspect_lora(model): for name, _module in model.named_modules(): if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + # get the up and down weight matrices: ups = _module.lora_up.weight.data.clone() downs = _module.lora_down.weight.data.clone() - + + # flatten and compute dot product: wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) - + # get the mean of the absolute value of the dot product: dist = wght.flatten().abs().mean().item() + if name in moved: moved[name].append(dist) else: moved[name] = [dist] - return moved