-
Notifications
You must be signed in to change notification settings - Fork 3.7k
GridSample operator performance improvement on bilinear interpolation… #27359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,8 +1,14 @@ | ||||||||||||||||||||||||||||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||||||||||||||||||||||||||||
| // Licensed under the MIT License. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #include "core/providers/cpu/tensor/grid_sample.h" | ||||||||||||||||||||||||||||
| #include <type_traits> | ||||||||||||||||||||||||||||
| #include <vector> | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #if defined(MLAS_NEON_INTRINSICS) | ||||||||||||||||||||||||||||
| #include <arm_neon.h> | ||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #include "core/providers/cpu/tensor/grid_sample.h" | ||||||||||||||||||||||||||||
| #include "core/framework/element_type_lists.h" | ||||||||||||||||||||||||||||
| #include "core/framework/TensorSeq.h" | ||||||||||||||||||||||||||||
| #include "core/providers/common.h" | ||||||||||||||||||||||||||||
|
|
@@ -148,6 +154,181 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, | |||||||||||||||||||||||||||
| return pixel; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| constexpr uint8_t kTopLeftMask = 1u << 0; | ||||||||||||||||||||||||||||
| constexpr uint8_t kTopRightMask = 1u << 1; | ||||||||||||||||||||||||||||
| constexpr uint8_t kBottomLeftMask = 1u << 2; | ||||||||||||||||||||||||||||
| constexpr uint8_t kBottomRightMask = 1u << 3; | ||||||||||||||||||||||||||||
| constexpr uint8_t kAllNeighborsMask = kTopLeftMask | kTopRightMask | kBottomLeftMask | kBottomRightMask; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||
| struct BilinearSamplePlan2D { | ||||||||||||||||||||||||||||
| int64_t x1; | ||||||||||||||||||||||||||||
| int64_t x2; | ||||||||||||||||||||||||||||
| int64_t y1; | ||||||||||||||||||||||||||||
| int64_t y2; | ||||||||||||||||||||||||||||
| T w11; | ||||||||||||||||||||||||||||
| T w12; | ||||||||||||||||||||||||||||
| T w21; | ||||||||||||||||||||||||||||
| T w22; | ||||||||||||||||||||||||||||
| uint8_t mask = 0; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
| // PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and storing that point’s four indices, four weights, and mask in plans[idx] | ||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||
| void PrecomputeBilinearSamplePlan2D(const T* grid_data, | ||||||||||||||||||||||||||||
| int64_t H_out, | ||||||||||||||||||||||||||||
| int64_t W_out, | ||||||||||||||||||||||||||||
| int64_t H_in, | ||||||||||||||||||||||||||||
| int64_t W_in, | ||||||||||||||||||||||||||||
| std::vector<BilinearSamplePlan2D<T>>& plans) { | ||||||||||||||||||||||||||||
| const int64_t point_count = H_out * W_out; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for (int64_t idx = 0; idx < point_count; ++idx) { | ||||||||||||||||||||||||||||
| auto& plan = plans[onnxruntime::narrow<size_t>(idx)]; | ||||||||||||||||||||||||||||
| const T nx = grid_data[idx * 2]; | ||||||||||||||||||||||||||||
| const T ny = grid_data[idx * 2 + 1]; | ||||||||||||||||||||||||||||
| const T x = GsDenormalize<T>(nx, W_in, false); | ||||||||||||||||||||||||||||
| const T y = GsDenormalize<T>(ny, H_in, false); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| const int64_t x1 = static_cast<int64_t>(std::floor(x)); | ||||||||||||||||||||||||||||
| const int64_t y1 = static_cast<int64_t>(std::floor(y)); | ||||||||||||||||||||||||||||
| const int64_t x2 = x1 + 1; | ||||||||||||||||||||||||||||
| const int64_t y2 = y1 + 1; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| const T dx2 = static_cast<T>(x2) - x; | ||||||||||||||||||||||||||||
| const T dx1 = x - static_cast<T>(x1); | ||||||||||||||||||||||||||||
| const T dy2 = static_cast<T>(y2) - y; | ||||||||||||||||||||||||||||
| const T dy1 = y - static_cast<T>(y1); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| uint8_t mask = 0; | ||||||||||||||||||||||||||||
| if (x1 >= 0 && x1 < W_in && y1 >= 0 && y1 < H_in) { | ||||||||||||||||||||||||||||
| mask |= kTopLeftMask; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (x2 >= 0 && x2 < W_in && y1 >= 0 && y1 < H_in) { | ||||||||||||||||||||||||||||
| mask |= kTopRightMask; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (x1 >= 0 && x1 < W_in && y2 >= 0 && y2 < H_in) { | ||||||||||||||||||||||||||||
| mask |= kBottomLeftMask; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (x2 >= 0 && x2 < W_in && y2 >= 0 && y2 < H_in) { | ||||||||||||||||||||||||||||
| mask |= kBottomRightMask; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| plan.x1 = x1; | ||||||||||||||||||||||||||||
| plan.x2 = x2; | ||||||||||||||||||||||||||||
| plan.y1 = y1; | ||||||||||||||||||||||||||||
| plan.y2 = y2; | ||||||||||||||||||||||||||||
| plan.w11 = dy2 * dx2; | ||||||||||||||||||||||||||||
| plan.w12 = dy2 * dx1; | ||||||||||||||||||||||||||||
| plan.w21 = dy1 * dx2; | ||||||||||||||||||||||||||||
| plan.w22 = dy1 * dx1; | ||||||||||||||||||||||||||||
| plan.mask = mask; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||
| void EvaluatePlanForChannel(const T* input_data, | ||||||||||||||||||||||||||||
| T* output_data, | ||||||||||||||||||||||||||||
| int64_t W_in, | ||||||||||||||||||||||||||||
| const BilinearSamplePlan2D<T>* plan_data, | ||||||||||||||||||||||||||||
| int64_t point_count) { | ||||||||||||||||||||||||||||
| for (int64_t idx = 0; idx < point_count; ++idx) { | ||||||||||||||||||||||||||||
| const auto& plan = plan_data[idx]; | ||||||||||||||||||||||||||||
| if (plan.mask == 0) { | ||||||||||||||||||||||||||||
| output_data[idx] = T{}; | ||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #if defined(MLAS_NEON_INTRINSICS) | ||||||||||||||||||||||||||||
| if constexpr (std::is_same_v<T, float>) { | ||||||||||||||||||||||||||||
| if (plan.mask == kAllNeighborsMask) { | ||||||||||||||||||||||||||||
| const float* row1_ptr = input_data + plan.y1 * W_in + plan.x1; | ||||||||||||||||||||||||||||
| const float* row2_ptr = input_data + plan.y2 * W_in + plan.x1; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| float32x2_t row1 = vld1_f32(row1_ptr); // [p11, p12] | ||||||||||||||||||||||||||||
| float32x2_t row2 = vld1_f32(row2_ptr); // [p21, p22] | ||||||||||||||||||||||||||||
| float32x4_t neighbors = vcombine_f32(row1, row2); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| float32x2_t weights_row1 = vdup_n_f32(plan.w12); | ||||||||||||||||||||||||||||
| weights_row1 = vset_lane_f32(plan.w11, weights_row1, 0); | ||||||||||||||||||||||||||||
| float32x2_t weights_row2 = vdup_n_f32(plan.w22); | ||||||||||||||||||||||||||||
| weights_row2 = vset_lane_f32(plan.w21, weights_row2, 0); | ||||||||||||||||||||||||||||
| float32x4_t weights = vcombine_f32(weights_row1, weights_row2); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| float32x4_t products = vmulq_f32(neighbors, weights); | ||||||||||||||||||||||||||||
| float32x2_t sum_pairs = vadd_f32(vget_low_f32(products), vget_high_f32(products)); | ||||||||||||||||||||||||||||
| float32x2_t accum = vpadd_f32(sum_pairs, sum_pairs); | ||||||||||||||||||||||||||||
| output_data[idx] = vget_lane_f32(accum, 0); | ||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| T p11 = T{}; | ||||||||||||||||||||||||||||
| T p12 = T{}; | ||||||||||||||||||||||||||||
| T p21 = T{}; | ||||||||||||||||||||||||||||
| T p22 = T{}; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if (plan.mask == kAllNeighborsMask) { | ||||||||||||||||||||||||||||
| const int64_t row1 = plan.y1 * W_in; | ||||||||||||||||||||||||||||
| const int64_t row2 = plan.y2 * W_in; | ||||||||||||||||||||||||||||
| p11 = input_data[row1 + plan.x1]; | ||||||||||||||||||||||||||||
| p12 = input_data[row1 + plan.x2]; | ||||||||||||||||||||||||||||
| p21 = input_data[row2 + plan.x1]; | ||||||||||||||||||||||||||||
| p22 = input_data[row2 + plan.x2]; | ||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||
| if (plan.mask & kTopLeftMask) { | ||||||||||||||||||||||||||||
| p11 = input_data[plan.y1 * W_in + plan.x1]; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (plan.mask & kTopRightMask) { | ||||||||||||||||||||||||||||
| p12 = input_data[plan.y1 * W_in + plan.x2]; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (plan.mask & kBottomLeftMask) { | ||||||||||||||||||||||||||||
| p21 = input_data[plan.y2 * W_in + plan.x1]; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| if (plan.mask & kBottomRightMask) { | ||||||||||||||||||||||||||||
| p22 = input_data[plan.y2 * W_in + plan.x2]; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| output_data[idx] = plan.w11 * p11 + plan.w12 * p12 + plan.w21 * p21 + plan.w22 * p22; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||
| void TryRunBilinearZerosFastPath2D(const Tensor& input, | ||||||||||||||||||||||||||||
| const Tensor& grid, | ||||||||||||||||||||||||||||
| Tensor& output, | ||||||||||||||||||||||||||||
| int64_t n, | ||||||||||||||||||||||||||||
| int64_t C, | ||||||||||||||||||||||||||||
| int64_t H_in, | ||||||||||||||||||||||||||||
| int64_t W_in, | ||||||||||||||||||||||||||||
| int64_t H_out, | ||||||||||||||||||||||||||||
| int64_t W_out, | ||||||||||||||||||||||||||||
| concurrency::ThreadPool* tp, | ||||||||||||||||||||||||||||
| std::vector<BilinearSamplePlan2D<T>>& sampling_plan) { | ||||||||||||||||||||||||||||
| const int64_t plane_in = H_in * W_in; | ||||||||||||||||||||||||||||
| const int64_t plane_out = H_out * W_out; | ||||||||||||||||||||||||||||
| sampling_plan.resize(onnxruntime::narrow<size_t>(plane_out)); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| const T* grid_data = grid.Data<T>() + n * plane_out * 2; | ||||||||||||||||||||||||||||
| PrecomputeBilinearSamplePlan2D(grid_data, H_out, W_out, H_in, W_in, sampling_plan); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| const T* input_data = input.Data<T>(); | ||||||||||||||||||||||||||||
| T* output_data = output.MutableData<T>(); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| concurrency::ThreadPool::TrySimpleParallelFor( | ||||||||||||||||||||||||||||
| tp, onnxruntime::narrow<std::ptrdiff_t>(C), | ||||||||||||||||||||||||||||
| [&](std::ptrdiff_t c) { | ||||||||||||||||||||||||||||
| const T* X_data = input_data + (n * C + c) * plane_in; | ||||||||||||||||||||||||||||
| T* Y_data = output_data + (n * C + c) * plane_out; | ||||||||||||||||||||||||||||
| EvaluatePlanForChannel(X_data, Y_data, W_in, sampling_plan.data(), plane_out); | ||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| } // namespace | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| // When grid sampling, padding is applied before interpolation. | ||||||||||||||||||||||||||||
| // For instance, in bilinear mode and zeros padding-mode, pixel p at actual | ||||||||||||||||||||||||||||
| // image location (-0.5, -0.5) | ||||||||||||||||||||||||||||
|
|
@@ -210,13 +391,14 @@ Status GridSample<T>::Compute(OpKernelContext* context) const { | |||||||||||||||||||||||||||
| T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; | ||||||||||||||||||||||||||||
| for (int64_t n = 0; n < N; n++) { | ||||||||||||||||||||||||||||
| const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| const auto run_generic_path_for_n = [&](int64_t n_idx) { | ||||||||||||||||||||||||||||
| const T* grid_data = grid->Data<T>() + n_idx * (H_out * W_out) * 2; | ||||||||||||||||||||||||||||
| concurrency::ThreadPool::TrySimpleParallelFor( | ||||||||||||||||||||||||||||
| tp, onnxruntime::narrow<std::ptrdiff_t>(C), | ||||||||||||||||||||||||||||
| [&](std::ptrdiff_t c) { | ||||||||||||||||||||||||||||
| const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in); | ||||||||||||||||||||||||||||
| T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out); | ||||||||||||||||||||||||||||
| const T* X_data = input->Data<T>() + (n_idx * C + c) * (H_in * W_in); | ||||||||||||||||||||||||||||
| T* Y_data = Y.MutableData<T>() + (n_idx * C + c) * (H_out * W_out); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for (int64_t oy = 0; oy < H_out; oy++) { | ||||||||||||||||||||||||||||
| for (int64_t ox = 0; ox < W_out; ox++) { | ||||||||||||||||||||||||||||
|
|
@@ -265,6 +447,19 @@ Status GridSample<T>::Compute(OpKernelContext* context) const { | |||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| const bool can_use_fast_path = (mode_ == Linear && padding_mode_ == Zeros && !align_corners_); | ||||||||||||||||||||||||||||
| for (int64_t n = 0; n < N; n++) { | ||||||||||||||||||||||||||||
| if (can_use_fast_path) { | ||||||||||||||||||||||||||||
| // Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors. | ||||||||||||||||||||||||||||
| // This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization. | ||||||||||||||||||||||||||||
| // sampling_plan helps precomputing a separate plan entry per output pixel. | ||||||||||||||||||||||||||||
| std::vector<BilinearSamplePlan2D<T>> sampling_plan; | ||||||||||||||||||||||||||||
|
Comment on lines
+453
to
+458
|
||||||||||||||||||||||||||||
| for (int64_t n = 0; n < N; n++) { | |
| if (can_use_fast_path) { | |
| // Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors. | |
| // This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization. | |
| // sampling_plan helps precomputing a separate plan entry per output pixel. | |
| std::vector<BilinearSamplePlan2D<T>> sampling_plan; | |
| std::vector<BilinearSamplePlan2D<T>> sampling_plan; | |
| for (int64_t n = 0; n < N; n++) { | |
| if (can_use_fast_path) { | |
| // Choose fast path when all 4 neighbors are within the image and use zero for out-of-boundary neighbors. | |
| // This fast path can be 2-3x faster than the generic path with boundary check and supports Neon optimization. | |
| // sampling_plan helps precomputing a separate plan entry per output pixel. | |
| sampling_plan.clear(); |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -727,6 +727,38 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corner | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RunTests(test, GetExecutionProviders(20)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Crafts grid points that mix fully in-bounds sampling with cases where either the right, bottom, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // or both neighbors fall outside the source image so zero padding must be applied. This ensures | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // the optimized bilinear fast path matches the generic implementation for boundary handling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| OpTester test("GridSample", 20); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::string mode = "linear"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::string padding_mode = "zeros"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t align_corners = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::initializer_list<int64_t> X_shape{1, 1, 2, 2}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::initializer_list<TypeParam> X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // (nx, ny) pairs: center (in-bounds), right edge (x out), bottom edge (y out), corner (both out) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::initializer_list<TypeParam> Grid_data{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(0.9f), TypeParam(0.0f), // near right edge (right neighbors out of bounds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(0.0f), TypeParam(0.9f), // near bottom edge (bottom neighbors out) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(0.9f), TypeParam(0.9f)}; // near bottom-right corner (both right and bottom neighbors out) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::initializer_list<int64_t> Y_shape{1, 1, 2, 2}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::initializer_list<TypeParam> Y_data{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(2.5f), // all neighbors in bounds | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(1.8f), // right neighbors partially out-of-bounds | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(2.1f), // bottom neighbors partially out-of-bounds | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TypeParam(1.44f)}; // both right and bottom neighbors out-of-bounds | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test.AddInput<TypeParam>("X", X_shape, X_data); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test.AddAttribute("mode", mode); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test.AddAttribute("padding_mode", padding_mode); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test.AddAttribute("align_corners", align_corners); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test.AddOutput<TypeParam>("Y", Y_shape, Y_data); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RunTests(test, GetExecutionProviders(20)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds_left_top) { | |
| // Similar to test_grid_sample_20_4D_linear_zeros_mixed_bounds but focuses on left/top boundary cases, | |
| // where the left and/or top neighbors fall outside the source image and zero padding must be applied. | |
| // This ensures the optimized bilinear fast path correctly handles left/top boundary conditions. | |
| OpTester test("GridSample", 20); | |
| std::string mode = "linear"; | |
| std::string padding_mode = "zeros"; | |
| int64_t align_corners = 0; | |
| std::initializer_list<int64_t> X_shape{1, 1, 2, 2}; | |
| std::initializer_list<TypeParam> X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)}; | |
| std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2}; | |
| // (nx, ny) pairs: center (in-bounds), left edge (x out), top edge (y out), corner (both out) | |
| std::initializer_list<TypeParam> Grid_data{ | |
| TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds) | |
| TypeParam(-0.9f), TypeParam(0.0f), // near left edge (left neighbors out of bounds) | |
| TypeParam(0.0f), TypeParam(-0.9f), // near top edge (top neighbors out of bounds) | |
| TypeParam(-0.9f), TypeParam(-0.9f)}; // near top-left corner (both left and top neighbors out of bounds) | |
| std::initializer_list<int64_t> Y_shape{1, 1, 2, 2}; | |
| std::initializer_list<TypeParam> Y_data{ | |
| TypeParam(2.5f), // all neighbors in bounds | |
| TypeParam(1.2f), // left neighbors partially out-of-bounds | |
| TypeParam(0.9f), // top neighbors partially out-of-bounds | |
| TypeParam(0.36f)}; // both left and top neighbors out-of-bounds | |
| test.AddInput<TypeParam>("X", X_shape, X_data); | |
| test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data); | |
| test.AddAttribute("mode", mode); | |
| test.AddAttribute("padding_mode", padding_mode); | |
| test.AddAttribute("align_corners", align_corners); | |
| test.AddOutput<TypeParam>("Y", Y_shape, Y_data); | |
| RunTests(test, GetExecutionProviders(20)); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment exceeds the typical line length and could be wrapped or reformatted for better readability. Consider breaking it into multiple lines or simplifying the wording.