From 8e137e1e804152ab79b759c29c050738012efc80 Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Thu, 5 Feb 2026 21:28:26 +0200 Subject: [PATCH 1/9] MOD-13577 support Homogenues array floating point forcing(deserialization path only) --- Cargo.toml | 1 + src/array.rs | 176 +++++++++++++++++++++++++++++------- src/de.rs | 177 ++++++++++++++++++++++++++++++++++--- src/{alloc.rs => error.rs} | 14 +++ src/lib.rs | 6 +- 5 files changed, 325 insertions(+), 49 deletions(-) rename src/{alloc.rs => error.rs} (56%) diff --git a/Cargo.toml b/Cargo.toml index 9332c82..ed0021a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ serde_json = { workspace = true } ctor = { version = "0.1.16", optional = true } paste = "1.0.15" half = "2.0.0" +thiserror = "2.0.18" [dev-dependencies] mockalloc = "0.1.2" diff --git a/src/array.rs b/src/array.rs index 2cc0d70..4360d6d 100644 --- a/src/array.rs +++ b/src/array.rs @@ -9,8 +9,9 @@ use std::iter::FromIterator; use std::ops::{Index, IndexMut}; use std::slice::{from_raw_parts, from_raw_parts_mut, SliceIndex}; +use crate::error::IJsonError; use crate::{ - alloc::AllocError, + error::AllocError, thin::{ThinMut, ThinMutExt, ThinRef, ThinRefExt}, value::TypeTag, Defrag, DefragAllocator, IValue, @@ -54,6 +55,41 @@ impl Default for ArrayTag { } } +/// Enum representing different types of floating-point types +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum FloatType { + /// F16 + F16, + /// BF16 + BF16, + /// F32 + F32, + /// F64 + F64, +} + +impl fmt::Display for FloatType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FloatType::F16 => write!(f, "F16"), + FloatType::BF16 => write!(f, "BF16"), + FloatType::F32 => write!(f, "F32"), + FloatType::F64 => write!(f, "F64"), + } + } +} + +impl From for ArrayTag { + fn from(fp_type: FloatType) -> Self { + match fp_type { + FloatType::F16 => ArrayTag::F16, + FloatType::BF16 => ArrayTag::BF16, + FloatType::F32 => ArrayTag::F32, + FloatType::F64 => ArrayTag::F64, + } + } +} + impl ArrayTag { fn from_type() -> Self { use ArrayTag::*; @@ -182,14 +218,25 @@ impl ArrayTag { /// Determines the ArrayTag for an IValue if it represents a primitive type /// Prefers signed types over unsigned types for positive values to be more conservative fn from_ivalue(value: &IValue) -> ArrayTag { + Self::from_ivalue_with_hint(value, None) + } + + /// Determines the ArrayTag for an IValue, using the provided fp_type for floating-point types. + /// + /// When `fp_type` is `Some`, uses the hinted type directly for floating-point values. + fn from_ivalue_with_hint(value: &IValue, fp_type: Option) -> ArrayTag { use ArrayTag::*; if let Some(num) = value.as_number() { if num.has_decimal_point() { - num.to_f16() - .map(|_| F16) - .or_else(|| num.to_bf16().map(|_| BF16)) - .or_else(|| num.to_f32().map(|_| F32)) - .or_else(|| num.to_f64().map(|_| F64)) + fp_type.map(ArrayTag::from).unwrap_or_else(|| { + num.to_f16() + .map(|_| F16) + .or_else(|| num.to_bf16().map(|_| BF16)) + .or_else(|| num.to_f32().map(|_| F32)) + .or_else(|| num.to_f64().map(|_| F64)) + // Safety: We know the value is a decimal number, and f64 can represent any JSON number + .unwrap_or_else(|| unsafe { std::hint::unreachable_unchecked() }) + }) } else { num.to_i8() .map(|_| I8) @@ -200,9 +247,9 @@ impl ArrayTag { .or_else(|| num.to_u32().map(|_| U32)) .or_else(|| num.to_i64().map(|_| I64)) .or_else(|| num.to_u64().map(|_| U64)) + // Safety: We know the value is a number, and we've checked all possible number types + .unwrap_or_else(|| unsafe { std::hint::unreachable_unchecked() }) } - // Safety: We know the value is a number, and we've checked all possible number types - .unwrap_or_else(|| unsafe { std::hint::unreachable_unchecked() }) } else { Heterogeneous } @@ -401,11 +448,11 @@ impl Header { const TAG_MASK: u64 = 0xF; const TAG_SHIFT: u64 = 60; - const fn new(len: usize, cap: usize, tag: ArrayTag) -> Result { + const fn new(len: usize, cap: usize, tag: ArrayTag) -> Result { // assert!(len <= Self::LEN_MASK as usize, "Length exceeds 30-bit limit"); // assert!(cap <= Self::CAP_MASK as usize, "Capacity exceeds 30-bit limit"); if len > Self::LEN_MASK as usize || cap > Self::CAP_MASK as usize { - return Err(AllocError); + return Err(IJsonError::Alloc(AllocError)); } let packed = ((len as u64) & Self::LEN_MASK) << Self::LEN_SHIFT @@ -670,7 +717,7 @@ impl IArray { .pad_to_align()) } - fn alloc(cap: usize, tag: ArrayTag) -> Result<*mut Header, AllocError> { + fn alloc(cap: usize, tag: ArrayTag) -> Result<*mut Header, IJsonError> { unsafe { let ptr = alloc(Self::layout(cap, tag).map_err(|_| AllocError)?).cast::
(); ptr.write(Header::new(0, cap, tag)?); @@ -678,7 +725,7 @@ impl IArray { } } - fn realloc(ptr: *mut Header, new_cap: usize) -> Result<*mut Header, AllocError> { + fn realloc(ptr: *mut Header, new_cap: usize) -> Result<*mut Header, IJsonError> { unsafe { let tag = (*ptr).type_tag(); let old_layout = Self::layout((*ptr).cap(), tag).map_err(|_| AllocError)?; @@ -706,13 +753,13 @@ impl IArray { /// Constructs a new `IArray` with the specified capacity. At least that many items /// can be added to the array without reallocating. #[must_use] - pub fn with_capacity(cap: usize) -> Result { + pub fn with_capacity(cap: usize) -> Result { Self::with_capacity_and_tag(cap, ArrayTag::Heterogeneous) } /// Constructs a new `IArray` with the specified capacity and array type. #[must_use] - fn with_capacity_and_tag(cap: usize, tag: ArrayTag) -> Result { + fn with_capacity_and_tag(cap: usize, tag: ArrayTag) -> Result { if cap == 0 { Ok(Self::new()) } else { @@ -743,7 +790,7 @@ impl IArray { /// Converts this array to a new type, promoting all existing elements. /// This is used for automatic type promotion when incompatible types are added. - fn promote_to_type(&mut self, new_tag: ArrayTag) -> Result<(), AllocError> { + fn promote_to_type(&mut self, new_tag: ArrayTag) -> Result<(), IJsonError> { if self.is_static() || self.header().type_tag() == new_tag { return Ok(()); } @@ -898,7 +945,7 @@ impl IArray { self.header_mut().as_mut_slice_unchecked::() } - fn resize_internal(&mut self, cap: usize) -> Result<(), AllocError> { + fn resize_internal(&mut self, cap: usize) -> Result<(), IJsonError> { if self.is_static() || cap == 0 { let tag = if self.is_static() { ArrayTag::Heterogeneous @@ -916,7 +963,7 @@ impl IArray { } /// Reserves space for at least this many additional items. - pub fn reserve(&mut self, additional: usize) -> Result<(), AllocError> { + pub fn reserve(&mut self, additional: usize) -> Result<(), IJsonError> { let hd = self.header(); let current_capacity = hd.cap(); let desired_capacity = hd.len().checked_add(additional).ok_or(AllocError)?; @@ -956,7 +1003,7 @@ impl IArray { /// on or after this index will be shifted down to accomodate this. For large /// arrays, insertions near the front will be slow as it will require shifting /// a large number of items. - pub fn insert(&mut self, index: usize, item: impl Into) -> Result<(), AllocError> { + pub fn insert(&mut self, index: usize, item: impl Into) -> Result<(), IJsonError> { let item = item.into(); let current_tag = self.header().type_tag(); let len = self.len(); @@ -1080,8 +1127,49 @@ impl IArray { } } + /// Pushes a new item onto the back of the array with a specific floating-point type. + /// + /// If the item cannot be represented in the specified floating-point type, + /// returns an error. + pub(crate) fn push_with_fp_type( + &mut self, + item: impl Into, + fp_type: FloatType, + ) -> Result<(), IJsonError> { + let desired_tag = fp_type.into(); + let current_tag = self.header().type_tag(); + let len = self.len(); + let item = item.into(); + let can_fit = || match fp_type { + FloatType::F16 => item.to_f16().is_some(), + FloatType::BF16 => item.to_bf16().is_some(), + FloatType::F32 => item.to_f32().is_some(), + FloatType::F64 => item.to_f64().is_some(), + }; + + if (desired_tag != current_tag && len > 0) || !can_fit() { + return Err(IJsonError::OutOfRange(fp_type)); + } + + // We can fit the item into the array, so we can push it directly + + if len == 0 { + if self.is_static() { + *self = IArray::with_capacity_and_tag(4, desired_tag)?; + } else { + self.promote_to_type(desired_tag)?; + } + } + + self.reserve(1)?; + unsafe { + self.header_mut().push(item); + } + Ok(()) + } + /// Pushes a new item onto the back of the array. - pub fn push(&mut self, item: impl Into) -> Result<(), AllocError> { + pub fn push(&mut self, item: impl Into) -> Result<(), IJsonError> { let item = item.into(); let current_tag = self.header().type_tag(); let len = self.len(); @@ -1425,11 +1513,11 @@ pub trait TryExtend { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns an `AllocError` if memory allocation fails during the extension. - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError>; + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError>; } impl + private::Sealed> TryExtend for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let iter = iter.into_iter(); self.reserve(iter.size_hint().0)?; for v in iter { @@ -1442,7 +1530,7 @@ impl + private::Sealed> TryExtend for IArray { macro_rules! extend_impl_int { ($($ty:ty),*) => { $(impl TryExtend<$ty> for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let expected_tag = ArrayTag::from_type::<$ty>(); let iter = iter.into_iter(); let size_hint = iter.size_hint().0; @@ -1494,7 +1582,7 @@ macro_rules! extend_impl_int { macro_rules! extend_impl_float { ($($ty:ty),*) => { $(impl TryExtend<$ty> for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let expected_tag = ArrayTag::from_type::<$ty>(); let iter = iter.into_iter(); let size_hint = iter.size_hint().0; @@ -1564,13 +1652,13 @@ pub trait TryFromIterator { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns `AllocError` if memory allocation fails during the construction. - fn try_from_iter>(iter: U) -> Result + fn try_from_iter>(iter: U) -> Result where Self: Sized; } impl + private::Sealed> TryFromIterator for IArray { - fn try_from_iter>(iter: T) -> Result { + fn try_from_iter>(iter: T) -> Result { let mut res = IArray::new(); res.try_extend(iter)?; Ok(res) @@ -1580,7 +1668,7 @@ impl + private::Sealed> TryFromIterator for IArray { macro_rules! from_iter_impl { ($($ty:ty),*) => { $(impl TryFromIterator<$ty> for IArray { - fn try_from_iter>(iter: T) -> Result { + fn try_from_iter>(iter: T) -> Result { let iter = iter.into_iter(); let mut res = IArray::with_capacity_and_tag(iter.size_hint().0, ArrayTag::from_type::<$ty>())?; res.try_extend(iter)?; @@ -1599,13 +1687,13 @@ pub trait TryCollect: Iterator + Sized { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns `AllocError` if memory allocation fails during the collection. - fn try_collect(self) -> Result + fn try_collect(self) -> Result where B: TryFromIterator; } impl> TryCollect for I { - fn try_collect(self) -> Result + fn try_collect(self) -> Result where B: TryFromIterator, { @@ -1614,7 +1702,7 @@ impl> TryCollect for I { } impl + private::Sealed> TryFrom> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: Vec) -> Result { let mut res = IArray::with_capacity(other.len())?; res.try_extend(other.into_iter().map(Into::into))?; @@ -1623,7 +1711,7 @@ impl + private::Sealed> TryFrom> for IArray { } impl + Clone + private::Sealed> TryFrom<&[T]> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: &[T]) -> Result { let mut res = IArray::with_capacity(other.len())?; res.try_extend(other.iter().cloned().map(Into::into))?; @@ -1634,7 +1722,7 @@ impl + Clone + private::Sealed> TryFrom<&[T]> for IArray { macro_rules! from_slice_impl { ($($ty:ty),*) => {$( impl TryFrom> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: Vec<$ty>) -> Result { let mut res = IArray::with_capacity_and_tag(other.len(), ArrayTag::from_type::<$ty>())?; TryExtend::<$ty>::try_extend(&mut res, other.into_iter().map(Into::into))?; @@ -1642,7 +1730,7 @@ macro_rules! from_slice_impl { } } impl TryFrom<&[$ty]> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: &[$ty]) -> Result { let mut res = IArray::with_capacity_and_tag(other.len(), ArrayTag::from_type::<$ty>())?; TryExtend::<$ty>::try_extend(&mut res, other.iter().cloned().map(Into::into))?; @@ -3207,4 +3295,28 @@ mod tests { } } } + + #[test] + fn test_push_with_fp_type_creates_typed_array() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + arr.push_with_fp_type(IValue::from(2.5), FloatType::F16) + .unwrap(); + + assert_eq!(arr.len(), 2); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + } + + #[test] + fn test_push_with_fp_type_overflow() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + arr.push_with_fp_type(IValue::from(100000.0), FloatType::F16) + .unwrap_err(); + assert_eq!(arr.len(), 1); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + } } diff --git a/src/de.rs b/src/de.rs index 4eabe26..1f1faa9 100644 --- a/src/de.rs +++ b/src/de.rs @@ -8,14 +8,40 @@ use serde::de::{ use serde::{forward_to_deserialize_any, Deserialize, Deserializer}; use serde_json::error::Error; -use crate::{DestructuredRef, IArray, INumber, IObject, IString, IValue}; +use crate::{DestructuredRef, FloatType, IArray, INumber, IObject, IString, IValue}; + +/// Seed for deserializing an [`IValue`]. +#[derive(Debug)] +pub struct IValueDeserSeed { + /// Optional floating point type enforcment type for homogeneous arrays. + pub fpha_type: Option, +} + +impl IValueDeserSeed { + /// Creates a new [`IValueDeserSeed`] with the given floating point type enforcment type for homogeneous arrays. + pub fn new(fpha_type: Option) -> Self { + IValueDeserSeed { fpha_type } + } +} + +impl<'de> DeserializeSeed<'de> for IValueDeserSeed { + type Value = IValue; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Pass hint to a custom visitor + deserializer.deserialize_any(ValueVisitor::new(self.fpha_type)) + } +} impl<'de> Deserialize<'de> for IValue { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - deserializer.deserialize_any(ValueVisitor) + deserializer.deserialize_any(ValueVisitor::new(None)) } } @@ -42,7 +68,7 @@ impl<'de> Deserialize<'de> for IArray { where D: Deserializer<'de>, { - deserializer.deserialize_seq(ArrayVisitor) + deserializer.deserialize_seq(ArrayVisitor { fpha_type: None }) } } @@ -51,11 +77,19 @@ impl<'de> Deserialize<'de> for IObject { where D: Deserializer<'de>, { - deserializer.deserialize_map(ObjectVisitor) + deserializer.deserialize_map(ObjectVisitor { fpha_type: None }) } } -struct ValueVisitor; +struct ValueVisitor { + fpha_type: Option, +} + +impl ValueVisitor { + fn new(fpha_type: Option) -> Self { + ValueVisitor { fpha_type } + } +} impl<'de> Visitor<'de> for ValueVisitor { type Value = IValue; @@ -104,7 +138,7 @@ impl<'de> Visitor<'de> for ValueVisitor { where D: Deserializer<'de>, { - Deserialize::deserialize(deserializer) + IValueDeserSeed::new(self.fpha_type).deserialize(deserializer) } #[inline] @@ -117,14 +151,22 @@ impl<'de> Visitor<'de> for ValueVisitor { where V: SeqAccess<'de>, { - ArrayVisitor.visit_seq(visitor).map(Into::into) + ArrayVisitor { + fpha_type: self.fpha_type, + } + .visit_seq(visitor) + .map(Into::into) } fn visit_map(self, visitor: V) -> Result where V: MapAccess<'de>, { - ObjectVisitor.visit_map(visitor).map(Into::into) + ObjectVisitor { + fpha_type: self.fpha_type, + } + .visit_map(visitor) + .map(Into::into) } } @@ -192,7 +234,9 @@ impl<'de> Visitor<'de> for StringVisitor { } } -struct ArrayVisitor; +struct ArrayVisitor { + fpha_type: Option, +} impl<'de> Visitor<'de> for ArrayVisitor { type Value = IArray; @@ -208,15 +252,20 @@ impl<'de> Visitor<'de> for ArrayVisitor { { let mut arr = IArray::with_capacity(visitor.size_hint().unwrap_or(0)) .map_err(|_| SError::custom("Failed to allocate array"))?; - while let Some(v) = visitor.next_element::()? { - arr.push(v) - .map_err(|_| SError::custom("Failed to push to array"))?; + while let Some(v) = visitor.next_element_seed(IValueDeserSeed::new(self.fpha_type))? { + match self.fpha_type { + Some(fp_type) => arr.push_with_fp_type(v, fp_type), + None => arr.push(v), + } + .map_err(|e| SError::custom(e.to_string()))?; } Ok(arr) } } -struct ObjectVisitor; +struct ObjectVisitor { + fpha_type: Option, +} impl<'de> Visitor<'de> for ObjectVisitor { type Value = IObject; @@ -230,7 +279,8 @@ impl<'de> Visitor<'de> for ObjectVisitor { V: MapAccess<'de>, { let mut obj = IObject::with_capacity(visitor.size_hint().unwrap_or(0)); - while let Some((k, v)) = visitor.next_entry::()? { + while let Some(k) = visitor.next_key::()? { + let v = visitor.next_value_seed(IValueDeserSeed::new(self.fpha_type))?; obj.insert(k, v); } Ok(obj) @@ -999,3 +1049,102 @@ where { T::deserialize(value) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::ArraySliceRef; + use serde::de::DeserializeSeed; + + #[test] + fn test_deserialize_with_f64_fp() { + let json = r#"[1.5, 2.5, 3.5]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F64)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_f32_fp() { + let json = r#"[1.5, 2.5, 3.5]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_f16_fp() { + let json = r#"[0.5, 1.0, 1.5]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F16)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_bf16_fp() { + let json = r#"[0.5, 1.0, 2.0]"#; + let seed = IValueDeserSeed::new(Some(FloatType::BF16)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::BF16(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_mixed_array_with_fp() { + let json = r#"[1, "string", 3.5]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_integer_array_with_fp() { + let json = r#"[1, 2, 3]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_f16_value_no_fit() { + let json = r#"[0.5, 100000.0, 1.5]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F16)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_bf16_value_too_large() { + let json = r#"[1e39, 2e39]"#; + let seed = IValueDeserSeed::new(Some(FloatType::BF16)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_f32_value_too_large() { + let json = r#"[1e39, 2e39]"#; + let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } +} diff --git a/src/alloc.rs b/src/error.rs similarity index 56% rename from src/alloc.rs rename to src/error.rs index af0c87d..d5b9f36 100644 --- a/src/alloc.rs +++ b/src/error.rs @@ -2,6 +2,9 @@ use std::error::Error; use std::fmt; +use thiserror::Error; + +use crate::FloatType; /// Error type for fallible allocation /// This error is returned when an allocation fails. @@ -16,3 +19,14 @@ impl fmt::Display for AllocError { f.write_str("memory allocation failed") } } + +/// Error type for ijson +#[derive(Error, Debug)] +pub enum IJsonError { + /// Memory allocation failed + #[error("memory allocation failed")] + Alloc(#[from] AllocError), + /// Value out of range for the specified floating-point type + #[error("value out of range for {0}")] + OutOfRange(FloatType), +} diff --git a/src/lib.rs b/src/lib.rs index afcff7b..12850fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,11 +36,11 @@ pub mod unsafe_string; #[cfg(not(feature = "thread_safe"))] pub use unsafe_string::IString; -pub mod alloc; +pub mod error; mod thin; mod value; -pub use array::IArray; +pub use array::{FloatType, IArray}; pub use number::INumber; pub use object::IObject; use std::alloc::Layout; @@ -51,7 +51,7 @@ pub use value::{ mod de; mod ser; -pub use de::from_value; +pub use de::{from_value, IValueDeserSeed}; pub use ser::to_value; /// Trait to implement defrag allocator From bfd678b22cb3f597c06bc03fff5403519a1673b6 Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 11:42:30 +0200 Subject: [PATCH 2/9] add fallback option --- src/de.rs | 125 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 95 insertions(+), 30 deletions(-) diff --git a/src/de.rs b/src/de.rs index 1f1faa9..1eb3681 100644 --- a/src/de.rs +++ b/src/de.rs @@ -8,19 +8,51 @@ use serde::de::{ use serde::{forward_to_deserialize_any, Deserialize, Deserializer}; use serde_json::error::Error; +use crate::error::IJsonError; use crate::{DestructuredRef, FloatType, IArray, INumber, IObject, IString, IValue}; +#[derive(Debug, Clone, Copy)] +pub struct FPHAConfig { + /// Floating point type for homogeneous arrays. + pub fpha_type: FloatType, + /// If `fallback` is true, arrays that don't fit the fpha_type will fall back to regular push. + pub fpha_fallback: bool, +} + +impl FPHAConfig { + /// Creates a new [`FPHAConfig`] with the given floating point type. + pub fn new(fpha_type: FloatType, fpha_fallback: bool) -> Self { + Self { + fpha_type, + fpha_fallback, + } + } + + pub fn new_with_type(fpha_type: FloatType) -> Self { + Self { + fpha_type, + fpha_fallback: false, + } + } + + /// Sets the fallback behavior. + pub fn with_fallback(mut self, fallback: bool) -> Self { + self.fpha_fallback = fallback; + self + } +} + /// Seed for deserializing an [`IValue`]. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct IValueDeserSeed { - /// Optional floating point type enforcment type for homogeneous arrays. - pub fpha_type: Option, + /// Optional FPHA configuration for homogeneous arrays. + pub fpha_config: Option, } impl IValueDeserSeed { /// Creates a new [`IValueDeserSeed`] with the given floating point type enforcment type for homogeneous arrays. - pub fn new(fpha_type: Option) -> Self { - IValueDeserSeed { fpha_type } + pub fn new(fpha_config: Option) -> Self { + IValueDeserSeed { fpha_config } } } @@ -32,7 +64,7 @@ impl<'de> DeserializeSeed<'de> for IValueDeserSeed { D: Deserializer<'de>, { // Pass hint to a custom visitor - deserializer.deserialize_any(ValueVisitor::new(self.fpha_type)) + deserializer.deserialize_any(ValueVisitor::new(self.fpha_config)) } } @@ -68,7 +100,7 @@ impl<'de> Deserialize<'de> for IArray { where D: Deserializer<'de>, { - deserializer.deserialize_seq(ArrayVisitor { fpha_type: None }) + deserializer.deserialize_seq(ArrayVisitor { fpha_config: None }) } } @@ -77,17 +109,17 @@ impl<'de> Deserialize<'de> for IObject { where D: Deserializer<'de>, { - deserializer.deserialize_map(ObjectVisitor { fpha_type: None }) + deserializer.deserialize_map(ObjectVisitor { fpha_config: None }) } } struct ValueVisitor { - fpha_type: Option, + fpha_config: Option, } impl ValueVisitor { - fn new(fpha_type: Option) -> Self { - ValueVisitor { fpha_type } + fn new(fpha_config: Option) -> Self { + ValueVisitor { fpha_config } } } @@ -138,7 +170,7 @@ impl<'de> Visitor<'de> for ValueVisitor { where D: Deserializer<'de>, { - IValueDeserSeed::new(self.fpha_type).deserialize(deserializer) + IValueDeserSeed::new(self.fpha_config).deserialize(deserializer) } #[inline] @@ -152,7 +184,7 @@ impl<'de> Visitor<'de> for ValueVisitor { V: SeqAccess<'de>, { ArrayVisitor { - fpha_type: self.fpha_type, + fpha_config: self.fpha_config, } .visit_seq(visitor) .map(Into::into) @@ -163,7 +195,7 @@ impl<'de> Visitor<'de> for ValueVisitor { V: MapAccess<'de>, { ObjectVisitor { - fpha_type: self.fpha_type, + fpha_config: self.fpha_config, } .visit_map(visitor) .map(Into::into) @@ -235,7 +267,7 @@ impl<'de> Visitor<'de> for StringVisitor { } struct ArrayVisitor { - fpha_type: Option, + fpha_config: Option, } impl<'de> Visitor<'de> for ArrayVisitor { @@ -252,10 +284,16 @@ impl<'de> Visitor<'de> for ArrayVisitor { { let mut arr = IArray::with_capacity(visitor.size_hint().unwrap_or(0)) .map_err(|_| SError::custom("Failed to allocate array"))?; - while let Some(v) = visitor.next_element_seed(IValueDeserSeed::new(self.fpha_type))? { - match self.fpha_type { - Some(fp_type) => arr.push_with_fp_type(v, fp_type), - None => arr.push(v), + while let Some(v) = visitor.next_element_seed(IValueDeserSeed::new(self.fpha_config))? { + match self.fpha_config.map(|c| (c.fpha_type, c.fpha_fallback)) { + Some((fp_type, fallback)) => { + arr.push_with_fp_type(v.clone(), fp_type) + .or_else(|_| match self.fpha_config { + Some(c) if fallback => arr.push(v), + _ => Err(IJsonError::OutOfRange(fp_type)), + }) + } + None => arr.push(v).map_err(Into::into), } .map_err(|e| SError::custom(e.to_string()))?; } @@ -264,7 +302,7 @@ impl<'de> Visitor<'de> for ArrayVisitor { } struct ObjectVisitor { - fpha_type: Option, + fpha_config: Option, } impl<'de> Visitor<'de> for ObjectVisitor { @@ -280,7 +318,7 @@ impl<'de> Visitor<'de> for ObjectVisitor { { let mut obj = IObject::with_capacity(visitor.size_hint().unwrap_or(0)); while let Some(k) = visitor.next_key::()? { - let v = visitor.next_value_seed(IValueDeserSeed::new(self.fpha_type))?; + let v = visitor.next_value_seed(IValueDeserSeed::new(self.fpha_config))?; obj.insert(k, v); } Ok(obj) @@ -1059,7 +1097,7 @@ mod tests { #[test] fn test_deserialize_with_f64_fp() { let json = r#"[1.5, 2.5, 3.5]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F64)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F64))); let mut deserializer = serde_json::Deserializer::from_str(json); let value = seed.deserialize(&mut deserializer).unwrap(); @@ -1071,7 +1109,7 @@ mod tests { #[test] fn test_deserialize_with_f32_fp() { let json = r#"[1.5, 2.5, 3.5]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); let mut deserializer = serde_json::Deserializer::from_str(json); let value = seed.deserialize(&mut deserializer).unwrap(); @@ -1083,7 +1121,7 @@ mod tests { #[test] fn test_deserialize_with_f16_fp() { let json = r#"[0.5, 1.0, 1.5]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F16)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); let mut deserializer = serde_json::Deserializer::from_str(json); let value = seed.deserialize(&mut deserializer).unwrap(); @@ -1095,7 +1133,7 @@ mod tests { #[test] fn test_deserialize_with_bf16_fp() { let json = r#"[0.5, 1.0, 2.0]"#; - let seed = IValueDeserSeed::new(Some(FloatType::BF16)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); let mut deserializer = serde_json::Deserializer::from_str(json); let value = seed.deserialize(&mut deserializer).unwrap(); @@ -1107,7 +1145,7 @@ mod tests { #[test] fn test_deserialize_mixed_array_with_fp() { let json = r#"[1, "string", 3.5]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); } @@ -1115,7 +1153,7 @@ mod tests { #[test] fn test_deserialize_integer_array_with_fp() { let json = r#"[1, 2, 3]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); let mut deserializer = serde_json::Deserializer::from_str(json); let value = seed.deserialize(&mut deserializer).unwrap(); @@ -1127,24 +1165,51 @@ mod tests { #[test] fn test_deserialize_f16_value_no_fit() { let json = r#"[0.5, 100000.0, 1.5]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F16)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); + + let seed = IValueDeserSeed::new(Some( + FPHAConfig::new_with_type(FloatType::F16).with_fallback(true), + )); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); } #[test] fn test_deserialize_bf16_value_too_large() { let json = r#"[1e39, 2e39]"#; - let seed = IValueDeserSeed::new(Some(FloatType::BF16)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); + + let seed = IValueDeserSeed::new(Some( + FPHAConfig::new_with_type(FloatType::BF16).with_fallback(true), + )); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); + assert_eq!(arr.len(), 2); } #[test] fn test_deserialize_f32_value_too_large() { let json = r#"[1e39, 2e39]"#; - let seed = IValueDeserSeed::new(Some(FloatType::F32)); + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); + + let seed = IValueDeserSeed::new(Some( + FPHAConfig::new_with_type(FloatType::F32).with_fallback(true), + )); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); + assert_eq!(arr.len(), 2); } } From 2079a6851e8783ddc66ad7fbbe418df07f591b4b Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 11:43:53 +0200 Subject: [PATCH 3/9] export FPHAConfig --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 12850fe..d4fabe6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ pub use value::{ mod de; mod ser; -pub use de::{from_value, IValueDeserSeed}; +pub use de::{from_value, IValueDeserSeed, FPHAConfig}; pub use ser::to_value; /// Trait to implement defrag allocator From cec89cfe4d7dd235679731af50cdfb5a91fe0349 Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 11:47:29 +0200 Subject: [PATCH 4/9] docs --- src/de.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/de.rs b/src/de.rs index 1eb3681..f6e510b 100644 --- a/src/de.rs +++ b/src/de.rs @@ -12,6 +12,7 @@ use crate::error::IJsonError; use crate::{DestructuredRef, FloatType, IArray, INumber, IObject, IString, IValue}; #[derive(Debug, Clone, Copy)] +/// Configuration for floating point homogeneous arrays. pub struct FPHAConfig { /// Floating point type for homogeneous arrays. pub fpha_type: FloatType, @@ -28,6 +29,7 @@ impl FPHAConfig { } } + /// Creates a new [`FPHAConfig`] with the given floating point type and fallback behavior. pub fn new_with_type(fpha_type: FloatType) -> Self { Self { fpha_type, From bc60e68506c7ca3aa87d0d3c9f50b5881c6e82aa Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 11:53:38 +0200 Subject: [PATCH 5/9] fmt --- src/array.rs | 2 +- src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array.rs b/src/array.rs index 4360d6d..6db4d30 100644 --- a/src/array.rs +++ b/src/array.rs @@ -59,7 +59,7 @@ impl Default for ArrayTag { #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum FloatType { /// F16 - F16, + F16 = 1, /// BF16 BF16, /// F32 diff --git a/src/lib.rs b/src/lib.rs index d4fabe6..86e8255 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ pub use value::{ mod de; mod ser; -pub use de::{from_value, IValueDeserSeed, FPHAConfig}; +pub use de::{from_value, FPHAConfig, IValueDeserSeed}; pub use ser::to_value; /// Trait to implement defrag allocator From 931d520d5d32f19288f09263136293ee5c7a75dd Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 12:18:18 +0200 Subject: [PATCH 6/9] comments --- src/array.rs | 14 ++++++++++++++ src/de.rs | 18 +++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/array.rs b/src/array.rs index 6db4d30..8768418 100644 --- a/src/array.rs +++ b/src/array.rs @@ -79,6 +79,20 @@ impl fmt::Display for FloatType { } } +impl TryFrom for FloatType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(FloatType::F16), + 2 => Ok(FloatType::BF16), + 3 => Ok(FloatType::F32), + 4 => Ok(FloatType::F64), + _ => Err(()), + } + } +} + impl From for ArrayTag { fn from(fp_type: FloatType) -> Self { match fp_type { diff --git a/src/de.rs b/src/de.rs index f6e510b..9b3cf37 100644 --- a/src/de.rs +++ b/src/de.rs @@ -8,7 +8,6 @@ use serde::de::{ use serde::{forward_to_deserialize_any, Deserialize, Deserializer}; use serde_json::error::Error; -use crate::error::IJsonError; use crate::{DestructuredRef, FloatType, IArray, INumber, IObject, IString, IValue}; #[derive(Debug, Clone, Copy)] @@ -287,14 +286,15 @@ impl<'de> Visitor<'de> for ArrayVisitor { let mut arr = IArray::with_capacity(visitor.size_hint().unwrap_or(0)) .map_err(|_| SError::custom("Failed to allocate array"))?; while let Some(v) = visitor.next_element_seed(IValueDeserSeed::new(self.fpha_config))? { - match self.fpha_config.map(|c| (c.fpha_type, c.fpha_fallback)) { - Some((fp_type, fallback)) => { - arr.push_with_fp_type(v.clone(), fp_type) - .or_else(|_| match self.fpha_config { - Some(c) if fallback => arr.push(v), - _ => Err(IJsonError::OutOfRange(fp_type)), - }) - } + // Matching Some(..) twice, to avoind cloning the value :/ + match self.fpha_config { + Some(FPHAConfig { + fpha_type, + fpha_fallback: true, + }) => arr + .push_with_fp_type(v.clone(), fpha_type) + .or_else(|_| arr.push(v).map_err(Into::into)), + Some(FPHAConfig { fpha_type, .. }) => arr.push_with_fp_type(v, fpha_type), None => arr.push(v).map_err(Into::into), } .map_err(|e| SError::custom(e.to_string()))?; From 3ba050f4ff630dae593b3f6aa83f6bd6e04132c3 Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 12:27:01 +0200 Subject: [PATCH 7/9] lower fuzz time --- .github/actions/fuzz_tests/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/fuzz_tests/action.yml b/.github/actions/fuzz_tests/action.yml index 95c7b09..16774d5 100644 --- a/.github/actions/fuzz_tests/action.yml +++ b/.github/actions/fuzz_tests/action.yml @@ -8,7 +8,7 @@ inputs: fuzz_time: description: 'Maximum time in seconds to run fuzzing' required: false - default: '180' + default: '120' cargo_fuzz_version: description: 'Version of cargo-fuzz to install' required: false From 8cd9ddffa9b87463dcc8d3977a77db6d4a1f58c3 Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 12:43:47 +0200 Subject: [PATCH 8/9] bring back old code --- src/array.rs | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/array.rs b/src/array.rs index 8768418..07f422a 100644 --- a/src/array.rs +++ b/src/array.rs @@ -232,25 +232,14 @@ impl ArrayTag { /// Determines the ArrayTag for an IValue if it represents a primitive type /// Prefers signed types over unsigned types for positive values to be more conservative fn from_ivalue(value: &IValue) -> ArrayTag { - Self::from_ivalue_with_hint(value, None) - } - - /// Determines the ArrayTag for an IValue, using the provided fp_type for floating-point types. - /// - /// When `fp_type` is `Some`, uses the hinted type directly for floating-point values. - fn from_ivalue_with_hint(value: &IValue, fp_type: Option) -> ArrayTag { use ArrayTag::*; if let Some(num) = value.as_number() { if num.has_decimal_point() { - fp_type.map(ArrayTag::from).unwrap_or_else(|| { - num.to_f16() - .map(|_| F16) - .or_else(|| num.to_bf16().map(|_| BF16)) - .or_else(|| num.to_f32().map(|_| F32)) - .or_else(|| num.to_f64().map(|_| F64)) - // Safety: We know the value is a decimal number, and f64 can represent any JSON number - .unwrap_or_else(|| unsafe { std::hint::unreachable_unchecked() }) - }) + num.to_f16() + .map(|_| F16) + .or_else(|| num.to_bf16().map(|_| BF16)) + .or_else(|| num.to_f32().map(|_| F32)) + .or_else(|| num.to_f64().map(|_| F64)) } else { num.to_i8() .map(|_| I8) @@ -261,9 +250,9 @@ impl ArrayTag { .or_else(|| num.to_u32().map(|_| U32)) .or_else(|| num.to_i64().map(|_| I64)) .or_else(|| num.to_u64().map(|_| U64)) - // Safety: We know the value is a number, and we've checked all possible number types - .unwrap_or_else(|| unsafe { std::hint::unreachable_unchecked() }) } + // Safety: We know the value is a number, and we've checked all possible number types + .unwrap_or_else(|| unsafe { std::hint::unreachable_unchecked() }) } else { Heterogeneous } From 3cfee7d014a39a3b71d5817ab82dfebc07ad8710 Mon Sep 17 00:00:00 2001 From: avivdavid23 Date: Mon, 9 Feb 2026 15:10:47 +0200 Subject: [PATCH 9/9] change to lossy push --- src/array.rs | 37 ++++++++++++++++++++++++--------- src/de.rs | 58 ++++++++++++++++++++++++++-------------------------- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/src/array.rs b/src/array.rs index 07f422a..8d725c8 100644 --- a/src/array.rs +++ b/src/array.rs @@ -611,6 +611,26 @@ trait HeaderMut<'a>: ThinMutExt<'a, Header> { self.set_len(index + 1); } + // Safety: Space must already be allocated for the item, + // and the item must be a number. The array type must be a floating-point type. + unsafe fn push_lossy(&mut self, item: IValue) { + use ArrayTag::*; + let index = self.len(); + + macro_rules! push_lossy_impl { + ($(($tag:ident, $ty:ty)),*) => { + match self.type_tag() { + $($tag => self.reborrow().raw_array_ptr_mut().cast::<$ty>().add(index).write( + paste::paste!(item.[]()).unwrap()),)* + _ => unreachable!(), + } + } + } + + push_lossy_impl!((F16, f16), (BF16, bf16), (F32, f32), (F64, f64)); + self.set_len(index + 1); + } + fn pop(&mut self) -> Option { if self.len() == 0 { None @@ -1130,10 +1150,7 @@ impl IArray { } } - /// Pushes a new item onto the back of the array with a specific floating-point type. - /// - /// If the item cannot be represented in the specified floating-point type, - /// returns an error. + /// Pushes a new item onto the back of the array with a specific floating-point type, potentially losing precision. pub(crate) fn push_with_fp_type( &mut self, item: impl Into, @@ -1144,10 +1161,10 @@ impl IArray { let len = self.len(); let item = item.into(); let can_fit = || match fp_type { - FloatType::F16 => item.to_f16().is_some(), - FloatType::BF16 => item.to_bf16().is_some(), - FloatType::F32 => item.to_f32().is_some(), - FloatType::F64 => item.to_f64().is_some(), + FloatType::F16 => item.to_f16_lossy().map_or(false, |v| v.is_finite()), + FloatType::BF16 => item.to_bf16_lossy().map_or(false, |v| v.is_finite()), + FloatType::F32 => item.to_f32_lossy().map_or(false, |v| v.is_finite()), + FloatType::F64 => item.to_f64_lossy().map_or(false, |v| v.is_finite()), }; if (desired_tag != current_tag && len > 0) || !can_fit() { @@ -1166,7 +1183,7 @@ impl IArray { self.reserve(1)?; unsafe { - self.header_mut().push(item); + self.header_mut().push_lossy(item); } Ok(()) } @@ -3312,7 +3329,7 @@ mod tests { } #[test] - fn test_push_with_fp_type_overflow() { + fn test_push_with_fp_type_overflow_rejected() { let mut arr = IArray::new(); arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) .unwrap(); diff --git a/src/de.rs b/src/de.rs index 9b3cf37..c7d2926 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1165,53 +1165,53 @@ mod tests { } #[test] - fn test_deserialize_f16_value_no_fit() { + fn test_deserialize_f16_value_overflow_rejected() { let json = r#"[0.5, 100000.0, 1.5]"#; let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); - - let seed = IValueDeserSeed::new(Some( - FPHAConfig::new_with_type(FloatType::F16).with_fallback(true), - )); - let mut deserializer = serde_json::Deserializer::from_str(json); - let value = seed.deserialize(&mut deserializer).unwrap(); - let arr = value.as_array().unwrap(); - assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); - assert_eq!(arr.len(), 3); } #[test] - fn test_deserialize_bf16_value_too_large() { + fn test_deserialize_bf16_value_overflow_rejected() { let json = r#"[1e39, 2e39]"#; let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); - - let seed = IValueDeserSeed::new(Some( - FPHAConfig::new_with_type(FloatType::BF16).with_fallback(true), - )); - let mut deserializer = serde_json::Deserializer::from_str(json); - let value = seed.deserialize(&mut deserializer).unwrap(); - let arr = value.as_array().unwrap(); - assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); - assert_eq!(arr.len(), 2); } #[test] - fn test_deserialize_f32_value_too_large() { + fn test_deserialize_f32_value_overflow_rejected() { let json = r#"[1e39, 2e39]"#; let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); let mut deserializer = serde_json::Deserializer::from_str(json); let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } - let seed = IValueDeserSeed::new(Some( - FPHAConfig::new_with_type(FloatType::F32).with_fallback(true), - )); - let mut deserializer = serde_json::Deserializer::from_str(json); - let value = seed.deserialize(&mut deserializer).unwrap(); - let arr = value.as_array().unwrap(); - assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); - assert_eq!(arr.len(), 2); + #[test] + fn test_ser_deser_roundtrip_preserves_type() { + let json = r#"[0.2, 1.0, 1.2]"#; + + for fp_type in [FloatType::F16, FloatType::BF16, FloatType::F32] { + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type))); + let mut de = serde_json::Deserializer::from_str(json); + let original = seed.deserialize(&mut de).unwrap(); + + let serialized = serde_json::to_string(&original).unwrap(); + + let reload_seed = + IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type).with_fallback(true))); + let mut de = serde_json::Deserializer::from_str(&serialized); + let roundtripped = reload_seed.deserialize(&mut de).unwrap(); + + let arr = roundtripped.as_array().unwrap(); + assert_eq!(arr.len(), 3); + let roundtrip_tag = arr.as_slice().type_tag(); + assert_eq!( + roundtrip_tag, + fp_type.into(), + "roundtrip should preserve {fp_type}" + ); + } } }