@@ -6,10 +6,13 @@ use std::sync::Arc;
66
77use async_trait:: async_trait;
88use cudarc:: driver:: sys;
9+ use futures:: future:: BoxFuture ;
910use vortex_array:: ArrayRef ;
1011use vortex_array:: Canonical ;
1112use vortex_array:: arrays:: PrimitiveArray ;
1213use vortex_array:: arrays:: PrimitiveArrayParts ;
14+ use vortex_array:: arrays:: StructArray ;
15+ use vortex_array:: arrays:: StructArrayParts ;
1316use vortex_array:: buffer:: BufferHandle ;
1417use vortex_array:: validity:: Validity ;
1518use vortex_error:: VortexResult ;
@@ -41,10 +44,7 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
4144 ) -> VortexResult < ArrowDeviceArray > {
4245 let cuda_array = array. execute_cuda ( ctx) . await ?;
4346
44- let ( arrow_array, sync_event) = match cuda_array {
45- Canonical :: Primitive ( primitive) => export_primitive ( primitive, ctx) . await ?,
46- c => todo ! ( "implement support for exporting {}" , c. dtype( ) ) ,
47- } ;
47+ let ( arrow_array, sync_event) = export_canonical ( cuda_array, ctx) . await ?;
4848
4949 Ok ( ArrowDeviceArray {
5050 array : arrow_array,
@@ -56,6 +56,85 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
5656 }
5757}
5858
59+ fn export_canonical (
60+ cuda_array : Canonical ,
61+ ctx : & mut CudaExecutionCtx ,
62+ ) -> BoxFuture < ' _ , VortexResult < ( ArrowArray , SyncEvent ) > > {
63+ Box :: pin ( async {
64+ match cuda_array {
65+ Canonical :: Struct ( struct_array) => export_struct ( struct_array, ctx) . await ,
66+ Canonical :: Primitive ( primitive) => export_primitive ( primitive, ctx) . await ,
67+ c => todo ! ( "support for exporting {} arrays" , c. dtype( ) ) ,
68+ }
69+ } )
70+ }
71+
72+ async fn export_struct (
73+ array : StructArray ,
74+ ctx : & mut CudaExecutionCtx ,
75+ ) -> VortexResult < ( ArrowArray , SyncEvent ) > {
76+ let len = array. len ( ) ;
77+ let StructArrayParts {
78+ validity, fields, ..
79+ } = array. into_parts ( ) ;
80+
81+ let null_count = match validity {
82+ Validity :: NonNullable | Validity :: AllValid => 0 ,
83+ Validity :: AllInvalid => len,
84+ Validity :: Array ( _) => {
85+ vortex_bail ! ( "Exporting PrimitiveArray with non-trivial validity not supported yet" )
86+ }
87+ } ;
88+
89+ // We need the children to be held across await points.
90+ let mut children = Vec :: with_capacity ( fields. len ( ) ) ;
91+
92+ for field in fields. iter ( ) {
93+ let cuda_field = field. clone ( ) . execute_cuda ( ctx) . await ?;
94+ let ( arrow_field, _) = export_canonical ( cuda_field, ctx) . await ?;
95+ children. push ( arrow_field) ;
96+ }
97+
98+ let cuda_event = ctx
99+ . stream ( )
100+ . record_event ( None )
101+ . map_err ( |_| vortex_err ! ( "failed to create cudaEvent_t" ) ) ?;
102+
103+ let children = children
104+ . into_iter ( )
105+ . map ( |array| Box :: into_raw ( Box :: new ( array) ) )
106+ . collect :: < Box < [ _ ] > > ( ) ;
107+
108+ let buffer_ptrs = vec ! [ sys:: CUdeviceptr :: default ( ) ] . into_boxed_slice ( ) ;
109+
110+ let mut private_data = Box :: new ( PrivateData {
111+ cuda_stream : Arc :: clone ( ctx. stream ( ) ) ,
112+ buffers : Box :: new ( [ None ] ) ,
113+ buffer_ptrs,
114+ cuda_event_ptr : cuda_event. cu_event ( ) . cast ( ) ,
115+ cuda_event,
116+ children,
117+ } ) ;
118+
119+ let sync_event: SyncEvent = NonNull :: new ( & raw mut private_data. cuda_event_ptr ) ;
120+
121+ // Populate the ArrowArray with the child arrays.
122+ let mut arrow_struct = ArrowArray :: empty ( ) ;
123+ arrow_struct. length = len as i64 ;
124+ arrow_struct. null_count = null_count as i64 ;
125+ arrow_struct. n_children = fields. len ( ) as i64 ;
126+ arrow_struct. children = private_data. children . as_mut_ptr ( ) ;
127+
128+ // StructArray _can_ contain a validity buffer. In this case, we just write a null pointer
129+ // for it.
130+ arrow_struct. n_buffers = 1 ;
131+ arrow_struct. buffers = private_data. buffer_ptrs . as_mut_ptr ( ) ;
132+ arrow_struct. release = Some ( release_array) ;
133+ arrow_struct. private_data = Box :: into_raw ( private_data) . cast ( ) ;
134+
135+ Ok ( ( arrow_struct, sync_event) )
136+ }
137+
59138async fn export_primitive (
60139 array : PrimitiveArray ,
61140 ctx : & mut CudaExecutionCtx ,
@@ -109,6 +188,7 @@ async fn export_primitive(
109188
110189 let mut private_data = Box :: new ( PrivateData {
111190 cuda_stream : Arc :: clone ( ctx. stream ( ) ) ,
191+ children : Box :: new ( [ ] ) ,
112192 buffers,
113193 buffer_ptrs,
114194 cuda_event_ptr : cuda_event. cu_event ( ) . cast ( ) ,
@@ -145,7 +225,12 @@ unsafe extern "C" fn release_array(array: *mut ArrowArray) {
145225 std:: ptr:: replace ( & raw mut ( * array) . private_data , std:: ptr:: null_mut ( ) ) ;
146226
147227 if !private_data_ptr. is_null ( ) {
148- drop ( Box :: from_raw ( private_data_ptr. cast :: < PrivateData > ( ) ) ) ;
228+ let mut private_data = Box :: from_raw ( private_data_ptr. cast :: < PrivateData > ( ) ) ;
229+ let children = std:: mem:: take ( & mut private_data. children ) ;
230+ for child in children {
231+ release_array ( child) ;
232+ }
233+ drop ( private_data) ;
149234 }
150235
151236 // update the release function to NULL to avoid any possibility of double-frees.
0 commit comments