diff --git a/.github/actions/fuzz_tests/action.yml b/.github/actions/fuzz_tests/action.yml index 95c7b09..16774d5 100644 --- a/.github/actions/fuzz_tests/action.yml +++ b/.github/actions/fuzz_tests/action.yml @@ -8,7 +8,7 @@ inputs: fuzz_time: description: 'Maximum time in seconds to run fuzzing' required: false - default: '180' + default: '120' cargo_fuzz_version: description: 'Version of cargo-fuzz to install' required: false diff --git a/Cargo.toml b/Cargo.toml index 9332c82..ed0021a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ serde_json = { workspace = true } ctor = { version = "0.1.16", optional = true } paste = "1.0.15" half = "2.0.0" +thiserror = "2.0.18" [dev-dependencies] mockalloc = "0.1.2" diff --git a/src/array.rs b/src/array.rs index 2cc0d70..8d725c8 100644 --- a/src/array.rs +++ b/src/array.rs @@ -9,8 +9,9 @@ use std::iter::FromIterator; use std::ops::{Index, IndexMut}; use std::slice::{from_raw_parts, from_raw_parts_mut, SliceIndex}; +use crate::error::IJsonError; use crate::{ - alloc::AllocError, + error::AllocError, thin::{ThinMut, ThinMutExt, ThinRef, ThinRefExt}, value::TypeTag, Defrag, DefragAllocator, IValue, @@ -54,6 +55,55 @@ impl Default for ArrayTag { } } +/// Enum representing different types of floating-point types +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum FloatType { + /// F16 + F16 = 1, + /// BF16 + BF16, + /// F32 + F32, + /// F64 + F64, +} + +impl fmt::Display for FloatType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FloatType::F16 => write!(f, "F16"), + FloatType::BF16 => write!(f, "BF16"), + FloatType::F32 => write!(f, "F32"), + FloatType::F64 => write!(f, "F64"), + } + } +} + +impl TryFrom for FloatType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(FloatType::F16), + 2 => Ok(FloatType::BF16), + 3 => Ok(FloatType::F32), + 4 => Ok(FloatType::F64), + _ => Err(()), + } + } +} + +impl From for ArrayTag { + fn from(fp_type: FloatType) -> Self { + match fp_type { + FloatType::F16 => ArrayTag::F16, + FloatType::BF16 => ArrayTag::BF16, + FloatType::F32 => ArrayTag::F32, + FloatType::F64 => ArrayTag::F64, + } + } +} + impl ArrayTag { fn from_type() -> Self { use ArrayTag::*; @@ -401,11 +451,11 @@ impl Header { const TAG_MASK: u64 = 0xF; const TAG_SHIFT: u64 = 60; - const fn new(len: usize, cap: usize, tag: ArrayTag) -> Result { + const fn new(len: usize, cap: usize, tag: ArrayTag) -> Result { // assert!(len <= Self::LEN_MASK as usize, "Length exceeds 30-bit limit"); // assert!(cap <= Self::CAP_MASK as usize, "Capacity exceeds 30-bit limit"); if len > Self::LEN_MASK as usize || cap > Self::CAP_MASK as usize { - return Err(AllocError); + return Err(IJsonError::Alloc(AllocError)); } let packed = ((len as u64) & Self::LEN_MASK) << Self::LEN_SHIFT @@ -561,6 +611,26 @@ trait HeaderMut<'a>: ThinMutExt<'a, Header> { self.set_len(index + 1); } + // Safety: Space must already be allocated for the item, + // and the item must be a number. The array type must be a floating-point type. + unsafe fn push_lossy(&mut self, item: IValue) { + use ArrayTag::*; + let index = self.len(); + + macro_rules! push_lossy_impl { + ($(($tag:ident, $ty:ty)),*) => { + match self.type_tag() { + $($tag => self.reborrow().raw_array_ptr_mut().cast::<$ty>().add(index).write( + paste::paste!(item.[]()).unwrap()),)* + _ => unreachable!(), + } + } + } + + push_lossy_impl!((F16, f16), (BF16, bf16), (F32, f32), (F64, f64)); + self.set_len(index + 1); + } + fn pop(&mut self) -> Option { if self.len() == 0 { None @@ -670,7 +740,7 @@ impl IArray { .pad_to_align()) } - fn alloc(cap: usize, tag: ArrayTag) -> Result<*mut Header, AllocError> { + fn alloc(cap: usize, tag: ArrayTag) -> Result<*mut Header, IJsonError> { unsafe { let ptr = alloc(Self::layout(cap, tag).map_err(|_| AllocError)?).cast::
(); ptr.write(Header::new(0, cap, tag)?); @@ -678,7 +748,7 @@ impl IArray { } } - fn realloc(ptr: *mut Header, new_cap: usize) -> Result<*mut Header, AllocError> { + fn realloc(ptr: *mut Header, new_cap: usize) -> Result<*mut Header, IJsonError> { unsafe { let tag = (*ptr).type_tag(); let old_layout = Self::layout((*ptr).cap(), tag).map_err(|_| AllocError)?; @@ -706,13 +776,13 @@ impl IArray { /// Constructs a new `IArray` with the specified capacity. At least that many items /// can be added to the array without reallocating. #[must_use] - pub fn with_capacity(cap: usize) -> Result { + pub fn with_capacity(cap: usize) -> Result { Self::with_capacity_and_tag(cap, ArrayTag::Heterogeneous) } /// Constructs a new `IArray` with the specified capacity and array type. #[must_use] - fn with_capacity_and_tag(cap: usize, tag: ArrayTag) -> Result { + fn with_capacity_and_tag(cap: usize, tag: ArrayTag) -> Result { if cap == 0 { Ok(Self::new()) } else { @@ -743,7 +813,7 @@ impl IArray { /// Converts this array to a new type, promoting all existing elements. /// This is used for automatic type promotion when incompatible types are added. - fn promote_to_type(&mut self, new_tag: ArrayTag) -> Result<(), AllocError> { + fn promote_to_type(&mut self, new_tag: ArrayTag) -> Result<(), IJsonError> { if self.is_static() || self.header().type_tag() == new_tag { return Ok(()); } @@ -898,7 +968,7 @@ impl IArray { self.header_mut().as_mut_slice_unchecked::() } - fn resize_internal(&mut self, cap: usize) -> Result<(), AllocError> { + fn resize_internal(&mut self, cap: usize) -> Result<(), IJsonError> { if self.is_static() || cap == 0 { let tag = if self.is_static() { ArrayTag::Heterogeneous @@ -916,7 +986,7 @@ impl IArray { } /// Reserves space for at least this many additional items. - pub fn reserve(&mut self, additional: usize) -> Result<(), AllocError> { + pub fn reserve(&mut self, additional: usize) -> Result<(), IJsonError> { let hd = self.header(); let current_capacity = hd.cap(); let desired_capacity = hd.len().checked_add(additional).ok_or(AllocError)?; @@ -956,7 +1026,7 @@ impl IArray { /// on or after this index will be shifted down to accomodate this. For large /// arrays, insertions near the front will be slow as it will require shifting /// a large number of items. - pub fn insert(&mut self, index: usize, item: impl Into) -> Result<(), AllocError> { + pub fn insert(&mut self, index: usize, item: impl Into) -> Result<(), IJsonError> { let item = item.into(); let current_tag = self.header().type_tag(); let len = self.len(); @@ -1080,8 +1150,46 @@ impl IArray { } } + /// Pushes a new item onto the back of the array with a specific floating-point type, potentially losing precision. + pub(crate) fn push_with_fp_type( + &mut self, + item: impl Into, + fp_type: FloatType, + ) -> Result<(), IJsonError> { + let desired_tag = fp_type.into(); + let current_tag = self.header().type_tag(); + let len = self.len(); + let item = item.into(); + let can_fit = || match fp_type { + FloatType::F16 => item.to_f16_lossy().map_or(false, |v| v.is_finite()), + FloatType::BF16 => item.to_bf16_lossy().map_or(false, |v| v.is_finite()), + FloatType::F32 => item.to_f32_lossy().map_or(false, |v| v.is_finite()), + FloatType::F64 => item.to_f64_lossy().map_or(false, |v| v.is_finite()), + }; + + if (desired_tag != current_tag && len > 0) || !can_fit() { + return Err(IJsonError::OutOfRange(fp_type)); + } + + // We can fit the item into the array, so we can push it directly + + if len == 0 { + if self.is_static() { + *self = IArray::with_capacity_and_tag(4, desired_tag)?; + } else { + self.promote_to_type(desired_tag)?; + } + } + + self.reserve(1)?; + unsafe { + self.header_mut().push_lossy(item); + } + Ok(()) + } + /// Pushes a new item onto the back of the array. - pub fn push(&mut self, item: impl Into) -> Result<(), AllocError> { + pub fn push(&mut self, item: impl Into) -> Result<(), IJsonError> { let item = item.into(); let current_tag = self.header().type_tag(); let len = self.len(); @@ -1425,11 +1533,11 @@ pub trait TryExtend { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns an `AllocError` if memory allocation fails during the extension. - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError>; + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError>; } impl + private::Sealed> TryExtend for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let iter = iter.into_iter(); self.reserve(iter.size_hint().0)?; for v in iter { @@ -1442,7 +1550,7 @@ impl + private::Sealed> TryExtend for IArray { macro_rules! extend_impl_int { ($($ty:ty),*) => { $(impl TryExtend<$ty> for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let expected_tag = ArrayTag::from_type::<$ty>(); let iter = iter.into_iter(); let size_hint = iter.size_hint().0; @@ -1494,7 +1602,7 @@ macro_rules! extend_impl_int { macro_rules! extend_impl_float { ($($ty:ty),*) => { $(impl TryExtend<$ty> for IArray { - fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), AllocError> { + fn try_extend(&mut self, iter: impl IntoIterator) -> Result<(), IJsonError> { let expected_tag = ArrayTag::from_type::<$ty>(); let iter = iter.into_iter(); let size_hint = iter.size_hint().0; @@ -1564,13 +1672,13 @@ pub trait TryFromIterator { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns `AllocError` if memory allocation fails during the construction. - fn try_from_iter>(iter: U) -> Result + fn try_from_iter>(iter: U) -> Result where Self: Sized; } impl + private::Sealed> TryFromIterator for IArray { - fn try_from_iter>(iter: T) -> Result { + fn try_from_iter>(iter: T) -> Result { let mut res = IArray::new(); res.try_extend(iter)?; Ok(res) @@ -1580,7 +1688,7 @@ impl + private::Sealed> TryFromIterator for IArray { macro_rules! from_iter_impl { ($($ty:ty),*) => { $(impl TryFromIterator<$ty> for IArray { - fn try_from_iter>(iter: T) -> Result { + fn try_from_iter>(iter: T) -> Result { let iter = iter.into_iter(); let mut res = IArray::with_capacity_and_tag(iter.size_hint().0, ArrayTag::from_type::<$ty>())?; res.try_extend(iter)?; @@ -1599,13 +1707,13 @@ pub trait TryCollect: Iterator + Sized { /// Returns an `AllocError` if allocation fails. /// # Errors /// Returns `AllocError` if memory allocation fails during the collection. - fn try_collect(self) -> Result + fn try_collect(self) -> Result where B: TryFromIterator; } impl> TryCollect for I { - fn try_collect(self) -> Result + fn try_collect(self) -> Result where B: TryFromIterator, { @@ -1614,7 +1722,7 @@ impl> TryCollect for I { } impl + private::Sealed> TryFrom> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: Vec) -> Result { let mut res = IArray::with_capacity(other.len())?; res.try_extend(other.into_iter().map(Into::into))?; @@ -1623,7 +1731,7 @@ impl + private::Sealed> TryFrom> for IArray { } impl + Clone + private::Sealed> TryFrom<&[T]> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: &[T]) -> Result { let mut res = IArray::with_capacity(other.len())?; res.try_extend(other.iter().cloned().map(Into::into))?; @@ -1634,7 +1742,7 @@ impl + Clone + private::Sealed> TryFrom<&[T]> for IArray { macro_rules! from_slice_impl { ($($ty:ty),*) => {$( impl TryFrom> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: Vec<$ty>) -> Result { let mut res = IArray::with_capacity_and_tag(other.len(), ArrayTag::from_type::<$ty>())?; TryExtend::<$ty>::try_extend(&mut res, other.into_iter().map(Into::into))?; @@ -1642,7 +1750,7 @@ macro_rules! from_slice_impl { } } impl TryFrom<&[$ty]> for IArray { - type Error = AllocError; + type Error = IJsonError; fn try_from(other: &[$ty]) -> Result { let mut res = IArray::with_capacity_and_tag(other.len(), ArrayTag::from_type::<$ty>())?; TryExtend::<$ty>::try_extend(&mut res, other.iter().cloned().map(Into::into))?; @@ -3207,4 +3315,28 @@ mod tests { } } } + + #[test] + fn test_push_with_fp_type_creates_typed_array() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + arr.push_with_fp_type(IValue::from(2.5), FloatType::F16) + .unwrap(); + + assert_eq!(arr.len(), 2); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + } + + #[test] + fn test_push_with_fp_type_overflow_rejected() { + let mut arr = IArray::new(); + arr.push_with_fp_type(IValue::from(1.5), FloatType::F16) + .unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + arr.push_with_fp_type(IValue::from(100000.0), FloatType::F16) + .unwrap_err(); + assert_eq!(arr.len(), 1); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + } } diff --git a/src/de.rs b/src/de.rs index 4eabe26..c7d2926 100644 --- a/src/de.rs +++ b/src/de.rs @@ -8,14 +8,73 @@ use serde::de::{ use serde::{forward_to_deserialize_any, Deserialize, Deserializer}; use serde_json::error::Error; -use crate::{DestructuredRef, IArray, INumber, IObject, IString, IValue}; +use crate::{DestructuredRef, FloatType, IArray, INumber, IObject, IString, IValue}; + +#[derive(Debug, Clone, Copy)] +/// Configuration for floating point homogeneous arrays. +pub struct FPHAConfig { + /// Floating point type for homogeneous arrays. + pub fpha_type: FloatType, + /// If `fallback` is true, arrays that don't fit the fpha_type will fall back to regular push. + pub fpha_fallback: bool, +} + +impl FPHAConfig { + /// Creates a new [`FPHAConfig`] with the given floating point type. + pub fn new(fpha_type: FloatType, fpha_fallback: bool) -> Self { + Self { + fpha_type, + fpha_fallback, + } + } + + /// Creates a new [`FPHAConfig`] with the given floating point type and fallback behavior. + pub fn new_with_type(fpha_type: FloatType) -> Self { + Self { + fpha_type, + fpha_fallback: false, + } + } + + /// Sets the fallback behavior. + pub fn with_fallback(mut self, fallback: bool) -> Self { + self.fpha_fallback = fallback; + self + } +} + +/// Seed for deserializing an [`IValue`]. +#[derive(Debug, Default)] +pub struct IValueDeserSeed { + /// Optional FPHA configuration for homogeneous arrays. + pub fpha_config: Option, +} + +impl IValueDeserSeed { + /// Creates a new [`IValueDeserSeed`] with the given floating point type enforcment type for homogeneous arrays. + pub fn new(fpha_config: Option) -> Self { + IValueDeserSeed { fpha_config } + } +} + +impl<'de> DeserializeSeed<'de> for IValueDeserSeed { + type Value = IValue; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Pass hint to a custom visitor + deserializer.deserialize_any(ValueVisitor::new(self.fpha_config)) + } +} impl<'de> Deserialize<'de> for IValue { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - deserializer.deserialize_any(ValueVisitor) + deserializer.deserialize_any(ValueVisitor::new(None)) } } @@ -42,7 +101,7 @@ impl<'de> Deserialize<'de> for IArray { where D: Deserializer<'de>, { - deserializer.deserialize_seq(ArrayVisitor) + deserializer.deserialize_seq(ArrayVisitor { fpha_config: None }) } } @@ -51,11 +110,19 @@ impl<'de> Deserialize<'de> for IObject { where D: Deserializer<'de>, { - deserializer.deserialize_map(ObjectVisitor) + deserializer.deserialize_map(ObjectVisitor { fpha_config: None }) } } -struct ValueVisitor; +struct ValueVisitor { + fpha_config: Option, +} + +impl ValueVisitor { + fn new(fpha_config: Option) -> Self { + ValueVisitor { fpha_config } + } +} impl<'de> Visitor<'de> for ValueVisitor { type Value = IValue; @@ -104,7 +171,7 @@ impl<'de> Visitor<'de> for ValueVisitor { where D: Deserializer<'de>, { - Deserialize::deserialize(deserializer) + IValueDeserSeed::new(self.fpha_config).deserialize(deserializer) } #[inline] @@ -117,14 +184,22 @@ impl<'de> Visitor<'de> for ValueVisitor { where V: SeqAccess<'de>, { - ArrayVisitor.visit_seq(visitor).map(Into::into) + ArrayVisitor { + fpha_config: self.fpha_config, + } + .visit_seq(visitor) + .map(Into::into) } fn visit_map(self, visitor: V) -> Result where V: MapAccess<'de>, { - ObjectVisitor.visit_map(visitor).map(Into::into) + ObjectVisitor { + fpha_config: self.fpha_config, + } + .visit_map(visitor) + .map(Into::into) } } @@ -192,7 +267,9 @@ impl<'de> Visitor<'de> for StringVisitor { } } -struct ArrayVisitor; +struct ArrayVisitor { + fpha_config: Option, +} impl<'de> Visitor<'de> for ArrayVisitor { type Value = IArray; @@ -208,15 +285,27 @@ impl<'de> Visitor<'de> for ArrayVisitor { { let mut arr = IArray::with_capacity(visitor.size_hint().unwrap_or(0)) .map_err(|_| SError::custom("Failed to allocate array"))?; - while let Some(v) = visitor.next_element::()? { - arr.push(v) - .map_err(|_| SError::custom("Failed to push to array"))?; + while let Some(v) = visitor.next_element_seed(IValueDeserSeed::new(self.fpha_config))? { + // Matching Some(..) twice, to avoind cloning the value :/ + match self.fpha_config { + Some(FPHAConfig { + fpha_type, + fpha_fallback: true, + }) => arr + .push_with_fp_type(v.clone(), fpha_type) + .or_else(|_| arr.push(v).map_err(Into::into)), + Some(FPHAConfig { fpha_type, .. }) => arr.push_with_fp_type(v, fpha_type), + None => arr.push(v).map_err(Into::into), + } + .map_err(|e| SError::custom(e.to_string()))?; } Ok(arr) } } -struct ObjectVisitor; +struct ObjectVisitor { + fpha_config: Option, +} impl<'de> Visitor<'de> for ObjectVisitor { type Value = IObject; @@ -230,7 +319,8 @@ impl<'de> Visitor<'de> for ObjectVisitor { V: MapAccess<'de>, { let mut obj = IObject::with_capacity(visitor.size_hint().unwrap_or(0)); - while let Some((k, v)) = visitor.next_entry::()? { + while let Some(k) = visitor.next_key::()? { + let v = visitor.next_value_seed(IValueDeserSeed::new(self.fpha_config))?; obj.insert(k, v); } Ok(obj) @@ -999,3 +1089,129 @@ where { T::deserialize(value) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::ArraySliceRef; + use serde::de::DeserializeSeed; + + #[test] + fn test_deserialize_with_f64_fp() { + let json = r#"[1.5, 2.5, 3.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F64))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_f32_fp() { + let json = r#"[1.5, 2.5, 3.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_f16_fp() { + let json = r#"[0.5, 1.0, 1.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F16(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_with_bf16_fp() { + let json = r#"[0.5, 1.0, 2.0]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::BF16(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_mixed_array_with_fp() { + let json = r#"[1, "string", 3.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_integer_array_with_fp() { + let json = r#"[1, 2, 3]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let value = seed.deserialize(&mut deserializer).unwrap(); + + let arr = value.as_array().unwrap(); + assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_))); + assert_eq!(arr.len(), 3); + } + + #[test] + fn test_deserialize_f16_value_overflow_rejected() { + let json = r#"[0.5, 100000.0, 1.5]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_bf16_value_overflow_rejected() { + let json = r#"[1e39, 2e39]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::BF16))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_deserialize_f32_value_overflow_rejected() { + let json = r#"[1e39, 2e39]"#; + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32))); + let mut deserializer = serde_json::Deserializer::from_str(json); + let _error = seed.deserialize(&mut deserializer).unwrap_err(); + } + + #[test] + fn test_ser_deser_roundtrip_preserves_type() { + let json = r#"[0.2, 1.0, 1.2]"#; + + for fp_type in [FloatType::F16, FloatType::BF16, FloatType::F32] { + let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type))); + let mut de = serde_json::Deserializer::from_str(json); + let original = seed.deserialize(&mut de).unwrap(); + + let serialized = serde_json::to_string(&original).unwrap(); + + let reload_seed = + IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type).with_fallback(true))); + let mut de = serde_json::Deserializer::from_str(&serialized); + let roundtripped = reload_seed.deserialize(&mut de).unwrap(); + + let arr = roundtripped.as_array().unwrap(); + assert_eq!(arr.len(), 3); + let roundtrip_tag = arr.as_slice().type_tag(); + assert_eq!( + roundtrip_tag, + fp_type.into(), + "roundtrip should preserve {fp_type}" + ); + } + } +} diff --git a/src/alloc.rs b/src/error.rs similarity index 56% rename from src/alloc.rs rename to src/error.rs index af0c87d..d5b9f36 100644 --- a/src/alloc.rs +++ b/src/error.rs @@ -2,6 +2,9 @@ use std::error::Error; use std::fmt; +use thiserror::Error; + +use crate::FloatType; /// Error type for fallible allocation /// This error is returned when an allocation fails. @@ -16,3 +19,14 @@ impl fmt::Display for AllocError { f.write_str("memory allocation failed") } } + +/// Error type for ijson +#[derive(Error, Debug)] +pub enum IJsonError { + /// Memory allocation failed + #[error("memory allocation failed")] + Alloc(#[from] AllocError), + /// Value out of range for the specified floating-point type + #[error("value out of range for {0}")] + OutOfRange(FloatType), +} diff --git a/src/lib.rs b/src/lib.rs index afcff7b..86e8255 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,11 +36,11 @@ pub mod unsafe_string; #[cfg(not(feature = "thread_safe"))] pub use unsafe_string::IString; -pub mod alloc; +pub mod error; mod thin; mod value; -pub use array::IArray; +pub use array::{FloatType, IArray}; pub use number::INumber; pub use object::IObject; use std::alloc::Layout; @@ -51,7 +51,7 @@ pub use value::{ mod de; mod ser; -pub use de::from_value; +pub use de::{from_value, FPHAConfig, IValueDeserSeed}; pub use ser::to_value; /// Trait to implement defrag allocator