-
Notifications
You must be signed in to change notification settings - Fork 142
Description
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
import torch
import cv2
import numpy as np
from PIL import Image
base_model_path = "/home/sobey/lzq/BrushEdit-main/model/BrushEdit/base_model/realisticVisionV60B1_v51VAE"
brushnet_path = "/home/sobey/lzq/BrushEdit-main/model/Brushnet/segmentation_mask_brushnet_ckpt"
blended = False
image_path="/home/sobey/lzq/BrushEdit-main/examples/brushnet/src/test_image.jpg"
mask_path="/home/sobey/lzq/BrushEdit-main/examples/brushnet/src/test_mask.jpg"
caption="A cake on the table."
brushnet_conditioning_scale=1.0
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
init_image = cv2.imread(image_path)[:,:,::-1]
mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis]
init_image = init_image * (1-mask_image)
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB")
generator = torch.Generator("cuda").manual_seed(1234)
image = pipe(
caption,
init_image,
mask_image,
num_inference_steps=50,
generator=generator,
brushnet_conditioning_scale=brushnet_conditioning_scale
).images[0]
if blended:
image_np=np.array(image)
init_image_np=cv2.imread(image_path)[:,:,::-1]
mask_np = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis]
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
mask_blurred = mask_blurred[:,:,np.newaxis]
mask_np = 1-(1-mask_np) * (1-mask_blurred)
image_pasted=init_image_np * (1-mask_np) + image_np*mask_np
image_pasted=image_pasted.astype(image_np.dtype)
image=Image.fromarray(image_pasted)
