-
Notifications
You must be signed in to change notification settings - Fork 205
[AURON #1891] Implement randn() function #1938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements the randn() function to improve Spark function coverage in Auron. The function generates random values from a standard normal distribution with optional seed support.
Changes:
- Added Rust implementation of
spark_randnfunction with seed handling - Registered the new function in the Scala converter and Rust function registry
- Added
rand_distrdependency for normal distribution sampling
Reviewed changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala | Added case handler for Randn expression to route to native implementation |
| native-engine/datafusion-ext-functions/src/spark_randn.rs | New implementation of randn function with seed handling and unit tests |
| native-engine/datafusion-ext-functions/src/lib.rs | Registered Spark_Randn function in the extension function factory |
| native-engine/datafusion-ext-functions/Cargo.toml | Added rand and rand_distr dependencies |
| Cargo.toml | Added rand_distr workspace dependency |
| Cargo.lock | Updated lock file with rand_distr package metadata |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #[cfg(test)] | ||
| mod test { | ||
| use std::error::Error; | ||
|
|
||
| use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; | ||
|
|
||
| use crate::spark_randn::spark_randn; | ||
|
|
||
| #[test] | ||
| fn test_randn_with_seed_reproducibility() -> Result<(), Box<dyn Error>> { | ||
| // Same seed should produce same result | ||
| let seed = ColumnarValue::Scalar(ScalarValue::Int64(Some(42))); | ||
|
|
||
| let result1 = spark_randn(&vec![seed.clone()])?; | ||
| let result2 = spark_randn(&vec![seed])?; | ||
|
|
||
| match (result1, result2) { | ||
| ( | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v1))), | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v2))), | ||
| ) => { | ||
| assert_eq!(v1, v2, "Same seed should produce same result"); | ||
| } | ||
| _ => panic!("Expected Float64 scalar results"), | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_randn_different_seeds() -> Result<(), Box<dyn Error>> { | ||
| // Different seeds should produce different results (with very high probability) | ||
| let seed1 = ColumnarValue::Scalar(ScalarValue::Int64(Some(42))); | ||
| let seed2 = ColumnarValue::Scalar(ScalarValue::Int64(Some(123))); | ||
|
|
||
| let result1 = spark_randn(&vec![seed1])?; | ||
| let result2 = spark_randn(&vec![seed2])?; | ||
|
|
||
| match (result1, result2) { | ||
| ( | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v1))), | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v2))), | ||
| ) => { | ||
| assert_ne!(v1, v2, "Different seeds should produce different results"); | ||
| } | ||
| _ => panic!("Expected Float64 scalar results"), | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_randn_no_seed() -> Result<(), Box<dyn Error>> { | ||
| // Without seed, should still produce a valid float | ||
| let result = spark_randn(&vec![])?; | ||
|
|
||
| match result { | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { | ||
| assert!(v.is_finite(), "Result should be a finite number"); | ||
| } | ||
| _ => panic!("Expected Float64 scalar result"), | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_randn_with_int32_seed() -> Result<(), Box<dyn Error>> { | ||
| // Int32 seed should work | ||
| let seed = ColumnarValue::Scalar(ScalarValue::Int32(Some(42))); | ||
|
|
||
| let result = spark_randn(&vec![seed])?; | ||
|
|
||
| match result { | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { | ||
| assert!(v.is_finite(), "Result should be a finite number"); | ||
| } | ||
| _ => panic!("Expected Float64 scalar result"), | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_randn_with_null_seed() -> Result<(), Box<dyn Error>> { | ||
| // Null seed should be treated as no seed (random) | ||
| let seed = ColumnarValue::Scalar(ScalarValue::Null); | ||
|
|
||
| let result = spark_randn(&vec![seed])?; | ||
|
|
||
| match result { | ||
| ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { | ||
| assert!(v.is_finite(), "Result should be a finite number"); | ||
| } | ||
| _ => panic!("Expected Float64 scalar result"), | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| } |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test coverage is incomplete as it only tests scalar seed values. Tests should be added for columnar (array) inputs to verify that the function correctly generates different random values for each row in a batch. This is the primary use case when randn() is used in DataFrame queries.
| pub fn spark_randn(args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
| // Parse seed argument, or generate random seed if not provided | ||
| let seed: u64 = if args.is_empty() { | ||
| rand::random() | ||
| } else { | ||
| match &args[0] { | ||
| ColumnarValue::Scalar(ScalarValue::Int64(Some(s))) => *s as u64, | ||
| ColumnarValue::Scalar(ScalarValue::Int32(Some(s))) => *s as u64, | ||
| _ => rand::random(), | ||
| } | ||
| }; | ||
|
|
||
| let mut rng = StdRng::seed_from_u64(seed); | ||
| let value: f64 = StandardNormal.sample(&mut rng); | ||
|
|
||
| Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(value)))) | ||
| } |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation only handles scalar seed inputs but does not handle columnar (array) data properly. When randn() is used in a DataFrame context (e.g., df.select(randn(seed))), it should generate a different random value for each row in the dataset, not just return a single value.
The function needs to be updated to handle ColumnarValue::Array inputs where the seed could be an array, and generate an array of random values - one per row. This is critical for correct Spark compatibility when randn is used in SELECT queries over datasets with multiple rows.
| ///from the standard normal distribution | ||
| /// | ||
| /// - Takes an optional seed (i64) for reproducibility | ||
| /// - If no seed is provided, uses a random seed |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states "If no seed is provided, uses a random seed" which is misleading. In Spark's randn() function, when no seed is provided, each invocation should generate different random values (using a randomized seed per partition and row). The current implementation with a single random seed would produce the same value for all rows in a batch when the seed argument is absent, which does not match Spark's behavior.
| /// - If no seed is provided, uses a random seed | |
| /// - If no seed is provided, a random seed is chosen once per invocation, so all rows in the | |
| /// batch share the same value (unlike Spark's randn, which yields different values per row) |
Which issue does this PR close?
Closes #1891
Rationale for this change
This improves Spark function coverage in Auron.
What changes are included in this PR?
Adds support for the Spark randn function
Are there any user-facing changes?
Yes, it adds the randn function.
How was this patch tested?
Unit tests and [TODO] manual testing in spark-shell.