-
Notifications
You must be signed in to change notification settings - Fork 908
Method.execute() silently produces wrong results for non-contiguous input tensors #18562
Description
🐛 Describe the bug
executorch.runtime.Method.execute() ignores tensor strides and reads data_ptr as if the tensor were contiguous. When a non-contiguous tensor is passed as input, the Runtime interprets the raw memory layout incorrectly, producing silently wrong output — no error, no warning.
This naturally occurs with the standard PyTorch image-loading pattern: PIL/OpenCV/torchvision load images as HWC, then .permute(2, 0, 1) produces a CHW view with non-contiguous strides.
Minimal Reproduction
import tempfile
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.runtime import Runtime
class TinyConvNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),
nn.Conv2d(16, 3, 3, padding=1), nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
torch.manual_seed(42)
model = TinyConvNet().eval()
h, w = 64, 48
with tempfile.TemporaryDirectory() as tmp_dir:
pte_path = str(Path(tmp_dir) / "model.pte")
exported = torch.export.export(model, (torch.randn(1, 3, h, w),))
edge = to_edge_transform_and_lower(exported, partitioner=[XnnpackPartitioner()])
with open(pte_path, "wb") as f:
f.write(edge.to_executorch().buffer)
# Standard image-loading pattern: HWC -> CHW via permute + normalize
rng = np.random.RandomState(42)
hwc_uint8 = rng.randint(0, 256, size=(h, w, 3), dtype=np.uint8)
hwc_float = torch.from_numpy(hwc_uint8).float() / 255.0
chw = hwc_float.permute(2, 0, 1) # non-contiguous view
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
chw = (chw - mean) / std
input_tensor = chw.unsqueeze(0) # (1, 3, H, W) — non-contiguous
assert torch.equal(input_tensor, input_tensor.contiguous()) # same values
runtime = Runtime.get()
program = runtime.load_program(pte_path)
method1 = program.load_method("forward")
out_non_contig = method1.execute([input_tensor])[0]
method2 = program.load_method("forward")
out_contig = method2.execute([input_tensor.contiguous()])[0]
print(f"Outputs match: {torch.equal(out_contig, out_non_contig)}")
print(f"Max abs diff: {torch.max(torch.abs(out_contig - out_non_contig)).item():.6f}")Output
Outputs match: False
Max abs diff: 0.127717
The two inputs hold identical values (torch.equal confirms this), but method.execute() produces different outputs depending on memory layout. Only the contiguous input produces correct results.
Workaround
Call .contiguous() on all inputs before passing to method.execute():
output = method.execute([input_tensor.contiguous()])[0]Expected Behavior
Either:
method.execute()should internally call.contiguous()on input tensors with non-standard strides, ormethod.execute()should raise an error when receiving a non-contiguous tensor
Environment
- ExecuTorch:
1.3.0a0+502d2de - PyTorch:
2.11.0 - OS: macOS 15.4 (ARM64 / Apple Silicon)
- Python: 3.12.9
- Backend: XNNPACK
Versions
Environment:
Collecting environment information...
PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.3.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.3.0
Libc version: N/A
Python version: 3.12.9 (main, Mar 17 2025, 21:36:21) [Clang 20.1.0 ] (64-bit runtime)
Python platform: macOS-26.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M3 Pro
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
Type
Projects
Status