@@ -6,10 +6,12 @@ use vortex_dtype::DType;
66use vortex_dtype:: DecimalType ;
77use vortex_dtype:: NativeDecimalType ;
88use vortex_dtype:: PrecisionScale ;
9+ use vortex_dtype:: i256;
910use vortex_dtype:: match_each_decimal_value_type;
1011use vortex_error:: VortexExpect ;
1112use vortex_error:: VortexResult ;
1213use vortex_error:: vortex_bail;
14+ use vortex_error:: vortex_err;
1315use vortex_vector:: Scalar ;
1416use vortex_vector:: ScalarOps ;
1517use vortex_vector:: Vector ;
@@ -113,7 +115,59 @@ impl<D: NativeDecimalType> Cast for DScalar<D> {
113115 {
114116 Ok ( self . clone ( ) . into ( ) )
115117 }
116- // TODO(connor): cast to different precision/scale
118+ // TODO(connor): cast to different scale
119+ DType :: Decimal ( ddt, n)
120+ if ddt. scale ( ) == self . scale ( ) && ( n. is_nullable ( ) || self . is_valid ( ) ) =>
121+ {
122+ let p = ddt. precision ( ) ;
123+ if p <= <i8 as NativeDecimalType >:: MAX_PRECISION {
124+ DScalar :: maybe_new (
125+ PrecisionScale :: < i8 > :: new ( ddt. precision ( ) , ddt. scale ( ) ) ,
126+ self . value ( ) . and_then ( |v| v. to_i8 ( ) ) ,
127+ )
128+ . map ( |ds| ds. into ( ) )
129+ . ok_or_else ( || vortex_err ! ( "Couldn't cast DScalar ({self:?} to {ddt:?}" ) )
130+ } else if p <= <i16 as NativeDecimalType >:: MAX_PRECISION {
131+ DScalar :: maybe_new (
132+ PrecisionScale :: < i16 > :: new ( ddt. precision ( ) , ddt. scale ( ) ) ,
133+ self . value ( ) . and_then ( |v| v. to_i16 ( ) ) ,
134+ )
135+ . map ( |ds| ds. into ( ) )
136+ . ok_or_else ( || vortex_err ! ( "Couldn't cast DScalar ({self:?} to {ddt:?}" ) )
137+ } else if p <= <i32 as NativeDecimalType >:: MAX_PRECISION {
138+ DScalar :: maybe_new (
139+ PrecisionScale :: < i32 > :: new ( ddt. precision ( ) , ddt. scale ( ) ) ,
140+ self . value ( ) . and_then ( |v| v. to_i32 ( ) ) ,
141+ )
142+ . map ( |ds| ds. into ( ) )
143+ . ok_or_else ( || vortex_err ! ( "Couldn't cast DScalar ({self:?} to {ddt:?}" ) )
144+ } else if p <= <i64 as NativeDecimalType >:: MAX_PRECISION {
145+ DScalar :: maybe_new (
146+ PrecisionScale :: < i64 > :: new ( ddt. precision ( ) , ddt. scale ( ) ) ,
147+ self . value ( ) . and_then ( |v| v. to_i64 ( ) ) ,
148+ )
149+ . map ( |ds| ds. into ( ) )
150+ . ok_or_else ( || vortex_err ! ( "Couldn't cast DScalar ({self:?} to {ddt:?}" ) )
151+ } else if p <= <i128 as NativeDecimalType >:: MAX_PRECISION {
152+ DScalar :: maybe_new (
153+ PrecisionScale :: < i128 > :: new ( ddt. precision ( ) , ddt. scale ( ) ) ,
154+ self . value ( ) . and_then ( |v| v. to_i128 ( ) ) ,
155+ )
156+ . map ( |ds| ds. into ( ) )
157+ . ok_or_else ( || vortex_err ! ( "Couldn't cast DScalar ({self:?} to {ddt:?}" ) )
158+ } else if p <= <i256 as NativeDecimalType >:: MAX_PRECISION {
159+ DScalar :: maybe_new (
160+ PrecisionScale :: < i256 > :: new ( ddt. precision ( ) , ddt. scale ( ) ) ,
161+ self . value ( ) . and_then ( |v| v. to_i256 ( ) ) ,
162+ )
163+ . map ( |ds| ds. into ( ) )
164+ . ok_or_else ( || vortex_err ! ( "Couldn't cast DScalar ({self:?} to {ddt:?}" ) )
165+ } else {
166+ vortex_bail ! (
167+ "Target precision {p} is out of range for supported decimal values"
168+ )
169+ }
170+ }
117171 DType :: Decimal ( ..) => {
118172 vortex_bail ! (
119173 "Casting DScalar to {} with different precision/scale not yet implemented" ,
@@ -126,3 +180,198 @@ impl<D: NativeDecimalType> Cast for DScalar<D> {
126180 }
127181 }
128182}
183+
184+ #[ cfg( test) ]
185+ mod tests {
186+ use rstest:: rstest;
187+ use vortex_dtype:: DType ;
188+ use vortex_dtype:: DecimalDType ;
189+ use vortex_dtype:: DecimalTypeDowncast ;
190+ use vortex_dtype:: NativeDecimalType ;
191+ use vortex_dtype:: Nullability ;
192+ use vortex_dtype:: PrecisionScale ;
193+ use vortex_dtype:: i256;
194+ use vortex_error:: VortexResult ;
195+ use vortex_vector:: ScalarOps ;
196+ use vortex_vector:: decimal:: DScalar ;
197+
198+ use crate :: cast:: Cast ;
199+
200+ #[ rstest]
201+ #[ case( 2 , 0 , 42i8 ) ]
202+ #[ case( 2 , 1 , 99i8 ) ]
203+ #[ case( 2 , -1 , 10i8 ) ]
204+ fn cast_dscalar_identity (
205+ #[ case] precision : u8 ,
206+ #[ case] scale : i8 ,
207+ #[ case] value : i8 ,
208+ ) -> VortexResult < ( ) > {
209+ let ps = PrecisionScale :: < i8 > :: new ( precision, scale) ;
210+ let scalar = DScalar :: maybe_new ( ps, Some ( value) ) . unwrap ( ) ;
211+ let target = DType :: Decimal (
212+ DecimalDType :: new ( precision, scale) ,
213+ Nullability :: NonNullable ,
214+ ) ;
215+ let result = scalar. cast ( & target) ?;
216+ let ds = result. into_decimal ( ) . into_i8 ( ) ;
217+ assert_eq ! ( ds. value( ) , Some ( value) ) ;
218+ assert_eq ! ( ds. precision( ) , precision) ;
219+ assert_eq ! ( ds. scale( ) , scale) ;
220+ Ok ( ( ) )
221+ }
222+
223+ #[ test]
224+ fn cast_dscalar_null_to_nullable ( ) -> VortexResult < ( ) > {
225+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
226+ let scalar = DScalar :: maybe_new ( ps, None ) . unwrap ( ) ;
227+ let target = DType :: Decimal ( DecimalDType :: new ( 2 , 0 ) , Nullability :: Nullable ) ;
228+ let result = scalar. cast ( & target) ?;
229+ assert ! ( !result. as_decimal( ) . is_valid( ) ) ;
230+ Ok ( ( ) )
231+ }
232+
233+ #[ test]
234+ fn cast_dscalar_null_to_non_nullable_fails ( ) {
235+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
236+ let scalar = DScalar :: maybe_new ( ps, None ) . unwrap ( ) ;
237+ let target = DType :: Decimal ( DecimalDType :: new ( 2 , 0 ) , Nullability :: NonNullable ) ;
238+ assert ! ( scalar. cast( & target) . is_err( ) ) ;
239+ }
240+
241+ #[ rstest]
242+ #[ case( 2 , 4 , 42i8 ) ] // i8 -> i16 (precision 2 -> 4)
243+ #[ case( 2 , 9 , 99i8 ) ] // i8 -> i32 (precision 2 -> 9)
244+ #[ case( 2 , 18 , 10i8 ) ] // i8 -> i64 (precision 2 -> 18)
245+ #[ case( 2 , 38 , 55i8 ) ] // i8 -> i128 (precision 2 -> 38)
246+ fn cast_dscalar_upcast_precision (
247+ #[ case] src_precision : u8 ,
248+ #[ case] target_precision : u8 ,
249+ #[ case] value : i8 ,
250+ ) -> VortexResult < ( ) > {
251+ let scale = 0i8 ;
252+ let ps = PrecisionScale :: < i8 > :: new ( src_precision, scale) ;
253+ let scalar = DScalar :: maybe_new ( ps, Some ( value) ) . unwrap ( ) ;
254+ let target = DType :: Decimal (
255+ DecimalDType :: new ( target_precision, scale) ,
256+ Nullability :: NonNullable ,
257+ ) ;
258+ let result = scalar. cast ( & target) ?;
259+ let ds = result. as_decimal ( ) ;
260+ assert ! ( ds. is_valid( ) ) ;
261+ assert_eq ! ( ds. precision( ) , target_precision) ;
262+ assert_eq ! ( ds. scale( ) , scale) ;
263+ Ok ( ( ) )
264+ }
265+
266+ #[ test]
267+ fn cast_dscalar_i8_to_i16 ( ) -> VortexResult < ( ) > {
268+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
269+ let scalar = DScalar :: maybe_new ( ps, Some ( 42i8 ) ) . unwrap ( ) ;
270+ // Precision 4 requires i16
271+ let target = DType :: Decimal ( DecimalDType :: new ( 4 , 0 ) , Nullability :: NonNullable ) ;
272+ let result = scalar. cast ( & target) ?;
273+ let ds = result. into_decimal ( ) . into_i16 ( ) ;
274+ assert_eq ! ( ds. value( ) , Some ( 42i16 ) ) ;
275+ assert_eq ! ( ds. precision( ) , 4 ) ;
276+ Ok ( ( ) )
277+ }
278+
279+ #[ test]
280+ fn cast_dscalar_i8_to_i32 ( ) -> VortexResult < ( ) > {
281+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
282+ let scalar = DScalar :: maybe_new ( ps, Some ( 99i8 ) ) . unwrap ( ) ;
283+ // Precision 9 requires i32
284+ let target = DType :: Decimal ( DecimalDType :: new ( 9 , 0 ) , Nullability :: NonNullable ) ;
285+ let result = scalar. cast ( & target) ?;
286+ let ds = result. into_decimal ( ) . into_i32 ( ) ;
287+ assert_eq ! ( ds. value( ) , Some ( 99i32 ) ) ;
288+ assert_eq ! ( ds. precision( ) , 9 ) ;
289+ Ok ( ( ) )
290+ }
291+
292+ #[ test]
293+ fn cast_dscalar_i16_to_i64 ( ) -> VortexResult < ( ) > {
294+ let ps = PrecisionScale :: < i16 > :: new ( 4 , 2 ) ;
295+ let scalar = DScalar :: maybe_new ( ps, Some ( 1234i16 ) ) . unwrap ( ) ;
296+ // Precision 18 requires i64
297+ let target = DType :: Decimal ( DecimalDType :: new ( 18 , 2 ) , Nullability :: NonNullable ) ;
298+ let result = scalar. cast ( & target) ?;
299+ let ds = result. into_decimal ( ) . into_i64 ( ) ;
300+ assert_eq ! ( ds. value( ) , Some ( 1234i64 ) ) ;
301+ assert_eq ! ( ds. precision( ) , 18 ) ;
302+ assert_eq ! ( ds. scale( ) , 2 ) ;
303+ Ok ( ( ) )
304+ }
305+
306+ #[ test]
307+ fn cast_dscalar_i32_to_i128 ( ) -> VortexResult < ( ) > {
308+ let ps = PrecisionScale :: < i32 > :: new ( 9 , 0 ) ;
309+ let scalar = DScalar :: maybe_new ( ps, Some ( 123456789i32 ) ) . unwrap ( ) ;
310+ // Precision 38 requires i128
311+ let target = DType :: Decimal ( DecimalDType :: new ( 38 , 0 ) , Nullability :: NonNullable ) ;
312+ let result = scalar. cast ( & target) ?;
313+ let ds = result. into_decimal ( ) . into_i128 ( ) ;
314+ assert_eq ! ( ds. value( ) , Some ( 123456789i128 ) ) ;
315+ assert_eq ! ( ds. precision( ) , 38 ) ;
316+ Ok ( ( ) )
317+ }
318+
319+ #[ test]
320+ fn cast_dscalar_different_scale_fails ( ) {
321+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
322+ let scalar = DScalar :: maybe_new ( ps, Some ( 42i8 ) ) . unwrap ( ) ;
323+ let target = DType :: Decimal ( DecimalDType :: new ( 2 , 1 ) , Nullability :: NonNullable ) ;
324+ assert ! ( scalar. cast( & target) . is_err( ) ) ;
325+ }
326+
327+ #[ test]
328+ fn cast_dscalar_to_non_decimal_fails ( ) {
329+ use vortex_dtype:: PType ;
330+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
331+ let scalar = DScalar :: maybe_new ( ps, Some ( 42i8 ) ) . unwrap ( ) ;
332+ let target = DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ;
333+ assert ! ( scalar. cast( & target) . is_err( ) ) ;
334+ }
335+
336+ #[ test]
337+ fn cast_dscalar_downcast_precision_within_same_type ( ) -> VortexResult < ( ) > {
338+ // Downcast within the same native type (i8 precision 2 -> precision 1)
339+ // should work as long as the value fits
340+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
341+ let scalar = DScalar :: maybe_new ( ps, Some ( 9i8 ) ) . unwrap ( ) ; // value 9 fits in precision 1
342+ let target = DType :: Decimal ( DecimalDType :: new ( 1 , 0 ) , Nullability :: NonNullable ) ;
343+ let result = scalar. cast ( & target) ?;
344+ let ds = result. into_decimal ( ) . into_i8 ( ) ;
345+ assert_eq ! ( ds. value( ) , Some ( 9i8 ) ) ;
346+ assert_eq ! ( ds. precision( ) , 1 ) ;
347+ Ok ( ( ) )
348+ }
349+
350+ #[ test]
351+ fn cast_dscalar_downcast_value_too_large_fails ( ) {
352+ // Value 42 doesn't fit in precision 1 (max 9)
353+ let ps = PrecisionScale :: < i8 > :: new ( 2 , 0 ) ;
354+ let scalar = DScalar :: maybe_new ( ps, Some ( 42i8 ) ) . unwrap ( ) ;
355+ let target = DType :: Decimal ( DecimalDType :: new ( 1 , 0 ) , Nullability :: NonNullable ) ;
356+ assert ! ( scalar. cast( & target) . is_err( ) ) ;
357+ }
358+
359+ #[ rstest]
360+ #[ case( <i8 as NativeDecimalType >:: MAX_PRECISION ) ]
361+ #[ case( <i16 as NativeDecimalType >:: MAX_PRECISION ) ]
362+ #[ case( <i32 as NativeDecimalType >:: MAX_PRECISION ) ]
363+ #[ case( <i64 as NativeDecimalType >:: MAX_PRECISION ) ]
364+ #[ case( <i128 as NativeDecimalType >:: MAX_PRECISION ) ]
365+ #[ case( <i256 as NativeDecimalType >:: MAX_PRECISION ) ]
366+ fn cast_dscalar_to_max_precision_boundary ( #[ case] target_precision : u8 ) -> VortexResult < ( ) > {
367+ let ps = PrecisionScale :: < i8 > :: new ( 1 , 0 ) ;
368+ let scalar = DScalar :: maybe_new ( ps, Some ( 1i8 ) ) . unwrap ( ) ;
369+ let target = DType :: Decimal (
370+ DecimalDType :: new ( target_precision, 0 ) ,
371+ Nullability :: NonNullable ,
372+ ) ;
373+ let result = scalar. cast ( & target) ?;
374+ assert_eq ! ( result. as_decimal( ) . precision( ) , target_precision) ;
375+ Ok ( ( ) )
376+ }
377+ }
0 commit comments