diff --git a/native-engine/datafusion-ext-plans/src/common/execution_context.rs b/native-engine/datafusion-ext-plans/src/common/execution_context.rs index 15b6e69d0..4513ab1a9 100644 --- a/native-engine/datafusion-ext-plans/src/common/execution_context.rs +++ b/native-engine/datafusion-ext-plans/src/common/execution_context.rs @@ -141,6 +141,89 @@ impl ExecutionContext { .counter(name.to_owned(), self.partition_id) } + pub fn split_with_default_batch_size( + self: &Arc, + input: SendableRecordBatchStream, + ) -> SendableRecordBatchStream { + struct SplitLargeBatchStream { + input: SendableRecordBatchStream, + current_batch: Option, + current_offset: usize, + } + + impl SplitLargeBatchStream { + fn split_next_chunk(&mut self) -> Option { + let batch = self.current_batch.as_ref()?; + let target_batch_size = batch_size(); + let num_rows = batch.num_rows(); + + if self.current_offset >= num_rows { + self.current_batch = None; + return None; + } + + let chunk_size = std::cmp::min(target_batch_size, num_rows - self.current_offset); + let chunk = batch.slice(self.current_offset, chunk_size); + self.current_offset += chunk_size; + + if self.current_offset >= num_rows { + self.current_batch = None; + } + + Some(chunk) + } + } + + impl RecordBatchStream for SplitLargeBatchStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } + } + + impl Stream for SplitLargeBatchStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + if let Some(chunk) = self.split_next_chunk() { + return Poll::Ready(Some(Ok(chunk))); + } + + match ready!(self.input.as_mut().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.is_empty() { + continue; + } + + let target_batch_size = batch_size(); + if target_batch_size == 0 { + return Poll::Ready(Some(Err(DataFusionError::Internal( + "Invalid batch size: 0".to_string(), + )))); + } + + let num_rows = batch.num_rows(); + if num_rows <= target_batch_size { + return Poll::Ready(Some(Ok(batch))); + } else { + self.current_batch = Some(batch); + self.current_offset = 0; + } + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + } + } + } + } + + Box::pin(SplitLargeBatchStream { + input, + current_batch: None, + current_offset: 0, + }) + } + pub fn coalesce_with_default_batch_size( self: &Arc, input: SendableRecordBatchStream, diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 78eda5b62..430a93e75 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs @@ -205,7 +205,8 @@ impl SortMergeJoinExec { .sub_duration(poll_time.duration()); }) }); - Ok(exec_ctx.coalesce_with_default_batch_size(output)) + Ok(exec_ctx + .coalesce_with_default_batch_size(exec_ctx.split_with_default_batch_size(output))) } }