Skip to content

SDPA dispatch to TPU Pallas kernel is missing inverse sqrt scaling and overrides casual as always true #39

@mozuysal-hubx

Description

@mozuysal-hubx

Hi @qihqi ,

I have encountered the following during experimentation. When using the pallas tpu kernel for scaled_dot_product_attention the following code is executed:

        return flash_attention.flash_attention(
            query, key, value, causal=True, block_sizes=block_sizes
        )

This always overrides causal as True, which is problematic for diffusion text-to-image models and does not match the reference implementation that handles this parameter correctly.

Another issue is the inverse square root scaling present in the reference here is not part of the pallas kernel computation, so it needs to be handled before the pallas call as is done for torch xla attention backend in diffusers here.

The following replaces torch sdpa call with a correct tpu pallas flash attention call:

def use_tpu_flash_attention(env):
    env.config.use_tpu_flash_attention = True

    @jax.default_matmul_precision("default")
    def custom_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                         is_causal=False, scale=None, enable_gqa=False):
        block_sizes = flash_attention.BlockSizes(
            block_b=min(2, query.shape[0]),
            block_q=min(512, query.shape[2]),
            block_k_major=min(512, key.shape[2]),
            block_k=min(512, key.shape[2]),
            block_q_major_dkv=min(512, query.shape[2]),
            block_k_major_dkv=min(512, key.shape[2]),
            block_k_dkv=min(512, key.shape[2]),
            block_q_dkv=min(512, query.shape[2]),
            block_k_major_dq=min(512, key.shape[2]),
            block_k_dq=min(256, key.shape[2]),
            block_q_dq=min(1024, query.shape[2]),
        )
        query = query / math.sqrt(query.shape[-1])
        jk, jq, jv = torchax.interop.jax_view((query, key, value))
        res = flash_attention.flash_attention(jk, jq, jv, block_sizes=block_sizes, causal=is_causal)
        return torchax.interop.torch_view(res)

    env.override_op_definition(torch.nn.functional.scaled_dot_product_attention, custom_attention)

This still misses correct handling of the attention mask, but fixes the two issues above.

I can open a PR with a fix for these two issues or wait for the core dev team to fix it. Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions