diff --git a/KernelBench/changelog/constant_fill_fixes.txt b/KernelBench/changelog/constant_fill_fixes.txt new file mode 100644 index 00000000..2c84bd89 --- /dev/null +++ b/KernelBench/changelog/constant_fill_fixes.txt @@ -0,0 +1,52 @@ +Changelog: Constant Fill Problems Fixes +======================================== + +Date: 2025-12-20 + +Fixed 3 problems that produced constant (zero) outputs regardless of input. + +-------------------------------------------------------------------------------- + +1. level2/80_Gemm_Max_Subtract_GELU.py + + Issue: After max(dim=1, keepdim=True), shape is (B,1). The mean along dim=1 + of a single-element tensor equals the value itself, so x - mean = 0. + + Fix: Changed mean dimension from 1 to 0. + - x = x - x.mean(dim=1, keepdim=True) + + x = x - x.mean(dim=0, keepdim=True) + + Why: Shape is (B,1), so mean(dim=0) gives scalar mean across B samples; each + sample's max differs, producing non-zero deviations from batch mean. + +-------------------------------------------------------------------------------- + +2. level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py + + Issue: min(x, 0.0) forces all values ≤ 0, then clamp(min=0.0) forces all + values to exactly 0. + + Fix: Changed min to use max_value instead of min_value; set max_value=0.5. + - x = torch.min(x, torch.tensor(min_value, device=x.device)) + + x = torch.min(x, torch.tensor(max_value, device=x.device)) + - max_value = 1.0 + + max_value = 0.5 + + Why: min(x, 0.5) caps at 0.5; clamp bounds to [0,0.5], giving output in [0,0.5] + range which preserves Conv3d/GroupNorm variation. + +-------------------------------------------------------------------------------- + +3. level2/23_Conv3d_GroupNorm_Mean.py + + Issue: GroupNorm normalizes to zero mean per group (with default affine + params γ=1, β=0). The global mean of zero-mean data is ~0. + + Fix: Replaced mean with amax (global max pooling). + - x = x.mean(dim=[1, 2, 3, 4]) + + x = x.amax(dim=[1, 2, 3, 4]) + + Why: After GroupNorm, mean is ~0 but max varies per input because different + inputs have different extreme values in the normalized distribution. + +-------------------------------------------------------------------------------- diff --git a/KernelBench/changelog/redundant_op_fixes.txt b/KernelBench/changelog/redundant_op_fixes.txt new file mode 100644 index 00000000..d3c6a20a --- /dev/null +++ b/KernelBench/changelog/redundant_op_fixes.txt @@ -0,0 +1,89 @@ +Changelog: Redundant Operation Fixes +===================================== + +Date: 2025-12-20 + +Removed 7 redundant operations that had no effect on model output. + +-------------------------------------------------------------------------------- + +1. level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py + + Issue: Second global avg pool is no-op (tensor is N×C×1×1 after first pool). + + Fix: Removed second mean operation. + - x = torch.mean(x, dim=[2, 3], keepdim=True) # First + - x = torch.mean(x, dim=[2, 3], keepdim=True) # Second (removed) + +-------------------------------------------------------------------------------- + +2. level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py + + Issue: Hardtanh[-1,1] after tanh→GELU is redundant (GELU of tanh output + is already in approximately [-0.16, 0.84] ⊂ [-1, 1]). + + Fix: Removed Hardtanh. + - x = torch.nn.functional.hardtanh(x, min_val=-1, max_val=1) + +-------------------------------------------------------------------------------- + +3. level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py + + Issue: Final clamp[-1,1] after tanh is redundant (tanh already outputs [-1,1]). + + Fix: Removed final clamp. + - x = torch.clamp(x, min=-1.0, max=1.0) + +-------------------------------------------------------------------------------- + +4. level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py + + Issue: LeakyReLU after ReLU is identity (ReLU output is ≥0, LeakyReLU is + identity for non-negative inputs). + + Fix: Removed LeakyReLU. + - x = torch.nn.functional.leaky_relu(x, negative_slope=0.01) + +-------------------------------------------------------------------------------- + +5. level3/36_LSTMHn.py + + Issue: fc layer computes output but returns h_n (state[0]) instead, making + fc dead code. + + Fix: Removed fc layer from __init__ and forward. + - self.fc = nn.Linear(hidden_size, output_size) + - out = self.fc(out[:, -1, :]) + +-------------------------------------------------------------------------------- + +6. level3/37_LSTMCn.py + + Issue: fc layer computes output but returns c_n (state[1]) instead, making + fc dead code. + + Fix: Removed fc layer from __init__ and forward. + - self.fc = nn.Linear(hidden_size, output_size) + - out = self.fc(out[:, -1, :]) + +-------------------------------------------------------------------------------- + +7. level3/49_Mamba2ReturnFinalState.py + + Issue: Y_diag einsum is computed but never used (returns new_states[:, -1]). + L is only used to compute Y_diag, so both are dead code. + + Fix: Removed dead code computing L and Y_diag. + - L = torch.exp(self.segsum(A_blocks)) + - Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", ...) + +-------------------------------------------------------------------------------- + +TODO: Pending Name Changes (5 files) +------------------------------------- +[ ] level2/23_Conv3d_GroupNorm_Mean.py → 23_Conv3d_GroupNorm_Amax.py +[ ] level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py → 44_ConvTranspose2d_Multiply_GlobalAvgPool_Mean.py +[ ] level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py → 95_Matmul_Add_Swish_Tanh_GELU.py +[ ] level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py → 81_Gemm_Swish_Divide_Clamp_Tanh.py +[ ] level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py → 7_Conv3d_ReLU_GELU_Sigmoid_BiasAdd.py + diff --git a/KernelBench/level2/23_Conv3d_GroupNorm_Mean.py b/KernelBench/level2/23_Conv3d_GroupNorm_Mean.py index f3d4f8ae..6669646a 100644 --- a/KernelBench/level2/23_Conv3d_GroupNorm_Mean.py +++ b/KernelBench/level2/23_Conv3d_GroupNorm_Mean.py @@ -19,7 +19,7 @@ def forward(self, x): """ x = self.conv(x) x = self.group_norm(x) - x = x.mean(dim=[1, 2, 3, 4]) # Compute mean across all dimensions except batch + x = x.amax(dim=[1, 2, 3, 4]) # Global max pool return x batch_size = 128 diff --git a/KernelBench/level2/23_Conv3d_GroupNorm_Mean_OLD.py b/KernelBench/level2/23_Conv3d_GroupNorm_Mean_OLD.py new file mode 100644 index 00000000..333d8abe --- /dev/null +++ b/KernelBench/level2/23_Conv3d_GroupNorm_Mean_OLD.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Model that performs a 3D convolution, applies Group Normalization, computes the mean + """ + def __init__(self, in_channels, out_channels, kernel_size, num_groups): + super(Model, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size) + self.group_norm = nn.GroupNorm(num_groups, out_channels) + + def forward(self, x): + """ + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W). + Returns: + torch.Tensor: Output tensor of shape (batch_size, 1). + """ + x = self.conv(x) + x = self.group_norm(x) + x = x.mean(dim=[1, 2, 3, 4]) # Compute mean across all dimensions except batch + return x + +batch_size = 128 +in_channels = 3 +out_channels = 24 +D, H, W = 24, 32, 32 +kernel_size = 3 +num_groups = 8 + +def get_inputs(): + return [torch.rand(batch_size, in_channels, D, H, W)] + +def get_init_inputs(): + return [in_channels, out_channels, kernel_size, num_groups] + diff --git a/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py b/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py index dc4d51a6..6644c56c 100644 --- a/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py +++ b/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py @@ -14,8 +14,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, outp def forward(self, x): x = self.conv_transpose(x) x = x * self.multiplier - x = torch.mean(x, dim=[2, 3], keepdim=True) # First global average pooling - x = torch.mean(x, dim=[2, 3], keepdim=True) # Second global average pooling + x = torch.mean(x, dim=[2, 3], keepdim=True) # Global average pooling return x batch_size = 16 diff --git a/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean_OLD.py b/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean_OLD.py new file mode 100644 index 00000000..e71a7d23 --- /dev/null +++ b/KernelBench/level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean_OLD.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling, + another global average pooling + """ + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier): + super(Model, self).__init__() + self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding) + self.multiplier = multiplier + + def forward(self, x): + x = self.conv_transpose(x) + x = x * self.multiplier + x = torch.mean(x, dim=[2, 3], keepdim=True) # First global average pooling + x = torch.mean(x, dim=[2, 3], keepdim=True) # Second global average pooling + return x + +batch_size = 16 +in_channels = 64 +out_channels = 128 +height, width = 128, 128 +kernel_size = 3 +stride = 2 +padding = 1 +output_padding = 1 +multiplier = 0.5 + +def get_inputs(): + return [torch.rand(batch_size, in_channels, height, width)] + +def get_init_inputs(): + return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier] + diff --git a/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py b/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py index 498103a2..a74359ab 100644 --- a/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py +++ b/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py @@ -13,7 +13,6 @@ def __init__(self, in_channels, out_channels, kernel_size, bias_shape): def forward(self, x): x = self.conv(x) x = torch.relu(x) - x = torch.nn.functional.leaky_relu(x, negative_slope=0.01) x = torch.nn.functional.gelu(x) x = torch.sigmoid(x) x = x + self.bias diff --git a/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd_OLD.py b/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd_OLD.py new file mode 100644 index 00000000..94f5c490 --- /dev/null +++ b/KernelBench/level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd_OLD.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Model that performs a 3D convolution, applies ReLU, LeakyReLU, GELU, Sigmoid activations, and bias in sequence. + """ + def __init__(self, in_channels, out_channels, kernel_size, bias_shape): + super(Model, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size) + self.bias = nn.Parameter(torch.randn(bias_shape)) + + def forward(self, x): + x = self.conv(x) + x = torch.relu(x) + x = torch.nn.functional.leaky_relu(x, negative_slope=0.01) + x = torch.nn.functional.gelu(x) + x = torch.sigmoid(x) + x = x + self.bias + return x + +batch_size = 64 +in_channels = 8 +out_channels = 32 +depth, height, width = 32, 64, 64 +kernel_size = 3 +bias_shape = (out_channels, 1, 1, 1) + +def get_inputs(): + return [torch.rand(batch_size, in_channels, depth, height, width)] + +def get_init_inputs(): + return [in_channels, out_channels, kernel_size, bias_shape] + diff --git a/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py b/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py index 486376a9..25f76a22 100644 --- a/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py +++ b/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py @@ -20,7 +20,7 @@ def forward(self, x): """ x = self.gemm(x) x = torch.max(x, dim=self.max_dim, keepdim=True).values - x = x - x.mean(dim=1, keepdim=True) + x = x - x.mean(dim=0, keepdim=True) x = torch.nn.functional.gelu(x) return x diff --git a/KernelBench/level2/80_Gemm_Max_Subtract_GELU_OLD.py b/KernelBench/level2/80_Gemm_Max_Subtract_GELU_OLD.py new file mode 100644 index 00000000..e0fc26ea --- /dev/null +++ b/KernelBench/level2/80_Gemm_Max_Subtract_GELU_OLD.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Model that performs a GEMM, followed by a max operation, subtraction, and GELU activation. + """ + def __init__(self, in_features, out_features, max_dim): + super(Model, self).__init__() + self.gemm = nn.Linear(in_features, out_features) + self.max_dim = max_dim + + def forward(self, x): + """ + Args: + x: Input tensor of shape (batch_size, in_features) + + Returns: + Output tensor of shape (batch_size, out_features) + """ + x = self.gemm(x) + x = torch.max(x, dim=self.max_dim, keepdim=True).values + x = x - x.mean(dim=1, keepdim=True) + x = torch.nn.functional.gelu(x) + return x + +batch_size = 1024 +in_features = 8192 +out_features = 8192 +max_dim = 1 + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + +def get_init_inputs(): + return [in_features, out_features, max_dim] + diff --git a/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py b/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py index 1aeca85a..0792c3fd 100644 --- a/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py +++ b/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py @@ -21,7 +21,6 @@ def forward(self, x): x = x / 2.0 x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1 x = torch.tanh(x) # Tanh activation - x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1 return x batch_size = 1024 diff --git a/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_OLD.py b/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_OLD.py new file mode 100644 index 00000000..561e889f --- /dev/null +++ b/KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_OLD.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Simple model that performs a gemm, swish, divide, clamp, tanh, and clamp operations. + """ + def __init__(self, in_features, out_features, bias=True): + super(Model, self).__init__() + self.gemm = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x): + """ + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_features). + Returns: + torch.Tensor: Output tensor of shape (batch_size, out_features). + """ + x = self.gemm(x) + x = x * torch.sigmoid(x) # Swish activation + x = x / 2.0 + x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1 + x = torch.tanh(x) # Tanh activation + x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1 + return x + +batch_size = 1024 +in_features = 8192 +out_features = 8192 + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + +def get_init_inputs(): + return [in_features, out_features] + diff --git a/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py b/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py index d29fa677..d55b9110 100644 --- a/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py +++ b/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py @@ -14,7 +14,7 @@ def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, ma def forward(self, x): x = self.conv(x) x = self.norm(x) - x = torch.min(x, torch.tensor(min_value, device=x.device)) + x = torch.min(x, torch.tensor(max_value, device=x.device)) x = torch.clamp(x, min=min_value, max=max_value) x = self.dropout(x) return x @@ -26,7 +26,7 @@ def forward(self, x): kernel_size = 3 groups = 8 min_value = 0.0 -max_value = 1.0 +max_value = 0.5 dropout_p = 0.2 def get_inputs(): diff --git a/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout_OLD.py b/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout_OLD.py new file mode 100644 index 00000000..c26342dd --- /dev/null +++ b/KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout_OLD.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Model that performs a 3D convolution, applies Group Normalization, minimum, clamp, and dropout. + """ + def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p): + super(Model, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size) + self.norm = nn.GroupNorm(groups, out_channels) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = torch.min(x, torch.tensor(min_value, device=x.device)) + x = torch.clamp(x, min=min_value, max=max_value) + x = self.dropout(x) + return x + +batch_size = 128 +in_channels = 3 +out_channels = 16 +depth, height, width = 16, 64, 64 +kernel_size = 3 +groups = 8 +min_value = 0.0 +max_value = 1.0 +dropout_p = 0.2 + +def get_inputs(): + return [torch.rand(batch_size, in_channels, depth, height, width)] + +def get_init_inputs(): + return [in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p] + diff --git a/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py b/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py index de9ccb40..83947d99 100644 --- a/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py +++ b/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py @@ -16,7 +16,6 @@ def forward(self, x): x = torch.sigmoid(x) * x # Swish x = torch.tanh(x) x = torch.nn.functional.gelu(x) # GELU - x = torch.nn.functional.hardtanh(x, min_val=-1, max_val=1) # Hardtanh return x batch_size = 1024 diff --git a/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh_OLD.py b/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh_OLD.py new file mode 100644 index 00000000..3579e67a --- /dev/null +++ b/KernelBench/level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh_OLD.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Simple model that performs a matrix multiplication, adds a value, applies Swish, Tanh, GELU, and Hardtanh activation functions. + """ + def __init__(self, in_features, out_features, add_value_shape): + super(Model, self).__init__() + self.matmul = nn.Linear(in_features, out_features) + self.add_value = nn.Parameter(torch.randn(add_value_shape)) + + def forward(self, x): + x = self.matmul(x) + x = x + self.add_value + x = torch.sigmoid(x) * x # Swish + x = torch.tanh(x) + x = torch.nn.functional.gelu(x) # GELU + x = torch.nn.functional.hardtanh(x, min_val=-1, max_val=1) # Hardtanh + return x + +batch_size = 1024 +in_features = 8192 +out_features = 8192 +add_value_shape = (out_features,) + +def get_inputs(): + return [torch.rand(batch_size, in_features)] + +def get_init_inputs(): + return [in_features, out_features, add_value_shape] + diff --git a/KernelBench/level3/36_LSTMHn.py b/KernelBench/level3/36_LSTMHn.py index b3365931..b44daccd 100644 --- a/KernelBench/level3/36_LSTMHn.py +++ b/KernelBench/level3/36_LSTMHn.py @@ -15,7 +15,6 @@ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0 super(Model, self).__init__() # Initialize hidden state with random values self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) - self.fc = nn.Linear(hidden_size, output_size) def forward(self, x,h0,c0): """ @@ -26,11 +25,7 @@ def forward(self, x,h0,c0): """ # Forward propagate LSTM - out, state = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) - - # Decode the hidden state of the last time step - out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size) - + _, state = self.lstm(x, (h0, c0)) return state[0] # Test code diff --git a/KernelBench/level3/36_LSTMHn_OLD.py b/KernelBench/level3/36_LSTMHn_OLD.py new file mode 100644 index 00000000..73d0b456 --- /dev/null +++ b/KernelBench/level3/36_LSTMHn_OLD.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0): + """ + Initialize the LSTM model. + + :param input_size: The number of expected features in the input `x` + :param hidden_size: The number of features in the hidden state `h` + :param num_layers: Number of recurrent layers + :param output_size: The number of output features + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout` + """ + super(Model, self).__init__() + # Initialize hidden state with random values + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) + self.fc = nn.Linear(hidden_size, output_size) + + def forward(self, x,h0,c0): + """ + Forward pass through the LSTM model. + + :param x: The input tensor, shape (batch_size, sequence_length, input_size) + :return: The output tensor, shape (batch_size, sequence_length, output_size) + """ + + # Forward propagate LSTM + out, state = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) + + # Decode the hidden state of the last time step + out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size) + + return state[0] + +# Test code +batch_size = 10 +sequence_length = 512 +input_size = 128 +hidden_size = 256 +num_layers = 6 +output_size = 10 +dropout = 0.0 + +def get_inputs(): + return [torch.rand(batch_size, sequence_length, input_size),torch.rand((num_layers, batch_size, hidden_size)),torch.rand((num_layers, batch_size, hidden_size))] + +def get_init_inputs(): + return [input_size, hidden_size, num_layers, output_size, dropout] + diff --git a/KernelBench/level3/37_LSTMCn.py b/KernelBench/level3/37_LSTMCn.py index 06f4a326..1c710a7b 100644 --- a/KernelBench/level3/37_LSTMCn.py +++ b/KernelBench/level3/37_LSTMCn.py @@ -15,7 +15,6 @@ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0 super(Model, self).__init__() # Initialize hidden state with random values self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) - self.fc = nn.Linear(hidden_size, output_size) def forward(self, x, h0, c0): """ @@ -26,11 +25,7 @@ def forward(self, x, h0, c0): """ # Forward propagate LSTM - out, state = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) - - # Decode the hidden state of the last time step - out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size) - + _, state = self.lstm(x, (h0, c0)) return state[1] # Test code diff --git a/KernelBench/level3/37_LSTMCn_OLD.py b/KernelBench/level3/37_LSTMCn_OLD.py new file mode 100644 index 00000000..98e01a34 --- /dev/null +++ b/KernelBench/level3/37_LSTMCn_OLD.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0): + """ + Initialize the LSTM model. + + :param input_size: The number of expected features in the input `x` + :param hidden_size: The number of features in the hidden state `h` + :param num_layers: Number of recurrent layers + :param output_size: The number of output features + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout` + """ + super(Model, self).__init__() + # Initialize hidden state with random values + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) + self.fc = nn.Linear(hidden_size, output_size) + + def forward(self, x, h0, c0): + """ + Forward pass through the LSTM model. + + :param x: The input tensor, shape (batch_size, sequence_length, input_size) + :return: The output tensor, shape (batch_size, sequence_length, output_size) + """ + + # Forward propagate LSTM + out, state = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) + + # Decode the hidden state of the last time step + out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size) + + return state[1] + +# Test code +batch_size = 10 +sequence_length = 512 +input_size = 128 +hidden_size = 256 +num_layers = 6 +output_size = 10 +dropout = 0.0 + +def get_inputs(): + return [torch.rand(batch_size, sequence_length, input_size),torch.rand((num_layers, batch_size, hidden_size)),torch.rand((num_layers, batch_size, hidden_size))] + +def get_init_inputs(): + return [input_size, hidden_size, num_layers, output_size, dropout] + diff --git a/KernelBench/level3/49_Mamba2ReturnFinalState.py b/KernelBench/level3/49_Mamba2ReturnFinalState.py index e0d70bf0..e2548357 100644 --- a/KernelBench/level3/49_Mamba2ReturnFinalState.py +++ b/KernelBench/level3/49_Mamba2ReturnFinalState.py @@ -57,12 +57,7 @@ def forward(self, X, initial_states=None): A_blocks = rearrange(A_blocks, "b c l h -> b h c l") A_cumsum = torch.cumsum(A_blocks, dim=-1) - # 1. Compute diagonal block outputs - L = torch.exp(self.segsum(A_blocks)) - Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", - C_blocks, B_blocks, L, X_blocks) - - # 2. Compute intra-chunk states + # Compute intra-chunk states decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B_blocks, decay_states, X_blocks) diff --git a/KernelBench/level3/49_Mamba2ReturnFinalState_OLD.py b/KernelBench/level3/49_Mamba2ReturnFinalState_OLD.py new file mode 100644 index 00000000..12eb2179 --- /dev/null +++ b/KernelBench/level3/49_Mamba2ReturnFinalState_OLD.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +class Model(nn.Module): + def __init__(self, batch_size, seq_length, n_heads, d_head, d_state, block_len=64): + """ + Mamba Structured State Space model implementation for benchmarking. + + :param batch_size: Size of the batch + :param seq_length: Length of the input sequence + :param n_heads: Number of attention heads + :param d_head: Dimension of each head + :param d_state: Dimension of the state space + :param block_len: Length of each block for chunked computation + """ + super(Model, self).__init__() + + assert seq_length % block_len == 0, "Sequence length must be divisible by block length" + + self.batch_size = batch_size + self.seq_length = seq_length + self.n_heads = n_heads + self.d_head = d_head + self.d_state = d_state + self.block_len = block_len + + # Initialize parameters + self.A = nn.Parameter(torch.randn(batch_size, seq_length, n_heads)) + self.B = nn.Parameter(torch.randn(batch_size, seq_length, n_heads, d_state)) + self.C = nn.Parameter(torch.randn(batch_size, seq_length, n_heads, d_state)) + + def segsum(self, x): + """Naive segment sum calculation.""" + T = x.size(-1) + x_cumsum = torch.cumsum(x, dim=-1) + x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + def forward(self, X, initial_states=None): + """ + Forward pass implementing the SSD operation. + + :param X: Input tensor of shape (batch, length, n_heads, d_head) + :param initial_states: Optional initial states + :return: Output tensor Y and final state + """ + # Rearrange into blocks/chunks + X_blocks, A_blocks, B_blocks, C_blocks = [ + rearrange(x, "b (c l) ... -> b c l ...", l=self.block_len) + for x in (X, self.A, self.B, self.C) + ] + + A_blocks = rearrange(A_blocks, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A_blocks, dim=-1) + + # 1. Compute diagonal block outputs + L = torch.exp(self.segsum(A_blocks)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", + C_blocks, B_blocks, L, X_blocks) + + # 2. Compute intra-chunk states + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", + B_blocks, decay_states, X_blocks) + + # 3. Compute inter-chunk recurrence + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + + decay_chunk = torch.exp(self.segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + return new_states[:, -1] + +# Test parameters +batch_size = 2048 +seq_length = 128 +n_heads = 8 +d_head = 64 +d_state = 16 +block_len = 64 + +def get_inputs(): + return [torch.rand(batch_size, seq_length, n_heads, d_head)] + +def get_init_inputs(): + return [batch_size, seq_length, n_heads, d_head, d_state, block_len] + diff --git a/temporary_tests/test_constant_fill_fixes.py b/temporary_tests/test_constant_fill_fixes.py new file mode 100644 index 00000000..15fd04a3 --- /dev/null +++ b/temporary_tests/test_constant_fill_fixes.py @@ -0,0 +1,94 @@ +""" +Test that constant-fill problems produce constant outputs (OLD) +and varying outputs after fix. + +Run with: pytest tests/test_constant_fill_fixes.py -v +Or directly: python tests/test_constant_fill_fixes.py +""" +import os +import sys +import importlib.util +import torch + +KERNEL_BENCH_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../KernelBench")) + + +def load_model_from_file(filepath): + """Load Model class and input functions from a KernelBench file.""" + spec = importlib.util.spec_from_file_location("module", filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.Model, module.get_inputs, module.get_init_inputs + + +def check_constant_vs_varying(old_path, new_path, atol=1e-5): + """ + Verify OLD model produces constant output, NEW model produces varying output. + Returns (old_is_constant, new_varies) booleans. + """ + OldModel, get_inputs, get_init_inputs = load_model_from_file(old_path) + NewModel, _, _ = load_model_from_file(new_path) + + torch.manual_seed(42) + init_inputs = get_init_inputs() + + old_model = OldModel(*init_inputs).eval() + new_model = NewModel(*init_inputs).eval() + + with torch.no_grad(): + torch.manual_seed(1) + x1 = get_inputs()[0] + torch.manual_seed(2) + x2 = get_inputs()[0] + + old_out1, old_out2 = old_model(x1), old_model(x2) + new_out1, new_out2 = new_model(x1), new_model(x2) + + # OLD should be constant (approximately zero or same for different inputs) + old_is_constant = torch.allclose(old_out1, old_out2, atol=atol) + + # NEW should vary with input + new_varies = not torch.allclose(new_out1, new_out2, atol=atol) + + return old_is_constant, new_varies + + +def test_80_gemm_max_subtract_gelu(): + """mean(dim=1) on (B,1) → value itself → x - mean = 0""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/80_Gemm_Max_Subtract_GELU_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/80_Gemm_Max_Subtract_GELU.py") + + old_const, new_varies = check_constant_vs_varying(old_path, new_path) + assert old_const, "OLD should produce constant output" + assert new_varies, "NEW should produce varying output" + print("✓ test_80_gemm_max_subtract_gelu passed") + + +def test_83_conv3d_groupnorm_min_clamp_dropout(): + """min(x,0) + clamp(min=0) → all zeros""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py") + + old_const, new_varies = check_constant_vs_varying(old_path, new_path) + assert old_const, "OLD should produce constant output" + assert new_varies, "NEW should produce varying output" + print("✓ test_83_conv3d_groupnorm_min_clamp_dropout passed") + + +def test_23_conv3d_groupnorm_mean(): + """GroupNorm zero-mean → global mean ≈ 0""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/23_Conv3d_GroupNorm_Mean_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/23_Conv3d_GroupNorm_Mean.py") + + old_const, new_varies = check_constant_vs_varying(old_path, new_path) + assert old_const, "OLD should produce constant output" + assert new_varies, "NEW should produce varying output" + print("✓ test_23_conv3d_groupnorm_mean passed") + + +if __name__ == "__main__": + test_80_gemm_max_subtract_gelu() + test_83_conv3d_groupnorm_min_clamp_dropout() + test_23_conv3d_groupnorm_mean() + print("\nAll tests passed!") + diff --git a/temporary_tests/test_redundant_op_fixes.py b/temporary_tests/test_redundant_op_fixes.py new file mode 100644 index 00000000..2147b007 --- /dev/null +++ b/temporary_tests/test_redundant_op_fixes.py @@ -0,0 +1,127 @@ +""" +Test that removing redundant operations produces equivalent outputs. + +Run with: pytest tests/test_redundant_op_fixes.py -v +Or directly: python tests/test_redundant_op_fixes.py +""" +import os +import importlib.util +import torch + +KERNEL_BENCH_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../KernelBench")) + + +def load_model_from_file(filepath): + """Load Model class and input functions from a KernelBench file.""" + spec = importlib.util.spec_from_file_location("module", filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.Model, module.get_inputs, module.get_init_inputs + + +def check_equivalence(old_path, new_path, atol=1e-5): + """ + Verify OLD and NEW models produce equivalent outputs. + Returns True if outputs match within tolerance. + """ + OldModel, get_inputs, get_init_inputs = load_model_from_file(old_path) + NewModel, _, _ = load_model_from_file(new_path) + + torch.manual_seed(42) + init_inputs = get_init_inputs() + + old_model = OldModel(*init_inputs).eval() + new_model = NewModel(*init_inputs).eval() + + # Copy weights from old to new (they may have different params due to removed layers) + old_state = old_model.state_dict() + new_state = new_model.state_dict() + # Only copy matching keys + for key in new_state: + if key in old_state: + new_state[key] = old_state[key] + new_model.load_state_dict(new_state) + + with torch.no_grad(): + torch.manual_seed(123) + inputs = get_inputs() + + old_out = old_model(*inputs) + new_out = new_model(*inputs) + + return torch.allclose(old_out, new_out, atol=atol) + + +def test_44_double_global_avg_pool(): + """Second global avg pool is no-op (tensor already 1x1 after first)""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_44_double_global_avg_pool passed") + + +def test_95_hardtanh_after_tanh_gelu(): + """Hardtanh redundant: tanh→GELU output is already in [-1,1]""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_95_hardtanh_after_tanh_gelu passed") + + +def test_81_clamp_after_tanh(): + """Clamp [-1,1] after tanh is redundant (tanh already outputs [-1,1])""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_81_clamp_after_tanh passed") + + +def test_7_leakyrelu_after_relu(): + """LeakyReLU after ReLU is identity (all values already ≥0)""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_7_leakyrelu_after_relu passed") + + +def test_36_lstm_hn_dead_fc(): + """fc layer is dead code (computes but returns h_n instead)""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level3/36_LSTMHn_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level3/36_LSTMHn.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_36_lstm_hn_dead_fc passed") + + +def test_37_lstm_cn_dead_fc(): + """fc layer is dead code (computes but returns c_n instead)""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level3/37_LSTMCn_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level3/37_LSTMCn.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_37_lstm_cn_dead_fc passed") + + +def test_49_mamba2_dead_y_diag(): + """Y_diag and L are computed but never used in return value""" + old_path = os.path.join(KERNEL_BENCH_PATH, "level3/49_Mamba2ReturnFinalState_OLD.py") + new_path = os.path.join(KERNEL_BENCH_PATH, "level3/49_Mamba2ReturnFinalState.py") + + assert check_equivalence(old_path, new_path), "Outputs should be equivalent" + print("✓ test_49_mamba2_dead_y_diag passed") + + +if __name__ == "__main__": + test_44_double_global_avg_pool() + test_95_hardtanh_after_tanh_gelu() + test_81_clamp_after_tanh() + test_7_leakyrelu_after_relu() + test_36_lstm_hn_dead_fc() + test_37_lstm_cn_dead_fc() + test_49_mamba2_dead_y_diag() + print("\nAll equivalence tests passed!") +