feat(runtime): TRT-RTX runtime cache, dynamic shapes strategy, and native CUDA graph support on the C++ runtime#2
Open
tp5uiuc wants to merge 8 commits into
Conversation
tp5uiuc
commented
May 28, 2026
tp5uiuc
commented
May 28, 2026
ffc62ad to
53470af
Compare
…y, and native CUDA graph support to C++ runtime - Introduce IRuntimeConfig scaffolding and bump ABI to v9 - Add runtime cache to C++ runtime for TensorRT-RTX - Add dynamic shapes kernel specialization strategy to C++ runtime - Add TensorRT-RTX native CUDA graph strategy to C++ runtime - Extract TRTRuntimeConfig - Consolidate C++ runtime tests and add model-level coverage
…xecution_context release_nccl_comm() previously rebuilt the IExecutionContext via direct calls to ICudaEngine::createExecutionContext, bypassing the TRTRuntimeConfig plumbing introduced earlier in this PR. On that path the RTX runtime cache was not flushed before context teardown, and the dynamic shapes kernel specialization and CUDA graph strategies stored on TRTRuntimeConfig were not re-applied to the new context. Delegate to recreate_execution_context() instead. It saves the runtime cache, ensures TRTRuntimeConfig is initialized, sets the allocation strategy from resource_allocation_strategy, and creates the new exec context via createExecutionContext(runtime_cfg.config.get()), keeping all strategies live across the NCCL bind/release cycle.
cuda_graph_strategy and dynamic_shapes_kernel_specialization_strategy are TRT-RTX-only at runtime, but they are accepted on every build through the public compile() / CompilationSettings surface. Their string-to-enum lookup lived inside the 'if ENABLED_FEATURES.tensorrt_rtx:' block in _pack_engine_info(), so on a standard (non-RTX) build a typo like cuda_graph_strategy="wholee_graph_capture" was silently dropped instead of raising. Hoist the membership check into TorchTensorRTModule.__init__ so that invalid strategy names always raise ValueError, regardless of backend. The RTX-gated index population in _pack_engine_info() keeps reading the maps unchanged -- only the redundant validation moves. Fixes the L1 dynamo core tests on standard-TensorRT Windows: TestCudaGraphStrategyInvalidValue::test_invalid_strategy_raises TestDynamicShapesKernelStrategyCppInvalidValue::test_invalid_strategy_raises
The C++ runtime config introduced in this branch unconditionally referenced
nvinfer1::IRuntimeConfig, which is only available on TensorRT-RTX and on
standard TensorRT >= 10.11. The TensorRT shipped with the Jetpack l4t-r36.4
toolchain (@tensorrt_l4t) predates 10.11 and does not export this type, so
the aarch64-jetpack build fails:
./core/runtime/TRTRuntimeConfig.h:47:29: error: 'IRuntimeConfig' is not
a member of 'nvinfer1'
Inject a TRT_HAS_IRUNTIME_CONFIG macro from core/runtime/BUILD via a
'defines = select({...})' on //core/runtime:runtime. The macro is set on
every build configuration except :jetpack (RTX, SBSA, Windows, default
x86_64 Linux). This is symmetric with how TRT_MAJOR_RTX and
ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION are already injected per-config
in the same target.
In the C++ sources, gate the IRuntimeConfig-using state with
'#ifdef TRT_HAS_IRUNTIME_CONFIG' inside TRTRuntimeConfig.{h,cpp}, and
expose a single TRTRuntimeConfig::create_execution_context member that
selects the right createExecutionContext overload internally:
- IRuntimeConfig path (>= 10.11 / RTX): set the allocation strategy on
the IRuntimeConfig and call createExecutionContext(IRuntimeConfig*).
- Legacy path (older TRT, e.g. Jetpack): call the legacy
createExecutionContext(ExecutionContextAllocationStrategy) overload
directly. The Jetpack path therefore still respects the user-requested
kDynamic / kSTATIC choice.
Callers in TRTEngine.cpp invoke runtime_cfg.create_execution_context(...)
and stay free of any TRT_HAS_IRUNTIME_CONFIG branching. The previous
public TRTRuntimeConfig::set_execution_context_allocation_strategy method
had only one caller and is removed.
The pre-existing TRT_MAJOR_RTX-gated runtime_cache / dynamic-shapes /
cuda-graph blocks remain a strict subset of TRT_HAS_IRUNTIME_CONFIG, so
behavior on TRT-RTX and on modern standard TensorRT is unchanged.
Note: macro semantics are now 'is the build config named jetpack?'
rather than 'does TRT actually export IRuntimeConfig?'. If @tensorrt_l4t
ever bumps to 10.11+, the BUILD select needs to be updated to flip the
gate on for jetpack.
Add tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py covering the three TRT-RTX features (runtime cache, dynamic shapes kernel strategy, native CUDA graph strategy) when use_python_runtime=False. The Python-runtime tests assert on Python TRTEngine attributes that the C++ engine (torch.classes.tensorrt.Engine) does not expose, so the C++ tests instead verify externally observable behavior: strategy-name typo validation in TorchTensorRTModule.__init__, compile+infer correctness via cosine similarity, and runtime-cache file persistence on destruction.
verify_serialization_fmt iterates over the serialized engine info and fetches the human-readable index name from kSerializedInfoIndexNames. On RTX builds SERIALIZATION_LEN is 15 but only 12 names were initialized, leaving the remaining 3 std::array slots zero-initialized to nullptr. fprintf(\"%s\", name) on a null pointer is undefined behavior and segfaults in practice when an engine is deserialized via the def_pickle path. Gate the three RTX-only names on TRT_MAJOR_RTX to mirror the SerializedInfoIndex enum and keep the array fully initialized on both backends.
3dfb264 to
0e09b02
Compare
…timeConfig
Two review-feedback changes:
1. Revert sink-by-value on pre-existing TRTEngine constructor parameters
(serialized_engine, serialized_metadata, mod_name) back to const-ref. A
broader sink-by-value sweep across all existing fields belongs in a separate
follow-up; this PR only keeps pass-by-value + std::move for the new
TRTRuntimeConfig parameter.
2. Gate the lazy-strategy capturability check on whether the engine actually
has dynamic-shape inputs, mirroring the Python _is_monolithic_capturable
implementation. Static-shape engines remain monolithically capturable under
the lazy strategy because lazy only swaps specialized kernels mid-run on
dynamic-shape inputs.
- New file-local helper engine_has_dynamic_inputs() in TRTEngine.cpp walks
the input bindings (including shape tensors) and reports whether any
dimension is dynamic.
- TRTRuntimeConfig gains a cached bool has_dynamic_inputs (default true so
the conservative branch is taken if the flag is never populated); the
TRTEngine constructor assigns to it once after binding names are known.
- is_monolithic_capturable returns true under kLazy iff has_dynamic_inputs
is false.
Drop the #ifdef TRT_MAJOR_RTX gate on the new SerializedInfoIndex entries so standard TRT and TRT-RTX engines share an identical on-disk layout. A saved program can be inspected and round-tripped across backends without a length mismatch. Add HAS_RUNTIME_CFG_IDX as a sentinel flag immediately before the three TRTRuntimeConfig slots. The producer writes \"1\" iff it authored the next three slots; the consumer treats them as defaults when the flag is \"0\". SERIALIZATION_LEN is now 16 on both backends (was 15 on RTX, 12 on standard). The Python layout-check op table merges the previous RTX-only checks into _LAYOUT_CPP_CHECKS, and register_jit_hooks always exposes the four *_IDX accessors.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR is the C++-runtime counterpart of pytorch#4294. It carries the C++/Bazel side of the three TensorRT-RTX runtime features that already work on the Python runtime through
TRTEngine:IRuntimeCache) loaded on engine setup, saved on destructorwhole_graph_capture)Stacked on top of
reintroduce-rtx-features(PR pytorch#4294). The base diff (this PR's changes) is purely the C++ runtime layer, the engine-info packing, and a smoke-test file — settings declarations, PythonTRTEnginemethods, and the existing Python tests are all owned by the parent branch and are not in this diff.Origin
This replaces pytorch#4202 after the Python-runtime-rework PR (pytorch#4222) landed and changed the runtime layout. Pre-rebase snapshot of pytorch#4202 is preserved at
backup/feat-trtrtx-cpp-runtime-pre-refactor-2026-05-28(headabeca04b).What's in this diff
New C++ files
core/runtime/TRTRuntimeConfig.{h,cpp}— encapsulates all TRT-RTX-onlyIRuntimeConfigstate, runtime-cache I/O, enum helpers, and the RTX-native cudagraph hooks. Confines#ifdef TRT_MAJOR_RTX/#ifdef TRT_HAS_IRUNTIME_CONFIGto one TU.Modified C++ files
core/runtime/runtime.h— addsHAS_RUNTIME_CFG_IDX+ threeTRTRuntimeConfigslots (RUNTIME_CACHE_PATH_IDX,DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX,CUDA_GRAPH_STRATEGY_IDX) toSerializedInfoIndex. The format is unified across standard TRT and TRT-RTX —SERIALIZATION_LENis16on both backends, with the producer settingHAS_RUNTIME_CFG_IDX = "1"only when the next three slots are meaningful, so engines round-trip across backends without a length mismatch.core/runtime/TRTEngine.{h,cpp}— holds aTRTRuntimeConfigby value; new privaterecreate_execution_context()replaces all 8 directcreateExecutionContextcall sites; destructor calls the noexceptsave_runtime_cache()so cached kernels are persisted across runs.core/runtime/execute_engine.cpp—effective_cudagraphs = cudagraphs_enabled && !runtime_cfg.uses_internal_capture(). On RTX with internal capture active and an outer stream capture detected, callsdisable_rtx_native_cudagraphs()one-shot so the outer capture isn't disturbed.core/runtime/register_jit_hooks.cpp— addsHAS_RUNTIME_CFG_IDX()+ three*_IDX()accessors (unconditionally) so Python can validate the layout at load time.core/runtime/BUILD— addsTRTRuntimeConfig.{h,cpp}, theTRT_HAS_IRUNTIME_CONFIGdefine for non-Jetpack configs, andENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATIONas alocal_defineson:rtx_win/:rtx_x86_64so the RTX header's feature gate is satisfied without nesting#ifdefs in the source.Python wiring (minimal)
py/torch_tensorrt/dynamo/runtime/_serialized_engine_layout.py— addsHAS_RUNTIME_CFG_IDX+ the threeTRTRuntimeConfigindex aliases;SERIALIZATION_LEN = 16;_LAYOUT_CPP_CHECKSextended with the new entries.py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py— strategy-name maps + unconditional validation in__init__;_pack_engine_infowrites the flag and the three slots on every build.Tests
tests/py/dynamo/runtime/test_001_cpp_runtime_rtx_features.py— new smoke-test file for the C++ runtime. The Python-runtime tests (already on the parent branch) assert via PythonTRTEngineattributes that the C++torch.classes.tensorrt.Enginedoes not expose, so this file exercises externally-observable behavior: strategy-name typo rejection at__init__, compile + infer correctness viacosine_similarity > COSINE_THRESHOLD, and runtime-cache file persistence on engine destruction.What's intentionally NOT in this diff
py/torch_tensorrt/dynamo/_defaults.py,_settings.py,_compiler.py— owned by the parent branch (feat: reintroduce TRT-RTX runtime cache, dynamic shapes, and native CUDA graph support pytorch/TensorRT#4294). PR feat(runtime): add TensorRT-RTX runtime cache, dynamic shapes strategy, and native CUDA graph support to C++ runtime pytorch/TensorRT#4202's original duplicates were dropped."8"to"9"—mainis already at"9"from Python runtime rework pytorch/TensorRT#4222.TRTEngineattributes the C++ engine doesn't expose. The new smoke-test file gives C++ runtime coverage without requiring those changes.cuda_graph_strategy/dynamic_shapes_kernel_specialization_strategy/runtime_cache_pathfromCompilationSettingsto runtime context managers — tracked as a follow-up in Move TRT-RTX runtime mode controls from CompilationSettings to runtime context managers pytorch/TensorRT#4310 per @narendasan's review feedback on feat(runtime): file-lock the TRT-RTX runtime cache pytorch/TensorRT#4237.Test plan
A100, TRT-RTX 1.5.0.103, PyTorch nightly
2.13.0.dev20260521+cu130, CUDA 13.0, built withpython3 setup.py bdist_wheel --use-rtx.Layout self-check (matches across both backends):
(
_assert_serialized_layout_matches_cpp()runs at import time and matched.)Total: 51 passed, 4 skipped, 0 failures. Skips are all non-RTX gates.
Notes for the reviewer
reintroduce-rtx-features(notmain) so only the C++-side delta shows up. When that branch lands upstream, this branch should be rebased ontomain.core/runtime/TRTRuntimeConfig.{h,cpp}is structured per the round-2/round-3 review feedback on the original feat(runtime): add TensorRT-RTX runtime cache, dynamic shapes strategy, and native CUDA graph support to C++ runtime pytorch/TensorRT#4202._TorchTensorRTModule.pyvalidation lives in__init__(not in_pack_engine_info) so typos fail fast on every build, regardless of backend.TRTRuntimeConfigparameter on theTRTEngineconstructor; the pre-existing string parameters (serialized_engine,serialized_metadata,mod_name) stayconst std::string&per review feedback (a broader sink-by-value sweep belongs in a separate follow-up MR).TRTRuntimeConfig::is_monolithic_capturablegates the lazy-strategy capturability check on whether the engine has dynamic-shape inputs, mirroring the Python_is_monolithic_capturableimplementation. Static-shape engines remain monolithically capturable under the lazy strategy.Follow-ups
cuda_graph_strategy,dynamic_shapes_kernel_specialization_strategy,runtime_cache_path) out ofCompilationSettingsinto runtime context managers, per @narendasan.🤖 Generated with Claude Code