-
Notifications
You must be signed in to change notification settings - Fork 908
Description
🐛 Describe the bug
Hello.
Following the recent 2.11.0 release of PyTorch, I was trying to give it a shot when it comes to exporting LSTM-based models with dynamic shapes. This was a well known issue severely affecting certain group of models, which has recently been fixed with the register_lstm_while_loop_decomposition context manager.
This indeed fixes the LSTM export problem, but causes another problem within the ExecuTorch runtime. If the exported model has any nn.Embedding call, it fails inferred on different input shape than it was exported with, even though the export explicitly specifies usage of dynamic shapes. The error looks usually like this:
[tensor_impl.cpp:117] Attempted to resize a static tensor. Expected shape (1, 16, 32), but received (1, 12, 32).
[op_embedding.cpp:93] Check failed (resize_embedding_output(weight, indices, out) == Error::Ok):
[method.cpp:1384] KernelCall failed at instruction 0:2 in operator aten::embedding.out: 0x12
[method.cpp:1390] arg 0 with type id 1
[method.cpp:1390] arg 1 with type id 1
[method.cpp:1390] arg 2 with type id 4
[method.cpp:1390] arg 3 with type id 5
[method.cpp:1390] arg 4 with type id 5
[method.cpp:1390] arg 5 with type id 1
[method.cpp:1390] arg 6 with type id 1
Here is a sample code to reproduce the error on some simple model:
import torch
from torch import nn
from torch.export import Dim
from executorch.exir import to_edge_transform_and_lower
from torch.export._patches import register_lstm_while_loop_decomposition
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.runtime import Runtime
class DumbModel(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(100, 32)
self.lstm = nn.LSTM(32, 64, batch_first=True)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
return x
def export_model():
model = DumbModel().eval()
example_input = torch.randint(0, 100, (1, 16), dtype=torch.long)
tokens = Dim("tokens", min=1, max=128)
dynamic_shapes = ({1: tokens},)
with register_lstm_while_loop_decomposition():
exported_program = torch.export.export(model, (example_input,), dynamic_shapes=dynamic_shapes)
edge_program = to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()],
)
executorch_program = edge_program.to_executorch()
with open("simple_embedding.pte", "wb") as f:
executorch_program.write_to_file(f)
print("Exported simple_embedding.pte")
runtime = Runtime.get()
program = runtime.load_program("simple_embedding.pte")
forward_method = program.load_method("forward")
# Use different sequence length than the one used during the export
inference_input = torch.randint(0, 100, (1, 5), dtype=torch.long)
output = forward_method.execute((inference_input,))
print(f"Inference successful! Input shape: {inference_input.shape}, Output shape: {output[0].shape}")
if __name__ == "__main__":
export_model()
When exported without to_edge_transform_and_lower (and without LSTM, since otherwise dynamic shape export is not possible), there is no such issue.
Versions
Collecting environment information...
PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: version 3.31.10
Libc version: N/A
Python version: 3.12.11 (main, Jun 3 2025, 15:41:47) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-26.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M4 Pro
Versions of relevant libraries:
[pip3] executorch==1.1.0a0+17adba1
[pip3] numpy==2.3.5
[pip3] onnxruntime==1.24.4
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.11.0
[pip3] torchao==0.15.0+git9338966da
[pip3] torchaudio==2.11.0
[pip3] torchcodec==0.10.0
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.26.0
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
Type
Projects
Status