From 9cd9edf4dba42756d7dd768c937c0a1b663ee4b1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 4 Apr 2024 16:59:03 +0300 Subject: [PATCH 01/15] face_id --- ip_adapter/ip_adapter_faceid.py | 473 ++++++++++++++++++++++++++++++++ 1 file changed, 473 insertions(+) create mode 100644 ip_adapter/ip_adapter_faceid.py diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py new file mode 100644 index 0000000..d0efa21 --- /dev/null +++ b/ip_adapter/ip_adapter_faceid.py @@ -0,0 +1,473 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class IPAdapter: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["blocks"]): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + self.target_blocks = target_blocks + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + selected = False + for block_name in self.target_blocks: + if block_name in name: + selected = True + break + if selected: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + else: + attn_procs[name] = AttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ).to(self.device, dtype=torch.float16) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + else: + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) + + if content_prompt_embeds is not None: + clip_image_embeds = clip_image_embeds - content_prompt_embeds + + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + neg_content_prompt=None, + neg_content_scale=1.0, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = None + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=pooled_prompt_embeds_ + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_prompt=None, + neg_content_scale=1.0, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = None + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, + content_prompt_embeds=pooled_prompt_embeds_) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images From 9b3b87f52a91f0e99996b1eb56fb2195ed7d1c65 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 08:48:59 +0300 Subject: [PATCH 02/15] face_id --- ip_adapter/attention_processor_faceid.py | 426 +++++++++++++++++++++++ ip_adapter/ip_adapter_faceid.py | 169 ++++++--- 2 files changed, 554 insertions(+), 41 deletions(-) create mode 100644 ip_adapter/attention_processor_faceid.py diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py new file mode 100644 index 0000000..7f7fd6e --- /dev/null +++ b/ip_adapter/attention_processor_faceid.py @@ -0,0 +1,426 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.lora import LoRALinearLayer + + +class LoRAAttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, + num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor2_0(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, + num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + self.num_tokens = num_tokens + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + # query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # for text + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index d0efa21..3152bad 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -8,22 +8,109 @@ from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from .utils import is_torch2_available, get_generator - -if is_torch2_available(): - from .attention_processor import ( - AttnProcessor2_0 as AttnProcessor, - ) - from .attention_processor import ( - CNAttnProcessor2_0 as CNAttnProcessor, +USE_DAFAULT_ATTN = False # should be True for visualization_attnmap +if is_torch2_available() and (not USE_DAFAULT_ATTN): + from .attention_processor_faceid import ( + LoRAAttnProcessor2_0 as LoRAAttnProcessor, ) - from .attention_processor import ( - IPAttnProcessor2_0 as IPAttnProcessor, + from .attention_processor_faceid import ( + LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor, ) else: - from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor -from .resampler import Resampler + from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor +from .resampler import PerceiverAttention, FeedForward + + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = x + scale * out + return out class ImageProjModel(torch.nn.Module): """Projection Model""" @@ -64,12 +151,14 @@ def forward(self, image_embeds): return clip_extra_context_tokens -class IPAdapter: - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["blocks"]): +class IPAdapterFaceID: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, target_blocks=["blocks"], torch_dtype=torch.float16): self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt + self.lora_rank = lora_rank self.num_tokens = num_tokens + self.torch_dtype = torch_dtype self.target_blocks = target_blocks self.pipe = sd_pipe.to(self.device) @@ -80,17 +169,18 @@ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, t self.device, dtype=torch.float16 ) self.clip_image_processor = CLIPImageProcessor() + # image proj model self.image_proj_model = self.init_proj() self.load_ip_adapter() def init_proj(self): - image_proj_model = ImageProjModel( + image_proj_model = MLPProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, clip_embeddings_dim=self.image_encoder.config.projection_dim, - clip_extra_context_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) return image_proj_model def set_ip_adapter(self): @@ -107,7 +197,9 @@ def set_ip_adapter(self): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: - attn_procs[name] = AttnProcessor() + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) else: selected = False for block_name in self.target_blocks: @@ -115,24 +207,16 @@ def set_ip_adapter(self): selected = True break if selected: - attn_procs[name] = IPAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) else: - attn_procs[name] = AttnProcessor( + attn_procs[name] = LoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, - ).to(self.device, dtype=torch.float16) + ).to(self.device, dtype=self.torch_dtype) unet.set_attn_processor(attn_procs) - if hasattr(self.pipe, "controlnet"): - if isinstance(self.pipe.controlnet, MultiControlNetModel): - for controlnet in self.pipe.controlnet.nets: - controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) - else: - self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) def load_ip_adapter(self): if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": @@ -150,7 +234,8 @@ def load_ip_adapter(self): ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None): + def get_image_embeds(self, faceid_embeds, pil_image=None, clip_image_embeds=None,content_prompt_embeds=None): + if pil_image is not None: if isinstance(pil_image, Image.Image): pil_image = [pil_image] @@ -162,19 +247,23 @@ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_promp if content_prompt_embeds is not None: clip_image_embeds = clip_image_embeds - content_prompt_embeds - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) - return image_prompt_embeds, uncond_image_prompt_embeds + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds) + style_image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) + return image_prompt_embeds, style_image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): - if isinstance(attn_processor, IPAttnProcessor): + if isinstance(attn_processor, LoRAIPAttnProcessor): attn_processor.scale = scale def generate( self, pil_image=None, clip_image_embeds=None, + faceid_embeds=None, prompt=None, negative_prompt=None, scale=1.0, @@ -187,11 +276,10 @@ def generate( **kwargs, ): self.set_scale(scale) - if pil_image is not None: num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) else: - num_prompts = clip_image_embeds.size(0) + num_prompts = faceid_embeds.size(0) if prompt is None: prompt = "best quality, high quality" @@ -220,9 +308,8 @@ def generate( else: pooled_prompt_embeds_ = None - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( - pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=pooled_prompt_embeds_ - ) + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=pooled_prompt_embeds_) + bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) From fb6df2cbede31b0ea74b1c21e068413f692aeb39 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 13:04:42 +0300 Subject: [PATCH 03/15] face_id implementation --- ip_adapter/ip_adapter_faceid.py | 174 ++------------------------------ 1 file changed, 10 insertions(+), 164 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 3152bad..53bb162 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -7,6 +7,8 @@ from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from .utils import is_torch2_available, get_generator + USE_DAFAULT_ATTN = False # should be True for visualization_attnmap if is_torch2_available() and (not USE_DAFAULT_ATTN): @@ -152,9 +154,8 @@ def forward(self, image_embeds): class IPAdapterFaceID: - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, target_blocks=["blocks"], torch_dtype=torch.float16): + def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, target_blocks=["blocks"], torch_dtype=torch.float16): self.device = device - self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.lora_rank = lora_rank self.num_tokens = num_tokens @@ -164,11 +165,6 @@ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() - # load image encoder - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( - self.device, dtype=torch.float16 - ) - self.clip_image_processor = CLIPImageProcessor() # image proj model self.image_proj_model = self.init_proj() @@ -178,7 +174,7 @@ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, def init_proj(self): image_proj_model = MLPProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, - clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_embeddings_dim=512, num_tokens=self.num_tokens, ).to(self.device, dtype=self.torch_dtype) return image_proj_model @@ -234,25 +230,14 @@ def load_ip_adapter(self): ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) @torch.inference_mode() - def get_image_embeds(self, faceid_embeds, pil_image=None, clip_image_embeds=None,content_prompt_embeds=None): - - if pil_image is not None: - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds - else: - clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) - - if content_prompt_embeds is not None: - clip_image_embeds = clip_image_embeds - content_prompt_embeds - + def get_image_embeds(self, faceid_embeds, content_prompt_embeds=None): faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + if content_prompt_embeds is not None: + faceid_embeds = faceid_embeds - content_prompt_embeds image_prompt_embeds = self.image_proj_model(faceid_embeds) - style_image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) - return image_prompt_embeds, style_image_prompt_embeds, uncond_image_prompt_embeds + return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): @@ -261,8 +246,6 @@ def set_scale(self, scale): def generate( self, - pil_image=None, - clip_image_embeds=None, faceid_embeds=None, prompt=None, negative_prompt=None, @@ -276,10 +259,7 @@ def generate( **kwargs, ): self.set_scale(scale) - if pil_image is not None: - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - else: - num_prompts = faceid_embeds.size(0) + num_prompts = faceid_embeds.size(0) if prompt is None: prompt = "best quality, high quality" @@ -308,7 +288,7 @@ def generate( else: pooled_prompt_embeds_ = None - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=pooled_prompt_embeds_) + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, pooled_prompt_embeds_) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -424,137 +404,3 @@ def generate( ).images return images - - -class IPAdapterPlus(IPAdapter): - """IP-Adapter with fine-grained features""" - - def init_proj(self): - image_proj_model = Resampler( - dim=self.pipe.unet.config.cross_attention_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4, - ).to(self.device, dtype=torch.float16) - return image_proj_model - - @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image_embeds=None): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder( - torch.zeros_like(clip_image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds - - -class IPAdapterFull(IPAdapterPlus): - """IP-Adapter with full features""" - - def init_proj(self): - image_proj_model = MLPProjModel( - cross_attention_dim=self.pipe.unet.config.cross_attention_dim, - clip_embeddings_dim=self.image_encoder.config.hidden_size, - ).to(self.device, dtype=torch.float16) - return image_proj_model - - -class IPAdapterPlusXL(IPAdapter): - """SDXL""" - - def init_proj(self): - image_proj_model = Resampler( - dim=1280, - depth=4, - dim_head=64, - heads=20, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4, - ).to(self.device, dtype=torch.float16) - return image_proj_model - - @torch.inference_mode() - def get_image_embeds(self, pil_image): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder( - torch.zeros_like(clip_image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds - - def generate( - self, - pil_image, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - num_inference_steps=30, - **kwargs, - ): - self.set_scale(scale) - - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.pipe.encode_prompt( - prompt, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) - negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - - generator = get_generator(seed, self.device) - - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - return images From 730a34d4a12c180b114b1e6e9fdc4165eacc73cc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 13:07:07 +0300 Subject: [PATCH 04/15] start faceidplus and faceidxl implementations --- ip_adapter/ip_adapter_faceid.py | 166 +++++++++++++++++++++++++++++--- 1 file changed, 152 insertions(+), 14 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 53bb162..785d9b1 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -333,6 +333,7 @@ def generate( num_samples=4, seed=None, num_inference_steps=30, + neg_content_emb=None, neg_content_prompt=None, neg_content_scale=1.0, **kwargs, @@ -351,20 +352,23 @@ def generate( if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts - if neg_content_prompt is not None: - with torch.inference_mode(): - ( - prompt_embeds_, # torch.Size([1, 77, 2048]) - negative_prompt_embeds_, - pooled_prompt_embeds_, # torch.Size([1, 1280]) - negative_pooled_prompt_embeds_, - ) = self.pipe.encode_prompt( - neg_content_prompt, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - pooled_prompt_embeds_ *= neg_content_scale + if neg_content_emb is None: + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = neg_content_emb else: pooled_prompt_embeds_ = None @@ -404,3 +408,137 @@ def generate( ).images return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images From 91340c0eb63c70f8f3cbca7c144c6cc106e58326 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 13:15:35 +0300 Subject: [PATCH 05/15] faceidxl --- ip_adapter/ip_adapter_faceid.py | 66 +++++++++------------------------ 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 785d9b1..f525fc4 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -321,26 +321,26 @@ def generate( return images -class IPAdapterXL(IPAdapter): +class IPAdapterFaceIDXL(IPAdapterFaceID): """SDXL""" def generate( - self, - pil_image, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - num_inference_steps=30, - neg_content_emb=None, - neg_content_prompt=None, - neg_content_scale=1.0, - **kwargs, + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_emb=None, + neg_content_prompt=None, + neg_content_scale=1.0, + **kwargs, ): self.set_scale(scale) - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + num_prompts = faceid_embeds.size(0) if prompt is None: prompt = "best quality, high quality" @@ -371,9 +371,8 @@ def generate( pooled_prompt_embeds_ = neg_content_emb else: pooled_prompt_embeds_ = None + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, content_prompt_embeds=pooled_prompt_embeds_) - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, - content_prompt_embeds=pooled_prompt_embeds_) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) @@ -395,7 +394,7 @@ def generate( prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - self.generator = get_generator(seed, self.device) + generator = get_generator(seed, self.device) images = self.pipe( prompt_embeds=prompt_embeds, @@ -403,44 +402,13 @@ def generate( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, - generator=self.generator, + generator=generator, **kwargs, ).images return images -class IPAdapterPlus(IPAdapter): - """IP-Adapter with fine-grained features""" - - def init_proj(self): - image_proj_model = Resampler( - dim=self.pipe.unet.config.cross_attention_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4, - ).to(self.device, dtype=torch.float16) - return image_proj_model - - @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image_embeds=None): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder( - torch.zeros_like(clip_image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds - - class IPAdapterFull(IPAdapterPlus): """IP-Adapter with full features""" From b0516779f5c5d9148831627a79a0b63b7e421143 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 14:24:46 +0300 Subject: [PATCH 06/15] fixes --- ip_adapter/ip_adapter_faceid.py | 154 ++++++++++++++++++++++++-------- 1 file changed, 118 insertions(+), 36 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index f525fc4..9fb067c 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -9,8 +9,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from .utils import is_torch2_available, get_generator - -USE_DAFAULT_ATTN = False # should be True for visualization_attnmap +USE_DAFAULT_ATTN = False # should be True for visualization_attnmap if is_torch2_available() and (not USE_DAFAULT_ATTN): from .attention_processor_faceid import ( LoRAAttnProcessor2_0 as LoRAAttnProcessor, @@ -114,6 +113,7 @@ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): out = x + scale * out return out + class ImageProjModel(torch.nn.Module): """Projection Model""" @@ -136,25 +136,29 @@ def forward(self, image_embeds): class MLPProjModel(torch.nn.Module): - """SD model with image prompt""" - - def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): super().__init__() + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.proj = torch.nn.Sequential( - torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), torch.nn.GELU(), - torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), - torch.nn.LayerNorm(cross_attention_dim) + torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) - def forward(self, image_embeds): - clip_extra_context_tokens = self.proj(image_embeds) - return clip_extra_context_tokens + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x class IPAdapterFaceID: - def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, target_blocks=["blocks"], torch_dtype=torch.float16): + def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, target_blocks=["blocks"], + torch_dtype=torch.float16): self.device = device self.ip_ckpt = ip_ckpt self.lora_rank = lora_rank @@ -165,7 +169,6 @@ def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, targe self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() - # image proj model self.image_proj_model = self.init_proj() @@ -174,7 +177,7 @@ def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, targe def init_proj(self): image_proj_model = MLPProjModel( cross_attention_dim=self.pipe.unet.config.cross_attention_dim, - clip_embeddings_dim=512, + id_embeddings_dim=512, num_tokens=self.num_tokens, ).to(self.device, dtype=self.torch_dtype) return image_proj_model @@ -204,7 +207,8 @@ def set_ip_adapter(self): break if selected: attn_procs[name] = LoRAIPAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, + rank=self.lora_rank, num_tokens=self.num_tokens, ).to(self.device, dtype=self.torch_dtype) else: @@ -325,18 +329,18 @@ class IPAdapterFaceIDXL(IPAdapterFaceID): """SDXL""" def generate( - self, - faceid_embeds=None, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - num_inference_steps=30, - neg_content_emb=None, - neg_content_prompt=None, - neg_content_scale=1.0, - **kwargs, + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_emb=None, + neg_content_prompt=None, + neg_content_scale=1.0, + **kwargs, ): self.set_scale(scale) @@ -371,7 +375,8 @@ def generate( pooled_prompt_embeds_ = neg_content_emb else: pooled_prompt_embeds_ = None - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, content_prompt_embeds=pooled_prompt_embeds_) + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, + content_prompt_embeds=pooled_prompt_embeds_) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) @@ -409,16 +414,93 @@ def generate( return images -class IPAdapterFull(IPAdapterPlus): - """IP-Adapter with full features""" +class IPAdapterFaceIDXL(IPAdapterFaceID): + """SDXL""" - def init_proj(self): - image_proj_model = MLPProjModel( - cross_attention_dim=self.pipe.unet.config.cross_attention_dim, - clip_embeddings_dim=self.image_encoder.config.hidden_size, - ).to(self.device, dtype=torch.float16) - return image_proj_model + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + neg_content_emb=None, + neg_content_prompt=None, + neg_content_scale=1.0, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + if neg_content_emb is None: + if neg_content_prompt is not None: + with torch.inference_mode(): + ( + prompt_embeds_, # torch.Size([1, 77, 2048]) + negative_prompt_embeds_, + pooled_prompt_embeds_, # torch.Size([1, 1280]) + negative_pooled_prompt_embeds_, + ) = self.pipe.encode_prompt( + neg_content_prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + pooled_prompt_embeds_ *= neg_content_scale + else: + pooled_prompt_embeds_ = neg_content_emb + else: + pooled_prompt_embeds_ = None + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, + content_prompt_embeds=pooled_prompt_embeds_) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images class IPAdapterPlusXL(IPAdapter): """SDXL""" From a35b6f5dfdf790ed66afbfe4f1e02eaa266490a8 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 16:34:07 +0300 Subject: [PATCH 07/15] adding plus --- ip_adapter/ip_adapter_faceid.py | 71 +++------------------------------ 1 file changed, 5 insertions(+), 66 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 9fb067c..da4451c 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -502,15 +502,15 @@ def generate( return images -class IPAdapterPlusXL(IPAdapter): - """SDXL""" +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" def init_proj(self): image_proj_model = Resampler( - dim=1280, + dim=self.pipe.unet.config.cross_attention_dim, depth=4, dim_head=64, - heads=20, + heads=12, num_queries=self.num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=self.pipe.unet.config.cross_attention_dim, @@ -519,7 +519,7 @@ def init_proj(self): return image_proj_model @torch.inference_mode() - def get_image_embeds(self, pil_image): + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values @@ -531,64 +531,3 @@ def get_image_embeds(self, pil_image): ).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds - - def generate( - self, - pil_image, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - num_inference_steps=30, - **kwargs, - ): - self.set_scale(scale) - - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.pipe.encode_prompt( - prompt, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) - negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - - generator = get_generator(seed, self.device) - - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - return images From a2c052b9c4a47b32be21859b94e67fd7adf36681 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 16:40:42 +0300 Subject: [PATCH 08/15] adding plus --- ip_adapter/ip_adapter_faceid.py | 166 ++++++++++++++++++++++++++++---- 1 file changed, 147 insertions(+), 19 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index da4451c..7922595 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -502,32 +502,160 @@ def generate( return images -class IPAdapterPlus(IPAdapter): - """IP-Adapter with fine-grained features""" + +class IPAdapterFaceIDPlus: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, + torch_dtype=torch.float16): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.lora_rank = lora_rank + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=self.torch_dtype + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() def init_proj(self): - image_proj_model = Resampler( - dim=self.pipe.unet.config.cross_attention_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4, - ).to(self.device, dtype=torch.float16) + image_proj_model = ProjPlusModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) return image_proj_model + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image_embeds=None): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) + def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): + if isinstance(face_image, Image.Image): + pil_image = [face_image] + clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=self.torch_dtype) clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, + shortcut=shortcut, scale=s_scale) return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=False, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, + shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + From fb922b46b54b62cd78a6558c6e1b2aa5ffd33c88 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 8 Apr 2024 17:31:14 +0300 Subject: [PATCH 09/15] adjust attention processor for faceid --- ip_adapter/attention_processor_faceid.py | 43 +++++++++++++++++++++-- ip_adapter/ip_adapter_faceid.py | 44 +++++++++++++++++++----- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py index 7f7fd6e..1ace1d6 100644 --- a/ip_adapter/attention_processor_faceid.py +++ b/ip_adapter/attention_processor_faceid.py @@ -105,7 +105,7 @@ class LoRAIPAttnProcessor(nn.Module): """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, - num_tokens=4): + num_tokens=4, skip=False): super().__init__() self.rank = rank @@ -120,6 +120,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha= self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens + self.skip = skip self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) @@ -176,6 +177,21 @@ def __call__( hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) + if not self.skip: + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) @@ -315,12 +331,13 @@ class LoRAIPAttnProcessor2_0(nn.Module): """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, - num_tokens=4): + num_tokens=4, skip=False): super().__init__() self.rank = rank self.lora_scale = lora_scale self.num_tokens = num_tokens + self.skip = skip self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) @@ -392,6 +409,28 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) + if not self.skip: + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + # for ip ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 7922595..2e3f4d1 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -197,7 +197,9 @@ def set_ip_adapter(self): hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=self.lora_rank, ).to(self.device, dtype=self.torch_dtype) else: selected = False @@ -207,14 +209,19 @@ def set_ip_adapter(self): break if selected: attn_procs[name] = LoRAIPAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, ).to(self.device, dtype=self.torch_dtype) else: - attn_procs[name] = LoRAAttnProcessor( + attn_procs[name] = LoRAIPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, + scale=1.0, + rank=self.lora_rank, + skip=True ).to(self.device, dtype=self.torch_dtype) unet.set_attn_processor(attn_procs) @@ -550,13 +557,34 @@ def set_ip_adapter(self): hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=self.lora_rank, ).to(self.device, dtype=self.torch_dtype) else: - attn_procs[name] = LoRAIPAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, - num_tokens=self.num_tokens, - ).to(self.device, dtype=self.torch_dtype) + selected = False + for block_name in self.target_blocks: + if block_name in name: + selected = True + break + if selected: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + rank=self.lora_rank, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + rank=self.lora_rank, + num_tokens=self.num_tokens, + skip=True + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) def load_ip_adapter(self): From 0ae21440260655e10c4a3c6c8c567d94a5a36f60 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 9 Apr 2024 11:51:15 +0300 Subject: [PATCH 10/15] fix attention processor for faceid --- ip_adapter/attention_processor_faceid.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py index 1ace1d6..c514b45 100644 --- a/ip_adapter/attention_processor_faceid.py +++ b/ip_adapter/attention_processor_faceid.py @@ -192,19 +192,6 @@ def __call__( hidden_states = hidden_states + self.scale * ip_hidden_states - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - self.attn_map = ip_attention_probs - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - - hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) From 0eeaeaa86fec8cf5676e512ae073cda7e7afb258 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 9 Apr 2024 12:41:34 +0300 Subject: [PATCH 11/15] fixes --- ip_adapter/ip_adapter_faceid.py | 46 +++------------------------------ 1 file changed, 4 insertions(+), 42 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 2e3f4d1..251bfeb 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -7,6 +7,8 @@ from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor from .utils import is_torch2_available, get_generator USE_DAFAULT_ATTN = False # should be True for visualization_attnmap @@ -114,48 +116,6 @@ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): return out -class ImageProjModel(torch.nn.Module): - """Projection Model""" - - def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): - super().__init__() - - self.generator = None - self.cross_attention_dim = cross_attention_dim - self.clip_extra_context_tokens = clip_extra_context_tokens - self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) - self.norm = torch.nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds): - embeds = image_embeds - clip_extra_context_tokens = self.proj(embeds).reshape( - -1, self.clip_extra_context_tokens, self.cross_attention_dim - ) - clip_extra_context_tokens = self.norm(clip_extra_context_tokens) - return clip_extra_context_tokens - - -class MLPProjModel(torch.nn.Module): - def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): - super().__init__() - - self.cross_attention_dim = cross_attention_dim - self.num_tokens = num_tokens - - self.proj = torch.nn.Sequential( - torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), - torch.nn.GELU(), - torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), - ) - self.norm = torch.nn.LayerNorm(cross_attention_dim) - - def forward(self, id_embeds): - x = self.proj(id_embeds) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - x = self.norm(x) - return x - - class IPAdapterFaceID: def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, target_blocks=["blocks"], torch_dtype=torch.float16): @@ -270,6 +230,7 @@ def generate( **kwargs, ): self.set_scale(scale) + num_prompts = faceid_embeds.size(0) if prompt is None: @@ -638,6 +599,7 @@ def generate( s_scale=1.0, shortcut=False, **kwargs, + ): self.set_scale(scale) From a6961936cef34b8c0637a84b20c5775da54dcd14 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 9 Apr 2024 13:43:05 +0300 Subject: [PATCH 12/15] fixes --- ip_adapter/attention_processor_faceid.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py index c514b45..9b576a7 100644 --- a/ip_adapter/attention_processor_faceid.py +++ b/ip_adapter/attention_processor_faceid.py @@ -418,24 +418,6 @@ def __call__( hidden_states = hidden_states + self.scale * ip_hidden_states - # for ip - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + self.scale * ip_hidden_states - # linear proj hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) # dropout From 281519224448922bf23d96fa0b568ff2dbf3e7e9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 9 Apr 2024 13:47:21 +0300 Subject: [PATCH 13/15] fixes --- ip_adapter/ip_adapter_faceid.py | 89 --------------------------------- 1 file changed, 89 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 251bfeb..f040e17 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -382,95 +382,6 @@ def generate( return images -class IPAdapterFaceIDXL(IPAdapterFaceID): - """SDXL""" - - def generate( - self, - faceid_embeds=None, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - num_inference_steps=30, - neg_content_emb=None, - neg_content_prompt=None, - neg_content_scale=1.0, - **kwargs, - ): - self.set_scale(scale) - - num_prompts = faceid_embeds.size(0) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - if neg_content_emb is None: - if neg_content_prompt is not None: - with torch.inference_mode(): - ( - prompt_embeds_, # torch.Size([1, 77, 2048]) - negative_prompt_embeds_, - pooled_prompt_embeds_, # torch.Size([1, 1280]) - negative_pooled_prompt_embeds_, - ) = self.pipe.encode_prompt( - neg_content_prompt, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - pooled_prompt_embeds_ *= neg_content_scale - else: - pooled_prompt_embeds_ = neg_content_emb - else: - pooled_prompt_embeds_ = None - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, - content_prompt_embeds=pooled_prompt_embeds_) - - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.pipe.encode_prompt( - prompt, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) - negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - - generator = get_generator(seed, self.device) - - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - return images - - class IPAdapterFaceIDPlus: def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): From 07f5f71ca8e147c9496c04d11f200b111caedb72 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 9 Apr 2024 13:54:53 +0300 Subject: [PATCH 14/15] faceid plus --- ip_adapter/ip_adapter_faceid.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index f040e17..18b77e5 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -510,7 +510,6 @@ def generate( s_scale=1.0, shortcut=False, **kwargs, - ): self.set_scale(scale) @@ -559,4 +558,3 @@ def generate( return images - From 08a087a360df1beeabc5da6b2b0f1f45a1e54997 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 9 Apr 2024 14:12:06 +0300 Subject: [PATCH 15/15] bugfix + faceid plusxl --- ip_adapter/ip_adapter_faceid.py | 138 +++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py index 18b77e5..46c465c 100644 --- a/ip_adapter/ip_adapter_faceid.py +++ b/ip_adapter/ip_adapter_faceid.py @@ -384,13 +384,14 @@ def generate( class IPAdapterFaceIDPlus: def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, - torch_dtype=torch.float16): + torch_dtype=torch.float16, target_blocks=["blocks"]): self.device = device self.image_encoder_path = image_encoder_path self.ip_ckpt = ip_ckpt self.lora_rank = lora_rank self.num_tokens = num_tokens self.torch_dtype = torch_dtype + self.target_blocks = target_blocks self.pipe = sd_pipe.to(self.device) self.set_ip_adapter() @@ -558,3 +559,138 @@ def generate( return images +class IPAdapterFaceIDXL(IPAdapterFaceID): + """SDXL""" + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): + """SDXL""" + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=True, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + **kwargs, + ).images + + return images