diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index 99f8078bf..2c82bbb1b 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -116,6 +116,9 @@ message PhysicalExprNode { // SparkPartitionID SparkPartitionIdExprNode spark_partition_id_expr = 20101; + // MonotonicIncreasingID + MonotonicIncreasingIdExprNode monotonic_increasing_id_expr = 20102; + // BloomFilterMightContain BloomFilterMightContainExprNode bloom_filter_might_contain_expr = 20200; } @@ -365,6 +368,12 @@ message StringContainsExprNode { message RowNumExprNode { } +message SparkPartitionIdExprNode { +} + +message MonotonicIncreasingIdExprNode { +} + message BloomFilterMightContainExprNode { string uuid = 1; PhysicalExprNode bloom_filter_expr = 2; @@ -917,5 +926,3 @@ message ArrowType { // } //} message EmptyMessage{} - -message SparkPartitionIdExprNode {} diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index cfab99e17..49ada28d0 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -52,7 +52,9 @@ use datafusion::{ use datafusion_ext_exprs::{ bloom_filter_might_contain::BloomFilterMightContainExpr, cast::TryCastExpr, get_indexed_field::GetIndexedFieldExpr, get_map_value::GetMapValueExpr, - named_struct::NamedStructExpr, row_num::RowNumExpr, spark_partition_id::SparkPartitionIdExpr, + named_struct::NamedStructExpr, row_num::RowNumExpr, + spark_monotonically_increasing_id::SparkMonotonicallyIncreasingIdExpr, + spark_partition_id::SparkPartitionIdExpr, spark_scalar_subquery_wrapper::SparkScalarSubqueryWrapperExpr, spark_udf_wrapper::SparkUDFWrapperExpr, string_contains::StringContainsExpr, string_ends_with::StringEndsWithExpr, string_starts_with::StringStartsWithExpr, @@ -965,6 +967,9 @@ impl PhysicalPlanner { ExprType::SparkPartitionIdExpr(_) => { Arc::new(SparkPartitionIdExpr::new(self.partition_id)) } + ExprType::MonotonicIncreasingIdExpr(_) => { + Arc::new(SparkMonotonicallyIncreasingIdExpr::new(self.partition_id)) + } ExprType::BloomFilterMightContainExpr(e) => Arc::new(BloomFilterMightContainExpr::new( e.uuid.clone(), self.try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?, diff --git a/native-engine/datafusion-ext-exprs/src/lib.rs b/native-engine/datafusion-ext-exprs/src/lib.rs index bb2757f00..c6732b576 100644 --- a/native-engine/datafusion-ext-exprs/src/lib.rs +++ b/native-engine/datafusion-ext-exprs/src/lib.rs @@ -23,6 +23,7 @@ pub mod get_indexed_field; pub mod get_map_value; pub mod named_struct; pub mod row_num; +pub mod spark_monotonically_increasing_id; pub mod spark_partition_id; pub mod spark_scalar_subquery_wrapper; pub mod spark_udf_wrapper; diff --git a/native-engine/datafusion-ext-exprs/src/spark_monotonically_increasing_id.rs b/native-engine/datafusion-ext-exprs/src/spark_monotonically_increasing_id.rs new file mode 100644 index 000000000..7fb646af6 --- /dev/null +++ b/native-engine/datafusion-ext-exprs/src/spark_monotonically_increasing_id.rs @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + sync::{ + Arc, + atomic::{AtomicI64, Ordering::SeqCst}, + }, +}; + +use arrow::{ + array::{Int64Array, RecordBatch}, + datatypes::{DataType, Schema}, +}; +use datafusion::{ + common::Result, + logical_expr::ColumnarValue, + physical_expr::{PhysicalExpr, PhysicalExprRef}, +}; + +pub struct SparkMonotonicallyIncreasingIdExpr { + partition_id: i64, + row_counter: AtomicI64, +} + +impl SparkMonotonicallyIncreasingIdExpr { + pub fn new(partition_id: usize) -> Self { + Self { + partition_id: partition_id as i64, + row_counter: AtomicI64::new(0), + } + } +} + +impl Display for SparkMonotonicallyIncreasingIdExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MonotonicallyIncreasingID") + } +} + +impl Debug for SparkMonotonicallyIncreasingIdExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MonotonicallyIncreasingID") + } +} + +impl PartialEq for SparkMonotonicallyIncreasingIdExpr { + fn eq(&self, other: &Self) -> bool { + self.partition_id == other.partition_id + } +} + +impl Eq for SparkMonotonicallyIncreasingIdExpr {} + +impl Hash for SparkMonotonicallyIncreasingIdExpr { + fn hash(&self, state: &mut H) { + self.partition_id.hash(state); + } +} + +impl PhysicalExpr for SparkMonotonicallyIncreasingIdExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); + let start_row = self.row_counter.fetch_add(num_rows as i64, SeqCst); + + let partition_offset = self.partition_id << 33; + let array: Int64Array = (start_row..start_row + num_rows as i64) + .map(|row_id| partition_offset | row_id) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn children(&self) -> Vec<&PhysicalExprRef> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec, + ) -> Result { + Ok(Arc::new(Self::new(self.partition_id as usize))) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "fmt_sql not used") + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::Int64Array, + datatypes::{Field, Schema}, + record_batch::RecordBatch, + }; + + use super::*; + + #[test] + fn test_data_type_and_nullable() { + let expr = SparkMonotonicallyIncreasingIdExpr::new(0); + let schema = Schema::new(vec![] as Vec); + assert_eq!( + expr.data_type(&schema).expect("data_type failed"), + DataType::Int64 + ); + assert!(!expr.nullable(&schema).expect("nullable failed")); + } + + #[test] + fn test_evaluate_generates_monotonic_ids() { + let expr = SparkMonotonicallyIncreasingIdExpr::new(0); + let schema = Schema::new(vec![Field::new("col", DataType::Int64, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int64Array::from(vec![1, 2, 3]))], + ) + .expect("RecordBatch creation failed"); + + let result = expr.evaluate(&batch).expect("evaluate failed"); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr + .as_any() + .downcast_ref::() + .expect("downcast failed"); + assert_eq!(int_arr.len(), 3); + assert_eq!(int_arr.value(0), 0); + assert_eq!(int_arr.value(1), 1); + assert_eq!(int_arr.value(2), 2); + } + _ => unreachable!("Expected Array result"), + } + + let batch2 = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int64Array::from(vec![4, 5]))], + ) + .expect("RecordBatch creation failed"); + + let result2 = expr.evaluate(&batch2).expect("evaluate failed"); + match result2 { + ColumnarValue::Array(arr) => { + let int_arr = arr + .as_any() + .downcast_ref::() + .expect("downcast failed"); + assert_eq!(int_arr.len(), 2); + assert_eq!(int_arr.value(0), 3); + assert_eq!(int_arr.value(1), 4); + } + _ => unreachable!("Expected Array result"), + } + } + + #[test] + fn test_evaluate_with_partition_offset() { + let partition_id = 5; + let expr = SparkMonotonicallyIncreasingIdExpr::new(partition_id); + let schema = Schema::new(vec![Field::new("col", DataType::Int64, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int64Array::from(vec![1, 2]))], + ) + .expect("RecordBatch creation failed"); + + let result = expr.evaluate(&batch).expect("evaluate failed"); + match result { + ColumnarValue::Array(arr) => { + let int_arr = arr + .as_any() + .downcast_ref::() + .expect("downcast failed"); + let expected_offset = (partition_id as i64) << 33; + assert_eq!(int_arr.value(0), expected_offset); + assert_eq!(int_arr.value(1), expected_offset + 1); + } + _ => unreachable!("Expected Array result"), + } + } + + #[test] + fn test_different_partitions_have_different_ranges() { + let schema = Schema::new(vec![Field::new("col", DataType::Int64, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int64Array::from(vec![1, 2]))], + ) + .expect("RecordBatch creation failed"); + + let expr1 = SparkMonotonicallyIncreasingIdExpr::new(0); + let expr2 = SparkMonotonicallyIncreasingIdExpr::new(1); + + let result1 = expr1.evaluate(&batch).expect("evaluate failed"); + let result2 = expr2.evaluate(&batch).expect("evaluate failed"); + + match (result1, result2) { + (ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)) => { + let int_arr1 = arr1 + .as_any() + .downcast_ref::() + .expect("downcast failed"); + let int_arr2 = arr2 + .as_any() + .downcast_ref::() + .expect("downcast failed"); + + assert_ne!(int_arr1.value(0), int_arr2.value(0)); + assert_eq!(int_arr1.value(0), 0); + assert_eq!(int_arr2.value(0), 1i64 << 33); + } + _ => unreachable!("Expected Array results"), + } + } +} diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala index 1427e01d3..f3a443a45 100644 --- a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Generator import org.apache.spark.sql.catalyst.expressions.Like import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.SparkPartitionID @@ -529,6 +530,13 @@ class ShimsImpl extends Shims with Logging { .setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder()) .build()) + case _: MonotonicallyIncreasingID => + Some( + pb.PhysicalExprNode + .newBuilder() + .setMonotonicIncreasingIdExpr(pb.MonotonicIncreasingIdExprNode.newBuilder()) + .build()) + case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, IntegerType)) // native StringSplit implementation does not support regex, so only most frequently // used cases without regex are supported