diff --git a/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp b/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp index ccda153fa66b..ee47b362661b 100644 --- a/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp +++ b/velox/experimental/cudf/connectors/hive/CudfHiveDataSource.cpp @@ -46,7 +46,6 @@ #include #include #include -#include #include #include @@ -99,7 +98,7 @@ CudfHiveDataSource::CudfHiveDataSource( VELOX_CHECK_NOT_NULL( tableHandle_, "TableHandle must be an instance of HiveTableHandle"); - // Copy subfield filters + // Copy subfield filters. for (const auto& [k, v] : tableHandle_->subfieldFilters()) { subfieldFilters_.emplace(k.clone(), v->clone()); // Add fields in the filter to the columns to read if not there @@ -493,6 +492,40 @@ void CudfHiveDataSource::setupCudfDataSourceAndOptions() { dataSource_ = std::move(cudf::io::make_datasources(sourceInfo).front()); } + RowTypePtr readerFilterType = nullptr; + bool hasDecimalFilter = false; + if (subfieldFilters_.size()) { + readerFilterType = [&] { + if (tableHandle_->dataColumns()) { + std::vector newNames; + std::vector newTypes; + + for (const auto& name : readColumnNames_) { + // Ensure all columns being read are available to the filter + auto parsedType = tableHandle_->dataColumns()->findChild(name); + newNames.emplace_back(std::move(name)); + newTypes.push_back(parsedType); + } + + return ROW(std::move(newNames), std::move(newTypes)); + } else { + return outputType_; + } + }(); + + for (const auto& [field, _] : subfieldFilters_) { + if (!field.valid()) { + continue; + } + const auto& fieldName = field.baseName(); + const auto fieldType = readerFilterType->findChild(fieldName); + if (fieldType && fieldType->isDecimal()) { + hasDecimalFilter = true; + break; + } + } + } + // Reader options readerOptions_ = cudf::io::parquet_reader_options::builder(std::move(sourceInfo)) @@ -501,6 +534,7 @@ void CudfHiveDataSource::setupCudfDataSourceAndOptions() { .allow_mismatched_pq_schemas( cudfHiveConfig_->isAllowMismatchedCudfHiveSchemas()) .timestamp_type(cudfHiveConfig_->timestampType()) + .use_jit_filter(hasDecimalFilter) .build(); // Set skip_bytes and num_bytes if available @@ -511,6 +545,7 @@ void CudfHiveDataSource::setupCudfDataSourceAndOptions() { readerOptions_.set_num_bytes(split_->size()); } + // Set filter expression created in constructor if any subfield filters if (subfieldFilterExpr_ != nullptr) { readerOptions_.set_filter(*subfieldFilterExpr_); } diff --git a/velox/experimental/cudf/exec/CudfHashJoin.cpp b/velox/experimental/cudf/exec/CudfHashJoin.cpp index aebc4debe16c..0a966e03a965 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.cpp +++ b/velox/experimental/cudf/exec/CudfHashJoin.cpp @@ -21,6 +21,7 @@ #include "velox/experimental/cudf/exec/Utilities.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" #include "velox/experimental/cudf/expression/AstExpression.h" +#include "velox/experimental/cudf/expression/AstExpressionUtils.h" #include "velox/experimental/cudf/expression/ExpressionEvaluator.h" #include "velox/core/PlanNode.h" @@ -416,6 +417,8 @@ CudfHashJoinProbe::CudfHashJoinProbe( // simplify expression exec::ExprSet exprs({joinNode_->filter()}, operatorCtx_->execCtx()); VELOX_CHECK_EQ(exprs.exprs().size(), 1); + useAstFilter_ = CudfConfig::getInstance().astExpressionEnabled && + !containsDecimalType(exprs.exprs()[0]); // Create a reusable evaluator for the filter column. This is expensive to // build, and the expression + input schema are stable for the lifetime of @@ -431,25 +434,27 @@ CudfHashJoinProbe::CudfHashJoinProbe( // and the column locations in that schema translate to column locations // in whole tables - // create ast tree - if (joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) { - createAstTree( - exprs.exprs()[0], - tree_, - scalars_, - buildType_, - probeType_, - rightPrecomputeInstructions_, - leftPrecomputeInstructions_); - } else { - createAstTree( - exprs.exprs()[0], - tree_, - scalars_, - probeType_, - buildType_, - leftPrecomputeInstructions_, - rightPrecomputeInstructions_); + if (useAstFilter_) { + // create ast tree + if (joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) { + createAstTree( + exprs.exprs()[0], + tree_, + scalars_, + buildType_, + probeType_, + rightPrecomputeInstructions_, + leftPrecomputeInstructions_); + } else { + createAstTree( + exprs.exprs()[0], + tree_, + scalars_, + probeType_, + buildType_, + leftPrecomputeInstructions_, + rightPrecomputeInstructions_); + } } } } @@ -803,15 +808,35 @@ std::vector> CudfHashJoinProbe::innerJoin( std::vector> joinedCols; if (joinNode_->filter()) { - cudfOutputs.push_back(filteredOutputIndices( - leftTableView, - leftIndicesCol, - rightTableView, - rightIndicesCol, - extendedLeftView, - extendedRightView, - cudf::join_kind::INNER_JOIN, - stream)); + if (useAstFilter_) { + cudfOutputs.push_back(filteredOutputIndices( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + extendedLeftView, + extendedRightView, + cudf::join_kind::INNER_JOIN, + stream)); + } else { + auto filterFunc = + [stream]( + std::vector>&& joinedCols, + cudf::column_view filterColumn) { + auto filterTable = + std::make_unique(std::move(joinedCols)); + auto filteredTable = cudf::apply_boolean_mask( + *filterTable, filterColumn, stream, get_output_mr()); + return filteredTable->release(); + }; + cudfOutputs.push_back(filteredOutput( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + filterFunc, + stream)); + } } else { cudfOutputs.push_back(unfilteredOutput( leftTableView, @@ -878,15 +903,35 @@ std::vector> CudfHashJoinProbe::leftJoin( std::vector> joinedCols; if (joinNode_->filter()) { - cudfOutputs.push_back(filteredOutputIndices( - leftTableView, - leftIndicesCol, - rightTableView, - rightIndicesCol, - extendedLeftView, - extendedRightView, - cudf::join_kind::LEFT_JOIN, - stream)); + if (useAstFilter_) { + cudfOutputs.push_back(filteredOutputIndices( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + extendedLeftView, + extendedRightView, + cudf::join_kind::LEFT_JOIN, + stream)); + } else { + auto filterFunc = + [stream]( + std::vector>&& joinedCols, + cudf::column_view filterColumn) { + auto filterTable = + std::make_unique(std::move(joinedCols)); + auto filteredTable = cudf::apply_boolean_mask( + *filterTable, filterColumn, stream, get_output_mr()); + return filteredTable->release(); + }; + cudfOutputs.push_back(filteredOutput( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + filterFunc, + stream)); + } } else { cudfOutputs.push_back(unfilteredOutput( leftTableView, @@ -1194,6 +1239,9 @@ std::vector> CudfHashJoinProbe::leftSemiFilterJoin( std::unique_ptr> leftJoinIndices; if (joinNode_->filter()) { + if (!useAstFilter_) { + VELOX_NYI("Join filter requires AST for semi joins"); + } leftJoinIndices = cudf::mixed_left_semi_join( leftTableView.select(leftKeyIndices_), rightTableView.select(rightKeyIndices_), @@ -1244,6 +1292,9 @@ CudfHashJoinProbe::rightSemiFilterJoin( std::unique_ptr> rightJoinIndices; if (joinNode_->filter()) { + if (!useAstFilter_) { + VELOX_NYI("Join filter requires AST for semi joins"); + } rightJoinIndices = cudf::mixed_left_semi_join( rightTableView.select(rightKeyIndices_), leftTableView.select(leftKeyIndices_), @@ -1313,6 +1364,9 @@ std::vector> CudfHashJoinProbe::antiJoin( std::unique_ptr> leftJoinIndices; if (joinNode_->filter()) { + if (!useAstFilter_) { + VELOX_NYI("Join filter requires AST for anti joins"); + } leftJoinIndices = cudf::mixed_left_anti_join( leftTableView.select(leftKeyIndices_), rightTableView.select(rightKeyIndices_), @@ -1402,10 +1456,10 @@ RowVectorPtr CudfHashJoinProbe::getOutput() { for (size_t li = 0; li < leftColumnOutputIndices_.size(); ++li) { auto outIdx = leftColumnOutputIndices_[li]; auto probeChannel = leftColumnIndicesToGather_[li]; - auto leftCudfType = - veloxToCudfTypeId(probeType_->childAt(probeChannel)); + auto leftCudfDataType = + veloxToCudfDataType(probeType_->childAt(probeChannel)); auto nullScalar = cudf::make_default_constructed_scalar( - cudf::data_type{leftCudfType}, stream, get_temp_mr()); + leftCudfDataType, stream, get_temp_mr()); outCols[outIdx] = cudf::make_column_from_scalar( *nullScalar, m, stream, get_output_mr()); } diff --git a/velox/experimental/cudf/exec/CudfHashJoin.h b/velox/experimental/cudf/exec/CudfHashJoin.h index 382b30f659cf..cb6070bf7b13 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.h +++ b/velox/experimental/cudf/exec/CudfHashJoin.h @@ -201,6 +201,7 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { /** @brief Output column positions for right table columns */ std::vector rightColumnOutputIndices_; bool finished_{false}; + bool useAstFilter_{true}; // Copied from HashProbe.h // Indicates whether to skip probe input data processing or not. It only diff --git a/velox/experimental/cudf/expression/AstExpression.cpp b/velox/experimental/cudf/expression/AstExpression.cpp index 00693061be48..f049a037cb8c 100644 --- a/velox/experimental/cudf/expression/AstExpression.cpp +++ b/velox/experimental/cudf/expression/AstExpression.cpp @@ -114,8 +114,7 @@ ColumnOrView ASTExpression::eval( } }(); if (finalize) { - const auto requestedType = - cudf::data_type(cudf_velox::veloxToCudfTypeId(expr_->type())); + const auto requestedType = cudf_velox::veloxToCudfDataType(expr_->type()); auto resultView = asView(result); if (resultView.type() != requestedType) { result = cudf::cast(resultView, requestedType, stream, mr); diff --git a/velox/experimental/cudf/expression/AstExpression.h b/velox/experimental/cudf/expression/AstExpression.h index 47f88404c840..978d713bb2b2 100644 --- a/velox/experimental/cudf/expression/AstExpression.h +++ b/velox/experimental/cudf/expression/AstExpression.h @@ -20,6 +20,8 @@ #include +#include + namespace facebook::velox::cudf_velox { const std::string kAstEvaluatorName = "ast"; diff --git a/velox/experimental/cudf/expression/AstExpressionUtils.h b/velox/experimental/cudf/expression/AstExpressionUtils.h index 64fc5b0d1ce3..534d896e4418 100644 --- a/velox/experimental/cudf/expression/AstExpressionUtils.h +++ b/velox/experimental/cudf/expression/AstExpressionUtils.h @@ -23,6 +23,7 @@ #include "velox/experimental/cudf/expression/AstUtils.h" // TODO(kn): in another PR // #include "velox/experimental/cudf/CudfNoDefaults.h" +#include "velox/experimental/cudf/expression/DecimalUtils.h" #include "velox/expression/ConstantExpr.h" #include "velox/expression/FieldReference.h" @@ -225,6 +226,14 @@ bool isAstExprSupported(const std::shared_ptr& expr) { using velox::exec::FieldReference; using Op = cudf::ast::ast_operator; + // reject anything with DECIMAL for now + // @TODO implement DECIMAL in AST and JIT + if (containsDecimalType(expr)) { + LOG(WARNING) << "DECIMAL expression not supported by AST/JIT: " + << expr->toString(); + return false; + } + const auto name = stripPrefix(expr->name(), CudfConfig::getInstance().functionNamePrefix); const auto len = expr->inputs().size(); @@ -232,7 +241,7 @@ bool isAstExprSupported(const std::shared_ptr& expr) { // Literals and field references are always supported auto isSupportedLiteral = [&](const TypePtr& type) { try { - auto cudfType = cudf::data_type(veloxToCudfTypeId(type)); + auto cudfType = veloxToCudfDataType(type); return cudf::is_fixed_width(cudfType) || cudfType.id() == cudf::type_id::STRING; } catch (...) { @@ -260,8 +269,7 @@ bool isAstExprSupported(const std::shared_ptr& expr) { inputCudfDataTypes.reserve(len); for (const auto& input : expr->inputs()) { try { - inputCudfDataTypes.push_back( - cudf::data_type(veloxToCudfTypeId(input->type()))); + inputCudfDataTypes.push_back(veloxToCudfDataType(input->type())); } catch (...) { return false; } @@ -386,7 +394,11 @@ cudf::ast::expression const& AstContext::addPrecomputeInstructionOnSide( auto nestedIndices = getNestedColumnIndices( inputRowSchema[sideIdx].get()->childAt(columnIndex), fieldName); precomputeInstructions[sideIdx].get().emplace_back( - columnIndex, instruction, newColumnIndex, nestedIndices, node); + columnIndex, + instruction, + newColumnIndex, + std::move(nestedIndices), + node); } auto side = static_cast(sideIdx); return tree.push(cudf::ast::column_reference(newColumnIndex, side)); diff --git a/velox/experimental/cudf/expression/AstUtils.h b/velox/experimental/cudf/expression/AstUtils.h index 739eb24c3a25..c3f9c7214a07 100644 --- a/velox/experimental/cudf/expression/AstUtils.h +++ b/velox/experimental/cudf/expression/AstUtils.h @@ -22,6 +22,7 @@ #include "velox/vector/VectorTypeUtils.h" #include +#include #include #include @@ -32,7 +33,18 @@ cudf::ast::literal makeLiteralFromScalar( cudf::scalar& scalar, const TypePtr& type) { if constexpr (cudf::is_fixed_width()) { - if (type->isIntervalDayTime()) { + if (type->isDecimal()) { + if (type->kind() == TypeKind::BIGINT) { + using CudfScalarType = cudf::fixed_point_scalar; + return cudf::ast::literal{*static_cast(&scalar)}; + } + if (type->kind() == TypeKind::HUGEINT) { + using CudfScalarType = cudf::fixed_point_scalar; + return cudf::ast::literal{*static_cast(&scalar)}; + } + VELOX_UNREACHABLE( + "Invalid Decimal Type (bad TypeKind: {})", type->kind()); + } else if (type->isIntervalDayTime()) { using CudfDurationType = cudf::duration_ms; if constexpr (std::is_same_v) { using CudfScalarType = cudf::duration_scalar; @@ -89,15 +101,28 @@ std::unique_ptr makeScalarFromValue( if constexpr (cudf::is_fixed_width()) { if (type->isDecimal()) { - VELOX_FAIL("Decimal not supported"); - /* TODO: enable after rewriting using binary ops - using CudfDecimalType = cudf::numeric::decimal64; - using cudfScalarType = cudf::fixed_point_scalar; - auto scalar = std::make_unique(value, - type->scale(), - true, - stream, - mr);*/ + // Velox DECIMAL scale is positive for fractional digits + // cuDF scale is negative for fractional digits + // @TODO check the bigger picture here! + if (type->kind() == TypeKind::BIGINT) { + auto const decimalType = + std::dynamic_pointer_cast(type); + VELOX_CHECK(decimalType, "Invalid Decimal Type (failed dynamic_cast)"); + auto const cudfScale = numeric::scale_type{-decimalType->scale()}; + using CudfDecimalType = cudf::fixed_point_scalar; + return std::make_unique( + value, cudfScale, !isNull, stream, mr); + } else if (type->kind() == TypeKind::HUGEINT) { + auto const decimalType = + std::dynamic_pointer_cast(type); + VELOX_CHECK(decimalType, "Invalid Decimal Type (failed dynamic_cast)"); + auto const cudfScale = numeric::scale_type{-decimalType->scale()}; + using CudfDecimalType = cudf::fixed_point_scalar; + return std::make_unique( + value, cudfScale, !isNull, stream, mr); + } + VELOX_UNREACHABLE( + "Invalid Decimal Type (bad TypeKind: {})", type->kind()); } else if (type->isIntervalYearMonth()) { VELOX_FAIL("Interval year month not supported"); } else if (type->isIntervalDayTime()) { diff --git a/velox/experimental/cudf/expression/CMakeLists.txt b/velox/experimental/cudf/expression/CMakeLists.txt index 5f35dd3ae398..549318f97268 100644 --- a/velox/experimental/cudf/expression/CMakeLists.txt +++ b/velox/experimental/cudf/expression/CMakeLists.txt @@ -15,9 +15,11 @@ add_library( velox_cudf_expression AstExpression.cpp + DecimalExpressionKernels.cu ExpressionEvaluator.cpp JitExpression.cpp SubfieldFiltersToAst.cpp + DecimalUtils.cpp ) target_link_libraries( @@ -27,3 +29,5 @@ target_link_libraries( ) target_compile_options(velox_cudf_expression PRIVATE -Wno-missing-field-initializers) + +set_target_properties(velox_cudf_expression PROPERTIES CUDA_STANDARD 20 CUDA_STANDARD_REQUIRED ON) diff --git a/velox/experimental/cudf/expression/DecimalExpressionKernels.cu b/velox/experimental/cudf/expression/DecimalExpressionKernels.cu new file mode 100644 index 000000000000..41678a859aa4 --- /dev/null +++ b/velox/experimental/cudf/expression/DecimalExpressionKernels.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/experimental/cudf/expression/DecimalExpressionKernels.h" + +#include +#include +#include +#include +#include + +#include + +#include + +namespace facebook::velox::cudf_velox { +namespace { + +template +__global__ void decimalDivideKernel( + const InT* lhs, + const InT* rhs, + OutT* out, + int32_t numRows, + __int128_t scale) { + int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numRows) { + return; + } + __int128_t numerator = static_cast<__int128_t>(lhs[idx]); + __int128_t denom = static_cast<__int128_t>(rhs[idx]); + if (denom == 0) { + out[idx] = OutT{0}; + return; + } + int sign = 1; + if (numerator < 0) { + numerator = -numerator; + sign = -sign; + } + if (denom < 0) { + denom = -denom; + sign = -sign; + } + __int128_t scaled = numerator * scale; + __int128_t quotient = scaled / denom; + __int128_t remainder = scaled % denom; + if (remainder * 2 >= denom) { + ++quotient; + } + if (sign < 0) { + quotient = -quotient; + } + out[idx] = static_cast(quotient); +} + +inline __int128_t pow10Int128(int32_t exp) { + __int128_t value = 1; + for (int32_t i = 0; i < exp; ++i) { + value *= 10; + } + return value; +} + +template +void launchDivideKernel( + const cudf::column_view& lhs, + const cudf::column_view& rhs, + cudf::mutable_column_view out, + int32_t aRescale, + rmm::cuda_stream_view stream) { + if (lhs.size() == 0) { + return; + } + int32_t blockSize = 256; + int32_t gridSize = (lhs.size() + blockSize - 1) / blockSize; + auto scale = pow10Int128(aRescale); + decimalDivideKernel<<>>( + lhs.data(), rhs.data(), out.data(), lhs.size(), scale); + CUDF_CUDA_TRY(cudaGetLastError()); +} + +} // namespace + +std::unique_ptr decimalDivide( + const cudf::column_view& lhs, + const cudf::column_view& rhs, + cudf::data_type outputType, + int32_t aRescale, + rmm::cuda_stream_view stream) { + CUDF_EXPECTS(lhs.size() == rhs.size(), "Decimal divide requires equal sizes"); + CUDF_EXPECTS( + lhs.type().id() == rhs.type().id(), + "Decimal divide requires matching input types"); + CUDF_EXPECTS( + aRescale >= 0, "Decimal divide requires non-negative rescale factor"); + + auto [nullMask, nullCount] = + cudf::bitmask_and(cudf::table_view({lhs, rhs}), stream); + auto out = cudf::make_fixed_width_column( + outputType, lhs.size(), std::move(nullMask), nullCount, stream); + + if (lhs.type().id() == cudf::type_id::DECIMAL64) { + if (outputType.id() == cudf::type_id::DECIMAL64) { + launchDivideKernel( + lhs, rhs, out->mutable_view(), aRescale, stream); + } else { + CUDF_EXPECTS( + outputType.id() == cudf::type_id::DECIMAL128, + "Unexpected output type for decimal divide"); + launchDivideKernel( + lhs, rhs, out->mutable_view(), aRescale, stream); + } + } else { + CUDF_EXPECTS( + lhs.type().id() == cudf::type_id::DECIMAL128, + "Unsupported input type for decimal divide"); + CUDF_EXPECTS( + outputType.id() == cudf::type_id::DECIMAL128, + "Unexpected output type for decimal divide"); + launchDivideKernel<__int128_t, __int128_t>( + lhs, rhs, out->mutable_view(), aRescale, stream); + } + + return out; +} + +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/expression/DecimalExpressionKernels.h b/velox/experimental/cudf/expression/DecimalExpressionKernels.h new file mode 100644 index 000000000000..1533d1391597 --- /dev/null +++ b/velox/experimental/cudf/expression/DecimalExpressionKernels.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include + +#include + +namespace facebook::velox::cudf_velox { + +std::unique_ptr decimalDivide( + const cudf::column_view& lhs, + const cudf::column_view& rhs, + cudf::data_type outputType, + int32_t aRescale, + rmm::cuda_stream_view stream); + +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/expression/DecimalUtils.cpp b/velox/experimental/cudf/expression/DecimalUtils.cpp new file mode 100644 index 000000000000..10b62f98d856 --- /dev/null +++ b/velox/experimental/cudf/expression/DecimalUtils.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/expression/DecimalUtils.h" + +namespace facebook::velox::cudf_velox { + +bool containsDecimalType(const std::shared_ptr& expr) { + if (!expr) { + return false; + } + if (expr->type() && expr->type()->isDecimal()) { + return true; + } + for (const auto& input : expr->inputs()) { + if (containsDecimalType(input)) { + return true; + } + } + return false; +} + +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/expression/DecimalUtils.h b/velox/experimental/cudf/expression/DecimalUtils.h new file mode 100644 index 000000000000..65fec1b63cce --- /dev/null +++ b/velox/experimental/cudf/expression/DecimalUtils.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/Expr.h" + +namespace facebook::velox::cudf_velox { + +bool containsDecimalType(const std::shared_ptr& expr); + +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/expression/ExpressionEvaluator.cpp b/velox/experimental/cudf/expression/ExpressionEvaluator.cpp index 401a995d6451..96c5b56221c0 100644 --- a/velox/experimental/cudf/expression/ExpressionEvaluator.cpp +++ b/velox/experimental/cudf/expression/ExpressionEvaluator.cpp @@ -16,23 +16,31 @@ #include "velox/experimental/cudf/exec/Validation.h" #include "velox/experimental/cudf/exec/VeloxCudfInterop.h" #include "velox/experimental/cudf/expression/AstUtils.h" +#include "velox/experimental/cudf/expression/DecimalExpressionKernels.h" #include "velox/experimental/cudf/expression/ExpressionEvaluator.h" +#include "velox/common/base/Exceptions.h" #include "velox/expression/ConstantExpr.h" #include "velox/expression/FieldReference.h" #include "velox/expression/FunctionSignature.h" #include "velox/expression/SignatureBinder.h" +#include "velox/type/DecimalUtil.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" +#include #include #include #include #include +#include #include #include +#include #include #include +#include +#include #include #include #include @@ -41,11 +49,72 @@ #include #include #include +#include #include +#include namespace facebook::velox::cudf_velox { namespace { +bool decimalScalarIsZero( + const cudf::scalar& scalar, + rmm::cuda_stream_view stream) { + if (!scalar.is_valid(stream)) { + return false; + } + if (scalar.type().id() == cudf::type_id::DECIMAL64) { + auto const& dec = + static_cast const&>( + scalar); + return dec.value(stream) == 0; + } + if (scalar.type().id() == cudf::type_id::DECIMAL128) { + auto const& dec = + static_cast const&>( + scalar); + return dec.value(stream) == 0; + } + return false; +} + +bool hasDecimalZero( + const cudf::column_view& col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { + if (col.is_empty()) { + return false; + } + std::unique_ptr zero; + auto scale = numeric::scale_type{col.type().scale()}; + if (col.type().id() == cudf::type_id::DECIMAL64) { + zero = + cudf::make_fixed_point_scalar(0, scale, stream, mr); + } else if (col.type().id() == cudf::type_id::DECIMAL128) { + zero = cudf::make_fixed_point_scalar( + 0, scale, stream, mr); + } else { + return false; + } + + auto equals = cudf::binary_operation( + col, + *zero, + cudf::binary_operator::EQUAL, + cudf::data_type{cudf::type_id::BOOL8}, + stream, + mr); + auto anyAgg = cudf::make_any_aggregation(); + auto anyScalar = cudf::reduce( + equals->view(), + *anyAgg, + cudf::data_type{cudf::type_id::BOOL8}, + stream, + mr); + auto const& boolScalar = + static_cast const&>(*anyScalar); + return boolScalar.is_valid(stream) && boolScalar.value(stream); +} + struct CudfExpressionEvaluatorEntry { int priority; CudfExpressionEvaluatorCanEvaluate canEvaluate; @@ -181,10 +250,9 @@ class CastFunction : public CudfFunction { CastFunction(const std::shared_ptr& expr) { VELOX_CHECK_EQ(expr->inputs().size(), 1, "cast expects exactly 1 input"); - targetCudfType_ = - cudf::data_type(cudf_velox::veloxToCudfTypeId(expr->type())); - auto sourceType = cudf::data_type( - cudf_velox::veloxToCudfTypeId(expr->inputs()[0]->type())); + targetCudfType_ = cudf_velox::veloxToCudfDataType(expr->type()); + auto sourceType = + cudf_velox::veloxToCudfDataType(expr->inputs()[0]->type()); VELOX_CHECK( cudf::is_supported_cast(sourceType, targetCudfType_), "Cast from {} to {} is not supported", @@ -298,10 +366,9 @@ class BinaryFunction : public CudfFunction { BinaryFunction( const std::shared_ptr& expr, cudf::binary_operator op) - : op_(op), - type_(cudf::data_type(cudf_velox::veloxToCudfTypeId(expr->type()))) { + : op_(op), type_(cudf_velox::veloxToCudfDataType(expr->type())) { VELOX_CHECK_EQ( - expr->inputs().size(), 2, "binary function expects exactly 2 inputs"); + expr->inputs().size(), 2, "Binary function expects exactly 2 inputs"); if (auto constExpr = std::dynamic_pointer_cast( expr->inputs()[0])) { auto constValue = constExpr->value(); @@ -317,34 +384,585 @@ class BinaryFunction : public CudfFunction { VELOX_CHECK( !(left_ != nullptr && right_ != nullptr), - "Not support both left and right are literals"); + "Binary function on two literals is not supported"); } ColumnOrView eval( std::vector& inputColumns, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) const override { + auto isComparisonOp = [](cudf::binary_operator op) { + switch (op) { + case cudf::binary_operator::EQUAL: + case cudf::binary_operator::NOT_EQUAL: + case cudf::binary_operator::GREATER: + case cudf::binary_operator::GREATER_EQUAL: + case cudf::binary_operator::LESS: + case cudf::binary_operator::LESS_EQUAL: + return true; + default: + return false; + } + }; if (left_ == nullptr && right_ == nullptr) { + if (op_ == cudf::binary_operator::DIV && cudf::is_fixed_point(type_)) { + auto lhsView = asView(inputColumns[0]); + auto rhsView = asView(inputColumns[1]); + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (type_.id() == cudf::type_id::DECIMAL128) { + if (lhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, lhsView.type().scale()}; + lhsCast = cudf::cast(lhsView, castType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, rhsView.type().scale()}; + rhsCast = cudf::cast(rhsView, castType, stream, mr); + rhsView = rhsCast->view(); + } + } + if (hasDecimalZero(rhsView, stream, mr)) { + VELOX_USER_FAIL("Division by zero"); + } + auto lhsScale = -lhsView.type().scale(); + auto rhsScale = -rhsView.type().scale(); + auto outScale = -type_.scale(); + auto aRescale = outScale - lhsScale + rhsScale; + return decimalDivide(lhsView, rhsView, type_, aRescale, stream); + } + auto lhsView = asView(inputColumns[0]); + auto rhsView = asView(inputColumns[1]); + if (isComparisonOp(op_) && cudf::is_fixed_point(lhsView.type()) && + cudf::is_fixed_point(rhsView.type())) { + auto lhsScale = -lhsView.type().scale(); + auto rhsScale = -rhsView.type().scale(); + auto targetScale = lhsScale > rhsScale ? lhsScale : rhsScale; + auto targetTypeId = (lhsView.type().id() == cudf::type_id::DECIMAL128 || + rhsView.type().id() == cudf::type_id::DECIMAL128) + ? cudf::type_id::DECIMAL128 + : cudf::type_id::DECIMAL64; + auto targetType = + cudf::data_type{targetTypeId, numeric::scale_type{-targetScale}}; + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (lhsView.type() != targetType) { + lhsCast = cudf::cast(lhsView, targetType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type() != targetType) { + rhsCast = cudf::cast(rhsView, targetType, stream, mr); + rhsView = rhsCast->view(); + } + return cudf::binary_operation(lhsView, rhsView, op_, type_, stream, mr); + } + if (cudf::is_fixed_point(type_)) { + if (op_ == cudf::binary_operator::ADD || + op_ == cudf::binary_operator::SUB || + op_ == cudf::binary_operator::MOD) { + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (lhsView.type() != type_) { + lhsCast = cudf::cast(lhsView, type_, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type() != type_) { + rhsCast = cudf::cast(rhsView, type_, stream, mr); + rhsView = rhsCast->view(); + } + return cudf::binary_operation( + lhsView, rhsView, op_, type_, stream, mr); + } + if (op_ == cudf::binary_operator::MUL) { + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (type_.id() == cudf::type_id::DECIMAL128) { + if (lhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, lhsView.type().scale()}; + lhsCast = cudf::cast(lhsView, castType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, rhsView.type().scale()}; + rhsCast = cudf::cast(rhsView, castType, stream, mr); + rhsView = rhsCast->view(); + } + } + return cudf::binary_operation( + lhsView, rhsView, op_, type_, stream, mr); + } + } + return cudf::binary_operation(lhsView, rhsView, op_, type_, stream, mr); + } else if (left_ == nullptr) { + if (op_ == cudf::binary_operator::DIV && cudf::is_fixed_point(type_)) { + if (decimalScalarIsZero(*right_, stream)) { + VELOX_USER_FAIL("Division by zero"); + } + auto lhsView = asView(inputColumns[0]); + auto lhsScale = -lhsView.type().scale(); + auto rhsScale = -right_->type().scale(); + auto outScale = -type_.scale(); + auto aRescale = outScale - lhsScale + rhsScale; + auto rhsCol = + cudf::make_column_from_scalar(*right_, lhsView.size(), stream, mr); + auto rhsView = rhsCol->view(); + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (type_.id() == cudf::type_id::DECIMAL128) { + if (lhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, lhsView.type().scale()}; + lhsCast = cudf::cast(lhsView, castType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, rhsView.type().scale()}; + rhsCast = cudf::cast(rhsView, castType, stream, mr); + rhsView = rhsCast->view(); + } + } + return decimalDivide(lhsView, rhsView, type_, aRescale, stream); + } + auto lhsView = asView(inputColumns[0]); + if (isComparisonOp(op_) && cudf::is_fixed_point(lhsView.type()) && + cudf::is_fixed_point(right_->type())) { + auto rhsCol = + cudf::make_column_from_scalar(*right_, lhsView.size(), stream, mr); + auto rhsView = rhsCol->view(); + auto lhsScale = -lhsView.type().scale(); + auto rhsScale = -rhsView.type().scale(); + auto targetScale = lhsScale > rhsScale ? lhsScale : rhsScale; + auto targetTypeId = (lhsView.type().id() == cudf::type_id::DECIMAL128 || + rhsView.type().id() == cudf::type_id::DECIMAL128) + ? cudf::type_id::DECIMAL128 + : cudf::type_id::DECIMAL64; + auto targetType = + cudf::data_type{targetTypeId, numeric::scale_type{-targetScale}}; + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (lhsView.type() != targetType) { + lhsCast = cudf::cast(lhsView, targetType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type() != targetType) { + rhsCast = cudf::cast(rhsView, targetType, stream, mr); + rhsView = rhsCast->view(); + } + return cudf::binary_operation(lhsView, rhsView, op_, type_, stream, mr); + } + if (cudf::is_fixed_point(type_)) { + auto rhsCol = + cudf::make_column_from_scalar(*right_, lhsView.size(), stream, mr); + auto rhsView = rhsCol->view(); + if (op_ == cudf::binary_operator::ADD || + op_ == cudf::binary_operator::SUB || + op_ == cudf::binary_operator::MOD) { + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (lhsView.type() != type_) { + lhsCast = cudf::cast(lhsView, type_, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type() != type_) { + rhsCast = cudf::cast(rhsView, type_, stream, mr); + rhsView = rhsCast->view(); + } + return cudf::binary_operation( + lhsView, rhsView, op_, type_, stream, mr); + } + if (op_ == cudf::binary_operator::MUL) { + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (type_.id() == cudf::type_id::DECIMAL128) { + if (lhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, lhsView.type().scale()}; + lhsCast = cudf::cast(lhsView, castType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, rhsView.type().scale()}; + rhsCast = cudf::cast(rhsView, castType, stream, mr); + rhsView = rhsCast->view(); + } + } + return cudf::binary_operation( + lhsView, rhsView, op_, type_, stream, mr); + } + } return cudf::binary_operation( + asView(inputColumns[0]), *right_, op_, type_, stream, mr); + } + if (op_ == cudf::binary_operator::DIV && cudf::is_fixed_point(type_)) { + auto rhsView = asView(inputColumns[0]); + if (hasDecimalZero(rhsView, stream, mr)) { + VELOX_USER_FAIL("Division by zero"); + } + auto lhsScale = -left_->type().scale(); + auto rhsScale = -rhsView.type().scale(); + auto outScale = -type_.scale(); + auto aRescale = outScale - lhsScale + rhsScale; + auto lhsCol = + cudf::make_column_from_scalar(*left_, rhsView.size(), stream, mr); + auto lhsView = lhsCol->view(); + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (type_.id() == cudf::type_id::DECIMAL128) { + if (lhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, lhsView.type().scale()}; + lhsCast = cudf::cast(lhsView, castType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, rhsView.type().scale()}; + rhsCast = cudf::cast(rhsView, castType, stream, mr); + rhsView = rhsCast->view(); + } + } + return decimalDivide(lhsView, rhsView, type_, aRescale, stream); + } + auto rhsView = asView(inputColumns[0]); + if (isComparisonOp(op_) && cudf::is_fixed_point(left_->type()) && + cudf::is_fixed_point(rhsView.type())) { + auto lhsCol = + cudf::make_column_from_scalar(*left_, rhsView.size(), stream, mr); + auto lhsView = lhsCol->view(); + auto lhsScale = -lhsView.type().scale(); + auto rhsScale = -rhsView.type().scale(); + auto targetScale = lhsScale > rhsScale ? lhsScale : rhsScale; + auto targetTypeId = (lhsView.type().id() == cudf::type_id::DECIMAL128 || + rhsView.type().id() == cudf::type_id::DECIMAL128) + ? cudf::type_id::DECIMAL128 + : cudf::type_id::DECIMAL64; + auto targetType = + cudf::data_type{targetTypeId, numeric::scale_type{-targetScale}}; + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (lhsView.type() != targetType) { + lhsCast = cudf::cast(lhsView, targetType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type() != targetType) { + rhsCast = cudf::cast(rhsView, targetType, stream, mr); + rhsView = rhsCast->view(); + } + return cudf::binary_operation(lhsView, rhsView, op_, type_, stream, mr); + } + if (cudf::is_fixed_point(type_)) { + auto lhsCol = + cudf::make_column_from_scalar(*left_, rhsView.size(), stream, mr); + auto lhsView = lhsCol->view(); + if (op_ == cudf::binary_operator::ADD || + op_ == cudf::binary_operator::SUB || + op_ == cudf::binary_operator::MOD) { + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (lhsView.type() != type_) { + lhsCast = cudf::cast(lhsView, type_, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type() != type_) { + rhsCast = cudf::cast(rhsView, type_, stream, mr); + rhsView = rhsCast->view(); + } + return cudf::binary_operation(lhsView, rhsView, op_, type_, stream, mr); + } + if (op_ == cudf::binary_operator::MUL) { + std::unique_ptr lhsCast; + std::unique_ptr rhsCast; + if (type_.id() == cudf::type_id::DECIMAL128) { + if (lhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, lhsView.type().scale()}; + lhsCast = cudf::cast(lhsView, castType, stream, mr); + lhsView = lhsCast->view(); + } + if (rhsView.type().id() == cudf::type_id::DECIMAL64) { + auto castType = cudf::data_type{ + cudf::type_id::DECIMAL128, rhsView.type().scale()}; + rhsCast = cudf::cast(rhsView, castType, stream, mr); + rhsView = rhsCast->view(); + } + } + return cudf::binary_operation(lhsView, rhsView, op_, type_, stream, mr); + } + } + return cudf::binary_operation(*left_, rhsView, op_, type_, stream, mr); + } + + private: + const cudf::binary_operator op_; + const cudf::data_type type_; + std::unique_ptr left_; + std::unique_ptr right_; +}; + +class UnaryFunction : public CudfFunction { + public: + UnaryFunction( + const std::shared_ptr& expr, + cudf::unary_operator op) + : op_(op) { + VELOX_CHECK_EQ( + expr->inputs().size(), 1, "Unary function expects exactly 1 input"); + auto constExpr = + std::dynamic_pointer_cast(expr->inputs()[0]); + VELOX_CHECK_NULL( + constExpr, "Unary function on literal input is not supported"); + // @TODO (seves 1/28/26) + // binary functions require at least ONE input to be non-literal + // do we need to support unary functions with ONLY a literal input? + // assuming not for now + } + + ColumnOrView eval( + std::vector& inputColumns, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override { + return cudf::unary_operation(asView(inputColumns[0]), op_, stream, mr); + } + + private: + const cudf::unary_operator op_; +}; + +class LogicalFunction : public CudfFunction { + public: + LogicalFunction( + const std::shared_ptr& expr, + cudf::binary_operator op) + : op_(op) { + VELOX_CHECK_GE( + expr->inputs().size(), 2, "Logical function expects at least 2 inputs"); + literals_.reserve(expr->inputs().size()); + for (const auto& input : expr->inputs()) { + auto constExpr = + std::dynamic_pointer_cast(input); + if (constExpr) { + literals_.push_back(VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + createCudfScalar, + constExpr->value()->typeKind(), + constExpr->value())); + } else { + literals_.push_back(nullptr); + } + } + } + + ColumnOrView eval( + std::vector& inputColumns, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override { + size_t rowCount = 0; + if (!inputColumns.empty()) { + rowCount = asView(inputColumns[0]).size(); + } + if (rowCount == 0 && inputColumns.empty()) { + rowCount = 1; + } + + std::vector> literalColumns; + literalColumns.reserve(literals_.size()); + std::vector operands; + operands.reserve(literals_.size()); + + size_t columnIndex = 0; + for (const auto& literal : literals_) { + if (literal) { + auto column = cudf::make_column_from_scalar(*literal, rowCount, stream); + operands.push_back(column->view()); + literalColumns.push_back(std::move(column)); + } else { + VELOX_CHECK_LT(columnIndex, inputColumns.size()); + operands.push_back(asView(inputColumns[columnIndex++])); + } + } + + VELOX_CHECK(!operands.empty()); + if (operands.size() == 1) { + if (!literalColumns.empty()) { + return std::move(literalColumns[0]); + } + return operands[0]; + } + + auto result = cudf::binary_operation( + operands[0], operands[1], op_, kBoolType, stream, mr); + for (size_t i = 2; i < operands.size(); ++i) { + result = cudf::binary_operation( + result->view(), operands[i], op_, kBoolType, stream, mr); + } + return result; + } + + private: + static constexpr cudf::data_type kBoolType{cudf::type_id::BOOL8}; + const cudf::binary_operator op_; + std::vector> literals_; +}; + +class BetweenFunction : public CudfFunction { + public: + BetweenFunction(const std::shared_ptr& expr) { + // must have exactly three inputs: value, min, max + VELOX_CHECK_EQ( + expr->inputs().size(), 3, "Between function expects exactly 3 inputs"); + // value must not be a literal + auto constExpr = + std::dynamic_pointer_cast(expr->inputs()[0]); + VELOX_CHECK_NULL( + constExpr, "Between function with literal input is not supported"); + if (auto constExpr = std::dynamic_pointer_cast( + expr->inputs()[1])) { + // min is a literal + auto constValue = constExpr->value(); + minLiteral_ = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + createCudfScalar, constValue->typeKind(), constValue); + } + if (auto constExpr = std::dynamic_pointer_cast( + expr->inputs()[2])) { + // max is a literal + auto constValue = constExpr->value(); + maxLiteral_ = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + createCudfScalar, constValue->typeKind(), constValue); + } + } + + ColumnOrView eval( + std::vector& inputColumns, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override { + // return (value >= min) && (value <= max) + std::unique_ptr geResultColumn, leResultColumn; + if (minLiteral_) { + geResultColumn = cudf::binary_operation( + asView(inputColumns[0]), + *minLiteral_, + cudf::binary_operator::GREATER_EQUAL, + kBoolType, + stream, + mr); + } else { + geResultColumn = cudf::binary_operation( + asView(inputColumns[0]), + asView(inputColumns[1]), + cudf::binary_operator::GREATER_EQUAL, + kBoolType, + stream, + mr); + } + if (maxLiteral_) { + leResultColumn = cudf::binary_operation( + asView(inputColumns[0]), + *maxLiteral_, + cudf::binary_operator::LESS_EQUAL, + kBoolType, + stream, + mr); + } else { + leResultColumn = cudf::binary_operation( + asView(inputColumns[0]), + asView(inputColumns[2]), + cudf::binary_operator::LESS_EQUAL, + kBoolType, + stream, + mr); + } + return cudf::binary_operation( + geResultColumn->view(), + leResultColumn->view(), + cudf::binary_operator::LOGICAL_AND, + kBoolType, + stream, + mr); + } + + private: + static constexpr cudf::data_type kBoolType{cudf::type_id::BOOL8}; + std::unique_ptr minLiteral_; + std::unique_ptr maxLiteral_; +}; + +class GreatestLeastFunction : public CudfFunction { + public: + GreatestLeastFunction( + const std::shared_ptr& expr, + cudf::binary_operator op) + : op_(op), type_(cudf_velox::veloxToCudfDataType(expr->type())) { + // must have at least three inputs + VELOX_CHECK_GE( + expr->inputs().size(), + 3, + "Greatest/Least function expects at least 3 inputs"); + // scan inputs for literals + for (size_t i = 0; i < expr->inputs().size(); ++i) { + auto constExpr = std::dynamic_pointer_cast( + expr->inputs()[i]); + if (constExpr) { + literals_.push_back(VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + createCudfScalar, + constExpr->value()->typeKind(), + constExpr->value())); + } else { + literals_.push_back(nullptr); + } + } + } + + ColumnOrView eval( + std::vector& inputColumns, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) const override { + // construct a chain of NULL_MIN or NULL_MAX operations + std::unique_ptr result; + // the first pair of values + if (literals_[0] && literals_[1]) { + // no variant of cudf::binary_operation that takes two scalars so we must + // create columns + auto col0 = cudf::make_column_from_scalar(*literals_[0], 1, stream); + auto col1 = cudf::make_column_from_scalar(*literals_[1], 1, stream); + result = cudf::binary_operation( + col0->view(), col1->view(), op_, type_, stream, mr); + } else if (literals_[0]) { + result = cudf::binary_operation( + *literals_[0], asView(inputColumns[1]), op_, type_, stream, mr); + } else if (literals_[1]) { + result = cudf::binary_operation( + asView(inputColumns[0]), *literals_[1], op_, type_, stream, mr); + } else { + result = cudf::binary_operation( asView(inputColumns[0]), asView(inputColumns[1]), op_, type_, stream, mr); - } else if (left_ == nullptr) { - return cudf::binary_operation( - asView(inputColumns[0]), *right_, op_, type_, stream, mr); } - return cudf::binary_operation( - *left_, asView(inputColumns[0]), op_, type_, stream, mr); + // remaining values + for (size_t i = 2; i < inputColumns.size(); ++i) { + if (literals_[i]) { + result = cudf::binary_operation( + result->view(), *literals_[i], op_, type_, stream, mr); + } else { + result = cudf::binary_operation( + result->view(), asView(inputColumns[i]), op_, type_, stream, mr); + } + } + return result; } private: const cudf::binary_operator op_; const cudf::data_type type_; - std::unique_ptr left_; - std::unique_ptr right_; + std::vector> literals_; }; class SwitchFunction : public CudfFunction { @@ -875,25 +1493,48 @@ bool registerBuiltinFunctions(const std::string& prefix) { .variableArity("T") .build()}); + registerCudfFunction( + "and", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared( + expr, cudf::binary_operator::LOGICAL_AND); + }, + {FunctionSignatureBuilder() + .returnType("boolean") + .argumentType("boolean") + .variableArity("boolean") + .build()}); + + registerCudfFunction( + "or", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared( + expr, cudf::binary_operator::LOGICAL_OR); + }, + {FunctionSignatureBuilder() + .returnType("boolean") + .argumentType("boolean") + .variableArity("boolean") + .build()}); + registerCudfFunction( prefix + "round", [](const std::string&, const std::shared_ptr& expr) { return std::make_shared(expr); }, - {// TODO(dm): Enable after adding decimal support to velox-cudf - // FunctionSignatureBuilder() - // .integerVariable("p") - // .integerVariable("s") - // .returnType("decimal(p,s)") - // .argumentType("decimal(p,s)") - // .build(), - // FunctionSignatureBuilder() - // .integerVariable("p") - // .integerVariable("s") - // .returnType("decimal(p,s)") - // .argumentType("decimal(p,s)") - // .constantArgumentType("integer") - // .build(), + {FunctionSignatureBuilder() + .integerVariable("p") + .integerVariable("s") + .returnType("decimal(p,s)") + .argumentType("decimal(p,s)") + .build(), + FunctionSignatureBuilder() + .integerVariable("p") + .integerVariable("s") + .returnType("decimal(p,s)") + .argumentType("decimal(p,s)") + .constantArgumentType("integer") + .build(), FunctionSignatureBuilder() .returnType("tinyint") .argumentType("tinyint") @@ -986,34 +1627,6 @@ bool registerBuiltinFunctions(const std::string& prefix) { .constantArgumentType("varchar") .build()}); - // Our cudf binary ops can take all numeric types but instead of listing them - // all, we're testing if input types can be casted to double. Coersion will - // pass because all numerics can be casted to double. - // TODO (dm): This could break for decimal - registerCudfFunctions( - {prefix + "greaterthan", prefix + "gt"}, - [](const std::string&, const std::shared_ptr& expr) { - return std::make_shared( - expr, cudf::binary_operator::GREATER); - }, - {FunctionSignatureBuilder() - .returnType("boolean") - .argumentType("double") - .argumentType("double") - .build()}); - - registerCudfFunction( - prefix + "divide", - [](const std::string&, const std::shared_ptr& expr) { - return std::make_shared( - expr, cudf::binary_operator::DIV); - }, - {FunctionSignatureBuilder() - .returnType("double") - .argumentType("double") - .argumentType("double") - .build()}); - // No prefix because switch and if are special form registerCudfFunctions( {"switch", "if"}, @@ -1043,6 +1656,227 @@ bool registerBuiltinFunctions(const std::string& prefix) { } else { registerPrestoFunctions(prefix); } + + // + // regular binary operators + // + + auto registerBinaryOp = [&](const std::vector& aliases, + cudf::binary_operator op) { + auto decimalBinarySignature = [&](cudf::binary_operator decimalOp) { + std::string rPrecisionConstraint; + std::string rScaleConstraint; + switch (decimalOp) { + case cudf::binary_operator::ADD: + case cudf::binary_operator::SUB: + rPrecisionConstraint = + "min(38, max(a_precision - a_scale, b_precision - b_scale) + " + "max(a_scale, b_scale) + 1)"; + rScaleConstraint = "max(a_scale, b_scale)"; + break; + case cudf::binary_operator::MUL: + rPrecisionConstraint = "min(38, a_precision + b_precision)"; + rScaleConstraint = "a_scale + b_scale"; + break; + case cudf::binary_operator::DIV: + rPrecisionConstraint = + "min(38, a_precision + b_scale + max(0, b_scale - a_scale))"; + rScaleConstraint = "max(a_scale, b_scale)"; + break; + case cudf::binary_operator::MOD: + rPrecisionConstraint = + "min(b_precision - b_scale, a_precision - a_scale) + " + "max(a_scale, b_scale)"; + rScaleConstraint = "max(a_scale, b_scale)"; + break; + default: + VELOX_FAIL("Unsupported decimal binary operator"); + } + + return FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable("r_precision", rPrecisionConstraint) + .integerVariable("r_scale", rScaleConstraint) + .returnType("decimal(r_precision, r_scale)") + .argumentType("decimal(a_precision, a_scale)") + .argumentType("decimal(b_precision, b_scale)") + .build(); + }; + + registerCudfFunctions( + aliases, + [op]( + const std::string&, + const std::shared_ptr& expr) { + return std::make_shared(expr, op); + }, + {FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .argumentType("double") + .build(), + decimalBinarySignature(op)}); + }; + + registerBinaryOp( + {prefix + "plus", prefix + "add"}, cudf::binary_operator::ADD); + registerBinaryOp( + {prefix + "minus", prefix + "subtract"}, cudf::binary_operator::SUB); + registerBinaryOp({prefix + "multiply"}, cudf::binary_operator::MUL); + registerBinaryOp({prefix + "divide"}, cudf::binary_operator::DIV); + registerBinaryOp({prefix + "mod"}, cudf::binary_operator::MOD); + + // + // regular comparison operators + // + + auto registerComparisonOp = [&](const std::vector& aliases, + cudf::binary_operator op) { + registerCudfFunctions( + aliases, + [op]( + const std::string&, + const std::shared_ptr& expr) { + return std::make_shared(expr, op); + }, + {FunctionSignatureBuilder() + .returnType("boolean") + .argumentType("double") + .argumentType("double") + .build(), + FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .returnType("boolean") + .argumentType("decimal(a_precision, a_scale)") + .argumentType("decimal(b_precision, b_scale)") + .build()}); + }; + + registerComparisonOp( + {prefix + "equal", prefix + "eq"}, cudf::binary_operator::EQUAL); + registerComparisonOp( + {prefix + "notequal", prefix + "neq"}, cudf::binary_operator::NOT_EQUAL); + registerComparisonOp( + {prefix + "greaterthanorequal", prefix + "gte"}, + cudf::binary_operator::GREATER_EQUAL); + registerComparisonOp( + {prefix + "lessthanorequal", prefix + "lte"}, + cudf::binary_operator::LESS_EQUAL); + registerComparisonOp( + {prefix + "greaterthan", prefix + "gt"}, cudf::binary_operator::GREATER); + registerComparisonOp( + {prefix + "lessthan", prefix + "lt"}, cudf::binary_operator::LESS); + + // + // regular unary operators + // + + auto registerUnaryOp = [&](const std::vector& aliases, + cudf::unary_operator op) { + registerCudfFunctions( + aliases, + [op]( + const std::string&, + const std::shared_ptr& expr) { + return std::make_shared(expr, op); + }, + {FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .build(), + FunctionSignatureBuilder() + .integerVariable("p") + .integerVariable("s") + .returnType("decimal(p,s)") + .argumentType("decimal(p,s)") + .build()}); + }; + + registerUnaryOp({prefix + "abs"}, cudf::unary_operator::ABS); + registerUnaryOp({prefix + "negate"}, cudf::unary_operator::NEGATE); + registerUnaryOp({prefix + "floor"}, cudf::unary_operator::FLOOR); + + // @TODO (seves 1/28/26) + // uncomment this once DecimalCeilFunction exists + // registerUnaryOp({prefix + "ceil"}, cudf::unary_operator::CEIL); + + // @TODO (seves 1/28/26) + // truncate + // no direct cudf mapping + // perhaps a compound operation using round/round_decimal + + // + // between + // + + registerCudfFunction( + prefix + "between", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared(expr); + }, + {FunctionSignatureBuilder() + .returnType("boolean") + .argumentType("double") + .argumentType("double") + .argumentType("double") + .build(), + FunctionSignatureBuilder() + .integerVariable("p") + .integerVariable("s") + .returnType("boolean") + .argumentType("decimal(p,s)") + .argumentType("decimal(p,s)") + .argumentType("decimal(p,s)") + .build()}); + + // + // greatest & least + // + + registerCudfFunction( + prefix + "greatest", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared( + expr, cudf::binary_operator::NULL_MAX); + }, + {FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .variableArity("double") + .build(), + FunctionSignatureBuilder() + .integerVariable("p") + .integerVariable("s") + .returnType("decimal(p,s)") + .argumentType("decimal(p,s)") + .variableArity("decimal(p,s)") + .build()}); + + registerCudfFunction( + prefix + "least", + [](const std::string&, const std::shared_ptr& expr) { + return std::make_shared( + expr, cudf::binary_operator::NULL_MIN); + }, + {FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .variableArity("double") + .build(), + FunctionSignatureBuilder() + .integerVariable("p") + .integerVariable("s") + .returnType("decimal(p,s)") + .argumentType("decimal(p,s)") + .variableArity("decimal(p,s)") + .build()}); + return true; } @@ -1091,8 +1925,7 @@ ColumnOrView FunctionExpression::eval( auto result = function_->eval(subexprResults, stream, mr); if (finalize) { - const auto requestedType = - cudf::data_type(cudf_velox::veloxToCudfTypeId(expr_->type())); + const auto requestedType = cudf_velox::veloxToCudfDataType(expr_->type()); auto resultView = asView(result); if (resultView.type() != requestedType) { return cudf::cast(resultView, requestedType, stream, mr); @@ -1125,8 +1958,8 @@ bool FunctionExpression::canEvaluate(std::shared_ptr expr) { if (srcType == nullptr || dstType == nullptr) { return false; } - auto src = cudf::data_type(cudf_velox::veloxToCudfTypeId(srcType)); - auto dst = cudf::data_type(cudf_velox::veloxToCudfTypeId(dstType)); + auto src = cudf_velox::veloxToCudfDataType(srcType); + auto dst = cudf_velox::veloxToCudfDataType(dstType); return cudf::is_supported_cast(src, dst); } diff --git a/velox/experimental/cudf/expression/JitExpression.cpp b/velox/experimental/cudf/expression/JitExpression.cpp index 6e89c4551392..3b0b5b2c64bf 100644 --- a/velox/experimental/cudf/expression/JitExpression.cpp +++ b/velox/experimental/cudf/expression/JitExpression.cpp @@ -66,7 +66,7 @@ ColumnOrView JitExpression::eval( }(); if (finalize) { const auto requestedType = - cudf::data_type(cudf_velox::veloxToCudfTypeId(expr_.expr_->type())); + cudf_velox::veloxToCudfDataType(expr_.expr_->type()); auto resultView = asView(result); if (resultView.type() != requestedType) { result = cudf::cast(resultView, requestedType, stream, mr); diff --git a/velox/experimental/cudf/expression/PrecomputeInstruction.h b/velox/experimental/cudf/expression/PrecomputeInstruction.h index bafaec01433e..4e8fb024ee10 100644 --- a/velox/experimental/cudf/expression/PrecomputeInstruction.h +++ b/velox/experimental/cudf/expression/PrecomputeInstruction.h @@ -20,6 +20,8 @@ #include +#include + namespace facebook::velox::cudf_velox { // Pre-compute instructions for the expression, @@ -47,12 +49,12 @@ struct PrecomputeInstruction { int depIndex, const std::string& name, int newIndex, - const std::vector& nestedIndices, + std::vector&& nestedIndices, const std::shared_ptr& node = nullptr) : dependent_column_index(depIndex), ins_name(name), new_column_index(newIndex), - nested_dependent_column_indices(nestedIndices), + nested_dependent_column_indices(std::move(nestedIndices)), cudf_expression(node) {} }; diff --git a/velox/experimental/cudf/expression/SubfieldFiltersToAst.cpp b/velox/experimental/cudf/expression/SubfieldFiltersToAst.cpp index c3b89d5eb343..577881a361f8 100644 --- a/velox/experimental/cudf/expression/SubfieldFiltersToAst.cpp +++ b/velox/experimental/cudf/expression/SubfieldFiltersToAst.cpp @@ -17,6 +17,7 @@ #include "velox/experimental/cudf/expression/SubfieldFiltersToAst.h" #include "velox/common/base/Exceptions.h" +#include "velox/type/DecimalUtil.h" #include #include @@ -27,6 +28,16 @@ namespace facebook::velox::cudf_velox { namespace { +std::pair getInt128BoundsForType(const TypePtr& type) { + if (type->isDecimal()) { + const auto [precision, _] = getDecimalPrecisionScale(*type); + const auto maxAbs = DecimalUtil::kPowersOfTen[precision] - 1; + return {-maxAbs, maxAbs}; + } + return { + std::numeric_limits::min(), + std::numeric_limits::max()}; +} template < typename RangeT, @@ -178,6 +189,101 @@ std::reference_wrapper buildBigintRangeExpr( } } +std::reference_wrapper buildHugeintRangeExpr( + const common::Filter& filter, + cudf::ast::tree& tree, + std::vector>& scalars, + const cudf::ast::expression& columnRef, + const TypePtr& columnTypePtr) { + using Op = cudf::ast::ast_operator; + using Operation = cudf::ast::operation; + + auto* hugeintRange = static_cast(&filter); + const auto lower = hugeintRange->lower(); + const auto upper = hugeintRange->upper(); + + const auto [minVal, maxVal] = getInt128BoundsForType(columnTypePtr); + const bool skipLowerBound = lower <= minVal; + const bool skipUpperBound = upper >= maxVal; + + auto addLiteral = [&](int128_t value) -> const cudf::ast::expression& { + variant veloxVariant = value; + const auto& literal = makeScalarAndLiteral( + columnTypePtr, veloxVariant, scalars); + return tree.push(literal); + }; + + if (lower == upper) { + if (skipLowerBound || skipUpperBound) { + return tree.push(Operation{Op::NOT_EQUAL, columnRef, columnRef}); + } + auto const& literal = addLiteral(lower); + return tree.push(Operation{Op::EQUAL, columnRef, literal}); + } + + const cudf::ast::expression* lowerExpr = nullptr; + if (!skipLowerBound) { + auto const& lowerLiteral = addLiteral(lower); + lowerExpr = + &tree.push(Operation{Op::GREATER_EQUAL, columnRef, lowerLiteral}); + } + + const cudf::ast::expression* upperExpr = nullptr; + if (!skipUpperBound) { + auto const& upperLiteral = addLiteral(upper); + upperExpr = &tree.push(Operation{Op::LESS_EQUAL, columnRef, upperLiteral}); + } + + if (lowerExpr && upperExpr) { + return tree.push(Operation{Op::NULL_LOGICAL_AND, *lowerExpr, *upperExpr}); + } else if (lowerExpr) { + return *lowerExpr; + } else if (upperExpr) { + return *upperExpr; + } + + // No bounds => pass-through filter. + return tree.push(Operation{Op::EQUAL, columnRef, columnRef}); +} + +template +const cudf::ast::expression& buildHashInListExpr( + const common::Filter& filter, + cudf::ast::tree& tree, + const cudf::ast::expression& columnRef, + std::vector>& scalars, + const TypePtr& columnTypePtr, + bool isNegated = false) { + using Op = cudf::ast::ast_operator; + using Operation = cudf::ast::operation; + + auto* valuesFilter = dynamic_cast(&filter); + VELOX_CHECK_NOT_NULL(valuesFilter, "Filter is not a hash-table list filter"); + auto const& values = valuesFilter->values(); + VELOX_CHECK(!values.empty(), "Empty List filter not supported"); + + std::vector exprVec; + for (const auto& value : values) { + variant veloxVariant = static_cast(value); + auto const& literal = tree.push( + makeScalarAndLiteral(columnTypePtr, veloxVariant, scalars)); + auto const& equalExpr = tree.push( + Operation{isNegated ? Op::NOT_EQUAL : Op::EQUAL, columnRef, literal}); + exprVec.push_back(&equalExpr); + } + + const cudf::ast::expression* result = exprVec[0]; + for (size_t i = 1; i < exprVec.size(); ++i) { + result = &tree.push( + Operation{ + isNegated ? Op::NULL_LOGICAL_AND : Op::NULL_LOGICAL_OR, + *result, + *exprVec[i]}); + } + + return *result; +} + template auto createFloatingPointRangeExpr( const common::Filter& filter, @@ -354,6 +460,21 @@ cudf::ast::expression const& createAstFromSubfieldFilter( return result.get(); } + case common::FilterKind::kHugeintRange: { + auto const& columnType = inputRowSchema->childAt(columnIndex); + auto const& expr = + buildHugeintRangeExpr(filter, tree, scalars, columnRef, columnType); + return expr.get(); + } + + case common::FilterKind::kBigintValuesUsingHashTable: { + auto const& columnType = inputRowSchema->childAt(columnIndex); + return buildHashInListExpr< + TypeKind::BIGINT, + common::BigintValuesUsingHashTable, + int64_t>(filter, tree, columnRef, scalars, columnType); + } + case common::FilterKind::kBigintValuesUsingBitmask: { auto const& columnType = inputRowSchema->childAt(columnIndex); // Dispatch by the column's integer kind and cast filter values to it. @@ -370,6 +491,14 @@ cudf::ast::expression const& createAstFromSubfieldFilter( return result.get(); } + case common::FilterKind::kHugeintValuesUsingHashTable: { + auto const& columnType = inputRowSchema->childAt(columnIndex); + return buildHashInListExpr< + TypeKind::HUGEINT, + common::HugeintValuesUsingHashTable, + int128_t>(filter, tree, columnRef, scalars, columnType); + } + case common::FilterKind::kBytesValues: { return buildInListExpr( filter, tree, columnRef, scalars, false, stream, mr); diff --git a/velox/experimental/cudf/tests/CMakeLists.txt b/velox/experimental/cudf/tests/CMakeLists.txt index 76c5fc812f0a..d6336e3a2f32 100644 --- a/velox/experimental/cudf/tests/CMakeLists.txt +++ b/velox/experimental/cudf/tests/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable(velox_cudf_aggregation_test Main.cpp AggregationTest.cpp) add_executable(velox_cudf_assign_unique_id_test Main.cpp AssignUniqueIdTest.cpp) add_executable(velox_cudf_config_test Main.cpp ConfigTest.cpp) add_executable(velox_cudf_expression_selection_test Main.cpp ExpressionEvaluatorSelectionTest.cpp) +add_executable(velox_cudf_decimal_expression_test Main.cpp DecimalExpressionTest.cpp) add_executable(velox_cudf_filter_project_test Main.cpp FilterProjectTest.cpp) add_executable(velox_cudf_hash_join_test HashJoinTest.cpp Main.cpp) add_executable(velox_cudf_limit_test Main.cpp LimitTest.cpp) @@ -63,6 +64,12 @@ add_test( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ) +add_test( + NAME velox_cudf_decimal_expression_test + COMMAND velox_cudf_decimal_expression_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + add_test( NAME velox_cudf_filter_project_test COMMAND velox_cudf_filter_project_test @@ -157,6 +164,7 @@ set_tests_properties( velox_cudf_expression_selection_test PROPERTIES LABELS cuda_driver TIMEOUT 3000 ) +set_tests_properties(velox_cudf_decimal_expression_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_filter_project_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_hash_join_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) set_tests_properties(velox_cudf_limit_test PROPERTIES LABELS cuda_driver TIMEOUT 3000) @@ -215,6 +223,17 @@ target_link_libraries( gtest_main ) +target_link_libraries( + velox_cudf_decimal_expression_test + velox_cudf_exec + velox_exec + velox_exec_test_lib + velox_functions_test_lib + velox_test_util + gtest + gtest_main +) + target_link_libraries( velox_cudf_filter_project_test velox_cudf_exec diff --git a/velox/experimental/cudf/tests/DecimalExpressionTest.cpp b/velox/experimental/cudf/tests/DecimalExpressionTest.cpp new file mode 100644 index 000000000000..2cc9dde7a164 --- /dev/null +++ b/velox/experimental/cudf/tests/DecimalExpressionTest.cpp @@ -0,0 +1,915 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/experimental/cudf/CudfConfig.h" +#include "velox/experimental/cudf/exec/ToCudf.h" + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/file/FileSystems.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/type/DecimalUtil.h" + +#include + +namespace facebook::velox::cudf_velox { +namespace { + +class CudfDecimalTest : public exec::test::OperatorTestBase { + protected: + void SetUp() override { + exec::test::OperatorTestBase::SetUp(); + filesystems::registerLocalFileSystem(); + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions(); + CudfConfig::getInstance().allowCpuFallback = false; + // Ensure a CUDA device is selected and initialized (RMM asserts otherwise). + int deviceCount = 0; + auto status = cudaGetDeviceCount(&deviceCount); + if (status != cudaSuccess) { + GTEST_SKIP() << "cudaGetDeviceCount failed: " << static_cast(status) + << " (" << cudaGetErrorString(status) << ")"; + } + if (deviceCount == 0) { + GTEST_SKIP() << "No CUDA devices visible (check CUDA_VISIBLE_DEVICES)"; + } + VELOX_CHECK_EQ(0, static_cast(cudaSetDevice(0))); + VELOX_CHECK_EQ(0, static_cast(cudaFree(nullptr))); + registerCudf(); + } + + void TearDown() override { + unregisterCudf(); + exec::test::OperatorTestBase::TearDown(); + } +}; + +TEST_F(CudfDecimalTest, decimal64And128ArithmeticAndComparison) { + // Short decimal (64-bit) uses scale 2, long decimal (128-bit) uses scale 10. + auto rowType = ROW({ + {"d64_a", DECIMAL(12, 2)}, + {"d64_b", DECIMAL(12, 2)}, + {"d128_a", DECIMAL(38, 10)}, + {"d128_b", DECIMAL(38, 10)}, + }); + + // Raw values are already scaled. + auto input = makeRowVector( + {"d64_a", "d64_b", "d128_a", "d128_b"}, + { + makeFlatVector( + {12345, -2500, 999999}, + DECIMAL(12, 2)), // 123.45, -25.00, 9999.99 + makeFlatVector( + {6789, 1500, -50000}, DECIMAL(12, 2)), // 67.89, 15.00, -500.00 + makeFlatVector( + { + static_cast(123'456'789'012), // 12.3456789012 + static_cast(-987'654'321'098), // -98.7654321098 + static_cast(555'000'000'000), // 55.5000000000 + }, + DECIMAL(38, 10)), + makeFlatVector( + { + static_cast(222'222'222'222), // 22.2222222222 + static_cast(333'333'333'333), // 33.3333333333 + static_cast(-111'111'111'111), // -11.1111111111 + }, + DECIMAL(38, 10)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({ + "d64_a + d64_b AS sum64", + "d64_a - d64_b AS diff64", + "d64_a > d64_b AS gt64", + "d128_a + d128_b AS sum128", + "d128_a - d128_b AS diff128", + "d128_a < d128_b AS lt128", + }) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT d64_a + d64_b AS sum64, " + "d64_a - d64_b AS diff64, " + "d64_a > d64_b AS gt64, " + "d128_a + d128_b AS sum128, " + "d128_a - d128_b AS diff128, " + "d128_a < d128_b AS lt128 " + "FROM tmp"); +} + +TEST_F(CudfDecimalTest, decimalIdentityProjection64And128) { + auto rowType = ROW({ + {"d64", DECIMAL(12, 2)}, + {"d128", DECIMAL(38, 10)}, + }); + + // Max absolute raw value for DECIMAL(38,10) is 10^28 - 1 (28 integer digits). + const int128_t max38p10 = facebook::velox::DecimalUtil::kPowersOfTen[28] - 1; + + auto input = makeRowVector( + {"d64", "d128"}, + { + makeFlatVector( + { + // Near max/min for DECIMAL(12,2): +/- 99,999,999,999.99 + 9'999'999'999'999, // 99,999,999,999.99 + -9'999'999'999'999, // -99,999,999,999.99 + // Mid-range values + 123'45, // 1,23.45 + -2'500, // -25.00 + 999'999, // 9,999.99 + -1'000, // -10.00 + 0, + 1, // 0.01 + -1, // -0.01 + }, + DECIMAL(12, 2)), + makeFlatVector( + { + // Near max/min for DECIMAL(38,10): +/- (10^28 - 1) with scale + // 10 + max38p10, + -max38p10, + // Mid-range values + static_cast(123'456'789'012), // 12.3456789012 + static_cast(-987'654'321'098), // -98.7654321098 + static_cast(555'000'000'000), // 55.5000000000 + static_cast(44'388'888'889), // 4.4388888889 + static_cast(1), // 0.0000000001 + static_cast(-1), // -0.0000000001 + static_cast(0), + }, + DECIMAL(38, 10)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"d64", "d128"}) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT d64, d128 FROM tmp"); +} + +TEST_F(CudfDecimalTest, decimalAddition64And128) { + auto rowType = ROW({ + {"d64_a", DECIMAL(12, 2)}, + {"d64_b", DECIMAL(12, 2)}, + {"d128_a", DECIMAL(38, 10)}, + {"d128_b", DECIMAL(38, 10)}, + }); + + const int128_t max38p10 = facebook::velox::DecimalUtil::kPowersOfTen[28] - 1; + const int128_t min38p10 = -max38p10; + + auto input = makeRowVector( + {"d64_a", "d64_b", "d128_a", "d128_b"}, + { + makeFlatVector( + { + 9'999'999'999'99, // 9,999,999,999.99 (near max for 12,2) + -9'999'999'999'99, // -9,999,999,999.99 + 123'45, // 1,23.45 + -2'500, // -25.00 + 0, + }, + DECIMAL(12, 2)), + makeFlatVector( + { + 1, // 0.01 + -1, // -0.01 + 9'999, // 99.99 + -100, // -1.00 + 50, // 0.50 + }, + DECIMAL(12, 2)), + makeFlatVector( + { + max38p10, + min38p10, + static_cast(123'456'789'012), // 12.3456789012 + static_cast(-987'654'321'098), // -98.7654321098 + static_cast(0), + }, + DECIMAL(38, 10)), + makeFlatVector( + { + static_cast(1), // 0.0000000001 + static_cast(-1), // -0.0000000001 + static_cast(44'388'888'889), // 4.4388888889 + static_cast(555'000'000'000), // 55.5000000000 + max38p10, + }, + DECIMAL(38, 10)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({ + "d64_a + d64_b AS sum64", + "d128_a + d128_b AS sum128", + }) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT d64_a + d64_b AS sum64, d128_a + d128_b AS sum128 FROM tmp"); +} + +TEST_F(CudfDecimalTest, decimalMultiplyPromotesToLong) { + // Two short decimals whose product requires long decimal precision. + auto rowType = ROW({ + {"a", DECIMAL(10, 0)}, + {"b", DECIMAL(10, 0)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector( + {9'999'999'999, 1'234'567'890, -2'000'000'000}, DECIMAL(10, 0)), + makeFlatVector({9'999'999'999, -2, 4}, DECIMAL(10, 0)), + }); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a * b AS prod"}) + .planNode(); + + const int128_t expected0 = static_cast(9'999'999'999LL) * + static_cast(9'999'999'999LL); + auto expected = makeRowVector( + {"prod"}, + {makeFlatVector( + {expected0, + static_cast(-2'469'135'780LL), + static_cast(-8'000'000'000LL)}, + DECIMAL(20, 0))}); + + // CPU (no cuDF adapter registered). + unregisterCudf(); + auto cpuResult = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + registerCudf(); + + // GPU (enable cuDF, no fallback). + auto gpuResult = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + + // Verify promotion to long decimal and exact results on CPU/GPU. + ASSERT_TRUE(cpuResult->childAt(0)->type()->isLongDecimal()); + ASSERT_TRUE(gpuResult->childAt(0)->type()->isLongDecimal()); + facebook::velox::test::assertEqualVectors(expected, cpuResult); + facebook::velox::test::assertEqualVectors(expected, gpuResult); +} + +TEST_F(CudfDecimalTest, decimalAddPromotesToLong) { + // Two short decimals whose sum requires long decimal precision. + auto rowType = ROW({ + {"a", DECIMAL(18, 0)}, + {"b", DECIMAL(18, 0)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector( + {999'999'999'999'999'999LL, + -900'000'000'000'000'000LL, + 123'456'789'012'345'678LL}, + DECIMAL(18, 0)), + makeFlatVector( + {999'999'999'999'999'999LL, + -900'000'000'000'000'000LL, + -123'456'789'012'345'678LL}, + DECIMAL(18, 0)), + }); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a + b AS sum"}) + .planNode(); + + auto expected = makeRowVector( + {"sum"}, + {makeFlatVector( + {static_cast(1'999'999'999'999'999'998LL), + static_cast(-1'800'000'000'000'000'000LL), + static_cast(0)}, + DECIMAL(19, 0))}); + + // CPU (no cuDF adapter registered). + unregisterCudf(); + auto cpuResult = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + registerCudf(); + + // GPU (cuDF enabled). + auto gpuResult = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + + ASSERT_TRUE(cpuResult->childAt(0)->type()->isLongDecimal()); + ASSERT_TRUE(gpuResult->childAt(0)->type()->isLongDecimal()); + facebook::velox::test::assertEqualVectors(expected, cpuResult); + facebook::velox::test::assertEqualVectors(expected, gpuResult); +} + +TEST_F(CudfDecimalTest, decimalAddDifferentScales) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 1)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({12345, -2500, 100}, DECIMAL(10, 2)), + makeFlatVector({10, -25, 3}, DECIMAL(10, 1)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a + b AS sum"}) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT a + b AS sum FROM tmp"); +} + +TEST_F(CudfDecimalTest, decimalSubtractDifferentScales) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 1)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({12345, -2500, 100}, DECIMAL(10, 2)), + makeFlatVector({10, -25, 3}, DECIMAL(10, 1)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a - b AS diff"}) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT a - b AS diff FROM tmp"); +} + +TEST_F(CudfDecimalTest, decimalMultiplyDifferentScales) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 1)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({12345, -2500, 100}, DECIMAL(10, 2)), + makeFlatVector({10, -25, 3}, DECIMAL(10, 1)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a * b AS prod"}) + .planNode(); + + auto expected = makeRowVector( + {"prod"}, + {makeFlatVector( + { + static_cast(12345) * static_cast(10), + static_cast(-2500) * static_cast(-25), + static_cast(100) * static_cast(3), + }, + DECIMAL(20, 3))}); + + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(CudfDecimalTest, decimalCompareDecimalDecimal) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({120, -250, 10}, DECIMAL(10, 2)), + makeFlatVector({110, -250, 30}, DECIMAL(10, 2)), + }); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({ + "a = b AS eq", + "a != b AS neq", + "a < b AS lt", + "a <= b AS lte", + "a > b AS gt", + "a >= b AS gte", + }) + .planNode(); + + auto expected = makeRowVector( + {"eq", "neq", "lt", "lte", "gt", "gte"}, + { + makeNullableFlatVector({false, true, false}, BOOLEAN()), + makeNullableFlatVector({true, false, true}, BOOLEAN()), + makeNullableFlatVector({false, false, true}, BOOLEAN()), + makeNullableFlatVector({false, true, true}, BOOLEAN()), + makeNullableFlatVector({true, false, false}, BOOLEAN()), + makeNullableFlatVector({true, true, false}, BOOLEAN()), + }); + + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(CudfDecimalTest, decimalCompareWithLiteral) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a"}, + {makeNullableFlatVector( + {120, 110, 130, std::nullopt}, DECIMAL(10, 2))}); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({ + "a = CAST('1.20' AS DECIMAL(10, 2)) AS eq_r", + "a != CAST('1.20' AS DECIMAL(10, 2)) AS neq_r", + "a < CAST('1.20' AS DECIMAL(10, 2)) AS lt_r", + "a <= CAST('1.20' AS DECIMAL(10, 2)) AS lte_r", + "a > CAST('1.20' AS DECIMAL(10, 2)) AS gt_r", + "a >= CAST('1.20' AS DECIMAL(10, 2)) AS gte_r", + "CAST('1.20' AS DECIMAL(10, 2)) = a AS eq_l", + "CAST('1.20' AS DECIMAL(10, 2)) != a AS neq_l", + "CAST('1.20' AS DECIMAL(10, 2)) < a AS lt_l", + "CAST('1.20' AS DECIMAL(10, 2)) <= a AS lte_l", + "CAST('1.20' AS DECIMAL(10, 2)) > a AS gt_l", + "CAST('1.20' AS DECIMAL(10, 2)) >= a AS gte_l", + }) + .planNode(); + + auto expected = makeRowVector( + {"eq_r", + "neq_r", + "lt_r", + "lte_r", + "gt_r", + "gte_r", + "eq_l", + "neq_l", + "lt_l", + "lte_l", + "gt_l", + "gte_l"}, + { + makeNullableFlatVector( + {true, false, false, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {false, true, true, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {false, true, false, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {true, true, false, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {false, false, true, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {true, false, true, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {true, false, false, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {false, true, true, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {false, false, true, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {true, false, true, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {false, true, false, std::nullopt}, BOOLEAN()), + makeNullableFlatVector( + {true, true, false, std::nullopt}, BOOLEAN()), + }); + + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(CudfDecimalTest, decimalLogicalAndOrProject) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({100, 200, 300, 400, 500}, DECIMAL(10, 2)), + makeFlatVector({250, 150, 350, 100, 500}, DECIMAL(10, 2)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({ + "(a > CAST('1.00' AS DECIMAL(10, 2)) " + "AND b < CAST('2.00' AS DECIMAL(10, 2))) AS and2", + "(a > CAST('1.00' AS DECIMAL(10, 2)) " + "AND b < CAST('3.00' AS DECIMAL(10, 2)) " + "AND a < CAST('4.00' AS DECIMAL(10, 2))) AS and3", + "(a < CAST('1.00' AS DECIMAL(10, 2)) " + "OR b > CAST('3.00' AS DECIMAL(10, 2)) " + "OR a = CAST('2.00' AS DECIMAL(10, 2))) AS or3", + }) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT " + "(a > CAST('1.00' AS DECIMAL(10, 2)) " + " AND b < CAST('2.00' AS DECIMAL(10, 2))) AS and2, " + "(a > CAST('1.00' AS DECIMAL(10, 2)) " + " AND b < CAST('3.00' AS DECIMAL(10, 2)) " + " AND a < CAST('4.00' AS DECIMAL(10, 2))) AS and3, " + "(a < CAST('1.00' AS DECIMAL(10, 2)) " + " OR b > CAST('3.00' AS DECIMAL(10, 2)) " + " OR a = CAST('2.00' AS DECIMAL(10, 2))) AS or3 " + "FROM tmp"); +} + +TEST_F(CudfDecimalTest, decimalLogicalAndOrFilter) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({100, 200, 300, 400, 500}, DECIMAL(10, 2)), + makeFlatVector({250, 150, 350, 100, 500}, DECIMAL(10, 2)), + }); + + std::vector vectors = {input}; + createDuckDbTable(vectors); + + const std::string filter = + "((a between CAST('1.50' AS DECIMAL(10, 2)) AND " + "CAST('3.50' AS DECIMAL(10, 2)) " + "AND b < CAST('3.00' AS DECIMAL(10, 2)) " + "AND a > CAST('1.00' AS DECIMAL(10, 2))) " + "OR a = CAST('4.00' AS DECIMAL(10, 2)) " + "OR a = CAST('5.00' AS DECIMAL(10, 2)))"; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .filter(filter) + .project({"a", "b"}) + .planNode(); + + facebook::velox::exec::test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT a, b FROM tmp WHERE " + "((a between CAST('1.50' AS DECIMAL(10, 2)) AND " + "CAST('3.50' AS DECIMAL(10, 2)) " + "AND b < CAST('3.00' AS DECIMAL(10, 2)) " + "AND a > CAST('1.00' AS DECIMAL(10, 2))) " + "OR a = CAST('4.00' AS DECIMAL(10, 2)) " + "OR a = CAST('5.00' AS DECIMAL(10, 2)))"); +} + +TEST_F(CudfDecimalTest, decimalBinaryNullPropagation) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector( + {100, std::nullopt, 300, std::nullopt}, DECIMAL(10, 2)), + makeNullableFlatVector( + {200, 200, std::nullopt, std::nullopt}, DECIMAL(10, 2)), + }); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({ + "a + b AS sum", + "a / b AS div", + "a = b AS eq", + }) + .planNode(); + + auto expected = makeRowVector( + {"sum", "div", "eq"}, + { + makeNullableFlatVector( + {300, std::nullopt, std::nullopt, std::nullopt}, DECIMAL(11, 2)), + makeNullableFlatVector( + {50, std::nullopt, std::nullopt, std::nullopt}, DECIMAL(12, 2)), + makeNullableFlatVector( + {false, std::nullopt, std::nullopt, std::nullopt}, BOOLEAN()), + }); + + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(CudfDecimalTest, decimalMultiplyDoubleCast) { + auto rowType = ROW({ + {"d", DECIMAL(10, 2)}, + {"x", DOUBLE()}, + }); + + auto input = makeRowVector( + {"d", "x"}, + { + makeFlatVector({125, -250, 50}, DECIMAL(10, 2)), + makeFlatVector({2.0, -4.0, 0.0}), + }); + + std::vector vectors = {input}; + + auto expected = + makeRowVector({"prod"}, {makeFlatVector({2.5, 10.0, 0.0})}); + + auto runAndAssert = [&](bool useCudf) { + if (!useCudf) { + unregisterCudf(); + } + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"cast(d as double) * x AS prod"}) + .planNode(); + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults( + pool()); + facebook::velox::test::assertEqualVectors(expected, result); + }; + + runAndAssert(true); + runAndAssert(false); +} + +TEST_F(CudfDecimalTest, decimalMultiplyDoubleCastRight) { + auto rowType = ROW({ + {"d", DECIMAL(10, 2)}, + {"x", DOUBLE()}, + }); + + auto input = makeRowVector( + {"d", "x"}, + { + makeFlatVector({125, -250, 50}, DECIMAL(10, 2)), + makeFlatVector({2.0, -4.0, 0.0}), + }); + + std::vector vectors = {input}; + + auto expected = + makeRowVector({"prod"}, {makeFlatVector({2.5, 10.0, 0.0})}); + + auto runAndAssert = [&](bool useCudf) { + if (!useCudf) { + unregisterCudf(); + } + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"x * cast(d as double) AS prod"}) + .planNode(); + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults( + pool()); + facebook::velox::test::assertEqualVectors(expected, result); + }; + + runAndAssert(true); + runAndAssert(false); +} + +TEST_F(CudfDecimalTest, decimalAstRecursiveMixedScaleAdd) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 1)}, + {"x", DOUBLE()}, + }); + + auto input = makeRowVector( + {"a", "b", "x"}, + { + makeFlatVector({12345, -2500, 100}, DECIMAL(10, 2)), + makeFlatVector({10, -25, 3}, DECIMAL(10, 1)), + makeFlatVector({2.0, -4.0, 0.0}), + }); + + std::vector vectors = {input}; + + auto expected = + makeRowVector({"prod"}, {makeFlatVector({126.45, -31.5, 1.3})}); + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"cast(a + b as double) + x AS prod"}) + .planNode(); + + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(CudfDecimalTest, decimalCastToDoubleProjection) { + auto rowType = ROW({ + {"d", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"d"}, {makeFlatVector({125, -250, 50}, DECIMAL(10, 2))}); + + std::vector vectors = {input}; + + auto expected = + makeRowVector({"d_double"}, {makeFlatVector({1.25, -2.5, 0.5})}); + + auto runAndAssert = [&](bool useCudf) { + if (!useCudf) { + unregisterCudf(); + } + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"cast(d as double) AS d_double"}) + .planNode(); + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults( + pool()); + facebook::velox::test::assertEqualVectors(expected, result); + }; + + runAndAssert(true); + runAndAssert(false); +} + +TEST_F(CudfDecimalTest, decimalCastToRealProjection) { + auto rowType = ROW({ + {"d", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"d"}, {makeFlatVector({125, -250, 50}, DECIMAL(10, 2))}); + + std::vector vectors = {input}; + + auto expected = + makeRowVector({"d_real"}, {makeFlatVector({1.25f, -2.5f, 0.5f})}); + + auto runAndAssert = [&](bool useCudf) { + if (!useCudf) { + unregisterCudf(); + } + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"cast(d as real) AS d_real"}) + .planNode(); + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults( + pool()); + facebook::velox::test::assertEqualVectors(expected, result); + }; + + runAndAssert(true); + runAndAssert(false); +} + +TEST_F(CudfDecimalTest, decimalDivideRounds) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({200, 100, -200}, DECIMAL(10, 2)), + makeFlatVector({300, 300, 300}, DECIMAL(10, 2)), + }); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a / b AS div"}) + .planNode(); + + auto computeDiv = [](int64_t a, int64_t b) { + __int128_t out = 0; + facebook::velox::DecimalUtil:: + divideWithRoundUp<__int128_t, __int128_t, __int128_t>( + out, + static_cast<__int128_t>(a), + static_cast<__int128_t>(b), + false, + 2, + 0); + return static_cast(out); + }; + + auto expected = makeRowVector( + {"div"}, + {makeFlatVector( + {computeDiv(200, 300), computeDiv(100, 300), computeDiv(-200, 300)}, + DECIMAL(12, 2))}); + + auto result = + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()); + facebook::velox::test::assertEqualVectors(expected, result); +} + +TEST_F(CudfDecimalTest, decimalDivideByZero) { + auto rowType = ROW({ + {"a", DECIMAL(10, 2)}, + {"b", DECIMAL(10, 2)}, + }); + + auto input = makeRowVector( + {"a", "b"}, + { + makeFlatVector({100}, DECIMAL(10, 2)), + makeFlatVector({0}, DECIMAL(10, 2)), + }); + + std::vector vectors = {input}; + + auto plan = exec::test::PlanBuilder() + .values(vectors) + .project({"a / b AS div"}) + .planNode(); + + VELOX_ASSERT_USER_THROW( + facebook::velox::exec::test::AssertQueryBuilder(plan).copyResults(pool()), + "Division by zero"); +} + +} // namespace +} // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/tests/FilterProjectTest.cpp b/velox/experimental/cudf/tests/FilterProjectTest.cpp index 3b43fc844278..2d48e0d5224b 100644 --- a/velox/experimental/cudf/tests/FilterProjectTest.cpp +++ b/velox/experimental/cudf/tests/FilterProjectTest.cpp @@ -686,14 +686,25 @@ TEST_F(CudfFilterProjectTest, round) { AssertQueryBuilder(plan).assertResults(expected); } -// TODO (dm): Enable after adding decimal support to velox-cudf -TEST_F(CudfFilterProjectTest, DISABLED_roundDecimal) { +TEST_F(CudfFilterProjectTest, roundDecimal) { parse::ParseOptions options; options.parseIntegerAsBigint = false; + // Note that the underlying cudf::round_decimal function returns + // a value with the specified scale, and rounds the internal integer + // value accordingly, e.g. rounding 41.2389 to 2 decimal places + // results in an internal integer value of 4124 with a scale of 2. + // + // When the specified scale is non-zero, Velox inserts extra casts + // to restore the original scale. However, when the specified scale + // is zero, Velox does NOT do that, and the result has a scale of 0. + + // Input values 41.2389 and -45.6789 as DECIMAL(10, 4). auto decimalData = makeRowVector( {makeFlatVector({412389, -456789}, DECIMAL(10, 4))}); + // Round to 2 decimal places. + // Expected values are 41.24 and -45.68 as DECIMAL(10, 4). auto plan = PlanBuilder() .setParseOptions(options) .values({decimalData}) @@ -703,15 +714,19 @@ TEST_F(CudfFilterProjectTest, DISABLED_roundDecimal) { {makeFlatVector({412400, -456800}, DECIMAL(10, 4))}); AssertQueryBuilder(plan).assertResults(decimalExpected); + // Round to 0 decimal places. + // Expected values are 41.0 and -46.0 as DECIMAL(10, 0). plan = PlanBuilder() .setParseOptions(options) .values({decimalData}) .project({"round(c0) as c1"}) .planNode(); - decimalExpected = makeRowVector( - {makeFlatVector({410000, -460000}, DECIMAL(10, 4))}); + decimalExpected = + makeRowVector({makeFlatVector({41, -46}, DECIMAL(10, 0))}); AssertQueryBuilder(plan).assertResults(decimalExpected); + // Round to -1 decimal places. + // Expected values are 40.0 and -50.0 as DECIMAL(10, 4). plan = PlanBuilder() .setParseOptions(options) .values({decimalData}) diff --git a/velox/experimental/cudf/tests/SubfieldFilterAstTest.cpp b/velox/experimental/cudf/tests/SubfieldFilterAstTest.cpp index e4bb1af4421b..b2f613ada896 100644 --- a/velox/experimental/cudf/tests/SubfieldFilterAstTest.cpp +++ b/velox/experimental/cudf/tests/SubfieldFilterAstTest.cpp @@ -142,6 +142,11 @@ class SubfieldFilterAstTest : public OperatorTestBase { veloxExpected = filter.testBool(v); break; } + case TypeKind::HUGEINT: { + auto v = fieldVec->asFlatVector()->valueAt(i); + veloxExpected = filter.testInt128(v); + break; + } case TypeKind::VARCHAR: { auto sv = fieldVec->asFlatVector()->valueAt(i); veloxExpected = filter.testBytes(sv.data(), sv.size()); @@ -475,6 +480,42 @@ TEST_F(SubfieldFilterAstTest, SmallIntTypeBounds) { testFilterExecution(rowType, columnName, *filter, vec, expr); } +TEST_F(SubfieldFilterAstTest, DecimalRange) { + const std::string columnName = "c0"; + auto rowType = ROW({{columnName, DECIMAL(20, 2)}}); + // Range [1.23, 4.56] encoded as unscaled integer values. + auto filter = std::make_unique( + int128_t{123}, int128_t{456}, /*nullAllowed*/ false); + + common::Subfield subfield(columnName); + cudf::ast::tree tree; + std::vector> scalars; + const auto& expr = + createAstFromSubfieldFilter(subfield, *filter, tree, scalars, rowType); + + EXPECT_GT(tree.size(), 0UL); + auto vec = makeTestVector(rowType, 100); + testFilterExecution(rowType, columnName, *filter, vec, expr); +} + +TEST_F(SubfieldFilterAstTest, DecimalInList) { + const std::string columnName = "c0"; + auto rowType = ROW({{columnName, DECIMAL(20, 2)}}); + // Values [1.23, 4.56] encoded as unscaled integer values. + std::vector values = {int128_t{123}, int128_t{456}}; + auto filter = common::createHugeintValues(values, /*nullAllowed*/ false); + + common::Subfield subfield(columnName); + cudf::ast::tree tree; + std::vector> scalars; + const auto& expr = + createAstFromSubfieldFilter(subfield, *filter, tree, scalars, rowType); + + EXPECT_GT(tree.size(), 0UL); + auto vec = makeTestVector(rowType, 100); + testFilterExecution(rowType, columnName, *filter, vec, expr); +} + TEST_F(SubfieldFilterAstTest, EmptyInListHandling) { auto rowType = ROW({{"c0", BIGINT()}}); std::vector emptyVals = {}; diff --git a/velox/experimental/cudf/tests/TableScanTest.cpp b/velox/experimental/cudf/tests/TableScanTest.cpp index cfcdb652f282..4e2815c6236f 100644 --- a/velox/experimental/cudf/tests/TableScanTest.cpp +++ b/velox/experimental/cudf/tests/TableScanTest.cpp @@ -19,6 +19,7 @@ #include "velox/experimental/cudf/connectors/hive/CudfHiveConnectorSplit.h" #include "velox/experimental/cudf/connectors/hive/CudfHiveDataSource.h" #include "velox/experimental/cudf/connectors/hive/CudfHiveTableHandle.h" +#include "velox/experimental/cudf/expression/SubfieldFiltersToAst.h" #include "velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.h" #include "velox/common/base/Fs.h" @@ -42,6 +43,8 @@ #include "velox/type/Type.h" #include "velox/type/tests/SubfieldFiltersBuilder.h" +#include + #include using namespace facebook::velox; @@ -56,6 +59,37 @@ using namespace facebook::velox::cudf_velox; using namespace facebook::velox::cudf_velox::exec; using namespace facebook::velox::cudf_velox::exec::test; +namespace { +struct StatsFilterMetrics { + cudf::size_type inputRowGroups{0}; + std::optional rowGroupsAfterStats; + cudf::size_type outputRows{0}; +}; + +StatsFilterMetrics readParquetWithStatsFilter( + const std::string& filePath, + const RowTypePtr& rowType, + const common::SubfieldFilters& filters, + bool useJitFilter) { + cudf::ast::tree tree; + std::vector> scalars; + auto const& expr = + createAstFromSubfieldFilters(filters, tree, scalars, rowType); + + auto options = + cudf::io::parquet_reader_options::builder(cudf::io::source_info(filePath)) + .use_jit_filter(useJitFilter) + .build(); + options.set_filter(expr); + + auto result = cudf::io::read_parquet(options); + return { + result.metadata.num_input_row_groups, + result.metadata.num_row_groups_after_stats_filter, + result.tbl->num_rows()}; +} +} // namespace + class TableScanTest : public virtual CudfHiveConnectorTestBase { protected: void SetUp() override { @@ -526,6 +560,164 @@ TEST_F(TableScanTest, filterPushdown) { #endif } +// Disable this test and the one below for now, pending a CUDF fix. +// simoneves 2/25/26 +// @TODO simoneves/mattgara re-enable once fixed. + +TEST_F(TableScanTest, DISABLED_decimalFilterPushdown) { + auto rowType = ROW({"c0", "c1"}, {DECIMAL(12, 2), DECIMAL(20, 2)}); + + auto vector = makeRowVector( + {"c0", "c1"}, + { + makeFlatVector( + {123, 500, -250, 300, 400, 200}, DECIMAL(12, 2)), + makeFlatVector( + {int128_t{200}, + int128_t{200}, + int128_t{700}, + int128_t{700}, + int128_t{900}, + int128_t{-100}}, + DECIMAL(20, 2)), + }); + + std::vector vectors = {vector}; + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + createDuckDbTable(vectors); + + // c0 between 1.00 and 4.00 and c1 in (2.00, 7.00) + common::SubfieldFilters subfieldFilters = + common::test::SubfieldFiltersBuilder() + .add( + "c0", + std::make_unique( + int64_t{100}, int64_t{400}, /*nullAllowed*/ false)) + .add( + "c1", + common::createHugeintValues( + {int128_t{200}, int128_t{700}}, /*nullAllowed*/ false)) + .build(); + + auto tableHandle = makeTableHandle( + "parquet_table", rowType, std::move(subfieldFilters), nullptr); + + auto assignments = + facebook::velox::exec::test::HiveConnectorTestBase::allRegularColumns( + rowType); + + auto plan = PlanBuilder() + .startTableScan() + .outputType(rowType) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + + assertQuery( + plan, + {filePath}, + "SELECT c0, c1 FROM tmp " + "WHERE c0 BETWEEN CAST('1.00' AS DECIMAL(12, 2)) " + "AND CAST('4.00' AS DECIMAL(12, 2)) " + "AND c1 IN (CAST('2.00' AS DECIMAL(20, 2)), " + "CAST('7.00' AS DECIMAL(20, 2)))"); +} + +TEST_F(TableScanTest, DISABLED_decimalStatsFilterIoPruning) { + auto rowType = ROW({"c0", "c1"}, {DECIMAL(12, 2), DECIMAL(20, 2)}); + auto vec0 = makeRowVector( + {"c0", "c1"}, + {makeFlatVector({100, 200}, DECIMAL(12, 2)), + makeFlatVector( + {int128_t{1000}, int128_t{2000}}, DECIMAL(20, 2))}); + auto vec1 = makeRowVector( + {"c0", "c1"}, + {makeFlatVector({300, 400}, DECIMAL(12, 2)), + makeFlatVector( + {int128_t{3000}, int128_t{4000}}, DECIMAL(20, 2))}); + auto vec2 = makeRowVector( + {"c0", "c1"}, + {makeFlatVector({500, 600}, DECIMAL(12, 2)), + makeFlatVector( + {int128_t{5000}, int128_t{6000}}, DECIMAL(20, 2))}); + + std::vector vectors = {vec0, vec1, vec2}; + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + common::SubfieldFilters filters = + common::test::SubfieldFiltersBuilder() + .add( + "c0", + std::make_unique( + int64_t{300}, int64_t{400}, /*nullAllowed*/ false)) + .add( + "c1", + std::make_unique( + int128_t{3000}, int128_t{4000}, /*nullAllowed*/ false)) + .build(); + + auto metrics = readParquetWithStatsFilter( + filePath->getPath(), rowType, filters, /*useJitFilter*/ true); + EXPECT_EQ(metrics.inputRowGroups, 3); + ASSERT_TRUE(metrics.rowGroupsAfterStats.has_value()); + EXPECT_EQ(metrics.rowGroupsAfterStats.value(), 1); + EXPECT_EQ(metrics.outputRows, 2); +} + +TEST_F(TableScanTest, doubleStatsFilterIoPruning) { + auto rowType = ROW({"c0", "c1"}, {DOUBLE(), DOUBLE()}); + auto vec0 = makeRowVector( + {"c0", "c1"}, + {makeFlatVector({1.0, 2.0}), + makeFlatVector({10.0, 20.0})}); + auto vec1 = makeRowVector( + {"c0", "c1"}, + {makeFlatVector({3.0, 4.0}), + makeFlatVector({30.0, 40.0})}); + auto vec2 = makeRowVector( + {"c0", "c1"}, + {makeFlatVector({5.0, 6.0}), + makeFlatVector({50.0, 60.0})}); + + std::vector vectors = {vec0, vec1, vec2}; + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + + common::SubfieldFilters filters = + common::test::SubfieldFiltersBuilder() + .add( + "c0", + std::make_unique( + 3.0, + /*lowerUnbounded*/ false, + /*lowerExclusive*/ false, + 4.0, + /*upperUnbounded*/ false, + /*upperExclusive*/ false, + /*nullAllowed*/ false)) + .add( + "c1", + std::make_unique( + 30.0, + /*lowerUnbounded*/ false, + /*lowerExclusive*/ false, + 40.0, + /*upperUnbounded*/ false, + /*upperExclusive*/ false, + /*nullAllowed*/ false)) + .build(); + + auto metrics = readParquetWithStatsFilter( + filePath->getPath(), rowType, filters, /*useJitFilter*/ true); + EXPECT_EQ(metrics.inputRowGroups, 3); + ASSERT_TRUE(metrics.rowGroupsAfterStats.has_value()); + EXPECT_EQ(metrics.rowGroupsAfterStats.value(), 1); + EXPECT_EQ(metrics.outputRows, 2); +} + TEST_F(TableScanTest, splitOffsetAndLength) { auto vectors = makeVectors(10, 1'000); auto filePath = TempFilePath::create(); diff --git a/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp b/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp index 86522942de97..5a26ecd97afa 100644 --- a/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp +++ b/velox/experimental/cudf/tests/utils/CudfHiveConnectorTestBase.cpp @@ -319,7 +319,7 @@ CudfHiveConnectorTestBase::makeCudfHiveInsertTableHandle( std::make_shared( tableColumnNames.at(i), tableColumnTypes.at(i), - cudf::data_type{veloxToCudfTypeId(tableColumnTypes.at(i))})); + veloxToCudfDataType(tableColumnTypes.at(i)))); } return std::make_shared(