diff --git a/README.md b/README.md index 0baeb33..874c419 100644 --- a/README.md +++ b/README.md @@ -4,82 +4,57 @@ A lightweight neural network library written in C11 for embedded systems. ## Overview -cTensor is a compact tensor computation library designed for small client-side devices, such as mobile phones, microcontrollers. The library implements automatic differentiation and dynamic compute graph functionality, allowing for efficient training and deployment of neural networks on resource-constrained devices. +cTensor is a compact tensor computation library designed for small client-side devices, such as mobile phones and microcontrollers. The library implements automatic differentiation and dynamic compute graph functionality, allowing for efficient training and deployment of neural networks on resource-constrained devices. -## Current Status - -This project is under active development. The prototype demonstrates basic tensor operations and neural network functionality using the Iris dataset as an example. Many core mathematical operators and features are still being implemented. +This library was developed as part of GSoC 2025 and has been successfully validated on ARM Cortex-M3 microcontrollers, achieving 90% classification accuracy on the Iris dataset in a bare-metal environment. ## Features -### Currently Implemented - +### Core Infrastructure - **Lightweight C11 Implementation:** Minimal dependencies for wide compatibility -- **Automatic Differentiation Framework:** Basic gradient computation infrastructure -- **Dynamic Compute Graph:** Groundwork for efficient computation flow -- **Basic Tensor Operations:** - - Basic arithmetic: add, subtract, multiply, divide, power - - Element-wise operations: square, reciprocal - - Matrix multiplication - - Tensor transpose -- **Reduction Operations:** - - Sum (all elements or along dimension) - - Mean (all elements or along dimension) - - Max (all elements or along dimension with indices) - - Min (all elements or along dimension with indices) - - Argmax function -- **Neural Network Components:** - - Linear layer - - Activation functions: ReLU, Sigmoid, Softmax - - Cross-entropy loss - - Softmax cross-entropy (combined operation) - - Glorot weight initialization -- **SGD Optimizer:** Stochastic gradient descent implementation -- **Memory Management:** Pool-based memory allocation system -- **Tensor Utilities:** - - Element access and manipulation - - Tensor detachment - - Tensor unsqueeze operation - - Broadcasting support for element-wise operations - - Dataset normalization and shuffling utilities - -### Development Roadmap - -The following features are planned for implementation: - -#### Math Operators -- **Unary Operations:** - - Negative (Tensor_neg) - - Absolute value (Tensor_abs) -- **Mathematical Functions:** - - Logarithm (nn_log) - - Exponential (nn_exp) - - Trigonometric functions (nn_sin, nn_cos, nn_tan) - -#### Broadcasting System Enhancements -- Broadcasting for Matmul - -#### Activation Functions -- ELU (Exponential Linear Unit) -- SELU (Scaled Exponential Linear Unit) -- Additional activation functions - -#### Loss Functions -- Mean Squared Error (MSE) -- Mean Absolute Error (MAE) -- Huber Loss -- Enhanced multi-class classification losses - -#### Advanced Optimizers -- Adam optimizer -- RMSProp optimizer -- AdaGrad optimizer -- Weight decay implementation -- Gradient clipping - -#### Performance Enhancements -- Profiling and benchmarking infrastructure -- Loop unrolling and SIMD optimizations where applicable +- **Automatic Differentiation Framework:** Complete gradient computation with backward pass +- **Dynamic Compute Graph:** Efficient computation flow with gradient tracking +- **Pool-based Memory Management:** Efficient memory allocation system for embedded devices + +### Tensor Operations +- **Basic Arithmetic:** add, subtract, multiply, divide, power (both tensor-tensor and tensor-scalar) +- **Unary Operations:** negation, absolute value, square, reciprocal +- **Matrix Operations:** matrix multiplication, transpose +- **Mathematical Functions:** logarithm, exponential, sine, cosine, tangent +- **Shape Operations:** unsqueeze, detach +- **Broadcasting:** Element-wise broadcasting for operations on tensors with different shapes + +### Reduction Operations +- **Sum:** All elements or along specific dimension +- **Mean:** All elements or along specific dimension +- **Max/Min:** All elements or along dimension with indices +- **Argmax:** Find indices of maximum values + +### Neural Network Components +- **Layers:** Linear (fully connected) layer +- **Activation Functions:** ReLU, Sigmoid, Tanh, ELU, SELU, Softmax +- **Loss Functions:** Cross-entropy, Softmax Cross-entropy, MSE, MAE, Huber Loss +- **Weight Initialization:** Glorot/Xavier initialization + +### Optimizers +- **SGD:** Stochastic Gradient Descent with momentum +- **Adam:** Adaptive moment estimation +- **RMSProp:** Root Mean Square Propagation +- **AdaGrad:** Adaptive Gradient Algorithm +- **Features:** Weight decay support for all optimizers + +### Training Utilities +- **Gradient Clipping:** By norm, value, range, positive/negative values +- **Evaluation Mode:** Disable gradient computation for inference +- **Dataset Utilities:** Normalization, shuffling + +## Validation + +cTensor has been successfully deployed and tested on: +- **ARM Cortex-M3 (STM32F103ZE)** using Keil MDK simulation +- **Task:** Neural network classification on Iris dataset +- **Result:** 90% accuracy matching desktop performance +- **Complete validation project:** [cTensor_Cortex_SIM](https://github.com/PrimedErwin/cTensor_Cortex_SIM) ## Getting Started @@ -121,8 +96,6 @@ and run `main.exe` from root directory cTensor uses a custom test framework. To run the tests: -For a more detailed guide, refer to [Testing Documentation](tests/README.md). - ```bash # Build the test executable with CMake mkdir -p build && cd build @@ -133,42 +106,105 @@ cmake --build . ./cten_exe ``` +For detailed testing information, refer to [Testing Documentation](tests/README.md). + ## Usage Example -The repository includes a simple example in `src2/main.c` that demonstrates how to train a neural network on the Iris dataset: +Here's a complete example of training a neural network to predict sine wave values with noise: ```c #include "cten.h" #include +#include +#include + +// Define memory pools +enum MemoryPoolIds { + PoolId_Default = 0, + PoolId_Model = 1, + PoolId_Optimizer = 2, +}; + +// Define the model structure +typedef struct { + Tensor w1, b1; + Tensor w2, b2; + Tensor w3, b3; +} Model; + +// Forward pass for the model +Tensor Model_forward(Model* model, Tensor x) { + x = nn_linear(x, model->w1, model->b1); + x = nn_elu(x, 1.0f); + x = nn_linear(x, model->w2, model->b2); + x = nn_elu(x, 1.0f); + x = nn_linear(x, model->w3, model->b3); + return x; +} int main() { - // Initialize cTensor library cten_initilize(); - - // Load the Iris dataset - const float (*X)[4]; - const int* y; - int num_samples = load_iris_dataset(&X, &y); - - // Create a simple neural network - TensorShape input_shape = {1, 4, 0, 0}; // 4 features - TensorShape hidden_shape = {4, 10, 0, 0}; // 10 hidden units - TensorShape output_shape = {10, 3, 0, 0}; // 3 classes (iris species) - - // Initialize network parameters with Glorot initialization - Tensor W1 = Glorot_init(hidden_shape, true); - Tensor b1 = Tensor_zeros((TensorShape){1, 10, 0, 0}, true); - Tensor W2 = Glorot_init(output_shape, true); - Tensor b2 = Tensor_zeros((TensorShape){1, 3, 0, 0}, true); - - // Setup optimizer - Tensor params[4] = {W1, b1, W2, b2}; - optim_sgd* optimizer = optim_sgd_new(4, params); - optim_sgd_config(optimizer, 0.01f, 0.9f); - + + // Generate sine wave data + int n_samples = 2048; + float* x_data = malloc(n_samples * sizeof(float)); + float* y_data = malloc(n_samples * sizeof(float)); + // ... (data generation logic) ... + + // Create model and allocate in its own memory pool + Model model; + cten_begin_malloc(PoolId_Model); + model.w1 = Glorot_init((TensorShape){1, 64}, true); + model.b1 = Tensor_zeros((TensorShape){1, 64}, true); + model.w2 = Glorot_init((TensorShape){64, 32}, true); + model.b2 = Tensor_zeros((TensorShape){1, 32}, true); + model.w3 = Glorot_init((TensorShape){32, 1}, true); + model.b3 = Tensor_zeros((TensorShape){1, 1}, true); + cten_end_malloc(); + + // Create optimizer + float learning_rate = 0.01f; + cten_begin_malloc(PoolId_Optimizer); + optim_adam* optimizer = optim_adam_new(6, (Tensor*)&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f); + cten_end_malloc(); + // Training loop - // ... - + int batch_size = 64; + for (int epoch = 0; epoch < 200; epoch++) { + // ... (training logic with batching, loss calculation, backpropagation) ... + + cten_begin_malloc(PoolId_Default); // for temporary tensors in each step + + // ... create input and y_true tensors ... + + optim_adam_zerograd(optimizer); + Tensor y_pred = Model_forward(&model, input); + + // Combined Loss + Tensor huber = nn_huber_loss(y_true, y_pred, 1.0f); + Tensor mae = nn_mae_loss(y_true, y_pred); + Tensor loss = Tensor_add(huber, Tensor_mulf(mae, 0.3f)); + + Tensor_backward(loss, Tensor_ones((TensorShape){1}, false)); + + // Gradient Clipping + cten_clip_grad_norm((Tensor*)&model, 6, 5.0f); + + optim_adam_step(optimizer); + + cten_end_malloc(); + cten_free(PoolId_Default); // free temporary tensors + } + + // Evaluate model + cten_begin_eval(); + // ... (evaluation logic) ... + cten_end_eval(); + + // Free memory pools + cten_free(PoolId_Optimizer); + cten_free(PoolId_Model); + cten_finalize(); return 0; } @@ -218,10 +254,25 @@ Tensor Tensor_powf(Tensor self, float other); Tensor Tensor_matmul(Tensor self, Tensor other); // Unary operations +Tensor Tensor_neg(Tensor self); +Tensor Tensor_abs(Tensor self); Tensor Tensor_square(Tensor self); Tensor Tensor_reciprocal(Tensor self); ``` +### Mathematical Functions + +```c +// Logarithmic and exponential +Tensor nn_log(Tensor self); +Tensor nn_exp(Tensor self); + +// Trigonometric functions +Tensor nn_sin(Tensor self); +Tensor nn_cos(Tensor self); +Tensor nn_tan(Tensor self); +``` + ### Reduction Operations ```c @@ -252,25 +303,58 @@ Tensor nn_linear(Tensor input, Tensor weight, Tensor bias); Tensor nn_relu(Tensor input); Tensor nn_sigmoid(Tensor input); Tensor nn_tanh(Tensor input); -Tensor nn_softmax(Tensor input); +Tensor nn_elu(Tensor self, float alpha); +Tensor nn_selu(Tensor self); +Tensor nn_softmax(Tensor input, int dim); // Loss functions Tensor nn_crossentropy(Tensor y_true, Tensor y_pred); Tensor nn_softmax_crossentropy(Tensor y_true, Tensor logits); +Tensor nn_mse_loss(Tensor y_true, Tensor y_pred); +Tensor nn_mae_loss(Tensor y_true, Tensor y_pred); +Tensor nn_huber_loss(Tensor y_true, Tensor y_pred, float delta); // Weight initialization Tensor Glorot_init(TensorShape shape, bool requires_grad); ``` -### Optimizer +### Optimizers ```c // SGD Optimizer -optim_sgd* optim_sgd_new(int n_params, Tensor* params); +optim_sgd* optim_sgd_new(int n_params, Tensor* params, float weight_decay); void optim_sgd_config(optim_sgd* self, float lr, float momentum); void optim_sgd_zerograd(optim_sgd* self); void optim_sgd_step(optim_sgd* self); -void optim_sgd_delete(optim_sgd* self); + +// Adam Optimizer +optim_adam* optim_adam_new(int n_params, Tensor* params, float lr, + float β1, float β2, float ε, float weight_decay); +void optim_adam_zerograd(optim_adam* self); +void optim_adam_step(optim_adam* self); + +// RMSProp Optimizer +optim_rmsprop* optim_rmsprop_new(int n_params, Tensor* params, float lr, + float β, float ε, float weight_decay); +void optim_rmsprop_zerograd(optim_rmsprop* self); +void optim_rmsprop_step(optim_rmsprop* self); + +// AdaGrad Optimizer +optim_adagrad* optim_adagrad_new(int n_params, Tensor* params, float lr, + float ε, float weight_decay); +void optim_adagrad_zerograd(optim_adagrad* self); +void optim_adagrad_step(optim_adagrad* self); +``` + +### Gradient Clipping + +```c +// Gradient clipping functions +void cten_clip_grad_norm(Tensor* params, int n_params, float max_norm); +void cten_clip_grad_value(Tensor* params, int n_params, float max_value); +void cten_clip_grad_value_range(Tensor* params, int n_params, float min_value, float max_value); +void cten_clip_grad_positive(Tensor* params, int n_params, float max_value); +void cten_clip_grad_negative(Tensor* params, int n_params, float min_value); ``` ### Utility Functions @@ -291,6 +375,11 @@ void Tensor_shuffle_dataset(const float (*X)[4], const int *y, float (*X_shuffle void cten_begin_eval(); bool cten_is_eval(); void cten_end_eval(); + +// Broadcasting +bool cten_elemwise_broadcast(Tensor* a, Tensor* b); +Tensor reduce_gradient_for_broadcasting(Tensor grad, TensorShape original_shape, + TensorShape broadcasted_shape); ``` ## Memory Management @@ -298,6 +387,8 @@ void cten_end_eval(); cTensor uses a pool-based memory management system to efficiently handle tensor allocations: ```c +void cten_initilize(); +void cten_finalize(); void cten_begin_malloc(PoolId id); void cten_end_malloc(); void cten_free(PoolId id); @@ -308,29 +399,52 @@ void cten_free(PoolId id); ``` cTensor/ ├── include/ # Header files defining the API -├── src/ # Core implementation files -│ ├── basic.c # Basic tensor operations -│ ├── nn.c # Neural network primitives -│ ├── operator.c # Mathematical operators +│ └── cten.h # Complete API header +├── src/ # Core implementation files +│ ├── basic.c # Basic tensor operations +│ ├── nn.c # Neural network primitives +│ ├── operator.c # Mathematical operators +│ ├── context.c # Memory management +│ ├── utils.c # Utility functions +│ ├── optimizer/ # Optimizer implementations │ └── ... -├── src2/ # Example applications -│ └── main.c # Iris dataset example -└── tests/ # Test suite +├── src2/ # Example applications +│ └── main.c # Sine regression example +└── tests/ # Test suite ``` -## API Reference -For a detailed API reference, refer to [API Documentation](API.md). +## Implemented Features Summary + +| Category | Components | Status | +|----------|------------|--------| +| **Core Structs** | `Tensor`, `GradNode`, `TensorMaxMinResult` | ✅ | +| **Autograd** | `Tensor_backward`, `requires_grad`, `detach` | ✅ | +| **Tensor Creation** | `Tensor_new`, `zeros`, `ones`, `Glorot_init` | ✅ | +| **Binary Operations** | `add`, `sub`, `mul`, `div`, `pow`, `matmul` | ✅ | +| **Unary Operations** | `neg`, `abs`, `square`, `reciprocal` | ✅ | +| **Math Functions** | `log`, `exp`, `sin`, `cos`, `tan` | ✅ | +| **Aggregations** | `sum`, `mean`, `max`, `min` (with indices) | ✅ | +| **Search/Sort** | `argmax` | ✅ | +| **Shape Operations** | `transpose`, `unsqueeze` | ✅ | +| **NN Layers** | `nn_linear` | ✅ | +| **Activations** | `ReLU`, `Sigmoid`, `Tanh`, `ELU`, `SELU`, `Softmax` | ✅ | +| **Loss Functions** | `CrossEntropy`, `MSE`, `MAE`, `Huber` | ✅ | +| **Optimizers** | `SGD`, `Adam`, `RMSProp`, `AdaGrad` | ✅ | +| **Training Utils** | `Gradient Clipping`, `Evaluation Mode`, `Weight Decay` | ✅ | ## Contributing -Contributions to cTensor are welcome! The project needs implementation of various components as outlined in the Development Roadmap section. Key areas for contribution include: +Contributions to cTensor are welcome! Key areas for contribution include: + +1. **Performance Optimization:** Benchmarking and SIMD implementations +2. **Advanced Layers:** Convolutional and recurrent neural network layers +3. **Documentation:** Examples, tutorials, and API documentation improvements +4. **Testing:** Expanding test coverage and validation on different platforms + +## GSoC 2025 Acknowledgments -1. **Activation Functions:** Implementing additional activation functions (ELU, SELU) with gradient support -2. **Loss Functions:** Adding more loss functions (MSE, MAE, Huber) with gradient support -3. **Advanced Optimizers:** Creating additional optimizers beyond SGD (Adam, RMSProp, AdaGrad) -4. **Performance Optimization:** Enhancing computational efficiency through benchmarking and optimizations -5. **Documentation:** Improving examples, tutorials, and API documentation +This project was developed during Google Summer of Code 2025 by [Advait Gaur](https://github.com/Advaitgaur004) under the mentorship of [PrimedErwin](https://github.com/PrimedErwin), [Anurag Bhat](https://github.com/faze-geek), and [blueloveTH](https://github.com/blueloveTH). The project successfully transformed cTensor from a basic prototype into a functional deep learning framework suitable for embedded applications. ## License -This project is licensed under the MIT License - see the LICENSE file for details. \ No newline at end of file +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. \ No newline at end of file diff --git a/src/operator.c b/src/operator.c index d4f9e63..5004d21 100644 --- a/src/operator.c +++ b/src/operator.c @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef Tensor_mean #undef Tensor_mean @@ -226,40 +227,91 @@ Tensor Tensor_sum(Tensor self, ...) { static Tensor GradFn_matmul(Tensor self, int i) { return Tensor_transpose(Tensor_detach(self.node->inputs[1 - i])); - ; } Tensor Tensor_matmul(Tensor self, Tensor other) { int self_dim = TensorShape_dim(self.shape); int other_dim = TensorShape_dim(other.shape); + assert(self_dim >= 2); assert(other_dim >= 2); + int batch_self = (self_dim >= 3) ? self.shape[0] : 1; + int batch_other = (other_dim >= 3) ? other.shape[0] : 1; + int batch = (batch_self > batch_other) ? batch_self : batch_other; + + int group_self = (self_dim == 4) ? self.shape[1] : 1; + int group_other = (other_dim == 4) ? other.shape[1] : 1; + int group = (group_self > group_other) ? group_self : group_other; + int m = self.shape[self_dim - 2]; int n = self.shape[self_dim - 1]; int p = other.shape[other_dim - 1]; assert(n == other.shape[other_dim - 2]); - TensorShape res_shape; - memcpy(res_shape, self.shape, sizeof(TensorShape)); - res_shape[self_dim - 1] = p; - Tensor res = Tensor_new( - res_shape, - self.node != NULL || - other.node != NULL); // here weight/bias have .node != NULL, so res have GradNode - - for(int i = 0; i < m; i++) { - for(int j = 0; j < p; j++) { - float sum = 0; - for(int k = 0; k < n; k++) { - sum += self.data->flex[i * n + k] * other.data->flex[k * p + j]; + bool has4D = (self_dim == 4 || other_dim == 4); + + TensorShape res_shape = {0, 0, 0, 0}; + if (self_dim <= 2 && other_dim <= 2) { + res_shape[0] = m; + res_shape[1] = p; + } else { + res_shape[0] = batch; + if (has4D) { + res_shape[1] = group; + res_shape[2] = m; + res_shape[3] = p; + } else { + res_shape[1] = m; + res_shape[2] = p; + res_shape[3] = 0; + } + } + + Tensor res = Tensor_new(res_shape, self.node != NULL || other.node != NULL); + + for (int b = 0; b < batch; b++) { + int self_b = (batch_self <= b) ? batch_self - 1 : b; + int other_b = (batch_other <= b) ? batch_other - 1 : b; + + for (int g = 0; g < group; g++) { + int self_g = (group_self <= g) ? group_self - 1 : g; + int other_g = (group_other <= g) ? group_other - 1 : g; + + int offset_self = 0; + if (self_dim == 4) { + offset_self = self_b * self.shape[1] * m * n + self_g * m * n; + } else if (self_dim == 3) { + offset_self = self_b * m * n; + } + + int offset_other = 0; + if (other_dim == 4) { + offset_other = other_b * other.shape[1] * n * p + other_g * n * p; + } else if (other_dim == 3) { + offset_other = other_b * n * p; + } + + int offset_res = ((batch > 1) ? b * group + g : g) * m * p; + + float* self_ptr = self.data->flex + offset_self; + float* other_ptr = other.data->flex + offset_other; + float* res_ptr = res.data->flex + offset_res; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < p; j++) { + float sum = 0; + for (int k = 0; k < n; k++) { + sum += self_ptr[i * n + k] * other_ptr[k * p + j]; + } + res_ptr[i * p + j] = sum; + } } - res.data->flex[i * p + j] = sum; } } - if(res.node != NULL) { + if (res.node != NULL) { res.node->grad_fn = GradFn_matmul; res.node->inputs[0] = self; res.node->inputs[1] = other; diff --git a/src2/main.c b/src2/main.c index 870e0db..5e30a71 100644 --- a/src2/main.c +++ b/src2/main.c @@ -4,108 +4,121 @@ #include #include +#define PI 3.14159265358979323846 + enum MemoryPoolIds { PoolId_Default = 0, PoolId_Model = 1, PoolId_Optimizer = 2, }; -typedef struct Model { - Tensor weight_1, weight_2; - Tensor bias_1, bias_2; +typedef struct { + Tensor w1, b1; + Tensor w2, b2; + Tensor w3, b3; } Model; Tensor Model_forward(Model* model, Tensor x) { - x = nn_linear(x, model->weight_1, model->bias_1); - x = nn_relu(x); - x = nn_linear(x, model->weight_2, model->bias_2); + x = nn_linear(x, model->w1, model->b1); + x = nn_elu(x, 1.0f); + x = nn_linear(x, model->w2, model->b2); + x = nn_elu(x, 1.0f); + x = nn_linear(x, model->w3, model->b3); return x; } +float rand_float() { + return (float)rand() / (RAND_MAX / 2.0f) - 1.0f; +} + +void generate_sine_data(float* x_data, float* y_data, int n_samples, float noise_level) { + for (int i = 0; i < n_samples; i++) { + x_data[i] = rand_float() * 4.0f * PI; + + // Generate Gaussian noise using the Box-Muller transform + float u1 = ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f); + float u2 = ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f); + float z = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * PI * u2); + + y_data[i] = sin(x_data[i]) + z * noise_level; + } +} int main() { cten_initilize(); - - // load iris dataset - const float(*X)[4]; - const int* y; - int n_samples = load_iris_dataset(&X, &y); - int n_features = 4; - int n_classes = 3; - - // Shuffle the dataset - float (*X_shuffled)[4] = malloc(n_samples * sizeof(*X_shuffled)); - int* y_shuffled = malloc(n_samples * sizeof(int)); - Tensor_shuffle_dataset(X, y, X_shuffled, y_shuffled, n_samples, n_features); - X = (const float(*)[4])X_shuffled; - y = (const int*)y_shuffled; - - int n_train_samples = n_samples * 0.8; - int n_test_samples = n_samples - n_train_samples; - - printf("n_samples: %d\n", n_samples); - printf("n_train_samples: %d\n", n_train_samples); - printf("n_test_samples: %d\n", n_test_samples); - - //normalize the dataset - float(*X_norm)[4] = malloc(n_samples * sizeof(*X_norm)); - Tensor_normalize_dataset(X, X_norm, n_samples, n_train_samples, n_features); - X = (const float(*)[4])X_norm; + + // Generating Sine Data + int n_samples = 2048; + int n_train_samples = n_samples * 0.8; + int n_test_samples = n_samples - n_train_samples; + float* x_data = malloc(n_samples * sizeof(float)); + float* y_data = malloc(n_samples * sizeof(float)); + generate_sine_data(x_data, y_data, n_samples, 0.05f); // create model Model model; cten_begin_malloc(PoolId_Model); - model.weight_1 = Glorot_init((TensorShape){n_features, 32}, true); - model.bias_1 = Tensor_zeros((TensorShape){1, 32}, true); - model.weight_2 = Glorot_init((TensorShape){32, n_classes}, true); - model.bias_2 = Tensor_zeros((TensorShape){1, n_classes}, true); + model.w1 = Glorot_init((TensorShape){1, 64}, true); + model.b1 = Tensor_zeros((TensorShape){1, 64}, true); + model.w2 = Glorot_init((TensorShape){64, 32}, true); + model.b2 = Tensor_zeros((TensorShape){1, 32}, true); + model.w3 = Glorot_init((TensorShape){32, 1}, true); + model.b3 = Tensor_zeros((TensorShape){1, 1}, true); cten_end_malloc(); // create optimizer + float learning_rate = 0.01f; cten_begin_malloc(PoolId_Optimizer); - optim_sgd* optimizer = optim_sgd_new(4, (Tensor*)&model); - optim_sgd_config(optimizer, 0.01f, 0.0f); + optim_adam* optimizer = optim_adam_new(6, (Tensor*)&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f); cten_end_malloc(); // train model - int batch_size = 8; - for(int epoch = 0; epoch < 3; epoch++) { - printf("==> epoch: %d\n", epoch); - float epoch_loss = 0.0f; + int batch_size = 64; + for (int epoch = 0; epoch < 200; epoch++) { + // Manual Learning Rate Scheduler + if (epoch > 0 && epoch % 100 == 0) { + learning_rate *= 0.7f; + printf("Epoch %d: Learning rate decreased to %f\n", epoch, learning_rate); + } + + float total_loss = 0.0f; int num_batches = 0; - for(int i = 0; i < n_train_samples; i += batch_size) { - int actual_batch_size = i + batch_size <= n_train_samples ? batch_size : n_train_samples - i; - printf(" batch: %d/%d samples\n", i, n_train_samples); - cten_begin_malloc(PoolId_Default); - Tensor input = Tensor_zeros((TensorShape){actual_batch_size, n_features}, false); - Tensor y_true = Tensor_zeros((TensorShape){actual_batch_size, n_classes}, false); - - for(int j = 0; j < actual_batch_size; j++) { - for(int k = 0; k < n_features; k++) { - input.data->flex[j * n_features + k] = X[i + j][k]; - } - // one-hot encoding - y_true.data->flex[j * n_classes + y[i + j]] = 1.0f; + for (int i = 0; i < n_train_samples; i += batch_size) { + int current_batch_size = (i + batch_size > n_train_samples) ? (n_train_samples - i) : batch_size; + cten_begin_malloc(PoolId_Default); + Tensor input = Tensor_zeros((TensorShape){current_batch_size, 1}, false); + Tensor y_true = Tensor_zeros((TensorShape){current_batch_size, 1}, false); + + for (int j = 0; j < current_batch_size; j++) { + input.data->flex[j] = x_data[i + j]; + y_true.data->flex[j] = y_data[i + j]; } - // zero the gradients - optim_sgd_zerograd(optimizer); - // forward pass - Tensor logit = Model_forward(&model, input); - Tensor loss = nn_softmax_crossentropy(y_true, logit); - epoch_loss += loss.data->flex[0]; - num_batches++; + + optim_adam_zerograd(optimizer); + Tensor y_pred = Model_forward(&model, input); - Tensor grad = Tensor_ones((TensorShape){1}, false); - Tensor_backward(loss, grad); + // Combined Loss: Huber + 30% MAE + Tensor huber = nn_huber_loss(y_true, y_pred, 1.0f); + Tensor mae = nn_mae_loss(y_true, y_pred); + Tensor loss = Tensor_add(huber, Tensor_mulf(mae, 0.3f)); + total_loss += loss.data->flex[0]; + num_batches++; + + Tensor_backward(loss, Tensor_ones((TensorShape){1}, false)); - optim_sgd_step(optimizer); + // Gradient Clipping + cten_clip_grad_norm((Tensor*)&model, 6, 5.0f); + + optim_adam_step(optimizer); cten_end_malloc(); // free temporary tensors cten_free(PoolId_Default); } - printf("Epoch %d average loss: %.6f\n", epoch, epoch_loss / num_batches); + if (epoch % 50 == 0) { + printf("Epoch %d, Average Loss: %.6f\n", epoch, total_loss / num_batches); + } } // free optimizer @@ -113,31 +126,24 @@ int main() { // evaluate model cten_begin_eval(); - int correct = 0; - for(int i = n_train_samples; i < n_samples; i++) { + float total_test_mse = 0; + for (int i = n_train_samples; i < n_samples; i++) { cten_begin_malloc(PoolId_Default); - // prepare input and target - Tensor input = Tensor_zeros((TensorShape){1, n_features}, false); - Tensor y_true = Tensor_zeros((TensorShape){1, n_classes}, false); - for(int j = 0; j < n_features; j++) { - input.data->flex[j] = X[i][j]; + Tensor input = Tensor_zeros((TensorShape){1, 1}, false); + input.data->flex[0] = x_data[i]; + + Tensor y_pred = Model_forward(&model, input); + + float true_val = y_data[i]; + float pred_val = y_pred.data->flex[0]; + total_test_mse += (true_val - pred_val) * (true_val - pred_val); + + if (i%50 == 0) { + printf("Input: %.3f, True: %.3f, Predicted: %.3f\n", x_data[i], true_val, pred_val); } - y_true.data->flex[0 * n_classes + y[i]] = 1.0f; //Writing 0 here just to follow the architecture of the code - - // forward pass - Tensor logit = Model_forward(&model, input); - Tensor y_pred = nn_softmax(logit); - Tensor loss = nn_crossentropy(y_true, y_pred); - // calculate accuracy - int pred_classes[1]; - Tensor_argmax(y_pred, pred_classes); - if(pred_classes[0] == y[i]) correct++; - printf("Sample %d - True: %d, Pred: %d\n", i - n_train_samples, y[i], pred_classes[0]); - cten_end_malloc(); - // free temporary tensors cten_free(PoolId_Default); } - printf("accuracy: %.4f\n", (float)correct / n_test_samples); + printf("Final Test MSE: %.6f\n", total_test_mse / n_test_samples); cten_end_eval(); // free model diff --git a/tests/Operator/test_matmul.c b/tests/Operator/test_matmul.c index 4da838f..e0362fe 100644 --- a/tests/Operator/test_matmul.c +++ b/tests/Operator/test_matmul.c @@ -259,37 +259,160 @@ void test_matmul_operator() { } } - // TODO : Currently MatMul Doesnt support batch matrix multiplication - // - // // Test Case 8: Batch Matrix Multiplication - // { - // const char* tc_name = "matmul_batch_matrices"; - - // // Sub-test 1: Batch matrix multiplication (2x3x4 * 2x4x5) - // { - // TensorShape s1_shape = {2, 3, 4}; - // float d1[] = {0.9256f, 0.4219f, 0.3916f, 0.6438f, 0.8790f, 0.0543f, 0.0463f, 0.5632f, - // 0.7813f, 0.9841f, 0.7979f, 0.8884f, 0.5976f, 0.0739f, 0.8306f, 0.0435f, 0.2653f, - // 0.7424f, 0.9176f, 0.6326f, 0.2545f, 0.6777f, 0.9430f, 0.4921f}; TensorShape s2_shape - // = {2, 4, 5}; float d2[] = {0.1146f, 0.8401f, 0.0189f, 0.9417f, 0.9551f, 0.3073f, - // 0.5162f, 0.6919f, 0.3872f, 0.9831f, 0.8261f, 0.6104f, 0.1850f, 0.4844f, 0.0732f, - // 0.8003f, 0.3244f, 0.6337f, 0.4984f, 0.1917f, 0.5972f, 0.8280f, 0.1163f, 0.1445f, - // 0.5281f, 0.3753f, 0.7377f, 0.0097f, 0.0460f, 0.8825f, 0.1283f, 0.3434f, 0.9592f, - // 0.2614f, 0.8935f, 0.9233f, 0.1056f, 0.1819f, 0.9243f, 0.1263f}; TensorShape exp_shape - // = {2, 3, 5}; float exp_d[] = {1.0745f, 1.4433f, 0.7899f, 1.5456f, 1.4509f, 0.6064f, - // 0.9774f, 0.4197f, 1.1520f, 1.0043f, 1.7620f, 1.9396f, 1.4062f, 1.9461f, 1.9424f, - // 0.5314f, 0.8391f, 0.8748f, 0.3471f, 1.1284f, 1.1388f, 1.1492f, 1.0333f, - // 0.8970f, 1.6950f, 0.9817f, 1.0865f, 1.0302f, 0.7693f, 1.6373f}; - - // Tensor t1 = create_test_tensor(s1_shape, d1, false); - // Tensor t2 = create_test_tensor(s2_shape, d2, false); - // Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); - // Tensor actual_res = Tensor_matmul(t1, t2); - - // compare_tensors(&actual_res, &expected_res, op_name, tc_name, 1, - // TEST_FLOAT_TOLERANCE); - // } - // } + // Test Case 8: Batch Matrix Multiplication + + { + const char* tc_name = "matmul_batch_matrices"; + + // Sub-test 1: Batch matrix multiplication (2x3x4 * 2x4x5) - existing + { + TensorShape s1_shape = {2, 3, 4}; + float d1[] = {0.9256f, 0.4219f, 0.3916f, 0.6438f, 0.8790f, 0.0543f, 0.0463f, 0.5632f, + 0.7813f, 0.9841f, 0.7979f, 0.8884f, 0.5976f, 0.0739f, 0.8306f, 0.0435f, 0.2653f, + 0.7424f, 0.9176f, 0.6326f, 0.2545f, 0.6777f, 0.9430f, 0.4921f}; + + TensorShape s2_shape= {2, 4, 5}; + float d2[] = {0.1146f, 0.8401f, 0.0189f, 0.9417f, 0.9551f, 0.3073f, + 0.5162f, 0.6919f, 0.3872f, 0.9831f, 0.8261f, 0.6104f, 0.1850f, 0.4844f, 0.0732f, + 0.8003f, 0.3244f, 0.6337f, 0.4984f, 0.1917f, 0.5972f, 0.8280f, 0.1163f, 0.1445f, + 0.5281f, 0.3753f, 0.7377f, 0.0097f, 0.0460f, 0.8825f, 0.1283f, 0.3434f, 0.9592f, + 0.2614f, 0.8935f, 0.9233f, 0.1056f, 0.1819f, 0.9243f, 0.1263f}; + + TensorShape exp_shape = {2, 3, 5}; + float exp_d[] = {1.0745f, 1.4433f, 0.7899f, 1.5456f, 1.4509f, 0.6064f, + 0.9774f, 0.4197f, 1.1520f, 1.0043f, 1.7620f, 1.9396f, 1.4062f, 1.9461f, 1.9424f, + 0.5314f, 0.8391f, 0.8748f, 0.3471f, 1.1284f, 1.1388f, 1.1492f, 1.0333f, + 0.8970f, 1.6950f, 0.9817f, 1.0865f, 1.0302f, 0.7693f, 1.6372f}; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE); + } + + // Sub-test case 2: Batch matrix multiplication using integers only (2x3x4 * 2x4x5) + { + TensorShape s1_shape = {2, 3, 4}; + float d1[] = { + /* batch0 */ 2.0f, 6.0f, 0.0f, 6.0f, + 3.0f, 5.0f, 9.0f, 9.0f, + 9.0f, 2.0f, 1.0f, 7.0f, + /* batch1 */ 6.0f, 8.0f, 4.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 9.0f, + 5.0f, 7.0f, 3.0f, 8.0f, + }; + + TensorShape s2_shape = {2, 4, 5}; + float d2[] = { + /* batch0 */ 0.0f, 8.0f, 9.0f, 3.0f, 7.0f, + 3.0f, 8.0f, 5.0f, 5.0f, 4.0f, + 3.0f, 0.0f, 8.0f, 4.0f, 0.0f, + 7.0f, 3.0f, 4.0f, 9.0f, 4.0f, + /* batch1 */ 8.0f, 3.0f, 2.0f, 1.0f, 6.0f, + 7.0f, 5.0f, 0.0f, 9.0f, 3.0f, + 1.0f, 3.0f, 4.0f, 4.0f, 1.0f, + 1.0f, 9.0f, 5.0f, 0.0f, 5.0f, + }; + + TensorShape exp_shape = {2, 3, 5}; + float exp_d[] = { + /* batch0 */ 60.0f, 82.0f, 72.0f, 90.0f, 62.0f, + 105.0f, 91.0f, 160.0f, 151.0f, 77.0f, + 58.0f, 109.0f, 127.0f, 104.0f, 99.0f, + /* batch1 */ 115.0f, 133.0f, 63.0f, 94.0f, 99.0f, + 96.0f, 136.0f, 73.0f, 67.0f, 100.0f, + 100.0f, 131.0f, 62.0f, 80.0f, 94.0f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 3: Batch of identity matrices — result should equal second operand + // s1: {3,2,2} (3 identity matrices), s2: {3,2,2} + { + TensorShape s1_shape = {3, 2, 2}; + float d1[] = { + /* batch0 */ 1.0f, 0.0f, 0.0f, 1.0f, + /* batch1 */ 1.0f, 0.0f, 0.0f, 1.0f, + /* batch2 */ 1.0f, 0.0f, 0.0f, 1.0f, + }; + TensorShape s2_shape = {3, 2, 2}; + float d2[] = { + /* batch0 */ 1.0f, 2.0f, 3.0f, 4.0f, + /* batch1 */ 5.0f, 6.0f, 7.0f, 8.0f, + /* batch2 */ 9.0f, 10.0f, 11.0f, 12.0f, + }; + TensorShape exp_shape = {3, 2, 2}; + float exp_d[] = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 3, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 4: Rectangular per-batch multiply (2 batches): {2,1,3} @ {2,3,2} -> {2,1,2} + { + TensorShape s1_shape = {2, 1, 3}; + float d1[] = { + /* batch0 */ 1.0f, 2.0f, 3.0f, // row vector + /* batch1 */ 4.0f, 5.0f, 6.0f, // row vector + }; + TensorShape s2_shape = {2, 3, 2}; + float d2[] = { + /* batch0 */ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, // 3x2 + /* batch1 */ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, // reuse same matrix for simplicity + }; + TensorShape exp_shape = {2, 1, 2}; + float exp_d[] = { + /* batch0 */ 58.0f, 64.0f, // [1,2,3] @ [[7,8],[9,10],[11,12]] + /* batch1 */ 139.0f, 154.0f, // [4,5,6] @ same + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 4, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 5: Batch of column-result matrices using ones to test reduction (4 batches): {4,2,3}@{4,3,1} -> {4,2,1} + { + TensorShape s1_shape = {4, 2, 3}; + // each 2x3 filled with ones + float d1[4 * 2 * 3]; + for(int i = 0; i < 4 * 2 * 3; ++i) d1[i] = 1.0f; + TensorShape s2_shape = {4, 3, 1}; + // each 3x1 filled with ones + float d2[4 * 3 * 1]; + for(int i = 0; i < 4 * 3 * 1; ++i) d2[i] = 1.0f; + TensorShape exp_shape = {4, 2, 1}; + // each 2x1 entry will be sum of 3 ones = 3 + float exp_d[4 * 2 * 1]; + for(int i = 0; i < 4 * 2 * 1; ++i) exp_d[i] = 3.0f; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 5, TEST_FLOAT_TOLERANCE); + } + } // Test Case 9: Special Matrix Content { @@ -315,121 +438,277 @@ void test_matmul_operator() { // TODO: Problem in Matmul Broadcasting // // Test Case 10: Broadcasting - // { - // const char* tc_name = "matmul_broadcasting"; - - // // Sub-test 1: Simple matrix multiplication {4,5} @ {5,3} -> {4,3} - // { - // TensorShape s1_shape = {4, 5}; - // float d1[] = { - // 0.3745f, 0.9507f, 0.7320f, 0.5987f, 0.1560f, // Row 0 - // 0.1560f, 0.0581f, 0.8662f, 0.6011f, 0.7081f, // Row 1 - // 0.0206f, 0.9699f, 0.8324f, 0.2123f, 0.1818f, // Row 2 - // 0.1834f, 0.3042f, 0.5248f, 0.4319f, 0.2912f, // Row 3 - // }; - - // TensorShape s2_shape = {5, 3}; - // float d2[] = { - // 0.6119f, 0.1395f, 0.2921f, // Row 0 - // 0.3664f, 0.4561f, 0.7852f, // Row 1 - // 0.1997f, 0.5142f, 0.5924f, // Row 2 - // 0.0465f, 0.6075f, 0.1705f, // Row 3 - // 0.0651f, 0.9489f, 0.9656f, // Row 4 - // }; - - // TensorShape exp_shape = {4, 3}; - // float exp_d[] = { - // 0.7616f, 1.3740f, 1.5423f, // Row 0 - // 0.3637f, 1.5308f, 1.3906f, // Row 1 - // 0.5558f, 1.1748f, 1.4725f, // Row 2 - // 0.3675f, 0.9730f, 0.9582f, // Row 3 - // }; - - // Tensor t1 = create_test_tensor(s1_shape, d1, false); - // Tensor t2 = create_test_tensor(s2_shape, d2, false); - // Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); - // Tensor actual_res = Tensor_matmul(t1, t2); - - // compare_tensors(&actual_res, &expected_res, op_name, tc_name, 1, - // TEST_FLOAT_TOLERANCE); - // } - - // // Sub-test 2: 3D Broadcasting {1,3,2} @ {2,2,4} -> {2,3,4} - // { - // TensorShape s1_shape = {1, 3, 2}; - // float d1[] = { - // 0.8084f, 0.3046f, // [0,0,:] - // 0.0977f, 0.6842f, // [0,1,:] - // 0.4402f, 0.1220f, // [0,2,:] - // }; - - // TensorShape s2_shape = {2, 2, 4}; - // float d2[] = { - // // Batch 0 - // 0.4952f, 0.0344f, 0.9093f, 0.2588f, // [0,0,:] - // 0.6625f, 0.3117f, 0.5201f, 0.5467f, // [0,1,:] - // // Batch 1 - // 0.1849f, 0.9696f, 0.7751f, 0.9395f, // [1,0,:] - // 0.8948f, 0.5979f, 0.9219f, 0.0885f, // [1,1,:] - // }; - - // TensorShape exp_shape = {2, 3, 4}; - // float exp_d[] = { - // // Batch 0 - // 0.6021f, 0.1228f, 0.8935f, 0.3757f, // [0,0,:] - // 0.5017f, 0.2166f, 0.4447f, 0.3994f, // [0,1,:] - // 0.2988f, 0.0532f, 0.4637f, 0.1806f, // [0,2,:] - // // Batch 1 - // 0.4220f, 0.9659f, 0.9074f, 0.7864f, // [1,0,:] - // 0.6303f, 0.5038f, 0.7065f, 0.1523f, // [1,1,:] - // 0.1906f, 0.4997f, 0.4537f, 0.4243f, // [1,2,:] - // }; - - // Tensor t1 = create_test_tensor(s1_shape, d1, false); - // Tensor t2 = create_test_tensor(s2_shape, d2, false); - // Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); - // Tensor actual_res = Tensor_matmul(t1, t2); - // compare_tensors(&actual_res, &expected_res, op_name, tc_name, 2, - // TEST_FLOAT_TOLERANCE); - // } - - // // Sub-test 3: 4D Broadcasting {2,1,2,3} @ {1,1,3,2} -> {2,1,2,2} - // { - // TensorShape s1_shape = {2, 1, 2, 3}; - // float d1[] = { - // // Batch 0 - // 0.1960f, 0.0452f, 0.3253f, // [0,0,0,:] - // 0.3887f, 0.2713f, 0.8287f, // [0,0,1,:] - // // Batch 1 - // 0.3568f, 0.2809f, 0.5427f, // [1,0,0,:] - // 0.1409f, 0.8022f, 0.0746f, // [1,0,1,:] - // }; - - // TensorShape s2_shape = {1, 1, 3, 2}; - // float d2[] = { - // 0.9869f, 0.7722f, // [0,0,0,:] - // 0.1987f, 0.0055f, // [0,0,1,:] - // 0.8155f, 0.7069f, // [0,0,2,:] - // }; - - // TensorShape exp_shape = {2, 1, 2, 2}; - // float exp_d[] = { - // // Batch 0 - // 0.4677f, 0.3816f, // [0,0,0,:] - // 1.1133f, 0.8875f, // [0,0,1,:] - // // Batch 1 - // 0.8504f, 0.6607f, // [1,0,0,:] - // 0.3593f, 0.1660f, // [1,0,1,:] - // }; - - // Tensor t1 = create_test_tensor(s1_shape, d1, false); - // Tensor t2 = create_test_tensor(s2_shape, d2, false); - // Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); - // Tensor actual_res = Tensor_matmul(t1, t2); - // compare_tensors(&actual_res, &expected_res, op_name, tc_name, 3, - // TEST_FLOAT_TOLERANCE); - // } - // } + { + const char* tc_name = "matmul_broadcasting"; + + // Sub-test 1: Simple matrix multiplication {4,5} @ {5,3} -> {4,3} + { + TensorShape s1_shape = {4, 5}; + float d1[] = { + 0.3745f, 0.9507f, 0.7320f, 0.5987f, 0.1560f, // Row 0 + 0.1560f, 0.0581f, 0.8662f, 0.6011f, 0.7081f, // Row 1 + 0.0206f, 0.9699f, 0.8324f, 0.2123f, 0.1818f, // Row 2 + 0.1834f, 0.3042f, 0.5248f, 0.4319f, 0.2912f, // Row 3 + }; + + TensorShape s2_shape = {5, 3}; + float d2[] = { + 0.6119f, 0.1395f, 0.2921f, // Row 0 + 0.3664f, 0.4561f, 0.7852f, // Row 1 + 0.1997f, 0.5142f, 0.5924f, // Row 2 + 0.0465f, 0.6075f, 0.1705f, // Row 3 + 0.0651f, 0.9489f, 0.9656f, // Row 4 + }; + + TensorShape exp_shape = {4, 3}; + float exp_d[] = { + 0.7617f, 1.3740f, 1.5422f, // Row 0 + 0.3638f, 1.5307f, 1.3906f, // Row 1 + 0.5559f, 1.1747f, 1.4724f, // Row 2 + 0.3675f, 0.9729f, 0.9581f, // Row 3 + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 2: 3D Broadcasting {1,3,2} @ {2,2,4} -> {2,3,4} + { + TensorShape s1_shape = {1, 3, 2}; + float d1[] = { + 0.8084f, 0.3046f, + 0.0977f, 0.6842f, + 0.4402f, 0.1220f, + }; + + TensorShape s2_shape = {2, 2, 4}; + float d2[] = { + 0.4952f, 0.0344f, 0.9093f, 0.2588f, + 0.6625f, 0.3117f, 0.5201f, 0.5467f, + + 0.1849f, 0.9696f, 0.7751f, 0.9395f, + 0.8948f, 0.5979f, 0.9219f, 0.0885f, + }; + + TensorShape exp_shape = {2, 3, 4}; + float exp_d[] = { + 0.6021f, 0.1228f, 0.8935f, 0.3757f, + 0.5017f, 0.2166f, 0.4447f, 0.3994f, + 0.2988f, 0.0532f, 0.4637f, 0.1806f, + + 0.4220f, 0.9659f, 0.9074f, 0.7864f, + 0.6303f, 0.5038f, 0.7065f, 0.1523f, + 0.1906f, 0.4997f, 0.4537f, 0.4243f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 3: 4D Broadcasting {2,1,2,3} @ {1,1,3,2} -> {2,1,2,2} + { + TensorShape s1_shape = {2, 1, 2, 3}; + float d1[] = { + 0.1960f, 0.0452f, 0.3253f, + 0.3887f, 0.2713f, 0.8287f, + + 0.3568f, 0.2809f, 0.5427f, + 0.1409f, 0.8022f, 0.0746f, + }; + + TensorShape s2_shape = {1, 1, 3, 2}; + float d2[] = { + 0.9869f, 0.7722f, + 0.1987f, 0.0055f, + 0.8155f, 0.7069f, + }; + + TensorShape exp_shape = {2, 1, 2, 2}; + float exp_d[] = { + 0.4677f, 0.3816f, + 1.1133f, 0.8875f, + + 0.8505f, 0.6607f, + 0.3593f, 0.1659f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 3, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 4: 3D × 3D Broadcasting {2,1,3} @ {1,3,2} -> {2,1,2} + { + TensorShape s1_shape = {2, 1, 3}; + float d1[] = { + 1.0f, 2.0f, 3.0f, + + 4.0f, 5.0f, 6.0f, + }; + + TensorShape s2_shape = {1, 3, 2}; + float d2[] = { + 1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f, + }; + + TensorShape exp_shape = {2, 1, 2}; + float exp_d[] = { + 22.0f, 28.0f, + + 49.0f, 64.0f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 4, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 5: 3D × 4D Broadcasting {1,2,3} @ {2,1,3,2} -> {2,1,2,2} + { + TensorShape s1_shape = {1, 2, 3}; + float d1[] = { + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + }; + + TensorShape s2_shape = {2, 1, 3, 2}; + float d2[] = { + 1.0f, 0.0f, + 0.0f, 1.0f, + 1.0f, 1.0f, + + 2.0f, 1.0f, + 1.0f, 2.0f, + 0.0f, 1.0f, + }; + + TensorShape exp_shape = {2, 1, 2, 2}; + float exp_d[] = { + 4.0f, 5.0f, + 10.0f, 11.0f, + + 4.0f, 8.0f, + 13.0f, 20.0f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 5, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 6: 4D × 4D Broadcasting {1,2,2,3} @ {2,1,3,2} -> {2,2,2,2} + { + TensorShape s1_shape = {1, 2, 2, 3}; + float d1[] = { + 1.0f, 0.0f, 1.0f, + 0.0f, 1.0f, 1.0f, + + 2.0f, 1.0f, 0.0f, + 1.0f, 2.0f, 1.0f, + }; + + TensorShape s2_shape = {2, 1, 3, 2}; + float d2[] = { + 1.0f, 1.0f, + 1.0f, 0.0f, + 0.0f, 1.0f, + + 0.0f, 1.0f, + 1.0f, 1.0f, + 1.0f, 0.0f, + }; + + TensorShape exp_shape = {2, 2, 2, 2}; + float exp_d[] = { + 1.0f, 2.0f, + 1.0f, 1.0f, + + 3.0f, 2.0f, + 3.0f, 2.0f, + + 1.0f, 1.0f, + 2.0f, 1.0f, + + 1.0f, 3.0f, + 3.0f, 3.0f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 6, TEST_FLOAT_TOLERANCE); + } + + // Sub-test 7: 4D × 3D Broadcasting {2,2,2,3} @ {1,3,4} -> {2,2,2,4} + { + TensorShape s1_shape = {2, 2, 2, 3}; + float d1[] = { + 1.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, + + 0.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 1.0f, + + 2.0f, 0.0f, 0.0f, + 0.0f, 2.0f, 0.0f, + + 0.0f, 0.0f, 2.0f, + 1.0f, 1.0f, 1.0f, + }; + + TensorShape s2_shape = {1, 3, 4}; + float d2[] = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f, + }; + + TensorShape exp_shape = {2, 2, 2, 4}; + float exp_d[] = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + + 9.0f, 10.0f, 11.0f, 12.0f, + 15.0f, 18.0f, 21.0f, 24.0f, + + + 2.0f, 4.0f, 6.0f, 8.0f, + 10.0f, 12.0f, 14.0f, 16.0f, + + 18.0f, 20.0f, 22.0f, 24.0f, + 15.0f, 18.0f, 21.0f, 24.0f, + }; + + Tensor t1 = create_test_tensor(s1_shape, d1, false); + Tensor t2 = create_test_tensor(s2_shape, d2, false); + Tensor expected_res = create_test_tensor(exp_shape, exp_d, false); + Tensor actual_res = Tensor_matmul(t1, t2); + + compare_tensors(&actual_res, &expected_res, op_name, tc_name, 7, TEST_FLOAT_TOLERANCE); + } + } cten_free(pool_id); }