Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _export_cuda(
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig
from torch.export import Dim, export

inductor_config.coordinate_descent_tuning = False
Expand Down Expand Up @@ -270,6 +271,14 @@ def _export_cuda(
alloc_graph_input=False,
),
emit_mutable_buffer_names=True,
# Keep method inputs/outputs device-resident so the CUDA backend
# does not insert boundary H2D/D2H copies: the runner stages inputs
# in CUDA memory and reads the sampled token back with a single
# small D2H. CUDA-only (no effect on the MLX path).
propagate_device_config=PropagateDeviceConfig(
skip_h2d_for_method_inputs=True,
skip_d2h_for_method_outputs=True,
),
),
)

Expand Down
55 changes: 51 additions & 4 deletions examples/models/gemma4_31b/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/extension/tensor/tensor_ptr.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/portable_type/device.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>
#include <pytorch/tokenizers/hf_tokenizer.h>

Expand Down Expand Up @@ -82,6 +85,9 @@ using ::executorch::extension::from_blob;
using ::executorch::extension::Module;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
#ifdef EXECUTORCH_BUILD_CUDA
using ::executorch::extension::clone_tensor_ptr_to;
#endif

using SizesType = executorch::aten::SizesType;

Expand Down Expand Up @@ -181,6 +187,8 @@ int main(int argc, char** argv) {
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);

#ifdef EXECUTORCH_BUILD_CUDA
const auto cuda_device =
executorch::aten::Device(executorch::aten::DeviceType::CUDA, 0);
if (FLAGS_cuda_graph) {
executorch::runtime::BackendOptions<2> cuda_opts;
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
Expand Down Expand Up @@ -217,8 +225,9 @@ int main(int argc, char** argv) {
ET_LOG(Error, "Failed to load decode method");
return 1;
}
auto temp_tensor =
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
auto temp_tensor = clone_tensor_ptr_to(
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float),
cuda_device);
#else
if (FLAGS_cuda_graph) {
ET_LOG(Info, "--cuda_graph ignored on non-CUDA build");
Expand Down Expand Up @@ -304,6 +313,12 @@ int main(int argc, char** argv) {
auto pos_tensor = from_blob(
pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long);

#ifdef EXECUTORCH_BUILD_CUDA
// skip_h2d: prefill/decode method inputs must already live in CUDA memory.
tokens_tensor = clone_tensor_ptr_to(tokens_tensor, cuda_device);
pos_tensor = clone_tensor_ptr_to(pos_tensor, cuda_device);
#endif

std::vector<EValue> inputs;
inputs.push_back(EValue(tokens_tensor));
inputs.push_back(EValue(pos_tensor));
Expand Down Expand Up @@ -356,17 +371,49 @@ int main(int argc, char** argv) {
int64_t pos = num_prompt_tokens;
std::vector<int64_t> decode_token_data = {static_cast<int64_t>(cur_token)};
std::vector<int64_t> decode_pos_data = {pos};
auto decode_tokens = from_blob(
auto decode_tokens_cpu = from_blob(
decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long);
auto decode_pos = from_blob(
auto decode_pos_cpu = from_blob(
decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long);
#ifdef EXECUTORCH_BUILD_CUDA
// skip_h2d: keep fixed device-resident input buffers across decode steps
// (seeded here with a one-time H2D). Their 8-byte contents are refreshed
// each step from the host (see loop), since the sampled id round-trips to
// the host for EOS detection.
auto decode_tokens = clone_tensor_ptr_to(decode_tokens_cpu, cuda_device);
auto decode_pos = clone_tensor_ptr_to(decode_pos_cpu, cuda_device);
#else
auto decode_tokens = decode_tokens_cpu;
auto decode_pos = decode_pos_cpu;
#endif

uint64_t prev_token = cur_token;
bool hit_eos = eos_ids.find(cur_token) != eos_ids.end();
for (int32_t step = 0; step < FLAGS_max_new_tokens && !hit_eos; step++) {
decode_token_data[0] = static_cast<int64_t>(cur_token);
decode_pos_data[0] = pos;

#ifdef EXECUTORCH_BUILD_CUDA
// skip_h2d: refresh the device-resident token/pos buffers ourselves (the
// backend no longer inserts the H2D copy). The prior step's sampled id
// already came back to the host via read_token, so re-upload the 8-byte
// token and position into the fixed device buffers.
ET_CHECK_MSG(
cudaMemcpy(
decode_tokens->mutable_data_ptr(),
decode_token_data.data(),
sizeof(int64_t),
cudaMemcpyHostToDevice) == cudaSuccess,
"Failed to upload decode token H2D");
ET_CHECK_MSG(
cudaMemcpy(
decode_pos->mutable_data_ptr(),
decode_pos_data.data(),
sizeof(int64_t),
cudaMemcpyHostToDevice) == cudaSuccess,
"Failed to upload decode position H2D");
#endif

std::vector<EValue> inputs;
inputs.push_back(EValue(decode_tokens));
inputs.push_back(EValue(decode_pos));
Expand Down
Loading