diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs index 3f2e06773..ecabfaed9 100644 --- a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs +++ b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs @@ -56,6 +56,10 @@ impl FullJoiner { self.lindices.len() >= self.join_params.batch_size } + fn has_enough_room(&self, new_size: usize) -> bool { + self.lindices.len() + new_size <= self.join_params.batch_size + } + async fn flush( mut self: Pin<&mut Self>, cur1: &mut StreamCursor, @@ -160,9 +164,26 @@ impl Joiner for FullJoiner, + right: Arc, + on: JoinOn, + join_type: JoinType, + batch_size: usize, + ) -> Result<(Vec, Vec)> { + MemManager::init(1000000); + let session_config = SessionConfig::new().with_batch_size(batch_size); + let session_ctx = SessionContext::new_with_config(session_config); + let task_ctx = session_ctx.task_ctx(); + let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?; + + let join: Arc = match test_type { + SMJ => { + let sort_options = vec![SortOptions::default(); on.len()]; + Arc::new(SortMergeJoinExec::try_new( + schema, + left, + right, + on, + join_type, + sort_options, + )?) + } + BHJLeftProbed => { + let right = Arc::new(BroadcastJoinBuildHashMapExec::new( + right, + on.iter().map(|(_, right_key)| right_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Right, + true, + None, + )?) + } + BHJRightProbed => { + let left = Arc::new(BroadcastJoinBuildHashMapExec::new( + left, + on.iter().map(|(left_key, _)| left_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Left, + true, + None, + )?) + } + SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Right, + false, + None, + )?), + SHJRightProbed => Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Left, + false, + None, + )?), + }; + let columns = columns(&join.schema()); + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + const ALL_TEST_TYPE: [TestType; 5] = [ SMJ, BHJLeftProbed, @@ -428,6 +513,118 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn join_inner_batchsize() -> Result<()> { + for test_type in ALL_TEST_TYPE { + let left = build_table( + ("a1", &vec![1, 1, 1, 1, 1]), + ("b1", &vec![1, 2, 3, 4, 5]), + ("c1", &vec![1, 2, 3, 4, 5]), + ); + let right = build_table( + ("a2", &vec![1, 1, 1, 1, 1, 1, 1]), + ("b2", &vec![1, 2, 3, 4, 5, 6, 7]), + ("c2", &vec![1, 2, 3, 4, 5, 6, 7]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a2", &right.schema())?), + )]; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 1 | 1 | 1 | 1 |", + "| 1 | 1 | 1 | 1 | 2 | 2 |", + "| 1 | 1 | 1 | 1 | 3 | 3 |", + "| 1 | 1 | 1 | 1 | 4 | 4 |", + "| 1 | 1 | 1 | 1 | 5 | 5 |", + "| 1 | 1 | 1 | 1 | 6 | 6 |", + "| 1 | 1 | 1 | 1 | 7 | 7 |", + "| 1 | 2 | 2 | 1 | 1 | 1 |", + "| 1 | 2 | 2 | 1 | 2 | 2 |", + "| 1 | 2 | 2 | 1 | 3 | 3 |", + "| 1 | 2 | 2 | 1 | 4 | 4 |", + "| 1 | 2 | 2 | 1 | 5 | 5 |", + "| 1 | 2 | 2 | 1 | 6 | 6 |", + "| 1 | 2 | 2 | 1 | 7 | 7 |", + "| 1 | 3 | 3 | 1 | 1 | 1 |", + "| 1 | 3 | 3 | 1 | 2 | 2 |", + "| 1 | 3 | 3 | 1 | 3 | 3 |", + "| 1 | 3 | 3 | 1 | 4 | 4 |", + "| 1 | 3 | 3 | 1 | 5 | 5 |", + "| 1 | 3 | 3 | 1 | 6 | 6 |", + "| 1 | 3 | 3 | 1 | 7 | 7 |", + "| 1 | 4 | 4 | 1 | 1 | 1 |", + "| 1 | 4 | 4 | 1 | 2 | 2 |", + "| 1 | 4 | 4 | 1 | 3 | 3 |", + "| 1 | 4 | 4 | 1 | 4 | 4 |", + "| 1 | 4 | 4 | 1 | 5 | 5 |", + "| 1 | 4 | 4 | 1 | 6 | 6 |", + "| 1 | 4 | 4 | 1 | 7 | 7 |", + "| 1 | 5 | 5 | 1 | 1 | 1 |", + "| 1 | 5 | 5 | 1 | 2 | 2 |", + "| 1 | 5 | 5 | 1 | 3 | 3 |", + "| 1 | 5 | 5 | 1 | 4 | 4 |", + "| 1 | 5 | 5 | 1 | 5 | 5 |", + "| 1 | 5 | 5 | 1 | 6 | 6 |", + "| 1 | 5 | 5 | 1 | 7 | 7 |", + "+----+----+----+----+----+----+", + ]; + let (_, batches) = join_collect_with_batch_size( + test_type, + left.clone(), + right.clone(), + on.clone(), + Inner, + 2, + ) + .await?; + assert_batches_sorted_eq!(expected, &batches); + let (_, batches) = join_collect_with_batch_size( + test_type, + left.clone(), + right.clone(), + on.clone(), + Inner, + 3, + ) + .await?; + assert_batches_sorted_eq!(expected, &batches); + let (_, batches) = join_collect_with_batch_size( + test_type, + left.clone(), + right.clone(), + on.clone(), + Inner, + 4, + ) + .await?; + assert_batches_sorted_eq!(expected, &batches); + let (_, batches) = join_collect_with_batch_size( + test_type, + left.clone(), + right.clone(), + on.clone(), + Inner, + 5, + ) + .await?; + assert_batches_sorted_eq!(expected, &batches); + let (_, batches) = join_collect_with_batch_size( + test_type, + left.clone(), + right.clone(), + on.clone(), + Inner, + 7, + ) + .await?; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn join_left_one() -> Result<()> { for test_type in ALL_TEST_TYPE {