Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,45 @@ cc_library(
"RTDevice.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
"TRTRuntimeConfig.cpp",
"execute_engine.cpp",
"runtime.cpp",
"runtime_utils.cpp",
],
hdrs = [
"Platform.h",
"RTDevice.h",
"TensorRTBindingNames.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TRTRuntimeConfig.h",
"TensorRTBindingNames.h",
"runtime.h",
],
copts = if_torch_nccl(["-DUSE_C10D_NCCL"]),
defines = select({
# nvinfer1::IRuntimeConfig (and the matching ICudaEngine::createRuntimeConfig
# / createExecutionContext(IRuntimeConfig*) overloads) was introduced in
# TensorRT 10.11. The TensorRT shipped with the Jetpack l4t-r36.4 toolchain
# (@tensorrt_l4t) predates 10.11 and does not export this type. Every other
# configuration here (RTX, SBSA, Windows, default x86_64 Linux) is on a
# TensorRT >= 10.11 bundle, so it gets the macro.
#
# Gate every IRuntimeConfig-using site in core/runtime with
# `#ifdef TRT_HAS_IRUNTIME_CONFIG`; the Jetpack path falls back to the
# legacy createExecutionContext() no-arg overload.
":jetpack": [],
"//conditions:default": ["TRT_HAS_IRUNTIME_CONFIG"],
}),
linkopts = [
"-lstdc++fs",
],
local_defines = select({
# TensorRT-RTX builds: opt into feature-gated APIs that the runtime layer
# depends on (e.g. IExecutionContext::isStreamCapturable).
":rtx_win": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"],
":rtx_x86_64": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"],
"//conditions:default": [],
}),
deps = [
":tensorrt_binding_names",
"//core/plugins:torch_tensorrt_plugins",
Expand Down Expand Up @@ -135,9 +158,9 @@ cc_library(
hdrs = [
"Platform.h",
"RTDevice.h",
"TensorRTBindingNames.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TensorRTBindingNames.h",
"runtime.h",
],
deps = [
Expand All @@ -151,9 +174,10 @@ filegroup(
srcs = [
"Platform.h",
"RTDevice.h",
"TensorRTBindingNames.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"TRTRuntimeConfig.h",
"TensorRTBindingNames.h",
"runtime.h",
],
visibility = ["//visibility:public"],
Expand Down
124 changes: 90 additions & 34 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <algorithm>
#include <filesystem>
#include <utility>

#include <cuda_runtime.h>
Expand All @@ -23,6 +24,23 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

namespace {
// TensorRT marks unspecified dimensions in dynamic-shape engines with -1.
constexpr int32_t kDynamicDim = -1;

// Returns true iff any of the listed input bindings (including shape tensors) has a
// dynamic dimension.
[[nodiscard]] bool engine_has_dynamic_inputs(
nvinfer1::ICudaEngine* cuda_engine,
std::vector<std::string> const& in_binding_names) {
TORCHTRT_CHECK(cuda_engine != nullptr, "engine_has_dynamic_inputs requires a live ICudaEngine");
return std::any_of(std::begin(in_binding_names), std::cend(in_binding_names), [cuda_engine](std::string const& name) {
auto const dims = cuda_engine->getTensorShape(name.c_str());
return std::any_of(dims.d, dims.d + dims.nbDims, [](int32_t d) { return d == kDynamicDim; });
});
}
} // namespace

std::string slugify(std::string s) {
std::replace(s.begin(), s.end(), '.', '_');
return s;
Expand Down Expand Up @@ -78,7 +96,8 @@ TRTEngine::TRTEngine(
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata,
const ResourceAllocationStrategy resource_allocation_strategy)
const ResourceAllocationStrategy resource_allocation_strategy,
TRTRuntimeConfig runtime_cfg)
: TRTEngine(
"deserialized_trt",
serialized_engine,
Expand All @@ -89,7 +108,8 @@ TRTEngine::TRTEngine(
hardware_compatible,
requires_output_allocator,
serialized_metadata,
resource_allocation_strategy) {}
resource_allocation_strategy,
std::move(runtime_cfg)) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
Expand All @@ -104,7 +124,13 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {
: ResourceAllocationStrategy::kStatic),
make_runtime_config_from_serialized(serialized_info)) {
// Single visible marker that this engine was instantiated through the C++ runtime
// entry point (i.e. torch.classes.tensorrt.Engine), distinguishing it from the Python
// TRTEngine path. Tests look for this string in captured stderr to verify the
// expected backend was exercised.
LOG_INFO("[torch-TensorRT C++ runtime] TRTEngine constructed from serialized info");
this->requires_native_multidevice = std::stoi(serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]);
if (this->requires_native_multidevice) {
LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution");
Expand All @@ -121,7 +147,9 @@ TRTEngine::TRTEngine(
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata,
const ResourceAllocationStrategy resource_allocation_strategy) {
const ResourceAllocationStrategy resource_allocation_strategy,
TRTRuntimeConfig runtime_cfg) {
this->runtime_cfg = std::move(runtime_cfg);
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
Expand Down Expand Up @@ -157,13 +185,7 @@ TRTEngine::TRTEngine(
LOG_DEBUG(
"Resource allocation strategy: "
<< (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");
recreate_execution_context();

// Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses)
cudaMalloc(&empty_tensor_placeholder, 1);
Expand Down Expand Up @@ -251,6 +273,8 @@ TRTEngine::TRTEngine(
num_io = std::make_pair(inputs_size, outputs);
}

runtime_cfg.has_dynamic_inputs = engine_has_dynamic_inputs(cuda_engine.get(), in_binding_names);

#ifndef NDEBUG
this->enable_profiling();
#endif
Expand All @@ -270,6 +294,9 @@ TRTEngine::TRTEngine(
}

TRTEngine::~TRTEngine() {
// Marked noexcept so safe to invoke from a destructor without
// explicit try/catch; any I/O error is logged internally.
runtime_cfg.save_runtime_cache();
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
Expand All @@ -283,8 +310,7 @@ void TRTEngine::disable_profiling() {
torch::cuda::synchronize(device_info.id);
profile_execution = false;
trt_engine_profiler.reset();
exec_ctx = make_trt(cuda_engine->createExecutionContext());
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context");
recreate_execution_context();
}

void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) {
Expand Down Expand Up @@ -381,10 +407,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
trt_engine_profiler.reset();
}
bool result = cuda_engine->setWeightStreamingBudgetV2(budget);
exec_ctx = make_trt(cuda_engine->createExecutionContext());
TORCHTRT_CHECK(
(exec_ctx.get() != nullptr),
"Unable to recreate TensorRT execution context after setting new device memory budget");
recreate_execution_context();
if (profile_execution) {
enable_profiling();
}
Expand Down Expand Up @@ -441,6 +464,7 @@ std::string TRTEngine::to_str() const {
ss << " Target Platform: " << target_platform << std::endl;
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
ss << " Multi-Device Engine: " << (requires_native_multidevice) << std::endl;
ss << runtime_cfg.to_str();
// clang-format on
return ss.str();
}
Expand Down Expand Up @@ -487,7 +511,11 @@ FlattenedState TRTEngine::__obj_flatten__() {
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]),
std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]));
std::tuple("requires_native_multidevice", serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX]),
std::tuple("has_runtime_cfg", serialized_info[HAS_RUNTIME_CFG_IDX]),
std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]),
std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]),
std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
Expand All @@ -514,6 +542,16 @@ std::vector<std::string> TRTEngine::serialize() {
this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
serialized_info[REQUIRES_NATIVE_MULTIDEVICE_IDX] = this->requires_native_multidevice ? "1" : "0";
// rank/world_size are runtime facts (may differ at load time); not serialized.
#ifdef TRT_MAJOR_RTX
serialized_info[HAS_RUNTIME_CFG_IDX] = "1";
#else
serialized_info[HAS_RUNTIME_CFG_IDX] = "0";
#endif
serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path;
serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string(
static_cast<std::underlying_type_t<DynamicShapesKernelStrategy>>(runtime_cfg.dynamic_shapes_kernel_strategy));
serialized_info[CUDA_GRAPH_STRATEGY_IDX] =
std::to_string(static_cast<std::underlying_type_t<CudaGraphStrategyOption>>(runtime_cfg.cuda_graph_strategy));

return serialized_info;
}
Expand All @@ -525,14 +563,11 @@ void TRTEngine::reset_captured_graph() {
void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) {
if (new_strategy != this->resource_allocation_strategy) {
this->resource_allocation_strategy = new_strategy;
if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
LOG_DEBUG("Setting resource allocation strategy to dynamic");
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
LOG_DEBUG("Setting resource allocation strategy to static");
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
LOG_DEBUG(
"Setting resource allocation strategy to "
<< (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic"
: "static"));
recreate_execution_context();
}
}

Expand Down Expand Up @@ -629,19 +664,40 @@ void TRTEngine::release_nccl_comm() {
LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'");
torch::cuda::synchronize(device_info.id);
this->exec_ctx.reset();
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
TORCHTRT_CHECK(
(exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm");
recreate_execution_context();
this->nccl_initialized = false;
LOG_INFO("NCCL communicator released from engine '" << this->name << "'");
}
#endif // ENABLE_TRT_NCCL_COLLECTIVES

bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const {
return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream);
}

void TRTEngine::disable_rtx_native_cudagraphs() {
bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled;
runtime_cfg.disable_rtx_native_cudagraphs(name);
if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) {
// The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx
// so the new strategy takes effect for subsequent enqueueV3 calls.
recreate_execution_context();
}
}

void TRTEngine::recreate_execution_context() {
// Flush any kernels the previous execution context may have compiled into the
// runtime cache before creating the replacement. The destructor also saves, but
// doing it here guards against losing compiled kernels across profiling toggles,
// allocator changes, or process kills that happen between allocator changes and
// teardown. No-op on standard TensorRT or when no cache path is configured.
runtime_cfg.save_runtime_cache();
const auto allocation_strategy = resource_allocation_strategy == ResourceAllocationStrategy::kDynamic
? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED
: nvinfer1::ExecutionContextAllocationStrategy::kSTATIC;
exec_ctx = runtime_cfg.create_execution_context(cuda_engine.get(), allocation_strategy);
TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context");
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
31 changes: 28 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "torch/custom_class.h"

#include "core/runtime/TRTEngineProfiler.h"
#include "core/runtime/TRTRuntimeConfig.h"
#include "core/runtime/TensorRTBindingNames.h"
#include "core/util/prelude.h"

Expand Down Expand Up @@ -47,7 +48,12 @@ using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>, // Platform
std::tuple<std::string, std::string>, // Resource Allocation Strategy
std::tuple<std::string, std::string>>; // requires_native_multidevice
std::tuple<std::string, std::string>, // requires_native_multidevice
std::tuple<std::string, std::string>, // has_runtime_cfg (gates next three)
std::tuple<std::string, std::string>, // Runtime Cache Path (TRT-RTX)
std::tuple<std::string, std::string>, // Dynamic Shapes Kernel Strategy (TRT-RTX)
std::tuple<std::string, std::string> // CUDA Graph Strategy (TRT-RTX)
>;

struct TorchTRTRuntimeStates {
// Indicates whether CUDAGraphs were enabled in the previous execute_engine
Expand Down Expand Up @@ -151,7 +157,8 @@ struct TRTEngine : torch::CustomClassHolder {
bool requires_output_allocator = false,
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);
TRTEngine::ResourceAllocationStrategy::kStatic,
TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{});

TRTEngine(std::vector<std::string> serialized_info);

Expand All @@ -166,7 +173,8 @@ struct TRTEngine : torch::CustomClassHolder {
bool requires_output_allocator = false,
const std::string& serialized_metadata = "",
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
TRTEngine::ResourceAllocationStrategy::kStatic);
TRTEngine::ResourceAllocationStrategy::kStatic,
TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{});

std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand Down Expand Up @@ -273,6 +281,23 @@ struct TRTEngine : torch::CustomClassHolder {
ResourceAllocationStrategy resource_allocation_strategy = kStatic;
void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy);
ResourceAllocationStrategy get_resource_allocation_strategy();

// Owns the IRuntimeConfig (where supported) and TRT-RTX runtime state. On older TRT
// without IRuntimeConfig (e.g. Jetpack) this just carries strategy values that get
// passed to the legacy createExecutionContext overload.
TRTRuntimeConfig runtime_cfg;

// Monolithic-capturability check used when this engine is wrapped by an outer whole-graph
// capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true.
bool is_monolithic_capturable(cudaStream_t stream) const;

// Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when
// an outer stream capture is detected around execute_engine). No-op on non-RTX.
void disable_rtx_native_cudagraphs();

private:
// Single entry point that (re)creates exec_ctx via runtime_cfg.create_execution_context.
void recreate_execution_context();
};

} // namespace runtime
Expand Down
Loading
Loading