Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 168 additions & 44 deletions lightx2v/common/ops/attn/ulysses_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def apply(
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
use_tensor_fusion=False,
enable_head_parallel=False,
img_first=True,
**kwargs,
Expand All @@ -44,6 +45,8 @@ def apply(
返回:
torch.Tensor: 计算得到的注意力结果
"""
use_qkv_fusion = use_tensor_fusion

if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
Expand Down Expand Up @@ -94,48 +97,125 @@ def apply(
txt_q, txt_k, txt_v = q[:txt_qkv_len, :, :].contiguous(), k[:txt_qkv_len, :, :].contiguous(), v[:txt_qkv_len, :, :].contiguous()
img_q, img_k, img_v = q[txt_qkv_len:, :, :].contiguous(), k[txt_qkv_len:, :, :].contiguous(), v[txt_qkv_len:, :, :].contiguous()

img_qkv = torch.stack([img_q, img_k, img_v], dim=0).reshape(3, img_qkv_len, world_size, shard_heads, hidden_dims)
original_dtype = img_qkv.dtype
if use_qkv_fusion:
img_qkv = torch.stack([img_q, img_k, img_v], dim=0).reshape(3, img_qkv_len, world_size, shard_heads, hidden_dims)
original_dtype = img_qkv.dtype
else:
img_q = img_q.reshape(img_qkv_len, world_size, shard_heads, hidden_dims)
img_k = img_k.reshape(img_qkv_len, world_size, shard_heads, hidden_dims)
img_v = img_v.reshape(img_qkv_len, world_size, shard_heads, hidden_dims)
original_dtype = img_q.dtype

if enable_head_parallel:
img_qkv = img_qkv.permute(3, 2, 1, 0, 4).contiguous() # (shard_heads, world_size, img_qkv_len, 3, hidden_dims)
output_qkv = torch.empty_like(img_qkv)
if use_qkv_fusion:
img_qkv = img_qkv.permute(3, 2, 1, 0, 4).contiguous() # (shard_heads, world_size, img_qkv_len, 3, hidden_dims)
output_qkv = torch.empty_like(img_qkv)
else:
img_q = img_q.permute(2, 1, 0, 3).contiguous() # (shard_heads, world_size, img_qkv_len, hidden_dims)
img_k = img_k.permute(2, 1, 0, 3).contiguous()
img_v = img_v.permute(2, 1, 0, 3).contiguous()
output_q = torch.empty_like(img_q)
output_k = torch.empty_like(img_k)
output_v = torch.empty_like(img_v)
Comment on lines +110 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This section, and several others throughout the apply method (e.g., lines 123-166, 168-181, and the subsequent for loop), introduces significant code duplication due to the nested if use_qkv_fusion: and else: blocks. This pattern makes the code harder to read, understand, and maintain.

Consider refactoring to reduce this duplication. For instance, you could prepare the tensors (e.g., img_qkv or individual img_q, img_k, img_v) and their corresponding output placeholders (output_qkv or output_q, output_k, output_v) in a unified manner before entering the communication and processing loops. This would allow the subsequent logic to operate on a consistent structure, regardless of whether QKV fusion is enabled, thereby reducing the need for repeated if/else checks.

For example, you could define lists of tensors to communicate and lists of output tensors, and then iterate over these lists in the communication and waiting phases.


# 通信图像的查询、键和值
if use_fp8_comm:
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
img_qkv_fp8 = img_qkv_fp8.reshape(shard_heads, world_size, img_qkv_len, 3, hidden_dims)
img_qkv_scale = img_qkv_scale.reshape(shard_heads, world_size, img_qkv_len, 3, 1)
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
output_qkv_scale = torch.empty_like(img_qkv_scale)
comm_fp8_works = []
comm_scale_works = []
for h in range(shard_heads):
work_fp8 = dist.all_to_all_single(output_qkv_fp8[h], img_qkv_fp8[h], group=seq_p_group, async_op=True)
work_scale = dist.all_to_all_single(output_qkv_scale[h], img_qkv_scale[h], group=seq_p_group, async_op=True)
comm_fp8_works.append(work_fp8)
comm_scale_works.append(work_scale)
if use_qkv_fusion:
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
img_qkv_fp8 = img_qkv_fp8.reshape(shard_heads, world_size, img_qkv_len, 3, hidden_dims)
img_qkv_scale = img_qkv_scale.reshape(shard_heads, world_size, img_qkv_len, 3, 1)
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
output_qkv_scale = torch.empty_like(img_qkv_scale)
comm_fp8_works = []
comm_scale_works = []
for h in range(shard_heads):
work_fp8 = dist.all_to_all_single(output_qkv_fp8[h], img_qkv_fp8[h], group=seq_p_group, async_op=True)
work_scale = dist.all_to_all_single(output_qkv_scale[h], img_qkv_scale[h], group=seq_p_group, async_op=True)
comm_fp8_works.append(work_fp8)
comm_scale_works.append(work_scale)
else:
img_q_fp8, img_q_scale = quant_fp8_vllm(img_q.reshape(-1, hidden_dims))
img_k_fp8, img_k_scale = quant_fp8_vllm(img_k.reshape(-1, hidden_dims))
img_v_fp8, img_v_scale = quant_fp8_vllm(img_v.reshape(-1, hidden_dims))
img_q_fp8 = img_q_fp8.reshape(shard_heads, world_size, img_qkv_len, hidden_dims)
img_k_fp8 = img_k_fp8.reshape(shard_heads, world_size, img_qkv_len, hidden_dims)
img_v_fp8 = img_v_fp8.reshape(shard_heads, world_size, img_qkv_len, hidden_dims)
img_q_scale = img_q_scale.reshape(shard_heads, world_size, img_qkv_len, 1)
img_k_scale = img_k_scale.reshape(shard_heads, world_size, img_qkv_len, 1)
img_v_scale = img_v_scale.reshape(shard_heads, world_size, img_qkv_len, 1)
output_q_fp8 = torch.empty_like(img_q_fp8)
output_k_fp8 = torch.empty_like(img_k_fp8)
output_v_fp8 = torch.empty_like(img_v_fp8)
output_q_scale = torch.empty_like(img_q_scale)
output_k_scale = torch.empty_like(img_k_scale)
output_v_scale = torch.empty_like(img_v_scale)
comm_fp8_works = []
comm_scale_works = []
for h in range(shard_heads):
work_q_fp8 = dist.all_to_all_single(output_q_fp8[h], img_q_fp8[h], group=seq_p_group, async_op=True)
work_k_fp8 = dist.all_to_all_single(output_k_fp8[h], img_k_fp8[h], group=seq_p_group, async_op=True)
work_v_fp8 = dist.all_to_all_single(output_v_fp8[h], img_v_fp8[h], group=seq_p_group, async_op=True)
work_q_scale = dist.all_to_all_single(output_q_scale[h], img_q_scale[h], group=seq_p_group, async_op=True)
work_k_scale = dist.all_to_all_single(output_k_scale[h], img_k_scale[h], group=seq_p_group, async_op=True)
work_v_scale = dist.all_to_all_single(output_v_scale[h], img_v_scale[h], group=seq_p_group, async_op=True)
comm_fp8_works.append(work_q_fp8)
comm_fp8_works.append(work_k_fp8)
comm_fp8_works.append(work_v_fp8)
comm_scale_works.append(work_q_scale)
comm_scale_works.append(work_k_scale)
comm_scale_works.append(work_v_scale)
else:
comm_works = []
for h in range(shard_heads):
work = dist.all_to_all_single(output_qkv[h], img_qkv[h], group=seq_p_group, async_op=True)
comm_works.append(work)
if use_qkv_fusion:
comm_works = []
for h in range(shard_heads):
work = dist.all_to_all_single(output_qkv[h], img_qkv[h], group=seq_p_group, async_op=True)
comm_works.append(work)
else:
comm_works = []
for h in range(shard_heads):
work_q = dist.all_to_all_single(output_q[h], img_q[h], group=seq_p_group, async_op=True)
work_k = dist.all_to_all_single(output_k[h], img_k[h], group=seq_p_group, async_op=True)
work_v = dist.all_to_all_single(output_v[h], img_v[h], group=seq_p_group, async_op=True)
comm_works.append(work_q)
comm_works.append(work_k)
comm_works.append(work_v)

# 逐个head完成Attention计算
single_head = 1
head_attns = []
for h in range(shard_heads):
if use_fp8_comm:
comm_fp8_works[h].wait()
comm_scale_works[h].wait()
output_qkv[h] = dequant_fp8_vllm(output_qkv_fp8[h], output_qkv_scale[h], original_dtype)
if use_qkv_fusion:
comm_fp8_works[h].wait()
comm_scale_works[h].wait()
output_qkv[h] = dequant_fp8_vllm(output_qkv_fp8[h], output_qkv_scale[h], original_dtype)
else:
comm_fp8_works[3 * h].wait()
comm_fp8_works[3 * h + 1].wait()
comm_fp8_works[3 * h + 2].wait()
comm_scale_works[3 * h].wait()
comm_scale_works[3 * h + 1].wait()
comm_scale_works[3 * h + 2].wait()
output_q[h] = dequant_fp8_vllm(output_q_fp8[h], output_q_scale[h], original_dtype)
output_k[h] = dequant_fp8_vllm(output_k_fp8[h], output_k_scale[h], original_dtype)
output_v[h] = dequant_fp8_vllm(output_v_fp8[h], output_v_scale[h], original_dtype)
else:
comm_works[h].wait()

qkv = output_qkv[h].reshape(global_img_seqlen, 3, single_head, hidden_dims).transpose(0, 1)
shard_img_q = qkv[0] # (global_img_seqlen, single_head, hidden_dims)
shard_img_k = qkv[1]
shard_img_v = qkv[2]
if use_qkv_fusion:
comm_works[h].wait()
else:
comm_works[3 * h].wait()
comm_works[3 * h + 1].wait()
comm_works[3 * h + 2].wait()

if use_qkv_fusion:
qkv = output_qkv[h].reshape(global_img_seqlen, 3, single_head, hidden_dims).transpose(0, 1)
shard_img_q = qkv[0] # (global_img_seqlen, single_head, hidden_dims)
shard_img_k = qkv[1]
shard_img_v = qkv[2]
else:
shard_img_q = output_q[h].reshape(global_img_seqlen, single_head, hidden_dims)
shard_img_k = output_k[h].reshape(global_img_seqlen, single_head, hidden_dims)
shard_img_v = output_v[h].reshape(global_img_seqlen, single_head, hidden_dims)

# 处理文本的查询、键和值,选择当前进程的当前头
shard_txt_q = txt_q[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
Expand All @@ -160,27 +240,71 @@ def apply(
attn = torch.cat(head_attns, dim=1)

else:
img_qkv = img_qkv.permute(2, 1, 0, 3, 4).contiguous() # (world_size, img_qkv_len, 3, shard_heads, hidden_dims)
if use_qkv_fusion:
img_qkv = img_qkv.permute(2, 1, 0, 3, 4).contiguous() # (world_size, img_qkv_len, 3, shard_heads, hidden_dims)
else:
img_q = img_q.permute(1, 0, 2, 3).contiguous() # (world_size, img_q_len, shard_heads, hidden_dims)
img_k = img_k.permute(1, 0, 2, 3).contiguous() # (world_size, img_k_len, shard_heads, hidden_dims)
img_v = img_v.permute(1, 0, 2, 3).contiguous() # (world_size, img_v_len, shard_heads, hidden_dims)

# 通信图像的查询、键和值
if use_fp8_comm:
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
img_qkv_fp8 = img_qkv_fp8.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims)
img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1)
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
output_qkv_scale = torch.empty_like(img_qkv_scale)
dist.all_to_all_single(output_qkv_fp8, img_qkv_fp8, group=seq_p_group)
dist.all_to_all_single(output_qkv_scale, img_qkv_scale, group=seq_p_group)
output_qkv = dequant_fp8_vllm(output_qkv_fp8, output_qkv_scale, original_dtype)
if use_qkv_fusion:
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
img_qkv_fp8 = img_qkv_fp8.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims)
img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1)
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
output_qkv_scale = torch.empty_like(img_qkv_scale)
dist.all_to_all_single(output_qkv_fp8, img_qkv_fp8, group=seq_p_group)
dist.all_to_all_single(output_qkv_scale, img_qkv_scale, group=seq_p_group)
output_qkv = dequant_fp8_vllm(output_qkv_fp8, output_qkv_scale, original_dtype)
else:
img_q_fp8, img_q_scale = quant_fp8_vllm(img_q.reshape(-1, hidden_dims))
img_k_fp8, img_k_scale = quant_fp8_vllm(img_k.reshape(-1, hidden_dims))
img_v_fp8, img_v_scale = quant_fp8_vllm(img_v.reshape(-1, hidden_dims))
img_q_fp8 = img_q_fp8.reshape(world_size, img_qkv_len, shard_heads, hidden_dims)
img_k_fp8 = img_k_fp8.reshape(world_size, img_qkv_len, shard_heads, hidden_dims)
img_v_fp8 = img_v_fp8.reshape(world_size, img_qkv_len, shard_heads, hidden_dims)
img_q_scale = img_q_scale.reshape(world_size, img_qkv_len, shard_heads, 1)
img_k_scale = img_k_scale.reshape(world_size, img_qkv_len, shard_heads, 1)
img_v_scale = img_v_scale.reshape(world_size, img_qkv_len, shard_heads, 1)
output_q_fp8 = torch.empty_like(img_q_fp8)
output_k_fp8 = torch.empty_like(img_k_fp8)
output_v_fp8 = torch.empty_like(img_v_fp8)
output_q_scale = torch.empty_like(img_q_scale)
output_k_scale = torch.empty_like(img_k_scale)
output_v_scale = torch.empty_like(img_v_scale)
dist.all_to_all_single(output_q_fp8, img_q_fp8, group=seq_p_group)
dist.all_to_all_single(output_k_fp8, img_k_fp8, group=seq_p_group)
dist.all_to_all_single(output_v_fp8, img_v_fp8, group=seq_p_group)
dist.all_to_all_single(output_q_scale, img_q_scale, group=seq_p_group)
dist.all_to_all_single(output_k_scale, img_k_scale, group=seq_p_group)
dist.all_to_all_single(output_v_scale, img_v_scale, group=seq_p_group)
output_q = dequant_fp8_vllm(output_q_fp8, output_q_scale, original_dtype)
output_k = dequant_fp8_vllm(output_k_fp8, output_k_scale, original_dtype)
output_v = dequant_fp8_vllm(output_v_fp8, output_v_scale, original_dtype)
else:
output_qkv = torch.empty_like(img_qkv)
dist.all_to_all_single(output_qkv, img_qkv, group=seq_p_group)
if use_qkv_fusion:
output_qkv = torch.empty_like(img_qkv)
dist.all_to_all_single(output_qkv, img_qkv, group=seq_p_group)
else:
output_q = torch.empty_like(img_q)
output_k = torch.empty_like(img_k)
output_v = torch.empty_like(img_v)
dist.all_to_all_single(output_q, img_q, group=seq_p_group)
dist.all_to_all_single(output_k, img_k, group=seq_p_group)
dist.all_to_all_single(output_v, img_v, group=seq_p_group)

# 完成Attention计算
qkv = output_qkv.reshape(global_img_seqlen, 3, shard_heads, hidden_dims).transpose(0, 1)
shard_img_q = qkv[0] # (global_img_seqlen, shard_head, hidden_dims)
shard_img_k = qkv[1]
shard_img_v = qkv[2]
if use_qkv_fusion:
qkv = output_qkv.reshape(global_img_seqlen, 3, shard_heads, hidden_dims).transpose(0, 1)
shard_img_q = qkv[0] # (global_img_seqlen, shard_head, hidden_dims)
shard_img_k = qkv[1]
shard_img_v = qkv[2]
else:
shard_img_q = output_q.reshape(global_img_seqlen, shard_heads, hidden_dims)
shard_img_k = output_k.reshape(global_img_seqlen, shard_heads, hidden_dims)
shard_img_v = output_v.reshape(global_img_seqlen, shard_heads, hidden_dims)

# 处理文本的查询、键和值,选择当前进程的当前头
shard_txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
Expand Down