Skip to content

Add JAX FFI Host support#1446

Draft
loney7 wants to merge 10 commits into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support
Draft

Add JAX FFI Host support#1446
loney7 wants to merge 10 commits into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support

Conversation

@loney7
Copy link
Copy Markdown

@loney7 loney7 commented May 8, 2026

Description

This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.

Changes:

  • Registered FFI targets for both "CUDA" and "Host" platforms in register_ffi_callback.
  • Handled the "Host" platform case in ffi_callback by using the CPU device and bypassing CUDA-specific features like streams and graphs.
  • Updated FfiCallable to reconstruct arguments and execute the function on the CPU when running on the Host platform.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Test plan

You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).

uv run warp/tests/interop/test_jax.py


<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Separate CUDA and CPU execution paths for FFI callbacks with device-aware execution scoping.
  * CUDA graph compatibility enabled only for CUDA execution.
  * CPU/host path can take a direct Python execution route for faster host calls.
  * Separate CUDA and Host callback registrations for JAX interop.

* **Tests**
  * Added CPU/host FFI tests covering kernels and callables (add, sincos, in/out args, scale).
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

This allows JAX FFI callbacks to run on CPU (Host) in addition to CUDA.

Signed-off-by: Ankit Jain <kitsrish@google.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 8, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 4309af4c-4ae6-4c4b-a771-7485ec1b3c0e

📥 Commits

Reviewing files that changed from the base of the PR and between f6048f4 and afe74c8.

📒 Files selected for processing (1)
  • warp/tests/interop/test_jax.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • warp/tests/interop/test_jax.py

📝 Walkthrough

Walkthrough

This pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a platform parameter, register separate FFI targets for each platform, and conditionally dispatch to platform-appropriate device and kernel launch logic. CUDA-specific traits and graph capture modes are guarded by platform checks.

Changes

CUDA and Host FFI Execution

Layer / File(s) Summary
Callback Protocol & Platform Parameter
warp/_src/jax_experimental/ffi.py
FfiKernel.ffi_callback(), FfiCallable.ffi_callback(), and register_ffi_callback() callbacks gain platform="CUDA" parameter. CUDA graph compatibility traits are now enabled only when platform=="CUDA".
Dual-Platform FFI Registration
warp/_src/jax_experimental/ffi.py
FfiKernel and FfiCallable register separate FFI capsules for both CUDA and Host platforms, each passing the appropriate platform argument. register_ffi_callback() stores capsules under distinct registry keys (_cuda and _host suffixes).
FfiKernel Platform-Conditional Launch
warp/_src/jax_experimental/ffi.py
FfiKernel execution branches by platform: CUDA selects CUDA device and stream, calls wp_cuda_launch_kernel; Host uses CPU device with no stream, calls wp_cpu_launch_kernel.
FfiCallable Host Execution Short-Circuit
warp/_src/jax_experimental/ffi.py
When platform=="Host", FfiCallable reconstructs Warp arrays on the CPU device from the call frame and directly invokes the wrapped Python function, bypassing all CUDA graph logic.
FfiCallable CUDA Execution & Graph Capture
warp/_src/jax_experimental/ffi.py
CUDA graph modes are guarded by device.is_cuda. Execution scopes conditionally use ScopedStream (when stream present) or ScopedDevice, restricting graph capture and replay to CUDA devices only.
CPU Host Tests
warp/tests/interop/test_jax.py
New CPU-only jax_kernel and jax_callable tests added and registered in TestJax with devices=None.

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add JAX FFI Host support' directly and clearly summarizes the main change: adding Host/CPU platform support to JAX FFI callbacks alongside existing CUDA support.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@greptile-apps
Copy link
Copy Markdown

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR extends JAX FFI Host (CPU) support in Warp by registering both "CUDA" and "Host" FFI targets for FfiKernel, FfiCallable, and register_ffi_callback, and adding a separate CPU execution code path in each callback.

  • FfiKernel / FfiCallable / register_ffi_callback: Dual ctypes wrapper functions are now created and registered via jax.ffi.register_ffi_target for both platforms; CUDA-graph traits and stream retrieval are guarded on platform == \"CUDA\"; the Host path early-returns after calling self.func directly under wp.ScopedDevice(\"cpu\").
  • Tests: Six new host-only test cases cover jax_kernel and jax_callable on CPU for addition, sincos, in-out args, and scalar scale operations.
  • Known outstanding issues: The wp_cpu_launch_kernel call inside FfiKernel.ffi_callback still passes arguments in the wrong order/type, and register_ffi_callback's ExecutionContext constructor still unconditionally invokes get_stream_from_callframe on the Host platform.

Confidence Score: 3/5

The Host execution path for FfiCallable is structurally sound, but the FfiKernel CPU launch and the register_ffi_callback stream extraction remain broken for CPU dispatch.

The wp_cpu_launch_kernel call inside FfiKernel.ffi_callback passes device.context as the function pointer, launch_bounds.size (an integer) as the bounds reference, and omits adj_args and apic_info — every Host FfiKernel invocation will crash or corrupt memory. ExecutionContext in register_ffi_callback still unconditionally calls get_stream_from_callframe, dereferencing a CUDA-specific XLA API on the Host platform.

warp/_src/jax_experimental/ffi.py — specifically the wp_cpu_launch_kernel call site in FfiKernel.ffi_callback and the ExecutionContext construction inside register_ffi_callback's ffi_callback closure.

Important Files Changed

Filename Overview
warp/_src/jax_experimental/ffi.py Adds Host/CPU execution paths to FfiKernel, FfiCallable, and register_ffi_callback; the FfiKernel CPU launch call still passes arguments in the wrong order and wrong types to wp_cpu_launch_kernel, and ExecutionContext in register_ffi_callback still unconditionally calls get_stream_from_callframe for Host calls.
warp/tests/interop/test_jax.py Adds six new CPU/host FFI test cases for kernel and callable paths; tests are well-structured and cover add, sincos, in-out, and scalar-scale scenarios.

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto code formatting" | Re-trigger Greptile

Comment on lines +226 to +229
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Self-referential NameError in FfiKernel.__init__: ffi_capsule_host is passed as its own argument to pycapsule before the name is ever assigned. This will raise NameError: name 'ffi_capsule_host' is not defined every time an FfiKernel is instantiated, making the entire Host platform registration dead code. The address variable ffi_ccall_address_host should be used here instead — consistent with how the CUDA path is written above.

Suggested change
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")
self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host"))
ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host")

Comment on lines +1775 to +1777
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Same self-referential NameError in register_ffi_callback: ffi_capsule_host is referenced before it is assigned. This causes every call to register_ffi_callback to raise NameError: name 'ffi_capsule_host' is not defined, so no Host target is ever registered. The fix mirrors the working CUDA block two lines above.

Suggested change
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_capsule_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")
ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p)
ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value)
jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host")

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c

📥 Commits

Reviewing files that changed from the base of the PR and between 54327e3 and 09569ba.

📒 Files selected for processing (1)
  • warp/_src/jax_experimental/ffi.py

Comment thread warp/_src/jax_experimental/ffi.py
Comment thread warp/_src/jax_experimental/ffi.py
Comment thread warp/_src/jax_experimental/ffi.py
Ankit Jain added 4 commits May 9, 2026 00:50
Signed-off-by: Ankit Jain <kitsrish@google.com>
Applied the patch from the original CL as requested.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Fix a typo where ffi_capsule_host used its own unassigned value instead of ffi_ccall_address_host.value.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Restore files to their state in main branch to keep this PR focused on JAX FFI changes.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Comment on lines +464 to +469
wp._src.context.runtime.core.wp_cpu_launch_kernel(
device.context,
hooks.forward,
launch_bounds.size,
kernel_params,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Wrong arguments passed to wp_cpu_launch_kernel

The call is missing one argument and the arguments are in the wrong positions. The registered ctypes signature (in context.py) is (func, bounds, args, adj_args, apic_info), but the call here passes device.context as func, hooks.forward as bounds, launch_bounds.size (an integer, not a pointer) as args, and kernel_params as adj_args, with apic_info omitted entirely. The correct call should place the casted hooks.forward function pointer as the first argument, a reference to launch_bounds as the second, and kernel_params as the third args pointer, with None for adj_args and apic_info. As written, this will pass garbage pointers to the native kernel launcher, causing a crash or silent memory corruption on every Host-platform FfiKernel call.

@loney7
Copy link
Copy Markdown
Author

loney7 commented May 8, 2026

pre-commit.ci autofix

pre-commit-ci Bot and others added 3 commits May 8, 2026 23:56
Add four tests that exercise the Host platform path for FFI:
- test_ffi_jax_kernel_add_host: basic two-input one-output kernel
- test_ffi_jax_kernel_sincos_host: one-input two-output kernel
- test_ffi_jax_kernel_in_out_host: in-out argument handling
- test_ffi_jax_callable_scale_constant_host: jax_callable with scalar constant

Signed-off-by: Ankit Jain <kitsrish@google.com>
Add test coverage for jax_kernel and jax_callable running on the CPU
Host platform. Tests cover basic add, sincos (multi-output), in-out
args, scalar constant args, and callable variants.

Signed-off-by: Ankit Jain <kitsrish@google.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be

📥 Commits

Reviewing files that changed from the base of the PR and between d39169c and f6048f4.

📒 Files selected for processing (1)
  • warp/tests/interop/test_jax.py

Comment thread warp/tests/interop/test_jax.py
@loney7
Copy link
Copy Markdown
Author

loney7 commented May 9, 2026

pre-commit.ci autofix

@shi-eric shi-eric requested a review from nvlukasz May 12, 2026 16:45
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)

# ffi Host (CPU) tests
add_function_test(TestJax, "test_ffi_jax_kernel_host_add", test_ffi_jax_kernel_host_add, devices=None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These Host tests only use jax.devices("cpu"), but they are registered inside the CUDA JAX availability gate. Please move them under CPU JAX availability so CPU-only CI covers the Host backend.

None, # apic_info
)
else:
wp._src.context.runtime.core.wp_cpu_launch_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Host path is using the CUDA launch ABI for wp_cpu_launch_kernel. The CPU binding expects (func, bounds, args, adj_args, apic_info), with kernel args packed into the CPU args struct. As written, Host jax_kernel raises TypeError: this function takes at least 5 arguments (4 given) instead of launching.

@shi-eric
Copy link
Copy Markdown
Contributor

Please add a CHANGELOG.md entry under Unreleased for this JAX Host FFI behavior change.

@shi-eric
Copy link
Copy Markdown
Contributor

Please squash this PR down to a single coherent commit before merge.

@shi-eric
Copy link
Copy Markdown
Contributor

Please rebase onto current main and move the fix/tests to the promoted JAX code paths. Commit 604a8961df6d40ea64ff1e740b23581e4c72c96f promoted the JAX code from jax_experimental to jax after this PR was opened, so the final diff should target the current locations.

Copy link
Copy Markdown
Contributor

@nvlukasz nvlukasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

assert num_inputs == self.num_inputs
assert num_outputs == self.num_outputs

if platform == "Host":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we deduplicate this argument reconstruction code? Perhaps extract to a helper function.

# call the Python function with reconstructed arguments
with wp.ScopedStream(stream, sync_enter=False):
if stream.is_capturing:
with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check stream here, since the host path returns early above. If we get here, the stream is not None.

call_desc.capture = capture

elif self.graph_mode == GraphMode.WARP:
elif self.graph_mode == GraphMode.WARP and device.is_cuda:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check device.is_cuda here, since the host part returns early above. Same comment on a few more lines below.

@loney7
Copy link
Copy Markdown
Author

loney7 commented May 19, 2026

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants