Add JAX FFI Host support#1446
Conversation
This allows JAX FFI callbacks to run on CPU (Host) in addition to CUDA. Signed-off-by: Ankit Jain <kitsrish@google.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThis pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a ChangesCUDA and Host FFI Execution
🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
Greptile SummaryThis PR extends JAX FFI Host (CPU) support in Warp by registering both
Confidence Score: 3/5The Host execution path for FfiCallable is structurally sound, but the FfiKernel CPU launch and the register_ffi_callback stream extraction remain broken for CPU dispatch. The wp_cpu_launch_kernel call inside FfiKernel.ffi_callback passes device.context as the function pointer, launch_bounds.size (an integer) as the bounds reference, and omits adj_args and apic_info — every Host FfiKernel invocation will crash or corrupt memory. ExecutionContext in register_ffi_callback still unconditionally calls get_stream_from_callframe, dereferencing a CUDA-specific XLA API on the Host platform. warp/_src/jax_experimental/ffi.py — specifically the wp_cpu_launch_kernel call site in FfiKernel.ffi_callback and the ExecutionContext construction inside register_ffi_callback's ffi_callback closure. Important Files Changed
Reviews (5): Last reviewed commit: "[pre-commit.ci] auto code formatting" | Re-trigger Greptile |
| self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) | ||
| ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) | ||
| ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) | ||
| jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") |
There was a problem hiding this comment.
Self-referential
NameError in FfiKernel.__init__: ffi_capsule_host is passed as its own argument to pycapsule before the name is ever assigned. This will raise NameError: name 'ffi_capsule_host' is not defined every time an FfiKernel is instantiated, making the entire Host platform registration dead code. The address variable ffi_ccall_address_host should be used here instead — consistent with how the CUDA path is written above.
| self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) | |
| ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) | |
| ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) | |
| jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") | |
| self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) | |
| ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) | |
| ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) | |
| jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") |
| ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) | ||
| ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) | ||
| jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") |
There was a problem hiding this comment.
Same self-referential
NameError in register_ffi_callback: ffi_capsule_host is referenced before it is assigned. This causes every call to register_ffi_callback to raise NameError: name 'ffi_capsule_host' is not defined, so no Host target is ever registered. The fix mirrors the working CUDA block two lines above.
| ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) | |
| ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value) | |
| jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") | |
| ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) | |
| ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) | |
| jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c
📒 Files selected for processing (1)
warp/_src/jax_experimental/ffi.py
Signed-off-by: Ankit Jain <kitsrish@google.com>
Applied the patch from the original CL as requested. Signed-off-by: Ankit Jain <kitsrish@google.com>
Fix a typo where ffi_capsule_host used its own unassigned value instead of ffi_ccall_address_host.value. Signed-off-by: Ankit Jain <kitsrish@google.com>
Restore files to their state in main branch to keep this PR focused on JAX FFI changes. Signed-off-by: Ankit Jain <kitsrish@google.com>
| wp._src.context.runtime.core.wp_cpu_launch_kernel( | ||
| device.context, | ||
| hooks.forward, | ||
| launch_bounds.size, | ||
| kernel_params, | ||
| ) |
There was a problem hiding this comment.
Wrong arguments passed to
wp_cpu_launch_kernel
The call is missing one argument and the arguments are in the wrong positions. The registered ctypes signature (in context.py) is (func, bounds, args, adj_args, apic_info), but the call here passes device.context as func, hooks.forward as bounds, launch_bounds.size (an integer, not a pointer) as args, and kernel_params as adj_args, with apic_info omitted entirely. The correct call should place the casted hooks.forward function pointer as the first argument, a reference to launch_bounds as the second, and kernel_params as the third args pointer, with None for adj_args and apic_info. As written, this will pass garbage pointers to the native kernel launcher, causing a crash or silent memory corruption on every Host-platform FfiKernel call.
|
pre-commit.ci autofix |
Add four tests that exercise the Host platform path for FFI: - test_ffi_jax_kernel_add_host: basic two-input one-output kernel - test_ffi_jax_kernel_sincos_host: one-input two-output kernel - test_ffi_jax_kernel_in_out_host: in-out argument handling - test_ffi_jax_callable_scale_constant_host: jax_callable with scalar constant Signed-off-by: Ankit Jain <kitsrish@google.com>
Add test coverage for jax_kernel and jax_callable running on the CPU Host platform. Tests cover basic add, sincos (multi-output), in-out args, scalar constant args, and callable variants. Signed-off-by: Ankit Jain <kitsrish@google.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be
📒 Files selected for processing (1)
warp/tests/interop/test_jax.py
|
pre-commit.ci autofix |
| add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices) | ||
|
|
||
| # ffi Host (CPU) tests | ||
| add_function_test(TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None) |
There was a problem hiding this comment.
These Host tests only use jax.devices("cpu"), but they are registered inside the CUDA JAX availability gate. Please move them under CPU JAX availability so CPU-only CI covers the Host backend.
| None, # apic_info | ||
| ) | ||
| else: | ||
| wp._src.context.runtime.core.wp_cpu_launch_kernel( |
There was a problem hiding this comment.
This Host path is using the CUDA launch ABI for wp_cpu_launch_kernel. The CPU binding expects (func, bounds, args, adj_args, apic_info), with kernel args packed into the CPU args struct. As written, Host jax_kernel raises TypeError: this function takes at least 5 arguments (4 given) instead of launching.
|
Please add a |
|
Please squash this PR down to a single coherent commit before merge. |
|
Please rebase onto current |
nvlukasz
left a comment
There was a problem hiding this comment.
Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.
| assert num_inputs == self.num_inputs | ||
| assert num_outputs == self.num_outputs | ||
|
|
||
| if platform == "Host": |
There was a problem hiding this comment.
Can we deduplicate this argument reconstruction code? Perhaps extract to a helper function.
| # call the Python function with reconstructed arguments | ||
| with wp.ScopedStream(stream, sync_enter=False): | ||
| if stream.is_capturing: | ||
| with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): |
There was a problem hiding this comment.
No need to check stream here, since the host path returns early above. If we get here, the stream is not None.
| call_desc.capture = capture | ||
|
|
||
| elif self.graph_mode == GraphMode.WARP: | ||
| elif self.graph_mode == GraphMode.WARP and device.is_cuda: |
There was a problem hiding this comment.
No need to check device.is_cuda here, since the host part returns early above. Same comment on a few more lines below.
@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay. |
Description
This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.
Changes:
"CUDA"and"Host"platforms inregister_ffi_callback."Host"platform case inffi_callbackby using the CPU device and bypassing CUDA-specific features like streams and graphs.FfiCallableto reconstruct arguments and execute the function on the CPU when running on the Host platform.Checklist
Test plan
You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).