-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathevaluator.py
More file actions
309 lines (251 loc) · 11.3 KB
/
evaluator.py
File metadata and controls
309 lines (251 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
Module for evaluating image editing with stable diffusion features and perceptual metrics.
Includes functionality for feature extraction and similarity measurements.
"""
import warnings
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn.functional as F
import lpips
from diffusers import StableDiffusionPipeline
# Suppress UserWarnings
warnings.filterwarnings(action='ignore', category=UserWarning)
class SDFeaturizer(StableDiffusionPipeline):
"""
Extracts Stable Diffusion 2.1 features for semantic point matching (DIFT).
Inherits from StableDiffusionPipeline and adds feature extraction capabilities.
"""
@torch.no_grad()
def __call__(
self,
img_tensor: torch.Tensor,
t: int = 261,
ensemble: int = 8,
prompt: str = None,
prompt_embeds: torch.Tensor = None
) -> torch.Tensor:
"""
Extract features from input image tensor.
Args:
img_tensor: Input image tensor (B,C,H,W)
t: Timestep for noise addition
ensemble: Number of ensemble predictions
prompt: Text prompt for conditioning
prompt_embeds: Pre-computed prompt embeddings
Returns:
torch.Tensor: Extracted features
"""
assert img_tensor.shape[0] == 1, "Batch size must be 1"
device = self._execution_device
# Encode image to latent space
latents = self.vae.encode(img_tensor).latent_dist.mode() * self.vae.config.scaling_factor
latents = latents.expand(ensemble, -1, -1, -1)
# Add noise
t = torch.tensor(t, dtype=torch.long, device=device)
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# Get prompt embeddings
if prompt_embeds is None:
prompt = "" if prompt is None else prompt
prompt_embeds = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=False
)[0]
prompt_embeds = prompt_embeds.expand(ensemble, -1, -1)
# Extract features using forward hook
unet_feature = []
def hook(module, input, output):
unet_feature.clear()
unet_feature.append(output)
handle = list(self.unet.children())[4][1].register_forward_hook(hook=hook)
self.unet(latents_noisy, t, prompt_embeds)
handle.remove()
return unet_feature[0].mean(dim=0, keepdim=True)
class DragEvaluator:
"""
Evaluator for computing perceptual and distance metrics between images.
Provides methods for LPIPS similarity and point-based distance measurements.
"""
def __init__(self):
"""Initialize the evaluator with required models and settings."""
self.sd_loaded = False
self.lpips_loaded = False
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.dtype = torch.float16
def load_sd(self):
"""Load Stable Diffusion model if not already loaded."""
if not self.sd_loaded:
self.sd_feat = SDFeaturizer.from_pretrained(
'stabilityai/stable-diffusion-2-1',
torch_dtype=self.dtype
).to(self.device)
self.sd_loaded = True
def load_lpips(self):
"""Load LPIPS model if not already loaded."""
if not self.lpips_loaded:
self.loss_fn_alex = lpips.LPIPS(net='alex').to(self.device).to(self.dtype)
self.lpips_loaded = True
def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
"""
Convert image to tensor and normalize to [-1, 1].
Args:
image: Input image as numpy array
Returns:
torch.Tensor: Preprocessed image tensor
"""
image = torch.from_numpy(np.array(image)).float() / 127.5 - 1
image = image.unsqueeze(0).permute(0, 3, 1, 2)
return image.to(self.device).to(self.dtype)
@torch.no_grad()
def compute_lpips(self, original_image: np.ndarray, edited_image: np.ndarray) -> float:
"""
Compute LPIPS perceptual similarity between two images.
Args:
original_image: Original image array
edited_image: Edited image array
Returns:
float: LPIPS similarity score
"""
self.load_lpips()
image1 = F.interpolate(self.preprocess_image(original_image), (224, 224), mode='bilinear')
image2 = F.interpolate(self.preprocess_image(edited_image), (224, 224), mode='bilinear')
return self.loss_fn_alex(image1, image2).item()
@torch.no_grad()
def compute_distance(self, original_image: np.ndarray, edited_image: np.ndarray,
handle_pts: np.ndarray, target_pts: np.ndarray,
prompt: str = None, plot_path: str = None) -> float:
"""
Compute mean distance metric between handle and target points.
Args:
original_image: Original image array
edited_image: Edited image array
handle_pts: Handle point coordinates
target_pts: Target point coordinates
prompt: Optional text prompt
plot_path: Optional path to save visualization
Returns:
float: Mean distance metric
"""
self.load_sd()
handle_pts = torch.tensor(handle_pts, device=self.device, dtype=torch.long)
target_pts = torch.tensor(target_pts, device=self.device, dtype=torch.long)
# Handle image size mismatch
if original_image.shape != edited_image.shape:
orig_h, orig_w = original_image.shape[:2]
edit_h, edit_w = edited_image.shape[:2]
edited_image = cv2.resize(edited_image, (orig_w, orig_h))
target_pts = target_pts * torch.tensor([orig_w, orig_h], device=self.device)
target_pts = (target_pts / torch.tensor([edit_w, edit_h], device=self.device)).long()
image_h, image_w = original_image.shape[:2]
orig_img = F.interpolate(self.preprocess_image(original_image), size=(768, 768))
edit_img = F.interpolate(self.preprocess_image(edited_image), size=(768, 768))
# Extract and process features
orig_feat = F.interpolate(self.sd_feat(orig_img, prompt=prompt), size=(image_h, image_w))
edit_feat = F.interpolate(self.sd_feat(edit_img, prompt=prompt), size=(image_h, image_w))
mask = self._create_mask(handle_pts, target_pts, (image_h, image_w))
matched_pts = self._nn_get_matches(orig_feat, edit_feat, handle_pts, mask=mask)
# Calculate distance metric
dist = target_pts - matched_pts
dist = dist.float() / torch.tensor([image_w, image_h], device=self.device)
mean_dist = dist.norm(dim=-1).mean().item()
if plot_path:
self.plot_drag_result(
edited_image,
matched_pts.cpu().numpy(),
target_pts.cpu().numpy(),
output_path=plot_path
)
return mean_dist
@staticmethod
def plot_drag_result(edited_image: np.ndarray, handle_pts: np.ndarray,
target_pts: np.ndarray, output_path: str = None):
"""
Plot drag editing results with arrows showing point movements.
Args:
edited_image: Edited image array
handle_pts: Handle point coordinates
target_pts: Target point coordinates
output_path: Optional path to save visualization
"""
plt.figure(figsize=(10, 10))
plt.imshow(edited_image)
# Convert points to numpy if needed
if torch.is_tensor(handle_pts):
handle_pts = handle_pts.cpu().numpy()
if torch.is_tensor(target_pts):
target_pts = target_pts.cpu().numpy()
# Plot points and arrows
plt.scatter(target_pts[:, 0], target_pts[:, 1], c='blue', label='Target Points')
plt.scatter(handle_pts[:, 0], handle_pts[:, 1], c='red', label='Matched Points')
for i in range(len(handle_pts)):
plt.arrow(handle_pts[i, 0], handle_pts[i, 1],
target_pts[i, 0] - handle_pts[i, 0],
target_pts[i, 1] - handle_pts[i, 1],
color='white', head_width=5, head_length=5)
plt.legend()
plt.axis('off')
if output_path:
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
else:
plt.show()
@staticmethod
def _create_mask(handle_pts: torch.Tensor, target_pts: torch.Tensor,
img_size: tuple) -> torch.Tensor:
"""
Create masks based on pixel distances to point pairs.
Args:
handle_pts: Handle point coordinates
target_pts: Target point coordinates
img_size: Image dimensions (H,W)
Returns:
torch.Tensor: Binary mask
"""
handle_pts, target_pts = handle_pts.float(), target_pts.float()
h, w = img_size
min_dist = ((handle_pts - target_pts).norm(dim=1) / 2**0.5).clamp(min=5)
y_grid, x_grid = torch.meshgrid(
torch.arange(h, device=handle_pts.device),
torch.arange(w, device=handle_pts.device),
indexing="ij"
)
y_grid = y_grid.expand(len(handle_pts), -1, -1)
x_grid = x_grid.expand(len(handle_pts), -1, -1)
handle_dist = ((x_grid - handle_pts[:, None, None, 0])**2 + (y_grid - handle_pts[:, None, None, 1])**2).sqrt()
target_dist = ((x_grid - target_pts[:, None, None, 0])**2 + (y_grid - target_pts[:, None, None, 1])**2).sqrt()
return (handle_dist < min_dist[:, None, None]) | (target_dist < min_dist[:, None, None])
@staticmethod
def _nn_get_matches(src_featmaps: torch.Tensor, trg_featmaps: torch.Tensor, query: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Find nearest neighbor matches between source and target feature maps.
Args:
src_featmaps: Source feature maps
trg_featmaps: Target feature maps
query: Query points
l2_norm: Whether to apply L2 normalization
mask: Optional mask for valid matches
Returns:
torch.Tensor: Matched point coordinates
"""
_, c, h, w = src_featmaps.shape
query = query.long()
src_feat = src_featmaps[0, :, query[:, 1], query[:, 0]]
src_feat = F.normalize(src_feat, p=2, dim=0)
trg_featmaps = F.normalize(trg_featmaps, p=2, dim=1)
trg_featmaps = trg_featmaps.view(c, -1)
similarity = torch.mm(src_feat.t(), trg_featmaps)
if mask is not None:
similarity = torch.where(
mask.view(-1, h * w),
similarity,
torch.full_like(similarity, -torch.inf)
)
best_idx = similarity.argmax(dim=-1)
y_coords = best_idx // w
x_coords = best_idx % w
return torch.stack((x_coords, y_coords), dim=1).float()