From aefe4cb683e8fd980355cd0d18e91ffc09e9d11b Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 15 Jun 2026 14:04:31 -0700 Subject: [PATCH] skip h2d and d2h copies methods Summary: This diff updates gemma4-31b export and runtime pipeline to skip the h2d and d2h copies between prefill and decode, and between previous round next round of decode as well. Differential Revision: D108661628 --- examples/models/gemma4_31b/export.py | 9 +++++ examples/models/gemma4_31b/main.cpp | 55 ++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index 987f6265d4d..c19582c49b7 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -171,6 +171,7 @@ def _export_cuda( ) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.passes import MemoryPlanningPass + from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig from torch.export import Dim, export inductor_config.coordinate_descent_tuning = False @@ -270,6 +271,14 @@ def _export_cuda( alloc_graph_input=False, ), emit_mutable_buffer_names=True, + # Keep method inputs/outputs device-resident so the CUDA backend + # does not insert boundary H2D/D2H copies: the runner stages inputs + # in CUDA memory and reads the sampled token back with a single + # small D2H. CUDA-only (no effect on the MLX path). + propagate_device_config=PropagateDeviceConfig( + skip_h2d_for_method_inputs=True, + skip_d2h_for_method_outputs=True, + ), ), ) diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 1b2cbc5432f..91888f6f751 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -23,8 +23,11 @@ #include #include #include +#include #include #include +#include +#include #include #include @@ -82,6 +85,9 @@ using ::executorch::extension::from_blob; using ::executorch::extension::Module; using ::executorch::runtime::Error; using ::executorch::runtime::EValue; +#ifdef EXECUTORCH_BUILD_CUDA +using ::executorch::extension::clone_tensor_ptr_to; +#endif using SizesType = executorch::aten::SizesType; @@ -181,6 +187,8 @@ int main(int argc, char** argv) { FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); #ifdef EXECUTORCH_BUILD_CUDA + const auto cuda_device = + executorch::aten::Device(executorch::aten::DeviceType::CUDA, 0); if (FLAGS_cuda_graph) { executorch::runtime::BackendOptions<2> cuda_opts; cuda_opts.set_option("enable_cuda_graph_for_method", "decode"); @@ -217,8 +225,9 @@ int main(int argc, char** argv) { ET_LOG(Error, "Failed to load decode method"); return 1; } - auto temp_tensor = - from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); + auto temp_tensor = clone_tensor_ptr_to( + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float), + cuda_device); #else if (FLAGS_cuda_graph) { ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); @@ -304,6 +313,12 @@ int main(int argc, char** argv) { auto pos_tensor = from_blob( pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + // skip_h2d: prefill/decode method inputs must already live in CUDA memory. + tokens_tensor = clone_tensor_ptr_to(tokens_tensor, cuda_device); + pos_tensor = clone_tensor_ptr_to(pos_tensor, cuda_device); +#endif + std::vector inputs; inputs.push_back(EValue(tokens_tensor)); inputs.push_back(EValue(pos_tensor)); @@ -356,10 +371,21 @@ int main(int argc, char** argv) { int64_t pos = num_prompt_tokens; std::vector decode_token_data = {static_cast(cur_token)}; std::vector decode_pos_data = {pos}; - auto decode_tokens = from_blob( + auto decode_tokens_cpu = from_blob( decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); - auto decode_pos = from_blob( + auto decode_pos_cpu = from_blob( decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); +#ifdef EXECUTORCH_BUILD_CUDA + // skip_h2d: keep fixed device-resident input buffers across decode steps + // (seeded here with a one-time H2D). Their 8-byte contents are refreshed + // each step from the host (see loop), since the sampled id round-trips to + // the host for EOS detection. + auto decode_tokens = clone_tensor_ptr_to(decode_tokens_cpu, cuda_device); + auto decode_pos = clone_tensor_ptr_to(decode_pos_cpu, cuda_device); +#else + auto decode_tokens = decode_tokens_cpu; + auto decode_pos = decode_pos_cpu; +#endif uint64_t prev_token = cur_token; bool hit_eos = eos_ids.find(cur_token) != eos_ids.end(); @@ -367,6 +393,27 @@ int main(int argc, char** argv) { decode_token_data[0] = static_cast(cur_token); decode_pos_data[0] = pos; +#ifdef EXECUTORCH_BUILD_CUDA + // skip_h2d: refresh the device-resident token/pos buffers ourselves (the + // backend no longer inserts the H2D copy). The prior step's sampled id + // already came back to the host via read_token, so re-upload the 8-byte + // token and position into the fixed device buffers. + ET_CHECK_MSG( + cudaMemcpy( + decode_tokens->mutable_data_ptr(), + decode_token_data.data(), + sizeof(int64_t), + cudaMemcpyHostToDevice) == cudaSuccess, + "Failed to upload decode token H2D"); + ET_CHECK_MSG( + cudaMemcpy( + decode_pos->mutable_data_ptr(), + decode_pos_data.data(), + sizeof(int64_t), + cudaMemcpyHostToDevice) == cudaSuccess, + "Failed to upload decode position H2D"); +#endif + std::vector inputs; inputs.push_back(EValue(decode_tokens)); inputs.push_back(EValue(decode_pos));