Skip to content

Maybe wrong positional embedding? #38

@Felix-Zhenghao

Description

@Felix-Zhenghao

In the implementation of DAT++, the rpe_table is initialized as:

self.rpe_table = nn.Parameter(
    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
)

Shouldn't the table be torch.zeros(self.n_heads, self.q_h * 2 + 1, self.q_w * 2 + 1)?

For instance, the q_h and q_w can be 56 both. Then the (x,y) displacement is within a square: $ {(x,y) | x \in [-1,1], y\in [-1,1]} $. Each range of length 1 is divided into 56 crops. Therefore, the total number of vertices on the square is 113 x 113 rather than 111 x 111.

If use torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1), we are ignoring the boundary of the 2D square.

@Vladimir2506 am I wrong anywhere? Thank you if you can give an answer!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions