Skip to content

Commit 65e7263

Browse files
authored
Support precision casting for DScalar (#5937)
This is a missing piece to support #5863 --------- Signed-off-by: Adam Gutglick <[email protected]>
1 parent 5cd8b29 commit 65e7263

File tree

1 file changed

+250
-1
lines changed

1 file changed

+250
-1
lines changed

vortex-compute/src/cast/dvector.rs

Lines changed: 250 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ use vortex_dtype::DType;
66
use vortex_dtype::DecimalType;
77
use vortex_dtype::NativeDecimalType;
88
use vortex_dtype::PrecisionScale;
9+
use vortex_dtype::i256;
910
use vortex_dtype::match_each_decimal_value_type;
1011
use vortex_error::VortexExpect;
1112
use vortex_error::VortexResult;
1213
use vortex_error::vortex_bail;
14+
use vortex_error::vortex_err;
1315
use vortex_vector::Scalar;
1416
use vortex_vector::ScalarOps;
1517
use vortex_vector::Vector;
@@ -113,7 +115,59 @@ impl<D: NativeDecimalType> Cast for DScalar<D> {
113115
{
114116
Ok(self.clone().into())
115117
}
116-
// TODO(connor): cast to different precision/scale
118+
// TODO(connor): cast to different scale
119+
DType::Decimal(ddt, n)
120+
if ddt.scale() == self.scale() && (n.is_nullable() || self.is_valid()) =>
121+
{
122+
let p = ddt.precision();
123+
if p <= <i8 as NativeDecimalType>::MAX_PRECISION {
124+
DScalar::maybe_new(
125+
PrecisionScale::<i8>::new(ddt.precision(), ddt.scale()),
126+
self.value().and_then(|v| v.to_i8()),
127+
)
128+
.map(|ds| ds.into())
129+
.ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}"))
130+
} else if p <= <i16 as NativeDecimalType>::MAX_PRECISION {
131+
DScalar::maybe_new(
132+
PrecisionScale::<i16>::new(ddt.precision(), ddt.scale()),
133+
self.value().and_then(|v| v.to_i16()),
134+
)
135+
.map(|ds| ds.into())
136+
.ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}"))
137+
} else if p <= <i32 as NativeDecimalType>::MAX_PRECISION {
138+
DScalar::maybe_new(
139+
PrecisionScale::<i32>::new(ddt.precision(), ddt.scale()),
140+
self.value().and_then(|v| v.to_i32()),
141+
)
142+
.map(|ds| ds.into())
143+
.ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}"))
144+
} else if p <= <i64 as NativeDecimalType>::MAX_PRECISION {
145+
DScalar::maybe_new(
146+
PrecisionScale::<i64>::new(ddt.precision(), ddt.scale()),
147+
self.value().and_then(|v| v.to_i64()),
148+
)
149+
.map(|ds| ds.into())
150+
.ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}"))
151+
} else if p <= <i128 as NativeDecimalType>::MAX_PRECISION {
152+
DScalar::maybe_new(
153+
PrecisionScale::<i128>::new(ddt.precision(), ddt.scale()),
154+
self.value().and_then(|v| v.to_i128()),
155+
)
156+
.map(|ds| ds.into())
157+
.ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}"))
158+
} else if p <= <i256 as NativeDecimalType>::MAX_PRECISION {
159+
DScalar::maybe_new(
160+
PrecisionScale::<i256>::new(ddt.precision(), ddt.scale()),
161+
self.value().and_then(|v| v.to_i256()),
162+
)
163+
.map(|ds| ds.into())
164+
.ok_or_else(|| vortex_err!("Couldn't cast DScalar ({self:?} to {ddt:?}"))
165+
} else {
166+
vortex_bail!(
167+
"Target precision {p} is out of range for supported decimal values"
168+
)
169+
}
170+
}
117171
DType::Decimal(..) => {
118172
vortex_bail!(
119173
"Casting DScalar to {} with different precision/scale not yet implemented",
@@ -126,3 +180,198 @@ impl<D: NativeDecimalType> Cast for DScalar<D> {
126180
}
127181
}
128182
}
183+
184+
#[cfg(test)]
185+
mod tests {
186+
use rstest::rstest;
187+
use vortex_dtype::DType;
188+
use vortex_dtype::DecimalDType;
189+
use vortex_dtype::DecimalTypeDowncast;
190+
use vortex_dtype::NativeDecimalType;
191+
use vortex_dtype::Nullability;
192+
use vortex_dtype::PrecisionScale;
193+
use vortex_dtype::i256;
194+
use vortex_error::VortexResult;
195+
use vortex_vector::ScalarOps;
196+
use vortex_vector::decimal::DScalar;
197+
198+
use crate::cast::Cast;
199+
200+
#[rstest]
201+
#[case(2, 0, 42i8)]
202+
#[case(2, 1, 99i8)]
203+
#[case(2, -1, 10i8)]
204+
fn cast_dscalar_identity(
205+
#[case] precision: u8,
206+
#[case] scale: i8,
207+
#[case] value: i8,
208+
) -> VortexResult<()> {
209+
let ps = PrecisionScale::<i8>::new(precision, scale);
210+
let scalar = DScalar::maybe_new(ps, Some(value)).unwrap();
211+
let target = DType::Decimal(
212+
DecimalDType::new(precision, scale),
213+
Nullability::NonNullable,
214+
);
215+
let result = scalar.cast(&target)?;
216+
let ds = result.into_decimal().into_i8();
217+
assert_eq!(ds.value(), Some(value));
218+
assert_eq!(ds.precision(), precision);
219+
assert_eq!(ds.scale(), scale);
220+
Ok(())
221+
}
222+
223+
#[test]
224+
fn cast_dscalar_null_to_nullable() -> VortexResult<()> {
225+
let ps = PrecisionScale::<i8>::new(2, 0);
226+
let scalar = DScalar::maybe_new(ps, None).unwrap();
227+
let target = DType::Decimal(DecimalDType::new(2, 0), Nullability::Nullable);
228+
let result = scalar.cast(&target)?;
229+
assert!(!result.as_decimal().is_valid());
230+
Ok(())
231+
}
232+
233+
#[test]
234+
fn cast_dscalar_null_to_non_nullable_fails() {
235+
let ps = PrecisionScale::<i8>::new(2, 0);
236+
let scalar = DScalar::maybe_new(ps, None).unwrap();
237+
let target = DType::Decimal(DecimalDType::new(2, 0), Nullability::NonNullable);
238+
assert!(scalar.cast(&target).is_err());
239+
}
240+
241+
#[rstest]
242+
#[case(2, 4, 42i8)] // i8 -> i16 (precision 2 -> 4)
243+
#[case(2, 9, 99i8)] // i8 -> i32 (precision 2 -> 9)
244+
#[case(2, 18, 10i8)] // i8 -> i64 (precision 2 -> 18)
245+
#[case(2, 38, 55i8)] // i8 -> i128 (precision 2 -> 38)
246+
fn cast_dscalar_upcast_precision(
247+
#[case] src_precision: u8,
248+
#[case] target_precision: u8,
249+
#[case] value: i8,
250+
) -> VortexResult<()> {
251+
let scale = 0i8;
252+
let ps = PrecisionScale::<i8>::new(src_precision, scale);
253+
let scalar = DScalar::maybe_new(ps, Some(value)).unwrap();
254+
let target = DType::Decimal(
255+
DecimalDType::new(target_precision, scale),
256+
Nullability::NonNullable,
257+
);
258+
let result = scalar.cast(&target)?;
259+
let ds = result.as_decimal();
260+
assert!(ds.is_valid());
261+
assert_eq!(ds.precision(), target_precision);
262+
assert_eq!(ds.scale(), scale);
263+
Ok(())
264+
}
265+
266+
#[test]
267+
fn cast_dscalar_i8_to_i16() -> VortexResult<()> {
268+
let ps = PrecisionScale::<i8>::new(2, 0);
269+
let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap();
270+
// Precision 4 requires i16
271+
let target = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable);
272+
let result = scalar.cast(&target)?;
273+
let ds = result.into_decimal().into_i16();
274+
assert_eq!(ds.value(), Some(42i16));
275+
assert_eq!(ds.precision(), 4);
276+
Ok(())
277+
}
278+
279+
#[test]
280+
fn cast_dscalar_i8_to_i32() -> VortexResult<()> {
281+
let ps = PrecisionScale::<i8>::new(2, 0);
282+
let scalar = DScalar::maybe_new(ps, Some(99i8)).unwrap();
283+
// Precision 9 requires i32
284+
let target = DType::Decimal(DecimalDType::new(9, 0), Nullability::NonNullable);
285+
let result = scalar.cast(&target)?;
286+
let ds = result.into_decimal().into_i32();
287+
assert_eq!(ds.value(), Some(99i32));
288+
assert_eq!(ds.precision(), 9);
289+
Ok(())
290+
}
291+
292+
#[test]
293+
fn cast_dscalar_i16_to_i64() -> VortexResult<()> {
294+
let ps = PrecisionScale::<i16>::new(4, 2);
295+
let scalar = DScalar::maybe_new(ps, Some(1234i16)).unwrap();
296+
// Precision 18 requires i64
297+
let target = DType::Decimal(DecimalDType::new(18, 2), Nullability::NonNullable);
298+
let result = scalar.cast(&target)?;
299+
let ds = result.into_decimal().into_i64();
300+
assert_eq!(ds.value(), Some(1234i64));
301+
assert_eq!(ds.precision(), 18);
302+
assert_eq!(ds.scale(), 2);
303+
Ok(())
304+
}
305+
306+
#[test]
307+
fn cast_dscalar_i32_to_i128() -> VortexResult<()> {
308+
let ps = PrecisionScale::<i32>::new(9, 0);
309+
let scalar = DScalar::maybe_new(ps, Some(123456789i32)).unwrap();
310+
// Precision 38 requires i128
311+
let target = DType::Decimal(DecimalDType::new(38, 0), Nullability::NonNullable);
312+
let result = scalar.cast(&target)?;
313+
let ds = result.into_decimal().into_i128();
314+
assert_eq!(ds.value(), Some(123456789i128));
315+
assert_eq!(ds.precision(), 38);
316+
Ok(())
317+
}
318+
319+
#[test]
320+
fn cast_dscalar_different_scale_fails() {
321+
let ps = PrecisionScale::<i8>::new(2, 0);
322+
let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap();
323+
let target = DType::Decimal(DecimalDType::new(2, 1), Nullability::NonNullable);
324+
assert!(scalar.cast(&target).is_err());
325+
}
326+
327+
#[test]
328+
fn cast_dscalar_to_non_decimal_fails() {
329+
use vortex_dtype::PType;
330+
let ps = PrecisionScale::<i8>::new(2, 0);
331+
let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap();
332+
let target = DType::Primitive(PType::I32, Nullability::NonNullable);
333+
assert!(scalar.cast(&target).is_err());
334+
}
335+
336+
#[test]
337+
fn cast_dscalar_downcast_precision_within_same_type() -> VortexResult<()> {
338+
// Downcast within the same native type (i8 precision 2 -> precision 1)
339+
// should work as long as the value fits
340+
let ps = PrecisionScale::<i8>::new(2, 0);
341+
let scalar = DScalar::maybe_new(ps, Some(9i8)).unwrap(); // value 9 fits in precision 1
342+
let target = DType::Decimal(DecimalDType::new(1, 0), Nullability::NonNullable);
343+
let result = scalar.cast(&target)?;
344+
let ds = result.into_decimal().into_i8();
345+
assert_eq!(ds.value(), Some(9i8));
346+
assert_eq!(ds.precision(), 1);
347+
Ok(())
348+
}
349+
350+
#[test]
351+
fn cast_dscalar_downcast_value_too_large_fails() {
352+
// Value 42 doesn't fit in precision 1 (max 9)
353+
let ps = PrecisionScale::<i8>::new(2, 0);
354+
let scalar = DScalar::maybe_new(ps, Some(42i8)).unwrap();
355+
let target = DType::Decimal(DecimalDType::new(1, 0), Nullability::NonNullable);
356+
assert!(scalar.cast(&target).is_err());
357+
}
358+
359+
#[rstest]
360+
#[case(<i8 as NativeDecimalType>::MAX_PRECISION)]
361+
#[case(<i16 as NativeDecimalType>::MAX_PRECISION)]
362+
#[case(<i32 as NativeDecimalType>::MAX_PRECISION)]
363+
#[case(<i64 as NativeDecimalType>::MAX_PRECISION)]
364+
#[case(<i128 as NativeDecimalType>::MAX_PRECISION)]
365+
#[case(<i256 as NativeDecimalType>::MAX_PRECISION)]
366+
fn cast_dscalar_to_max_precision_boundary(#[case] target_precision: u8) -> VortexResult<()> {
367+
let ps = PrecisionScale::<i8>::new(1, 0);
368+
let scalar = DScalar::maybe_new(ps, Some(1i8)).unwrap();
369+
let target = DType::Decimal(
370+
DecimalDType::new(target_precision, 0),
371+
Nullability::NonNullable,
372+
);
373+
let result = scalar.cast(&target)?;
374+
assert_eq!(result.as_decimal().precision(), target_precision);
375+
Ok(())
376+
}
377+
}

0 commit comments

Comments
 (0)