From ee922889df990f676443a721155aa597689773ce Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Sat, 31 Jan 2026 15:21:23 -0700 Subject: [PATCH 1/2] hybrid-array: add `Flatten` and `Unflatten` traits These were extracted from the `module-lattice` crate (which in turn was extracted from `ml-dsa`). They were contributed by @bifurcation and seem generally useful. I have extracted them largely verbatim except cleaning up a few things which were triggering lint failures. Having them as traits instead of inherent methods is nice because you can bound on them, where the bounds they actually require would be somewhat annoying to have downstream code repeat. --- src/flatten.rs | 139 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 ++- 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 src/flatten.rs diff --git a/src/flatten.rs b/src/flatten.rs new file mode 100644 index 0000000..ec30f02 --- /dev/null +++ b/src/flatten.rs @@ -0,0 +1,139 @@ +use crate::{ + Array, ArraySize, + typenum::{Prod, Quot, U0, Unsigned}, +}; +use core::{ + mem::ManuallyDrop, + ops::{Div, Mul, Rem}, + ptr, +}; + +/// Defines a sequence of sequences that can be merged into a bigger overall sequence. +pub trait Flatten { + /// Size of the output array. + type OutputSize: ArraySize; + + /// Flatten array. + fn flatten(self) -> Array; +} + +impl Flatten> for Array, N> +where + N: ArraySize, + M: ArraySize + Mul, + Prod: ArraySize, +{ + type OutputSize = Prod; + + // SAFETY: this is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed + // to be safe by the Rust memory layout of these types. + fn flatten(self) -> Array { + let whole = ManuallyDrop::new(self); + unsafe { ptr::read(whole.as_ptr().cast()) } + } +} + +/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size. +pub trait Unflatten +where + M: ArraySize, +{ + /// Part of the array we're decomposing into. + type Part; + + /// Unflatten array into `Self::Part` chunks. + fn unflatten(self) -> Array; +} + +impl Unflatten for Array +where + T: Default, + N: ArraySize + Div + Rem, + M: ArraySize, + Quot: ArraySize, +{ + type Part = Array>; + + // SAFETY: this is doing the same thing as what is done in `Array::split`. + // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to + // be safe by the Rust memory layout of these types. + fn unflatten(self) -> Array { + let part_size = Quot::::USIZE; + let whole = ManuallyDrop::new(self); + Array::from_fn(|i| unsafe { + let offset = i.checked_mul(part_size).expect("overflow"); + ptr::read(whole.as_ptr().add(offset).cast()) + }) + } +} + +impl<'a, T, N, M> Unflatten for &'a Array +where + T: Default, + N: ArraySize + Div + Rem, + M: ArraySize, + Quot: ArraySize, +{ + type Part = &'a Array>; + + // SAFETY: this is doing the same thing as what is done in `Array::split`. + // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to + // be safe by the Rust memory layout of these types. + fn unflatten(self) -> Array { + let part_size = Quot::::USIZE; + let mut ptr: *const T = self.as_ptr(); + Array::from_fn(|_i| unsafe { + let part = &*(ptr.cast()); + ptr = ptr.add(part_size); + part + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + Array, + sizes::{U2, U5}, + }; + + #[test] + fn flatten() { + let flat: Array = Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let unflat2: Array, _> = Array([ + Array([1, 2]), + Array([3, 4]), + Array([5, 6]), + Array([7, 8]), + Array([9, 10]), + ]); + let unflat5: Array, _> = + Array([Array([1, 2, 3, 4, 5]), Array([6, 7, 8, 9, 10])]); + + // Flatten + let actual = unflat2.flatten(); + assert_eq!(flat, actual); + + let actual = unflat5.flatten(); + assert_eq!(flat, actual); + + // Unflatten + let actual: Array, U5> = flat.unflatten(); + assert_eq!(unflat2, actual); + + let actual: Array, U2> = flat.unflatten(); + assert_eq!(unflat5, actual); + + // Unflatten on references + let actual: Array<&Array, U5> = (&flat).unflatten(); + for (i, part) in actual.iter().enumerate() { + assert_eq!(&unflat2[i], *part); + } + + let actual: Array<&Array, U2> = (&flat).unflatten(); + for (i, part) in actual.iter().enumerate() { + assert_eq!(&unflat5[i], *part); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 1d83b9f..d8b5405 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,7 @@ extern crate alloc; pub mod sizes; +mod flatten; mod from_fn; mod iter; mod traits; @@ -107,7 +108,11 @@ mod traits; #[cfg(feature = "serde")] mod serde; -pub use crate::{iter::TryFromIteratorError, traits::*}; +pub use crate::{ + flatten::{Flatten, Unflatten}, + iter::TryFromIteratorError, + traits::*, +}; pub use typenum; use core::{ From 61ab0dc2ffa42a43be711a049d30e41989a34aa1 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Sat, 31 Jan 2026 15:28:21 -0700 Subject: [PATCH 2/2] Remove unnecessary Default bounds --- src/flatten.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/flatten.rs b/src/flatten.rs index ec30f02..f3159b4 100644 --- a/src/flatten.rs +++ b/src/flatten.rs @@ -47,7 +47,6 @@ where impl Unflatten for Array where - T: Default, N: ArraySize + Div + Rem, M: ArraySize, Quot: ArraySize, @@ -69,7 +68,6 @@ where impl<'a, T, N, M> Unflatten for &'a Array where - T: Default, N: ArraySize + Div + Rem, M: ArraySize, Quot: ArraySize,