-
|
Hi, I am using In my experiments, output arrays allocated via Looking at Questions:
Here is the pattern I am currently exploring: import jax
import jax.numpy as jnp
import warp as wp
import warp.jax_experimental
from jax import Array
@wp.kernel
def warp_kernel(
inputs: wp.array(dtype=wp.float32), outputs: wp.array1d(dtype=wp.float32)
) -> None:
idx = wp.tid()
# # potentially unsafe if outputs[0] is garbage
outputs[0] += inputs[idx]
jax_kernel = warp.jax_experimental.jax_kernel(warp_kernel)
@jax.jit
def fun(x: Array) -> Array:
y: Array
(y,) = jax_kernel(x, output_dims={"outputs": (1,)}, launch_dims=x.shape)
return y[0]
def main() -> None:
N: int = 10000000
x: Array = jnp.ones((N,))
y: Array = fun(x)
assert jnp.allclose(y, jnp.sum(x))
if __name__ == "__main__":
main()Thanks for the help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Hi @liblaf, You are correct, Warp doesn't explicitly zero-initialize the output buffers. That would be wasteful for pure output buffers that would just be overwritten by the kernel launch. I don't know whether XLA initializes the output buffers in any way, but I wouldn't count on it (for the same reasons). The best way to ensure initialization it is to use in-out arguments. You can initialize the in-out array however you need (not just zero), then modify it in the kernel. @wp.kernel
def warp_kernel(
inputs: wp.array(dtype=wp.float32), output: wp.array(dtype=wp.float32)
) -> None:
idx = wp.tid()
wp.atomic_add(output, 0, inputs[idx]) # <--- note atomic_add()
jax_kernel = warp.jax_experimental.jax_kernel(
warp_kernel,
in_out_argnames=["output"], # <--- note in_out_argnames
)
@jax.jit
def fun(x: Array) -> Array:
output: Array = jnp.zeros(1)
(y,) = jax_kernel(x, output)
return y[0]
def main() -> None:
N: int = 10000000
x: Array = jnp.ones((N,))
y: Array = fun(x)
assert jnp.allclose(y, jnp.sum(x))
if __name__ == "__main__":
main()Note that JAX doesn't allow modifying arrays in-place, so the output array is a modified copy of the input. Also note that |
Beta Was this translation helpful? Give feedback.
Hi @liblaf,
You are correct, Warp doesn't explicitly zero-initialize the output buffers. That would be wasteful for pure output buffers that would just be overwritten by the kernel launch. I don't know whether XLA initializes the output buffers in any way, but I wouldn't count on it (for the same reasons).
The best way to ensure initialization it is to use in-out arguments. You can initialize the in-out array however you need (not just zero), then modify it in the kernel.