Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added inference/.DS_Store
Binary file not shown.
Binary file added inference/stable-diffusion/.DS_Store
Binary file not shown.
1,031 changes: 1,031 additions & 0 deletions inference/stable-diffusion/StableDiffusion2_1-inpaint.ipynb

Large diffs are not rendered by default.

193 changes: 193 additions & 0 deletions inference/stable-diffusion/src-inpaint/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

import os
os.environ["NEURON_FUSE_SOFTMAX"] = "1"
import time
import copy
import torch
import shutil
import argparse
import numpy as np
import torch_neuronx
import torch.nn as nn
from wrapper import NeuronTextEncoder, UNetWrap, NeuronUNet, get_attention_scores
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import Attention

height = 512 // 8
width = 512 // 8

def compile_text_encoder(text_encoder, args):
print("Compiling text encoder...")
base_dir='text_encoder'
os.makedirs(os.path.join(args.checkpoints_path, base_dir), exist_ok=True)
os.makedirs(os.path.join(args.model_path, base_dir), exist_ok=True)
t = time.time()
# Apply the wrapper to deal with custom return type
text_encoder = NeuronTextEncoder(text_encoder)

# Compile text encoder
# This is used for indexing a lookup table in torch.nn.Embedding,
# so using random numbers may give errors (out of range).
emb = torch.tensor([[49406, 18376, 525, 7496, 49407, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]])
text_encoder_neuron = torch_neuronx.trace(
text_encoder.neuron_text_encoder, emb,
#compiler_workdir=os.path.join(args.checkpoints_path, base_dir),
)

# Save the compiled text encoder
text_encoder_filename = os.path.join(args.model_path, base_dir, 'model.pt')
torch.jit.save(text_encoder_neuron, text_encoder_filename)

# delete unused objects
del text_encoder
del text_encoder_neuron
print(f"Done. Elapsed time: {(time.time()-t)*1000}ms")

def compile_vae(decoder, args, dtype):
print("Compiling VAE...")
base_dir='vae_decoder'
os.makedirs(os.path.join(args.checkpoints_path, base_dir), exist_ok=True)
os.makedirs(os.path.join(args.model_path, base_dir), exist_ok=True)
t = time.time()
# Compile vae decoder
decoder_in = torch.randn([1, 4, height, width]).type(dtype)
decoder_neuron = torch_neuronx.trace(
decoder,
decoder_in,
#compiler_workdir=os.path.join(args.checkpoints_path, base_dir),
compiler_args=["--verbose", "info"]
)

# Save the compiled vae decoder
decoder_filename = os.path.join(args.model_path, base_dir, 'model.pt')
torch.jit.save(decoder_neuron, decoder_filename)

# delete unused objects
del decoder
del decoder_neuron
print(f"Done. Elapsed time: {(time.time()-t)*1000}ms")

def compile_unet(unet, args, dtype):
print("Compiling U-Net...")
base_dir='unet'
os.makedirs(os.path.join(args.checkpoints_path, base_dir), exist_ok=True)
os.makedirs(os.path.join(args.model_path, base_dir), exist_ok=True)
t = time.time()
# Compile unet - BF16
sample_1b = torch.randn([1, 9, height, width]).type(dtype)
timestep_1b = torch.tensor(999).type(dtype).expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 1024]).type(dtype)
example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b


unet_neuron = torch_neuronx.trace(
unet,
example_inputs,
#compiler_workdir=os.path.join(args.checkpoints_path, base_dir),
compiler_args=["--model-type=unet-inference", "--verbose=info"]
)

# save compiled unet
unet_filename = os.path.join(args.model_path, base_dir, 'model.pt')
torch.jit.save(unet_neuron, unet_filename)

# delete unused objects
del unet
del unet_neuron
print(f"Done. Elapsed time: {(time.time()-t)*1000}ms")

def compile_vae_post_quant_conv(post_quant_conv, args, dtype):
print("Compiling Post Quant Conv...")
base_dir='vae_post_quant_conv'
os.makedirs(os.path.join(args.checkpoints_path, base_dir), exist_ok=True)
os.makedirs(os.path.join(args.model_path, base_dir), exist_ok=True)
t = time.time()

# # Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, height, width]).type(dtype)
post_quant_conv_neuron = torch_neuronx.trace(
post_quant_conv,
post_quant_conv_in,
#compiler_workdir=os.path.join(args.checkpoints_path, base_dir),
compiler_args=["--verbose", "info"]
)

# # Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(args.model_path, base_dir, 'model.pt')
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)

# delete unused objects
del post_quant_conv
del post_quant_conv_neuron
print(f"Done. Elapsed time: {(time.time()-t)*1000}ms")

if __name__=='__main__':
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--model-path', type=str, help="Path where we'll save the model", default=os.environ["SM_MODEL_DIR"])
parser.add_argument('--checkpoints-path', type=str, help="Path where we'll save the best model and cache", default='/opt/ml/checkpoints')
parser.add_argument('--dtype', type=str, help="Datatype of the weights", default='fp32')

args = parser.parse_args()

# make sure the checkpoint path exists
os.makedirs(args.checkpoints_path, exist_ok=True)

# Model ID for SD version pipeline
model_id = "stabilityai/stable-diffusion-2-inpainting"

# --- Compile CLIP text encoder and save ---

dtype = torch.float32
# Only keep the model being compiled in RAM to minimze memory pressure
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=dtype)
text_encoder = copy.deepcopy(pipe.text_encoder)
del pipe
compile_text_encoder(text_encoder, args)

# --- Compile VAE decoder and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=dtype)
decoder = copy.deepcopy(pipe.vae.decoder)
del pipe
compile_vae(decoder, args, dtype)

# --- Compile UNet and save ---

pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=dtype)

# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores

# Apply double wrapper to deal with custom return type
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))

# Only keep the model being compiled in RAM to minimze memory pressure
unet = copy.deepcopy(pipe.unet.unetwrap)
del pipe
compile_unet(unet, args, dtype)

# --- Compile VAE post_quant_conv and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=dtype)
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
del pipe
compile_vae_post_quant_conv(post_quant_conv, args, dtype)

code_path = os.path.join(args.model_path, 'code')
os.makedirs(code_path, exist_ok=True)

shutil.copyfile('inference.py', os.path.join(code_path, 'inference.py'))
#shutil.copyfile('wrapper.py', os.path.join(code_path, 'wrapper.py'))
shutil.copyfile('requirements.txt', os.path.join(code_path, 'requirements.txt'))
155 changes: 155 additions & 0 deletions inference/stable-diffusion/src-inpaint/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
os.environ['NEURON_RT_NUM_CORES'] = '2'
import torch
import torch.nn as nn
import torch_neuronx
import time
from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.attention_processor import Attention

import threading
import argparse
import sys
import copy
import PIL
import math
import json
import requests
import io
from io import BytesIO
import base64
from PIL import Image


model_id = "stabilityai/stable-diffusion-2-inpainting"
dtype = torch.float32

class UNetWrap(nn.Module):
def __init__(self, unet):
super().__init__()
self.unet = unet

def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):
out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)
return out_tuple

class NeuronUNet(nn.Module):
def __init__(self, unetwrap):
super().__init__()
self.unetwrap = unetwrap
self.config = unetwrap.unet.config
self.in_channels = unetwrap.unet.in_channels
self.device = unetwrap.unet.device

def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, return_dict=False):
sample = self.unetwrap(sample, timestep.float().expand((sample.shape[0],)), encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)

class NeuronTextEncoder(nn.Module):
def __init__(self, text_encoder):
super().__init__()
self.neuron_text_encoder = text_encoder
self.config = text_encoder.config
self.dtype = text_encoder.dtype
self.device = text_encoder.device

def forward(self, emb, attention_mask = None):
return [self.neuron_text_encoder(emb)['last_hidden_state']]

# Optimized attention
def get_attention_scores(self, query, key, attn_mask):
dtype = query.dtype

if self.upcast_attention:
query = query.float()
key = key.float()

# Check for square matmuls
if(query.size() == key.size()):
attention_scores = custom_badbmm(
key,
query.transpose(-1, -2)
)

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = torch.nn.functional.softmax(attention_scores, dim=1).permute(0,2,1)
attention_probs = attention_probs.to(dtype)

else:
attention_scores = custom_badbmm(
query,
key.transpose(-1, -2)
)

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.to(dtype)

return attention_probs

def custom_badbmm(a, b):
bmm = torch.bmm(a, b)
scaled = bmm * 0.125
return scaled

def model_fn(model_dir, context=None):
global model_id, dtype
print("Loading model parts...")
t=time.time()

text_encoder_filename = os.path.join(model_dir, 'text_encoder/model.pt')
decoder_filename = os.path.join(model_dir, 'vae_decoder/model.pt')
unet_filename = os.path.join(model_dir, 'unet/model.pt')
post_quant_conv_filename = os.path.join(model_dir, 'vae_post_quant_conv/model.pt')

pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=dtype)

# Load the compiled UNet onto two neuron cores.
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
device_ids = [0,1]
pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)

# Load other compiled models onto a single neuron core.
pipe.text_encoder = NeuronTextEncoder(pipe.text_encoder)
pipe.text_encoder.neuron_text_encoder = torch.jit.load(text_encoder_filename)
pipe.vae.decoder = torch.jit.load(decoder_filename)
pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)

print(f"Done. Elapsed time: {(time.time()-t)*1000}ms")
return pipe

def input_fn(request_body, request_content_type, context=None):
if request_content_type == 'application/json':
req = json.loads(request_body)
prompt = req.get('prompt')
init_image = req.get('init_image')
mask_image = req.get('mask_image')
height = 512
width = 512

if prompt is None or type(prompt) != str or len(prompt) < 5:
raise("Invalid prompt. It needs to be a string > 5")

return prompt,init_image,mask_image,height,width
else:
raise Exception(f"Unsupported mime type: {request_content_type}. Supported: application/json")

def predict_fn(input_req, model, context=None):
prompt,init_image,mask_image,height,width = input_req
init_image_input = Image.open(io.BytesIO(base64.b64decode((init_image)))).convert("RGB").resize((width, height))
mask_image_input = Image.open(io.BytesIO(base64.b64decode((mask_image)))).convert("RGB").resize((width, height))
return model(prompt,image=init_image_input, mask_image=mask_image_input, height=height, width=width).images[0]

def output_fn(image, accept, context=None):
if accept!='image/jpeg':
raise Exception(f'Invalid data type. Expected image/jpeg, got {accept}')

buffer = io.BytesIO()
image.save(buffer, 'jpeg', icc_profile=image.info.get('icc_profile'))
buffer.seek(0)
return buffer.read()
6 changes: 6 additions & 0 deletions inference/stable-diffusion/src-inpaint/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
diffusers==0.20.2
transformers==4.33.1
accelerate==0.22.0
safetensors==0.3.1
matplotlib
Pillow
Loading