diff --git a/olive/data/component/sd_lora/__init__.py b/olive/data/component/sd_lora/__init__.py index 862c45ce3..d3bef4acc 100644 --- a/olive/data/component/sd_lora/__init__.py +++ b/olive/data/component/sd_lora/__init__.py @@ -2,3 +2,22 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from olive.data.component.sd_lora import ( + aspect_ratio_bucketing, + auto_caption, + auto_tagging, + dataset, + image_filtering, + image_resizing, + preprocess_chain, +) + +__all__ = [ + "aspect_ratio_bucketing", + "auto_caption", + "auto_tagging", + "dataset", + "image_filtering", + "image_resizing", + "preprocess_chain", +] diff --git a/olive/data/component/sd_lora/aspect_ratio_bucketing.py b/olive/data/component/sd_lora/aspect_ratio_bucketing.py index c6e1f0fd2..019e692e7 100644 --- a/olive/data/component/sd_lora/aspect_ratio_bucketing.py +++ b/olive/data/component/sd_lora/aspect_ratio_bucketing.py @@ -256,6 +256,81 @@ def aspect_ratio_bucketing( except Exception as e: logger.warning("Failed to process %s: %s", image_path, e) + # Process class images for DreamBooth (if present) + if hasattr(dataset, "class_image_paths") and dataset.class_image_paths: + logger.info("Processing %d class images for DreamBooth", len(dataset.class_image_paths)) + + # Prepare class images output directory + class_output_dir = None + if output_dir: + class_output_dir = Path(output_dir) / "class_images" + class_output_dir.mkdir(parents=True, exist_ok=True) + + for i, class_path in enumerate(dataset.class_image_paths): + class_path = Path(class_path) # noqa: PLW2901 + try: + with Image.open(class_path) as img: + orig_w, orig_h = img.size + orig_aspect = orig_w / orig_h + + # Find best matching bucket + best_bucket = _find_best_bucket(orig_w, orig_h, buckets) + bucket_w, bucket_h = best_bucket + + # Calculate crop coordinates + crops_coords_top_left = _calculate_crop_coords( + orig_w, orig_h, bucket_w, bucket_h, crop_to_bucket, crop_position + ) + + final_path = str(class_path) + + # Resize class image if requested + if resize_images and class_output_dir: + if class_path.suffix: + out_name = f"class_{i:06d}{class_path.suffix}" + else: + out_name = f"class_{i:06d}.jpg" + out_path = class_output_dir / out_name + + if not overwrite and out_path.exists(): + final_path = str(out_path) + else: + if img.mode != "RGB": + img = img.convert("RGB") # noqa: PLW2901 + + resized = resize_image( + img, + bucket_w, + bucket_h, + resize_mode=resize_mode, + crop_position=crop_position, + fill_color=fill_color, + resample_filter=resample_filter, + ) + + if out_path.suffix: + resized.save(out_path, quality=95) + else: + resized.save(out_path, format="JPEG", quality=95) + final_path = str(out_path) + + # Update class image path in dataset + dataset.class_image_paths[i] = Path(final_path) + + # Store bucket assignment for class image + bucket_assignments[final_path] = { + "bucket": best_bucket, + "original_size": (orig_w, orig_h), + "aspect_ratio": orig_aspect, + "crops_coords_top_left": crops_coords_top_left, + } + bucket_counts[best_bucket] += 1 + + except Exception as e: + logger.warning("Failed to process class image %s: %s", class_path, e) + + logger.info("Processed %d class images", len(dataset.class_image_paths)) + # Log bucket distribution logger.info("Bucket distribution:") for bucket, count in sorted(bucket_counts.items(), key=lambda x: -x[1])[:10]: diff --git a/olive/data/component/sd_lora/image_resizing.py b/olive/data/component/sd_lora/image_resizing.py index d4055883e..e8de88199 100644 --- a/olive/data/component/sd_lora/image_resizing.py +++ b/olive/data/component/sd_lora/image_resizing.py @@ -51,9 +51,12 @@ def image_resizing( """ from PIL import Image + from olive.data.component.sd_lora.utils import calculate_cover_size + # Validate resize_mode early - resize_mode = ResizeMode(resize_mode) + resize_mode_enum = ResizeMode(resize_mode) resample_filter = get_resample_filter(resample_mode) + crop_to_bucket = resize_mode_enum == ResizeMode.COVER # Prepare output directory if specified if output_dir: @@ -62,22 +65,59 @@ def image_resizing( processed_count = 0 skipped_count = 0 - - for i, item in enumerate(dataset): - image_path = Path(item["image_path"]) - - # Determine output path - if output_dir: - out_path = Path(output_dir) / image_path.name + bucket_assignments = {} + target_bucket = (target_resolution, target_resolution) + + def _calculate_crop_coords(orig_w: int, orig_h: int) -> tuple[int, int]: + """Calculate crop coordinates for SDXL time embeddings.""" + if not crop_to_bucket: + return (0, 0) + + new_w, new_h = calculate_cover_size(orig_w, orig_h, target_resolution, target_resolution) + pos = CropPosition(crop_position) + if pos == CropPosition.CENTER: + left = (new_w - target_resolution) // 2 + top = (new_h - target_resolution) // 2 + elif pos == CropPosition.TOP: + left = (new_w - target_resolution) // 2 + top = 0 + elif pos == CropPosition.BOTTOM: + left = (new_w - target_resolution) // 2 + top = new_h - target_resolution + elif pos == CropPosition.LEFT: + left = 0 + top = (new_h - target_resolution) // 2 + elif pos == CropPosition.RIGHT: + left = new_w - target_resolution + top = (new_h - target_resolution) // 2 else: - out_path = image_path + left = (new_w - target_resolution) // 2 + top = (new_h - target_resolution) // 2 - # Check if already processed - if not overwrite and out_path.exists() and out_path != image_path: - skipped_count += 1 - continue + return (top, left) + + def _process_image(image_path: Path, out_path: Path, prefix: str = "") -> Optional[str]: + """Process a single image and return the final path.""" + nonlocal processed_count, skipped_count try: + # Get original size for bucket assignment (needed even if skipping resize) + with Image.open(image_path) as img: + orig_w, orig_h = img.size + + # Check if already processed + if not overwrite and out_path.exists() and out_path != image_path: + skipped_count += 1 + # Still need to add bucket assignment for skipped files + crops_coords = _calculate_crop_coords(orig_w, orig_h) + bucket_assignments[str(out_path)] = { + "bucket": target_bucket, + "original_size": (orig_w, orig_h), + "aspect_ratio": orig_w / orig_h, + "crops_coords_top_left": crops_coords, + } + return str(out_path) + with Image.open(image_path) as img: # Convert to RGB if necessary if img.mode != "RGB": @@ -87,7 +127,7 @@ def image_resizing( img, target_resolution, target_resolution, - resize_mode=resize_mode, + resize_mode=resize_mode_enum, crop_position=crop_position, fill_color=fill_color, resample_filter=resample_filter, @@ -96,13 +136,70 @@ def image_resizing( result.save(out_path, quality=95) processed_count += 1 - # Update dataset path if output location changed - if out_path != image_path: - dataset.image_paths[i] = out_path + # Store bucket assignment + crops_coords = _calculate_crop_coords(orig_w, orig_h) + bucket_assignments[str(out_path)] = { + "bucket": target_bucket, + "original_size": (orig_w, orig_h), + "aspect_ratio": orig_w / orig_h, + "crops_coords_top_left": crops_coords, + } + + return str(out_path) except Exception as e: - logger.warning("Failed to resize %s: %s", image_path, e) + logger.warning("Failed to resize %s%s: %s", prefix, image_path, e) + return None + + # Process instance images + for i, item in enumerate(dataset): + image_path = Path(item["image_path"]) + + # Determine output path + if output_dir: + out_path = Path(output_dir) / image_path.name + else: + out_path = image_path + + final_path = _process_image(image_path, out_path) + + # Update dataset path if output location changed + if final_path and out_path != image_path: + if hasattr(dataset, "set_image_path"): + dataset.set_image_path(i, out_path) + elif hasattr(dataset, "image_paths"): + dataset.image_paths[i] = out_path + + logger.info("Resized %d instance images, skipped %d", processed_count, skipped_count) + + # Process class images for DreamBooth (if present) + if hasattr(dataset, "class_image_paths") and dataset.class_image_paths: + logger.info("Processing %d class images for DreamBooth", len(dataset.class_image_paths)) + + class_processed = 0 + class_output_dir = None + if output_dir: + class_output_dir = Path(output_dir) / "class_images" + class_output_dir.mkdir(parents=True, exist_ok=True) + + for i, class_path in enumerate(dataset.class_image_paths): + class_path = Path(class_path) # noqa: PLW2901 + + if class_output_dir: + out_path = class_output_dir / f"class_{i:06d}{class_path.suffix or '.jpg'}" + else: + out_path = class_path + + final_path = _process_image(class_path, out_path, prefix="class image ") + + if final_path: + class_processed += 1 + dataset.class_image_paths[i] = Path(final_path) + + logger.info("Processed %d class images", class_processed) - logger.info("Resized %d images, skipped %d", processed_count, skipped_count) + # Store bucket assignments in dataset (for compatibility with aspect_ratio_bucketing) + dataset.bucket_assignments = bucket_assignments + dataset.buckets = [target_bucket] return dataset diff --git a/olive/data/container/image_data_container.py b/olive/data/container/image_data_container.py index 4f5213a59..5d56e4660 100644 --- a/olive/data/container/image_data_container.py +++ b/olive/data/container/image_data_container.py @@ -99,17 +99,23 @@ def _convert_hf_dataset(self, dataset, image_column: str, caption_column: Option return HuggingFaceImageDataset(dataset, image_column, caption_column) - def pre_process(self, dataset): - """Run preprocessing with HuggingFace dataset support.""" - # Check if this is a HuggingFace dataset and convert if needed + def load_dataset(self): + """Load dataset, extracting ImageDataContainer-specific params first.""" + # Pop image_column and caption_column so they don't get passed to huggingface_dataset + params = self.config.load_dataset_config.params + image_column = params.pop("image_column", "image") + caption_column = params.pop("caption_column", None) + + # Load the raw HuggingFace dataset + dataset = super().load_dataset() + + # Convert to HuggingFaceImageDataset if needed if self._is_huggingface_dataset(): - load_params = self.config.load_dataset_config.params - image_column = load_params.get("image_column", "image") - caption_column = load_params.get("caption_column") logger.info( - "Converting HuggingFace dataset: image_column=%s, caption_column=%s", image_column, caption_column + "Converting HuggingFace dataset: image_column=%s, caption_column=%s", + image_column, + caption_column, ) dataset = self._convert_hf_dataset(dataset, image_column, caption_column) - # Run the standard preprocessing - return super().pre_process(dataset) + return dataset diff --git a/olive/model/handler/diffusers.py b/olive/model/handler/diffusers.py index 013c97d73..f292d4ef3 100644 --- a/olive/model/handler/diffusers.py +++ b/olive/model/handler/diffusers.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union from olive.common.utils import StrEnumBase @@ -61,6 +62,9 @@ def __init__( model_attributes: Additional model attributes. """ + if not self.is_valid_model(model_path): + raise ValueError(f"The provided model_path '{model_path}' is not a valid diffusion model.") + super().__init__( framework=Framework.PYTORCH, model_file_format=ModelFileFormat.PYTORCH_ENTIRE_MODEL, @@ -73,6 +77,33 @@ def __init__( self.load_kwargs = load_kwargs or {} self._pipeline = None + @classmethod + def is_valid_model(cls, model_path: str) -> bool: + """Check if the path is a valid diffusion model. + + Diffusion models are identified by the presence of a model_index.json file. + + Args: + model_path: Local path or HuggingFace model ID. + + Returns: + True if the path points to a valid diffusion model. + + """ + # Local path + path = Path(model_path) + if path.is_dir(): + return (path / "model_index.json").exists() + + # HuggingFace model ID - try to check if model_index.json exists + try: + from huggingface_hub import hf_hub_download + + hf_hub_download(model_path, "model_index.json") + return True + except Exception: + return False + @property def adapter_path(self) -> Optional[str]: """Return the path to the LoRA adapter.""" @@ -86,8 +117,6 @@ def size_on_disk(self) -> int: Returns 0 if unable to compute (e.g., for HuggingFace Hub IDs). """ try: - from pathlib import Path - model_path = Path(self.model_path) if not model_path.exists(): # Remote model (HuggingFace Hub ID) diff --git a/olive/olive_config.json b/olive/olive_config.json index db98a8d62..77debc181 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -483,6 +483,16 @@ "supported_quantization_encodings": [ ], "dataset": "dataset" }, + "SDLoRA": { + "module_path": "olive.passes.diffusers.lora.SDLoRA", + "supported_providers": [ "*" ], + "supported_accelerators": [ "gpu" ], + "supported_precisions": [ "*" ], + "extra_dependencies": [ "sd-lora" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "dataset": "dataset_required" + }, "QNNContextBinaryGenerator": { "module_path": "olive.passes.qnn.context_binary_generator.QNNContextBinaryGenerator", "supported_providers": [ "*" ], diff --git a/olive/passes/diffusers/__init__.py b/olive/passes/diffusers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/olive/passes/diffusers/lora.py b/olive/passes/diffusers/lora.py new file mode 100644 index 000000000..51af1ef29 --- /dev/null +++ b/olive/passes/diffusers/lora.py @@ -0,0 +1,1096 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import math +import os +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import Optional, Union + +from olive.common.utils import StrEnumBase +from olive.data.config import DataConfig +from olive.hardware.accelerator import AcceleratorSpec +from olive.model import DiffusersModelHandler +from olive.model.handler.diffusers import DiffusersModelType +from olive.passes import Pass +from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig + +logger = logging.getLogger(__name__) + + +class MixedPrecision(StrEnumBase): + """Mixed precision training mode.""" + + NO = "no" + FP16 = "fp16" + BF16 = "bf16" + + +class LRSchedulerType(StrEnumBase): + """Learning rate scheduler type.""" + + CONSTANT = "constant" + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + + +class DiffusionTrainingArguments: + """Training arguments for diffusion model LoRA fine-tuning.""" + + def __init__( + self, + learning_rate: float = 1e-4, + max_train_steps: int = 1000, + train_batch_size: int = 1, + gradient_accumulation_steps: int = 4, + gradient_checkpointing: bool = True, + mixed_precision: Union[str, MixedPrecision] = MixedPrecision.BF16, + lr_scheduler: Union[str, LRSchedulerType] = LRSchedulerType.CONSTANT, + lr_warmup_steps: int = 0, + snr_gamma: Optional[float] = None, + max_grad_norm: float = 1.0, + checkpointing_steps: int = 500, + logging_steps: int = 10, + seed: Optional[int] = None, + # Flux-specific + guidance_scale: float = 3.5, + ): + self.learning_rate = learning_rate + self.max_train_steps = max_train_steps + self.train_batch_size = train_batch_size + self.gradient_accumulation_steps = gradient_accumulation_steps + self.gradient_checkpointing = gradient_checkpointing + self.mixed_precision = mixed_precision + self.lr_scheduler = lr_scheduler + self.lr_warmup_steps = lr_warmup_steps + self.snr_gamma = snr_gamma + self.max_grad_norm = max_grad_norm + self.checkpointing_steps = checkpointing_steps + self.logging_steps = logging_steps + self.seed = seed + self.guidance_scale = guidance_scale + + +class SDLoRA(Pass): + """Run LoRA fine-tuning on diffusion models. + + Supports: + - Stable Diffusion 1.5: UNet-based, CLIP text encoder + - Stable Diffusion XL: UNet-based, dual CLIP text encoders + - Flux.1: DiT-based (Transformer), CLIP + T5 text encoders + + Trains LoRA adapters on the denoising model (UNet for SD, Transformer for Flux). + """ + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + # Model config + "model_type": PassConfigParam( + type_=DiffusersModelType, + default_value=DiffusersModelType.AUTO, + description="Model type: 'sd15', 'sdxl', 'flux', or 'auto' to detect automatically.", + ), + # LoRA config + "r": PassConfigParam( + type_=int, + default_value=16, + description="LoRA rank. Flux typically needs higher rank (16-64) than SD (4-16).", + ), + "alpha": PassConfigParam( + type_=float, + default_value=None, + description="LoRA alpha for scaling. Defaults to r if not specified.", + ), + "lora_dropout": PassConfigParam( + type_=float, + default_value=0.0, + description="Dropout probability for LoRA layers.", + ), + "target_modules": PassConfigParam( + type_=list[str], + default_value=None, + description=( + "Target modules for LoRA. Defaults depend on model type:\n" + "- SD/SDXL: ['to_k', 'to_q', 'to_v', 'to_out.0']\n" + "- Flux: ['to_k', 'to_q', 'to_v', 'to_out.0', 'add_k_proj', 'add_q_proj', 'add_v_proj']" + ), + ), + # Data config + "train_data_config": PassConfigParam( + type_=Union[DataConfig, dict], + required=True, + description="Data config for training dataset.", + ), + # Training config + "training_args": PassConfigParam( + type_=Union[DiffusionTrainingArguments, dict], + default_value=None, + description="Training arguments. See DiffusionTrainingArguments for options.", + ), + # Output config + "merge_lora": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Merge LoRA weights into base model and save merged model. " + "If False, only saves LoRA adapter weights." + ), + ), + # DreamBooth config + "dreambooth": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Enable DreamBooth training with prior preservation loss. " + "Use this when training on a specific subject (e.g., your dog, your face). " + "Requires dataset to include 'class_image_path' and 'class_caption' fields." + ), + ), + "prior_loss_weight": PassConfigParam( + type_=float, + default_value=1.0, + description="Weight of the prior preservation loss (only used when dreambooth=True).", + ), + } + + def _run_for_config( + self, model: DiffusersModelHandler, config: BasePassConfig, output_model_path: str + ) -> DiffusersModelHandler: + """Run diffusion model LoRA training.""" + # Initialize training args + if config.training_args is None: + training_args = DiffusionTrainingArguments() + elif isinstance(config.training_args, dict): + training_args = DiffusionTrainingArguments(**config.training_args) + else: + training_args = config.training_args + + # Detect model type + model_type = self._detect_model_type(model, config) + logger.info("Detected model type: %s", model_type) + + # Route to appropriate training method + if model_type == DiffusersModelType.FLUX: + return self._train_flux(model, config, training_args, output_model_path) + else: + return self._train_sd(model, config, training_args, model_type, output_model_path) + + def _train_sd( + self, + model: DiffusersModelHandler, + config: BasePassConfig, + training_args: DiffusionTrainingArguments, + model_type: DiffusersModelType, + output_model_path: str, + ) -> DiffusersModelHandler: + """Train LoRA for Stable Diffusion (SD1.5/SDXL).""" + import torch + import torch.nn.functional as F + from accelerate import Accelerator + from accelerate.utils import ProjectConfiguration, set_seed + from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + from diffusers.optimization import get_scheduler + from diffusers.training_utils import compute_snr + from peft import LoraConfig + from tqdm.auto import tqdm + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + # Setup accelerator + with tempfile.TemporaryDirectory(prefix="olive_sd_lora_") as temp_dir: + project_config = ProjectConfiguration( + project_dir=temp_dir, + logging_dir=os.path.join(temp_dir, "logs"), + ) + accelerator = Accelerator( + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + mixed_precision=training_args.mixed_precision, + project_config=project_config, + ) + + if training_args.seed is not None: + set_seed(training_args.seed) + + # Load models + model_path = model.model_path + logger.info("Loading SD models from %s", model_path) + + # Load text encoders (frozen, for encoding prompts only) + text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder") + text_encoder_2 = None + if model_type == DiffusersModelType.SDXL: + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder_2") + + # Load VAE and UNet + vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet") + + # Load noise scheduler + noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + + # Set weight dtype + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Freeze all base models and move to device + vae.requires_grad_(False) + vae.to(accelerator.device, dtype=weight_dtype) + + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if text_encoder_2: + text_encoder_2.requires_grad_(False) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + # Setup LoRA for UNet only (after moving to device) + lora_alpha = config.alpha if config.alpha is not None else config.r + target_modules = config.target_modules or ["to_k", "to_q", "to_v", "to_out.0"] + + unet_lora_config = LoraConfig( + r=config.r, + lora_alpha=lora_alpha, + lora_dropout=config.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + unet.add_adapter(unet_lora_config) + + # LoRA trainable parameters should be fp32 for stable training with mixed precision + for param in unet.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.float32) + + # Log trainable parameters + trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in unet.parameters()) + logger.info( + "UNet trainable parameters: %d / %d (%.2f%%)", + trainable_params, + total_params, + 100 * trainable_params / total_params, + ) + + # Enable gradient checkpointing + if training_args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Load dataset + train_dataset = self._load_dataset(config) + train_dataloader = self._create_dataloader( + train_dataset, training_args, model_path, model_type, prior_preservation=config.dreambooth + ) + + # Calculate training steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) + num_train_epochs = math.ceil(training_args.max_train_steps / num_update_steps_per_epoch) + + # Setup optimizer (UNet only) + params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=training_args.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-8, + ) + + # Setup LR scheduler + lr_scheduler = get_scheduler( + training_args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=training_args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=training_args.max_train_steps * accelerator.num_processes, + ) + + # Prepare with accelerator + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # Training loop + logger.info("***** Running SD LoRA training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num epochs = %d", num_train_epochs) + logger.info(" Batch size = %d", training_args.train_batch_size) + logger.info(" Gradient accumulation steps = %d", training_args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", training_args.max_train_steps) + if config.dreambooth: + logger.info(" Prior preservation enabled with weight = %.2f", config.prior_loss_weight) + + global_step = 0 + progress_bar = tqdm( + range(training_args.max_train_steps), + disable=not accelerator.is_local_main_process, + desc="Training", + ) + + for _ in range(num_train_epochs): + unet.train() + + for _, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Encode images to latents + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=weight_dtype) + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise and timesteps + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # Add noise + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get text embeddings (frozen) + with torch.no_grad(): + if model_type == DiffusersModelType.SDXL: + encoder_hidden_states, pooled = self._encode_prompt_sdxl( + batch, text_encoder, text_encoder_2 + ) + add_time_ids = self._compute_time_ids_batch(batch, latents.device, weight_dtype) + added_cond_kwargs = {"text_embeds": pooled, "time_ids": add_time_ids} + else: + encoder_hidden_states = self._encode_prompt_sd15(batch, text_encoder) + added_cond_kwargs = None + + # UNet forward (with gradients for LoRA training) + # Cast to weight_dtype to match UNet's expected input dtype + noisy_latents = noisy_latents.to(dtype=weight_dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) + + if model_type == DiffusersModelType.SDXL: + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + else: + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + + # Compute loss target + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}") + + # Batch is [instance_images, class_images] concatenated + if config.dreambooth: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Instance loss + if training_args.snr_gamma is not None: + # SNR weighting only on instance portion + instance_timesteps = timesteps[: len(timesteps) // 2] + snr = compute_snr(noise_scheduler, instance_timesteps) + mse_loss_weights = torch.stack( + [snr, torch.full_like(snr, training_args.snr_gamma)], dim=1 + ).min(dim=1)[0] + if noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights + 1 + mse_loss_weights = mse_loss_weights / snr + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + instance_loss = ( + instance_loss.mean(dim=list(range(1, len(instance_loss.shape)))) * mse_loss_weights + ) + instance_loss = instance_loss.mean() + else: + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Combined loss + loss = instance_loss + config.prior_loss_weight * prior_loss + else: + # Standard LoRA training without prior preservation + if training_args.snr_gamma is not None: + snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack( + [snr, torch.full_like(snr, training_args.snr_gamma)], dim=1 + ).min(dim=1)[0] + if noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights + 1 + mse_loss_weights = mse_loss_weights / snr + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, training_args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % training_args.logging_steps == 0: + logger.info( + "Step %d: loss=%.4f, lr=%.6f", + global_step, + loss.detach().item(), + lr_scheduler.get_last_lr()[0], + ) + + if global_step % training_args.checkpointing_steps == 0: + save_path = os.path.join(temp_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + if global_step >= training_args.max_train_steps: + break + + # Save final model + accelerator.wait_for_everyone() + + # Save adapter weights + adapter_path = Path(output_model_path) / "adapter" + + if accelerator.is_main_process: + adapter_path.mkdir(parents=True, exist_ok=True) + + if config.merge_lora: + # Merge LoRA and save full UNet + unet_unwrapped = accelerator.unwrap_model(unet) + unet_merged = unet_unwrapped.merge_and_unload() + unet_merged.save_pretrained(adapter_path) + logger.info("Saved merged UNet to %s", adapter_path) + else: + # Save LoRA adapter in diffusers-compatible format + unet_unwrapped = accelerator.unwrap_model(unet) + unet_unwrapped = unet_unwrapped.to(torch.float32) + + from peft import get_peft_model_state_dict + + unet_lora_state_dict = get_peft_model_state_dict(unet_unwrapped) + + # Use appropriate pipeline class based on model type + if model_type == DiffusersModelType.SDXL: + from diffusers import StableDiffusionXLPipeline + + StableDiffusionXLPipeline.save_lora_weights( + save_directory=str(adapter_path), + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) + else: + from diffusers import StableDiffusionPipeline + + StableDiffusionPipeline.save_lora_weights( + save_directory=str(adapter_path), + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) + logger.info("Saved UNet LoRA to %s", adapter_path) + + accelerator.end_training() + + # Clean up + del unet + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Return model handler + output_model = deepcopy(model) + output_model.set_resource("adapter_path", str(adapter_path)) + return output_model + + def _train_flux( + self, + model: DiffusersModelHandler, + config: BasePassConfig, + training_args: DiffusionTrainingArguments, + output_model_path: str, + ) -> DiffusersModelHandler: + """Train LoRA for Flux models.""" + import torch + import torch.nn.functional as F + from accelerate import Accelerator + from accelerate.utils import ProjectConfiguration, set_seed + from diffusers import AutoencoderKL, FluxTransformer2DModel + from diffusers.optimization import get_scheduler + from peft import LoraConfig + from tqdm.auto import tqdm + from transformers import CLIPTextModel, T5EncoderModel + + # Flux requires bfloat16 + if training_args.mixed_precision == "fp16": + logger.warning("Flux requires bfloat16, switching from fp16") + training_args.mixed_precision = "bf16" + + with tempfile.TemporaryDirectory(prefix="olive_flux_lora_") as temp_dir: + project_config = ProjectConfiguration( + project_dir=temp_dir, + logging_dir=os.path.join(temp_dir, "logs"), + ) + accelerator = Accelerator( + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + mixed_precision=training_args.mixed_precision, + project_config=project_config, + ) + + if training_args.seed is not None: + set_seed(training_args.seed) + + model_path = model.model_path + logger.info("Loading Flux models from %s", model_path) + + # Set weight dtype (Flux needs bfloat16) + weight_dtype = torch.bfloat16 + + # Load text encoders (frozen, for encoding prompts only) + text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder") + text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2") + + # Load VAE and Transformer + vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") + transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer") + + # Freeze all base models and move to device + vae.requires_grad_(False) + vae.to(accelerator.device, dtype=weight_dtype) + + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + text_encoder_2.requires_grad_(False) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + + transformer.requires_grad_(False) + transformer.to(accelerator.device, dtype=weight_dtype) + + # Setup LoRA for transformer only (after moving to device) + lora_alpha = config.alpha if config.alpha is not None else config.r + target_modules = config.target_modules or [ + "to_k", + "to_q", + "to_v", + "to_out.0", + "add_k_proj", + "add_q_proj", + "add_v_proj", + "to_add_out", + ] + + transformer_lora_config = LoraConfig( + r=config.r, + lora_alpha=lora_alpha, + lora_dropout=config.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + # LoRA trainable parameters should be fp32 for stable training with mixed precision + for param in transformer.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.float32) + + # Log trainable parameters + trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in transformer.parameters()) + logger.info( + "Flux Transformer trainable parameters: %d / %d (%.2f%%)", + trainable_params, + total_params, + 100 * trainable_params / total_params, + ) + + # Enable gradient checkpointing + if training_args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # Load dataset + train_dataset = self._load_dataset(config) + train_dataloader = self._create_dataloader( + train_dataset, + training_args, + model_path, + DiffusersModelType.FLUX, + prior_preservation=config.dreambooth, + ) + + # Calculate training steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) + num_train_epochs = math.ceil(training_args.max_train_steps / num_update_steps_per_epoch) + + # Setup optimizer (transformer only) + params_to_optimize = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=training_args.learning_rate, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-8, + ) + + # Setup LR scheduler + lr_scheduler = get_scheduler( + training_args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=training_args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=training_args.max_train_steps * accelerator.num_processes, + ) + + # Prepare with accelerator + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # Training loop + logger.info("***** Running Flux LoRA training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num epochs = %d", num_train_epochs) + logger.info(" Batch size = %d", training_args.train_batch_size) + logger.info(" Gradient accumulation steps = %d", training_args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", training_args.max_train_steps) + if config.dreambooth: + logger.info(" Prior preservation enabled with weight = %.2f", config.prior_loss_weight) + + global_step = 0 + progress_bar = tqdm( + range(training_args.max_train_steps), + disable=not accelerator.is_local_main_process, + desc="Training Flux", + ) + + for _ in range(num_train_epochs): + transformer.train() + + for _, batch in enumerate(train_dataloader): + with accelerator.accumulate(transformer): + # Encode images to latents + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=weight_dtype) + latents = vae.encode(pixel_values).latent_dist.sample() + latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor + + # Save latent dimensions before packing (needed for image IDs) + batch_size, _channels, latent_height, latent_width = latents.shape + + # Pack latents for Flux + latents = self._pack_latents(latents) + + # Sample noise and timesteps (flow matching) + noise = torch.randn_like(latents) + + # Flux uses continuous timesteps in [0, 1] + u = torch.rand(batch_size, device=latents.device, dtype=weight_dtype) + timesteps = (u * 1000).to(latents.device) + + # Flow matching: sigmas = t + sigmas = self._get_sigmas(timesteps, latents.ndim, latents.dtype, latents.device) + noisy_latents = sigmas * noise + (1.0 - sigmas) * latents + + # Get text embeddings (frozen) + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = self._encode_prompt_flux( + batch, text_encoder, text_encoder_2 + ) + # Get latent image IDs + latent_image_ids = self._prepare_latent_image_ids( + batch_size, latent_height // 2, latent_width // 2, latents.device, weight_dtype + ) + + # Transformer forward (with gradients for LoRA training) + model_pred = transformer( + hidden_states=noisy_latents, + timestep=timesteps / 1000, + guidance=torch.full( + (batch_size,), training_args.guidance_scale, device=latents.device, dtype=weight_dtype + ), + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + # Flow matching target: velocity = noise - data + target = noise - latents + + # Prior preservation: split predictions and targets (official HuggingFace approach) + if config.dreambooth: + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Instance loss + instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Combined loss + loss = instance_loss + config.prior_loss_weight * prior_loss + else: + # Standard LoRA training + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, training_args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % training_args.logging_steps == 0: + logger.info( + "Step %d: loss=%.4f, lr=%.6f", + global_step, + loss.detach().item(), + lr_scheduler.get_last_lr()[0], + ) + + if global_step % training_args.checkpointing_steps == 0: + save_path = os.path.join(temp_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + if global_step >= training_args.max_train_steps: + break + + # Save final model + accelerator.wait_for_everyone() + output_path = Path(output_model_path) + + # Save adapter weights + adapter_path = output_path / "adapter" + + if accelerator.is_main_process: + adapter_path.mkdir(parents=True, exist_ok=True) + + if config.merge_lora: + # Merge LoRA and save full transformer + transformer_unwrapped = accelerator.unwrap_model(transformer) + transformer_merged = transformer_unwrapped.merge_and_unload() + transformer_merged.save_pretrained(adapter_path) + logger.info("Saved merged Transformer to %s", adapter_path) + else: + # Save LoRA adapter in diffusers-compatible format + transformer_unwrapped = accelerator.unwrap_model(transformer) + transformer_unwrapped = transformer_unwrapped.to(torch.float32) + + from diffusers import FluxPipeline + from peft import get_peft_model_state_dict + + transformer_lora_state_dict = get_peft_model_state_dict(transformer_unwrapped) + + FluxPipeline.save_lora_weights( + save_directory=str(adapter_path), + transformer_lora_layers=transformer_lora_state_dict, + safe_serialization=True, + ) + logger.info("Saved Flux Transformer LoRA to %s", adapter_path) + + accelerator.end_training() + + # Clean up + del transformer + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Return model handler + output_model = deepcopy(model) + output_model.set_resource("adapter_path", str(adapter_path)) + return output_model + + def _detect_model_type(self, model: DiffusersModelHandler, config: BasePassConfig) -> DiffusersModelType: + """Detect the model type.""" + if config.model_type != DiffusersModelType.AUTO: + return config.model_type + + return model.detected_model_type + + def _load_dataset(self, config): + """Load training dataset. + + Returns the dataset with bucket_assignments preserved for SDXL time embeddings. + """ + data_config = config.train_data_config + if isinstance(data_config, dict): + data_config = DataConfig(**data_config) + + # Load and preprocess dataset through container + data_container = data_config.to_data_container() + return data_container.pre_process(data_container.load_dataset()) + + def _create_dataloader(self, dataset, training_args, model_path, model_type, prior_preservation=False): + """Create training dataloader with image loading and tokenization. + + Args: + dataset: Dataset with image_path, caption, and bucket_assignments. + For DreamBooth, also expects class_image_path and class_caption. + training_args: Training arguments. + model_path: Path to the diffusion model for loading tokenizers. + model_type: Type of diffusion model (SD15, SDXL, FLUX). + prior_preservation: Whether to include class images for prior preservation. + + Raises: + ValueError: If dataset has not been preprocessed with aspect_ratio_bucketing or image_resizing. + + """ + import numpy as np + import torch + from PIL import Image + from transformers import AutoTokenizer + + # Validate that dataset has been preprocessed + bucket_assignments = getattr(dataset, "bucket_assignments", None) + if not bucket_assignments: + raise ValueError( + "Dataset must be preprocessed with 'aspect_ratio_bucketing' or 'image_resizing'. " + "Please configure train_data_config with appropriate preprocessing steps. " + "Example: {'pre_process_config': {'name': 'aspect_ratio_bucketing', 'params': {'base_resolution': 1024}}}" + ) + + # Load tokenizers based on model type + tokenizers = {} + if model_type == DiffusersModelType.FLUX: + tokenizers["clip"] = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + tokenizers["t5"] = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer_2") + elif model_type == DiffusersModelType.SDXL: + tokenizers["one"] = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + tokenizers["two"] = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer_2") + else: # SD15 + tokenizers["main"] = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") + + def process_image(image_path): + """Load and process a single image. + + Raises: + ValueError: If image_path is not found in bucket_assignments. + + """ + if image_path not in bucket_assignments: + raise ValueError( + f"Image '{image_path}' not found in bucket_assignments. " + "All images (including class images) must be preprocessed." + ) + + assignment = bucket_assignments[image_path] + original_size = assignment["original_size"] + bucket_size = assignment["bucket"] + crops_coords = assignment["crops_coords_top_left"] + + # Load image - should already be resized during preprocessing + img = Image.open(image_path).convert("RGB") + + # Verify image size matches bucket assignment + bucket_w, bucket_h = bucket_size + if img.size != (bucket_w, bucket_h): + raise ValueError( + f"Image '{image_path}' size {img.size} does not match bucket assignment {bucket_size}. " + "Please ensure images are resized during preprocessing (set resize_images=True)." + ) + + img_array = np.array(img, dtype=np.float32) / 127.5 - 1.0 + pixel_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)) + + return pixel_tensor, original_size, bucket_size, crops_coords + + def tokenize_captions(captions_list): + """Tokenize a list of captions.""" + result = {} + for name, tok in tokenizers.items(): + max_len = 512 if name == "t5" else 77 + tokens = tok( + captions_list, padding="max_length", max_length=max_len, truncation=True, return_tensors="pt" + ) + key_map = { + "main": "input_ids", + "one": "input_ids_one", + "two": "input_ids_two", + "clip": "input_ids_one", + "t5": "input_ids_t5", + } + result[key_map[name]] = tokens.input_ids + return result + + def collate_fn(examples): + # Instance images + pixel_values = [] + captions = [] + original_sizes = [] + target_sizes = [] + crops_coords = [] + + for ex in examples: + # Process instance image + image_path = ex.get("image_path", "") + pixel_tensor, orig_size, tgt_size, crop_coords = process_image(image_path) + pixel_values.append(pixel_tensor) + original_sizes.append(orig_size) + target_sizes.append(tgt_size) + crops_coords.append(crop_coords) + captions.append(ex.get("caption", "")) + + # For DreamBooth with prior preservation: concatenate class images after instance images + # This follows the official HuggingFace approach - single forward pass, then torch.chunk + if prior_preservation: + class_image_count = 0 + for ex in examples: + class_image_path = ex.get("class_image_path", "") + if class_image_path: + class_image_count += 1 + # Process class image (must be in bucket_assignments after preprocessing) + pixel_tensor, orig_size, tgt_size, crop_coords = process_image(class_image_path) + pixel_values.append(pixel_tensor) + original_sizes.append(orig_size) + target_sizes.append(tgt_size) + crops_coords.append(crop_coords) + captions.append(ex.get("class_caption", "")) + + if class_image_count == 0: + logger.warning( + "prior_preservation is enabled but no class images found in batch. " + "Ensure dataset includes 'class_image_path' and 'class_caption' fields." + ) + elif class_image_count != len(examples): + logger.warning( + "prior_preservation: only %d/%d examples have class images. " + "All examples should have class images for proper prior preservation.", + class_image_count, + len(examples), + ) + + result = {"pixel_values": torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()} + + # Tokenize all captions (instance + class) + tokens = tokenize_captions(captions) + result.update(tokens) + + # Add size info for SDXL + if model_type == DiffusersModelType.SDXL: + result["original_sizes"] = original_sizes + result["target_sizes"] = target_sizes + result["crops_coords_top_left"] = crops_coords + + return result + + # Use bucket batch sampler to ensure images in each batch have the same dimensions + from olive.data.component.sd_lora.dataloader import BucketBatchSampler + + batch_sampler = BucketBatchSampler( + dataset, + batch_size=training_args.train_batch_size, + drop_last=False, + shuffle=True, + seed=training_args.seed, + ) + + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + pin_memory=True, + ) + + def _encode_prompt_sd15(self, batch, text_encoder): + """Encode prompts for SD 1.5 (frozen text encoder).""" + input_ids = batch["input_ids"].to(text_encoder.device) + return text_encoder(input_ids, return_dict=False)[0] + + def _encode_prompt_sdxl(self, batch, text_encoder, text_encoder_2): + """Encode prompts for SDXL (frozen text encoders).""" + import torch + + input_ids_one = batch["input_ids_one"].to(text_encoder.device) + input_ids_two = batch["input_ids_two"].to(text_encoder_2.device) + + encoder_hidden_states = text_encoder(input_ids_one, output_hidden_states=True).hidden_states[-2] + encoder_output_2 = text_encoder_2(input_ids_two, output_hidden_states=True) + pooled = encoder_output_2[0] + encoder_hidden_states_2 = encoder_output_2.hidden_states[-2] + + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) + return encoder_hidden_states, pooled + + def _encode_prompt_flux(self, batch, text_encoder, text_encoder_2): + """Encode prompts for Flux (frozen text encoders).""" + import torch + + input_ids_clip = batch["input_ids_one"].to(text_encoder.device) + input_ids_t5 = batch.get("input_ids_t5", batch.get("input_ids_two")).to(text_encoder_2.device) + + # CLIP encoder for pooled embeddings + clip_output = text_encoder(input_ids_clip, output_hidden_states=True) + pooled_prompt_embeds = clip_output.pooler_output + + # T5 encoder for main text embeddings + t5_output = text_encoder_2(input_ids_t5) + prompt_embeds = t5_output[0] + + # Create text IDs + text_ids = torch.zeros(prompt_embeds.shape[0], prompt_embeds.shape[1], 3, device=prompt_embeds.device) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def _compute_time_ids_batch(self, batch, device, dtype): + """Compute SDXL time IDs for a batch.""" + import torch + + add_time_ids_list = [] + for i in range(len(batch["original_sizes"])): + original_size = batch["original_sizes"][i] + target_size = batch["target_sizes"][i] + crops_coords = batch["crops_coords_top_left"][i] + add_time_ids = list(original_size + crops_coords + target_size) + add_time_ids_list.append(add_time_ids) + + return torch.tensor(add_time_ids_list, dtype=dtype, device=device) + + def _pack_latents(self, latents): + """Pack latents for Flux (reshape to sequence format).""" + batch_size, channels, height, width = latents.shape + latents = latents.view(batch_size, channels, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + return latents.reshape(batch_size, (height // 2) * (width // 2), channels * 4) + + def _get_sigmas(self, timesteps, n_dim, dtype, device): + """Get sigmas for flow matching.""" + sigmas = timesteps / 1000.0 + while len(sigmas.shape) < n_dim: + sigmas = sigmas.unsqueeze(-1) + return sigmas.to(dtype=dtype, device=device) + + def _prepare_latent_image_ids(self, batch_size, height, width, device, dtype): + """Prepare latent image IDs for Flux.""" + import torch + + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) + latent_image_ids[..., 1] = torch.arange(height, device=device, dtype=dtype)[:, None] + latent_image_ids[..., 2] = torch.arange(width, device=device, dtype=dtype)[None, :] + latent_image_ids = latent_image_ids.reshape(height * width, 3) + return latent_image_ids.unsqueeze(0).expand(batch_size, -1, -1) diff --git a/test/model/test_diffusers_model.py b/test/model/test_diffusers_model.py index b75d27f5b..76fae9c63 100644 --- a/test/model/test_diffusers_model.py +++ b/test/model/test_diffusers_model.py @@ -52,6 +52,9 @@ def test_detected_model_type_explicit(self, model_type, expected): model = DiffusersModelHandler(model_path=self.model_path, model_type=model_type) assert model.detected_model_type == expected + @patch("olive.model.handler.diffusers.DiffusersModelHandler.is_valid_model", return_value=True) + @patch("diffusers.UNet2DConditionModel.load_config", side_effect=Exception("not found")) + @patch("diffusers.FluxTransformer2DModel.load_config", side_effect=Exception("not found")) @pytest.mark.parametrize( ("model_path", "expected"), [ @@ -63,15 +66,14 @@ def test_detected_model_type_explicit(self, model_type, expected): ("my-custom-sd-model", DiffusersModelType.SD15), ], ) - @patch("diffusers.UNet2DConditionModel.load_config", side_effect=Exception("not found")) - @patch("diffusers.FluxTransformer2DModel.load_config", side_effect=Exception("not found")) - def test_detected_model_type_auto_from_path(self, mock_flux, mock_unet, model_path, expected): + def test_detected_model_type_auto_from_path(self, mock_flux, mock_unet, mock_is_valid, model_path, expected): model = DiffusersModelHandler(model_path=model_path, model_type=DiffusersModelType.AUTO) assert model.detected_model_type == expected + @patch("olive.model.handler.diffusers.DiffusersModelHandler.is_valid_model", return_value=True) @patch("diffusers.UNet2DConditionModel.load_config", side_effect=Exception("not found")) @patch("diffusers.FluxTransformer2DModel.load_config", side_effect=Exception("not found")) - def test_detected_model_type_auto_raises_error(self, mock_flux, mock_unet): + def test_detected_model_type_auto_raises_error(self, mock_flux, mock_unet, mock_is_valid): model = DiffusersModelHandler(model_path="some-random-model", model_type=DiffusersModelType.AUTO) with pytest.raises(ValueError, match="Cannot detect model type"): _ = model.detected_model_type diff --git a/test/passes/diffusers/__init__.py b/test/passes/diffusers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/passes/diffusers/conftest.py b/test/passes/diffusers/conftest.py new file mode 100644 index 000000000..cc183a4e4 --- /dev/null +++ b/test/passes/diffusers/conftest.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from unittest.mock import MagicMock + +import pytest +import torch + +from olive.model import DiffusersModelHandler + + +@pytest.fixture +def test_image_folder(tmp_path): + """Create a test image folder with images and captions.""" + from PIL import Image + + data_dir = tmp_path / "train_images" + data_dir.mkdir(parents=True, exist_ok=True) + + for i in range(4): + img = Image.new("RGB", (64, 64), color=(i * 50, i * 50, i * 50)) + img.save(data_dir / f"image_{i}.png") + (data_dir / f"image_{i}.txt").write_text(f"a test image {i}") + + return str(data_dir) + + +@pytest.fixture +def output_folder(tmp_path): + """Create output folder.""" + folder = tmp_path / "output" + folder.mkdir() + return str(folder) + + +@pytest.fixture +def mock_accelerator(): + """Create a mock accelerator.""" + mock_acc = MagicMock() + mock_acc.device = "cpu" + mock_acc.mixed_precision = "no" + mock_acc.num_processes = 1 + mock_acc.is_main_process = True + mock_acc.is_local_main_process = True + mock_acc.sync_gradients = True + mock_acc.gradient_accumulation_steps = 1 + mock_acc.prepare.side_effect = lambda *args: args + mock_acc.backward = MagicMock() + mock_acc.clip_grad_norm_ = MagicMock() + mock_acc.unwrap_model = lambda x: x + mock_acc.wait_for_everyone = MagicMock() + mock_acc.end_training = MagicMock() + mock_acc.save_state = MagicMock() + mock_acc.accumulate = MagicMock(return_value=MagicMock(__enter__=MagicMock(), __exit__=MagicMock())) + return mock_acc + + +@pytest.fixture +def mock_torch_model(): + """Create a mock torch model with parameters.""" + mock_model = MagicMock() + param = torch.nn.Parameter(torch.randn(4, 4)) + mock_model.parameters.return_value = [param] + mock_model.named_parameters.return_value = [("weight", param)] + mock_model.requires_grad_ = MagicMock(return_value=mock_model) + mock_model.to = MagicMock(return_value=mock_model) + mock_model.train = MagicMock(return_value=mock_model) + mock_model.eval = MagicMock(return_value=mock_model) + mock_model.config = MagicMock() + mock_model.config.in_channels = 4 + return mock_model + + +@pytest.fixture +def mock_input_model_sd15(): + """Create mock input model for SD 1.5.""" + model = MagicMock(spec=DiffusersModelHandler) + model.model_path = "runwayml/stable-diffusion-v1-5" + model.model_attributes = {} + model.get_resource.return_value = model.model_path + return model + + +@pytest.fixture +def mock_input_model_sdxl(): + """Create mock input model for SDXL.""" + model = MagicMock(spec=DiffusersModelHandler) + model.model_path = "stabilityai/stable-diffusion-xl-base-1.0" + model.model_attributes = {} + model.get_resource.return_value = model.model_path + return model + + +@pytest.fixture +def mock_input_model_flux(): + """Create mock input model for Flux.""" + model = MagicMock(spec=DiffusersModelHandler) + model.model_path = "black-forest-labs/FLUX.1-dev" + model.model_attributes = {} + model.get_resource.return_value = model.model_path + return model diff --git a/test/passes/diffusers/test_lora.py b/test/passes/diffusers/test_lora.py new file mode 100644 index 000000000..f8d24081f --- /dev/null +++ b/test/passes/diffusers/test_lora.py @@ -0,0 +1,387 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from unittest.mock import MagicMock, patch + +import torch + +from olive.model.handler.diffusers import DiffusersModelType +from olive.passes.diffusers.lora import SDLoRA +from olive.passes.olive_pass import create_pass_from_dict + +# Constants +SD15_SCALING_FACTOR = 0.18215 +SDXL_SCALING_FACTOR = 0.13025 +DEFAULT_TRAINING_ARGS = {"max_train_steps": 1, "train_batch_size": 1} + + +def get_pass_config(data_dir, **kwargs): + return { + "train_data_config": { + "name": "test_data", + "type": "ImageDataContainer", + "load_dataset_config": { + "type": "image_folder_dataset", + "params": {"data_dir": data_dir}, + }, + }, + **kwargs, + } + + +def setup_tokenizer_mock(mock_auto_tokenizer): + mock_tokenizer = MagicMock() + mock_tokenizer.model_max_length = 77 + mock_token_output = MagicMock() + mock_token_output.input_ids = torch.ones(1, 77, dtype=torch.long) + mock_tokenizer.return_value = mock_token_output + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + return mock_tokenizer + + +def setup_sd_mocks( + mock_unet, + mock_vae, + mock_clip, + mock_scheduler, + mock_get_peft_model, + mock_torch_model, + scaling_factor=SD15_SCALING_FACTOR, +): + # UNet + mock_unet.from_pretrained.return_value = mock_torch_model + mock_get_peft_model.return_value = mock_torch_model + + # VAE + mock_vae_model = MagicMock() + mock_vae_model.config = MagicMock() + mock_vae_model.config.scaling_factor = scaling_factor + mock_vae_model.requires_grad_ = MagicMock(return_value=mock_vae_model) + mock_vae_model.to = MagicMock(return_value=mock_vae_model) + mock_latent_dist = MagicMock() + mock_latent_dist.sample.return_value = torch.randn(1, 4, 8, 8) + mock_vae_model.encode.return_value = mock_latent_dist + mock_vae.from_pretrained.return_value = mock_vae_model + + # CLIP + mock_clip_model = MagicMock() + mock_clip_model.return_value = (torch.randn(1, 77, 768),) + mock_clip_model.requires_grad_ = MagicMock(return_value=mock_clip_model) + mock_clip_model.to = MagicMock(return_value=mock_clip_model) + mock_clip.from_pretrained.return_value = mock_clip_model + + # Noise scheduler + mock_noise_scheduler = MagicMock() + mock_noise_scheduler.config.num_train_timesteps = 1000 + mock_noise_scheduler.config.prediction_type = "epsilon" + mock_noise_scheduler.add_noise.return_value = torch.randn(1, 4, 8, 8) + mock_scheduler.from_pretrained.return_value = mock_noise_scheduler + + return mock_vae_model, mock_clip_model + + +@patch("diffusers.StableDiffusionPipeline.save_lora_weights") +@patch("peft.get_peft_model") +@patch("peft.LoraConfig") +@patch("diffusers.DDPMScheduler") +@patch("diffusers.UNet2DConditionModel") +@patch("diffusers.AutoencoderKL") +@patch("transformers.CLIPTextModel") +@patch("transformers.AutoTokenizer") +@patch("accelerate.Accelerator") +@patch("diffusers.optimization.get_scheduler") +def test_sd_lora_train_sd15( + mock_get_scheduler, + mock_accelerator_cls, + mock_auto_tokenizer, + mock_clip, + mock_vae, + mock_unet, + mock_scheduler, + mock_lora_config, + mock_get_peft_model, + mock_save_lora, + test_image_folder, + output_folder, + mock_accelerator, + mock_torch_model, + mock_input_model_sd15, +): + mock_accelerator_cls.return_value = mock_accelerator + mock_get_scheduler.return_value = MagicMock() + setup_tokenizer_mock(mock_auto_tokenizer) + setup_sd_mocks(mock_unet, mock_vae, mock_clip, mock_scheduler, mock_get_peft_model, mock_torch_model) + + config = get_pass_config( + test_image_folder, + model_type=DiffusersModelType.SD15, + training_args=DEFAULT_TRAINING_ARGS, + ) + p = create_pass_from_dict(SDLoRA, config, disable_search=True) + result = p.run(mock_input_model_sd15, output_folder) + + assert result is not None + mock_unet.from_pretrained.assert_called_once() + mock_vae.from_pretrained.assert_called_once() + mock_clip.from_pretrained.assert_called_once() + + +@patch("peft.get_peft_model") +@patch("peft.LoraConfig") +@patch("diffusers.DDPMScheduler") +@patch("diffusers.UNet2DConditionModel") +@patch("diffusers.AutoencoderKL") +@patch("transformers.CLIPTextModel") +@patch("transformers.AutoTokenizer") +@patch("accelerate.Accelerator") +@patch("diffusers.optimization.get_scheduler") +def test_sd_lora_merge_lora( + mock_get_scheduler, + mock_accelerator_cls, + mock_auto_tokenizer, + mock_clip, + mock_vae, + mock_unet, + mock_scheduler, + mock_lora_config, + mock_get_peft_model, + test_image_folder, + output_folder, + mock_accelerator, + mock_torch_model, + mock_input_model_sd15, +): + mock_accelerator_cls.return_value = mock_accelerator + mock_get_scheduler.return_value = MagicMock() + setup_tokenizer_mock(mock_auto_tokenizer) + + # Add merge-specific mocks + mock_torch_model.merge_and_unload = MagicMock(return_value=mock_torch_model) + mock_torch_model.save_pretrained = MagicMock() + + setup_sd_mocks(mock_unet, mock_vae, mock_clip, mock_scheduler, mock_get_peft_model, mock_torch_model) + + config = get_pass_config( + test_image_folder, + model_type=DiffusersModelType.SD15, + merge_lora=True, + training_args=DEFAULT_TRAINING_ARGS, + ) + p = create_pass_from_dict(SDLoRA, config, disable_search=True) + result = p.run(mock_input_model_sd15, output_folder) + + assert result is not None + mock_torch_model.merge_and_unload.assert_called_once() + mock_torch_model.save_pretrained.assert_called_once() + + +@patch("diffusers.StableDiffusionXLPipeline.save_lora_weights") +@patch("peft.get_peft_model") +@patch("peft.LoraConfig") +@patch("diffusers.DDPMScheduler") +@patch("diffusers.UNet2DConditionModel") +@patch("diffusers.AutoencoderKL") +@patch("transformers.CLIPTextModel") +@patch("transformers.CLIPTextModelWithProjection") +@patch("transformers.AutoTokenizer") +@patch("accelerate.Accelerator") +@patch("diffusers.optimization.get_scheduler") +def test_sd_lora_train_sdxl( + mock_get_scheduler, + mock_accelerator_cls, + mock_auto_tokenizer, + mock_clip_proj, + mock_clip, + mock_vae, + mock_unet, + mock_scheduler, + mock_lora_config, + mock_get_peft_model, + mock_save_lora, + test_image_folder, + output_folder, + mock_accelerator, + mock_torch_model, + mock_input_model_sdxl, +): + mock_accelerator_cls.return_value = mock_accelerator + mock_get_scheduler.return_value = MagicMock() + setup_tokenizer_mock(mock_auto_tokenizer) + setup_sd_mocks( + mock_unet, mock_vae, mock_clip, mock_scheduler, mock_get_peft_model, mock_torch_model, SDXL_SCALING_FACTOR + ) + + # SDXL second text encoder + mock_clip_proj_model = MagicMock() + mock_clip_proj_model.return_value = (torch.randn(1, 77, 1280),) + mock_clip_proj_model.requires_grad_ = MagicMock(return_value=mock_clip_proj_model) + mock_clip_proj_model.to = MagicMock(return_value=mock_clip_proj_model) + mock_clip_proj.from_pretrained.return_value = mock_clip_proj_model + + config = get_pass_config( + test_image_folder, + model_type=DiffusersModelType.SDXL, + training_args=DEFAULT_TRAINING_ARGS, + ) + p = create_pass_from_dict(SDLoRA, config, disable_search=True) + result = p.run(mock_input_model_sdxl, output_folder) + + assert result is not None + mock_unet.from_pretrained.assert_called_once() + mock_clip.from_pretrained.assert_called_once() + mock_clip_proj.from_pretrained.assert_called_once() + + +@patch("diffusers.FluxPipeline.save_lora_weights") +@patch("peft.LoraConfig") +@patch("diffusers.FluxTransformer2DModel") +@patch("diffusers.AutoencoderKL") +@patch("transformers.CLIPTextModel") +@patch("transformers.T5EncoderModel") +@patch("transformers.AutoTokenizer") +@patch("accelerate.Accelerator") +@patch("diffusers.optimization.get_scheduler") +def test_sd_lora_train_flux( + mock_get_scheduler, + mock_accelerator_cls, + mock_auto_tokenizer, + mock_t5, + mock_clip, + mock_vae, + mock_transformer, + mock_lora_config, + mock_save_lora, + test_image_folder, + output_folder, + mock_accelerator, + mock_torch_model, + mock_input_model_flux, +): + mock_accelerator_cls.return_value = mock_accelerator + mock_get_scheduler.return_value = MagicMock() + + # Tokenizer + mock_tokenizer = MagicMock() + mock_tokenizer.model_max_length = 77 + mock_token_output = MagicMock() + mock_token_output.input_ids = torch.ones(1, 77, dtype=torch.long) + mock_tokenizer.return_value = mock_token_output + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + # Transformer + mock_torch_model.add_adapter = MagicMock() + mock_torch_model.enable_gradient_checkpointing = MagicMock() + mock_transformer.from_pretrained.return_value = mock_torch_model + + # VAE for Flux + mock_vae_model = MagicMock() + mock_vae_model.config = MagicMock() + mock_vae_model.config.scaling_factor = SDXL_SCALING_FACTOR + mock_vae_model.requires_grad_ = MagicMock(return_value=mock_vae_model) + mock_vae_model.to = MagicMock(return_value=mock_vae_model) + mock_latent_dist = MagicMock() + mock_latent_dist.latent_dist.sample.return_value = torch.randn(1, 16, 8, 8) + mock_vae_model.encode.return_value = mock_latent_dist + mock_vae.from_pretrained.return_value = mock_vae_model + + # CLIP + mock_clip_model = MagicMock() + mock_clip_model.return_value = MagicMock(pooler_output=torch.randn(1, 768)) + mock_clip_model.requires_grad_ = MagicMock(return_value=mock_clip_model) + mock_clip_model.to = MagicMock(return_value=mock_clip_model) + mock_clip.from_pretrained.return_value = mock_clip_model + + # T5 + mock_t5_model = MagicMock() + mock_t5_model.return_value = (torch.randn(1, 512, 4096),) + mock_t5_model.requires_grad_ = MagicMock(return_value=mock_t5_model) + mock_t5_model.to = MagicMock(return_value=mock_t5_model) + mock_t5.from_pretrained.return_value = mock_t5_model + + config = get_pass_config( + test_image_folder, + model_type=DiffusersModelType.FLUX, + training_args=DEFAULT_TRAINING_ARGS, + ) + p = create_pass_from_dict(SDLoRA, config, disable_search=True) + result = p.run(mock_input_model_flux, output_folder) + + assert result is not None + mock_transformer.from_pretrained.assert_called_once() + mock_clip.from_pretrained.assert_called_once() + mock_t5.from_pretrained.assert_called_once() + + +@patch("diffusers.StableDiffusionPipeline.save_lora_weights") +@patch("peft.get_peft_model") +@patch("peft.LoraConfig") +@patch("diffusers.DDPMScheduler") +@patch("diffusers.UNet2DConditionModel") +@patch("diffusers.AutoencoderKL") +@patch("transformers.CLIPTextModel") +@patch("transformers.AutoTokenizer") +@patch("accelerate.Accelerator") +@patch("diffusers.optimization.get_scheduler") +def test_sd_lora_dreambooth_sd15( + mock_get_scheduler, + mock_accelerator_cls, + mock_auto_tokenizer, + mock_clip, + mock_vae, + mock_unet, + mock_scheduler, + mock_lora_config, + mock_get_peft_model, + mock_save_lora, + test_image_folder, + output_folder, + mock_accelerator, + mock_torch_model, + mock_input_model_sd15, +): + mock_accelerator_cls.return_value = mock_accelerator + mock_get_scheduler.return_value = MagicMock() + setup_tokenizer_mock(mock_auto_tokenizer) + + # DreamBooth needs batch size 2 tensors + mock_torch_model.return_value = (torch.randn(2, 4, 8, 8),) + mock_unet.from_pretrained.return_value = mock_torch_model + mock_get_peft_model.return_value = mock_torch_model + + # VAE + mock_vae_model = MagicMock() + mock_vae_model.config = MagicMock() + mock_vae_model.config.scaling_factor = SD15_SCALING_FACTOR + mock_vae_model.requires_grad_ = MagicMock(return_value=mock_vae_model) + mock_vae_model.to = MagicMock(return_value=mock_vae_model) + mock_latent_dist = MagicMock() + mock_latent_dist.sample.return_value = torch.randn(2, 4, 8, 8) + mock_vae_model.encode.return_value = mock_latent_dist + mock_vae.from_pretrained.return_value = mock_vae_model + + # CLIP + mock_clip_model = MagicMock() + mock_clip_model.return_value = (torch.randn(2, 77, 768),) + mock_clip_model.requires_grad_ = MagicMock(return_value=mock_clip_model) + mock_clip_model.to = MagicMock(return_value=mock_clip_model) + mock_clip.from_pretrained.return_value = mock_clip_model + + # Noise scheduler + mock_noise_scheduler = MagicMock() + mock_noise_scheduler.config.num_train_timesteps = 1000 + mock_noise_scheduler.config.prediction_type = "epsilon" + mock_noise_scheduler.add_noise.return_value = torch.randn(2, 4, 8, 8) + mock_scheduler.from_pretrained.return_value = mock_noise_scheduler + + config = get_pass_config( + test_image_folder, + model_type=DiffusersModelType.SD15, + dreambooth=True, + prior_loss_weight=1.0, + training_args=DEFAULT_TRAINING_ARGS, + ) + p = create_pass_from_dict(SDLoRA, config, disable_search=True) + result = p.run(mock_input_model_sd15, output_folder) + + assert result is not None