Skip to content

ARM quantizers fail on nn.LeakyReLU due to device placement of negative_slope constant #16541

@lapid92

Description

@lapid92

🐛 Describe the bug

I am seeing device-dependent failures when quantizing a model that uses nn.LeakyReLU with ARM backends (VGF / Ethos-U).

The issue does not reproduce with the XNNPACK quantizer and appears specific to ARM quantizers.
I suspect that the negative_slope constant used by nn.LeakyReLU is kept on CPU, leading to failure during quantization calibration.

Before applying the patch, the prepare_pt2e() would fail with a device-mismatch RuntimeError.

The description below describes a flow using the patch.

Observed behavior

There are two failure modes depending on device placement:

  1. Model kept on CPU during quantization (prepare_pt2e(), calibration and convert_pt2e()) and then moved to a device:
    • Quantization succeeds.
    • Moving the quantized model to a non-CPU device and running inference fails with a RuntimeError due to a device mismatch.

  2. Model moved to a non-CPU device before prepare_pt2e():
    • Running quantization calibration fails with a RuntimeError due to a device mismatch.

This reproduces with VGF and Ethos-U; XNNPACK works.

Reproduce Steps

import torch
import torch.nn as nn
from executorch.backends.arm.quantizer import VgfQuantizer
from executorch.backends.arm.quantizer import get_symmetric_quantization_config
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.vgf import VgfCompileSpec
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

class LeakyReluNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 8, 3, padding=1)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.leaky_relu(x)
        return x

device = 'mps' # or 'cuda'

float_model = LeakyReluNet().eval().to(device)
x = torch.randn(1, 3, 16, 16, device=device)

exported_program = torch.export.export(
    float_model, (x, ), strict=True)

model = exported_program.module(check_guards=False)

# VGF target quantization
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
compile_spec = VgfCompileSpec(tosa_spec)
quantizer = VgfQuantizer(compile_spec)

operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)

model(x)
m = prepare_pt2e(model, quantizer)
m = m.to(device)
x = x.to(device)

# Calibrate (simplified)
with torch.no_grad():
    # Fails with RuntimeError due to a device mismatch
    m(x)

m = convert_pt2e(m)
m(x)

Error Message

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

Current Workaround

Replacing nn.LeakyReLU with a custom implementation where negative_slope is explicitly wrapped as a torch.tensor resolves the issue.

import torch.nn.functional as F
class LeakyReLUFromReLU(nn.Module):
    def __init__(self, negative_slope: float = 0.1):
        super().__init__()
        self.negative_slope = nn.Parameter(data=torch.tensor(negative_slope), requires_grad=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_pos = F.relu(x)
        x_neg = -F.relu(-x)
        return x_pos + self.negative_slope * x_neg

class LeakyReluNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 8, 3, padding=1)
        self.leaky_relu = LeakyReLUFromReLU(negative_slope=0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.leaky_relu(x)
        return x

Versions

PyTorch version: 2.10.0.dev20251120
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 3.31.10
Libc version: N/A

Python version: 3.12.0 (v3.12.0:0fb18b02c8, Oct 2 2023, 09:45:56) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] executorch==1.1.0a0+00859b1
[pip3] numpy==2.4.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.10.0.dev20251120
[pip3] torchao==0.16.0+git08e5e203f
[pip3] torchaudio==2.10.0.dev20251120
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.25.0.dev20251120
[conda] Could not collect
(exec_issues_patch) arilap01@G4ND6VDYXR AAIR-raw-denoising % python collect_env.py
Collecting environment information...
PyTorch version: 2.10.0.dev20251120
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 3.31.10
Libc version: N/A

Python version: 3.12.0 (v3.12.0:0fb18b02c8, Oct 2 2023, 09:45:56) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] executorch==1.1.0a0+00859b1
[pip3] numpy==2.4.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.10.0.dev20251120
[pip3] torchao==0.16.0+git08e5e203f
[pip3] torchaudio==2.10.0.dev20251120
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.25.0.dev20251120
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: armIssues related to arm backend

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions