Skip to content

Commit a06184d

Browse files
authored
add struct array for C export (#6294)
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 75cb1ee commit a06184d

File tree

3 files changed

+119
-10
lines changed

3 files changed

+119
-10
lines changed

vortex-cuda/cudf-test/src/lib.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33

44
//! This file is a simple C-compatible API that is called from the cudf-test-harness at CI time.
55
6-
#![allow(clippy::unwrap_used)]
6+
#![allow(clippy::unwrap_used, clippy::expect_used)]
77

88
use std::sync::LazyLock;
99

10-
use arrow_schema::DataType;
1110
use arrow_schema::ffi::FFI_ArrowSchema;
1211
use futures::executor::block_on;
12+
use vortex::array::Array;
1313
use vortex::array::IntoArray;
1414
use vortex::array::arrays::PrimitiveArray;
15+
use vortex::array::arrays::StructArray;
1516
use vortex::array::session::ArraySession;
17+
use vortex::array::validity::Validity;
18+
use vortex::dtype::FieldNames;
1619
use vortex::expr::session::ExprSession;
1720
use vortex::io::session::RuntimeSession;
1821
use vortex::layout::session::LayoutSession;
@@ -42,9 +45,22 @@ pub extern "C" fn export_array(
4245

4346
let primitive = PrimitiveArray::from_iter(0u32..1024);
4447

45-
*schema_ptr = FFI_ArrowSchema::try_from(DataType::UInt32).unwrap();
48+
let array = StructArray::new(
49+
FieldNames::from_iter(["a"]),
50+
vec![primitive.into_array()],
51+
1024,
52+
Validity::NonNullable,
53+
)
54+
.into_array();
4655

47-
match block_on(primitive.into_array().export_device_array(&mut ctx)) {
56+
let data_type = array
57+
.dtype()
58+
.to_arrow_dtype()
59+
.expect("converting schema to Arrow DataType");
60+
61+
*schema_ptr = FFI_ArrowSchema::try_from(data_type).expect("data_type to FFI_ArrowSchema");
62+
63+
match block_on(array.export_device_array(&mut ctx)) {
4864
Ok(exported) => {
4965
*array_ptr = exported;
5066
0

vortex-cuda/src/arrow/canonical.rs

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@ use std::sync::Arc;
66

77
use async_trait::async_trait;
88
use cudarc::driver::sys;
9+
use futures::future::BoxFuture;
910
use vortex_array::ArrayRef;
1011
use vortex_array::Canonical;
1112
use vortex_array::arrays::PrimitiveArray;
1213
use vortex_array::arrays::PrimitiveArrayParts;
14+
use vortex_array::arrays::StructArray;
15+
use vortex_array::arrays::StructArrayParts;
1316
use vortex_array::buffer::BufferHandle;
1417
use vortex_array::validity::Validity;
1518
use vortex_error::VortexResult;
@@ -41,10 +44,7 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
4144
) -> VortexResult<ArrowDeviceArray> {
4245
let cuda_array = array.execute_cuda(ctx).await?;
4346

44-
let (arrow_array, sync_event) = match cuda_array {
45-
Canonical::Primitive(primitive) => export_primitive(primitive, ctx).await?,
46-
c => todo!("implement support for exporting {}", c.dtype()),
47-
};
47+
let (arrow_array, sync_event) = export_canonical(cuda_array, ctx).await?;
4848

4949
Ok(ArrowDeviceArray {
5050
array: arrow_array,
@@ -56,6 +56,85 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
5656
}
5757
}
5858

59+
fn export_canonical(
60+
cuda_array: Canonical,
61+
ctx: &mut CudaExecutionCtx,
62+
) -> BoxFuture<'_, VortexResult<(ArrowArray, SyncEvent)>> {
63+
Box::pin(async {
64+
match cuda_array {
65+
Canonical::Struct(struct_array) => export_struct(struct_array, ctx).await,
66+
Canonical::Primitive(primitive) => export_primitive(primitive, ctx).await,
67+
c => todo!("support for exporting {} arrays", c.dtype()),
68+
}
69+
})
70+
}
71+
72+
async fn export_struct(
73+
array: StructArray,
74+
ctx: &mut CudaExecutionCtx,
75+
) -> VortexResult<(ArrowArray, SyncEvent)> {
76+
let len = array.len();
77+
let StructArrayParts {
78+
validity, fields, ..
79+
} = array.into_parts();
80+
81+
let null_count = match validity {
82+
Validity::NonNullable | Validity::AllValid => 0,
83+
Validity::AllInvalid => len,
84+
Validity::Array(_) => {
85+
vortex_bail!("Exporting PrimitiveArray with non-trivial validity not supported yet")
86+
}
87+
};
88+
89+
// We need the children to be held across await points.
90+
let mut children = Vec::with_capacity(fields.len());
91+
92+
for field in fields.iter() {
93+
let cuda_field = field.clone().execute_cuda(ctx).await?;
94+
let (arrow_field, _) = export_canonical(cuda_field, ctx).await?;
95+
children.push(arrow_field);
96+
}
97+
98+
let cuda_event = ctx
99+
.stream()
100+
.record_event(None)
101+
.map_err(|_| vortex_err!("failed to create cudaEvent_t"))?;
102+
103+
let children = children
104+
.into_iter()
105+
.map(|array| Box::into_raw(Box::new(array)))
106+
.collect::<Box<[_]>>();
107+
108+
let buffer_ptrs = vec![sys::CUdeviceptr::default()].into_boxed_slice();
109+
110+
let mut private_data = Box::new(PrivateData {
111+
cuda_stream: Arc::clone(ctx.stream()),
112+
buffers: Box::new([None]),
113+
buffer_ptrs,
114+
cuda_event_ptr: cuda_event.cu_event().cast(),
115+
cuda_event,
116+
children,
117+
});
118+
119+
let sync_event: SyncEvent = NonNull::new(&raw mut private_data.cuda_event_ptr);
120+
121+
// Populate the ArrowArray with the child arrays.
122+
let mut arrow_struct = ArrowArray::empty();
123+
arrow_struct.length = len as i64;
124+
arrow_struct.null_count = null_count as i64;
125+
arrow_struct.n_children = fields.len() as i64;
126+
arrow_struct.children = private_data.children.as_mut_ptr();
127+
128+
// StructArray _can_ contain a validity buffer. In this case, we just write a null pointer
129+
// for it.
130+
arrow_struct.n_buffers = 1;
131+
arrow_struct.buffers = private_data.buffer_ptrs.as_mut_ptr();
132+
arrow_struct.release = Some(release_array);
133+
arrow_struct.private_data = Box::into_raw(private_data).cast();
134+
135+
Ok((arrow_struct, sync_event))
136+
}
137+
59138
async fn export_primitive(
60139
array: PrimitiveArray,
61140
ctx: &mut CudaExecutionCtx,
@@ -109,6 +188,7 @@ async fn export_primitive(
109188

110189
let mut private_data = Box::new(PrivateData {
111190
cuda_stream: Arc::clone(ctx.stream()),
191+
children: Box::new([]),
112192
buffers,
113193
buffer_ptrs,
114194
cuda_event_ptr: cuda_event.cu_event().cast(),
@@ -145,7 +225,12 @@ unsafe extern "C" fn release_array(array: *mut ArrowArray) {
145225
std::ptr::replace(&raw mut (*array).private_data, std::ptr::null_mut());
146226

147227
if !private_data_ptr.is_null() {
148-
drop(Box::from_raw(private_data_ptr.cast::<PrivateData>()));
228+
let mut private_data = Box::from_raw(private_data_ptr.cast::<PrivateData>());
229+
let children = std::mem::take(&mut private_data.children);
230+
for child in children {
231+
release_array(child);
232+
}
233+
drop(private_data);
149234
}
150235

151236
// update the release function to NULL to avoid any possibility of double-frees.

vortex-cuda/src/arrow/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ pub struct ArrowDeviceArray {
6666
_reserved: [i64; 3],
6767
}
6868

69+
unsafe impl Send for ArrowDeviceArray {}
70+
unsafe impl Sync for ArrowDeviceArray {}
71+
6972
/// An FFI-compatible version of the ArrowArray that holds pointers to device buffers.
7073
#[repr(C)]
7174
#[derive(Debug)]
@@ -77,7 +80,8 @@ pub(crate) struct ArrowArray {
7780
n_children: i64,
7881
buffers: *mut sys::CUdeviceptr,
7982
children: *mut *mut ArrowArray,
80-
dictionary: *mut ArrowArray,
83+
// NOTE: we don't support exporting dictionary arrays, so we leave this as an opaque pointer.
84+
dictionary: *mut (),
8185
release: Option<unsafe extern "C" fn(arg1: *mut ArrowArray)>,
8286
// When exported, this MUST contain everything that is owned by this array.
8387
// for example, any buffer pointed to in `buffers` must be here, as well
@@ -105,6 +109,9 @@ impl ArrowArray {
105109
}
106110
}
107111

112+
unsafe impl Send for ArrowArray {}
113+
unsafe impl Sync for ArrowArray {}
114+
108115
#[expect(
109116
unused,
110117
reason = "cuda_stream and cuda_buffers need to have deferred drop"
@@ -120,6 +127,7 @@ pub(crate) struct PrivateData {
120127
pub(crate) buffer_ptrs: Box<[sys::CUdeviceptr]>,
121128
pub(crate) cuda_event: CudaEvent,
122129
pub(crate) cuda_event_ptr: cudaEvent_t,
130+
pub(crate) children: Box<[*mut ArrowArray]>,
123131
}
124132

125133
#[async_trait]

0 commit comments

Comments
 (0)