-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Description
Interesting Kernels
Diagonal Matrix Multiplication
Problem 13 in level 1 involves multiplying a matrix by another diagonal matrix:
torch.diag(A) @ B
torch.diag() takes in a vector of the diagonal elements of a matrix and returns a 2-D square tensor with the elements of input as the diagonal. The result is a matrix-matrix multiplication.
Mathematically, multiplying a matrix by a diagonal matrix is equivalent to scaling each row (or column, if the diagonal matrix is on the right side) of the original matrix by the corresponding diagonal element. As a result, the diagonal matrix doesn’t need to be explicitly constructed, reducing both memory usage and computational overhead.
This is the problem that gets the >12x speedup over torch and torch.compile() in level 1 for multiple models, one example of these generated CUDA kernel is below:
__global__ void diag_matmul_kernel(
const float* diag,
const float* mat,
float* out,
const int N,
const int M) {
const int row = blockIdx.y * blockDim.y + threadIdx.y;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < N && col < M) {
out[row * M + col] = diag[row] * mat[row * M + col];
}
}
I think this should be problem 12?
Metadata
Metadata
Assignees
Labels
No labels