-
Notifications
You must be signed in to change notification settings - Fork 885
Description
🐛 Describe the bug
During model quantiation for Ethos-U, LayerNorms are decomposed to atomic ops. The current decomposition uses aten::avg_pool2d for removing the mean of the tensor, which has 2 issues:
- The mean is calculated and removed twice, which is redundant.
- The
aten::avg_pool2druns very inefficiently, and is the main cause of increased inference time.
See attached script for reproducing the issue.
Current behaviour:
python ln_script.py
Output (most cycles are spent on AvgPool):
Batch Inference time 469.55 ms, 2.13 inferences/s (batch size 1)
Equivalent "decomposed" implementation of LayerNorm (for bypassing the LayerNorm Decomposition Pass) for improved inference time (using sum instead of mean):
python ln_script.py --replace_layernorm
Output:
Batch Inference time 8.75 ms, 114.27 inferences/s (batch size 1)
Note: the tensor sizes in the script are taken from the whisper-tiny encoder model.
Versions
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.3 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.2.1
Libc version: N/A
Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-26.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] executorch==1.0.0
[pip3] numpy==2.3.4
[pip3] onnx==1.20.0
[pip3] onnx-ir==0.1.14
[pip3] onnxruntime==1.23.2
[pip3] onnxscript==0.5.7
[pip3] torch==2.9.0
[pip3] torchao==0.14.0
[pip3] torchaudio==2.9.0
[pip3] torchcodec==0.9.1
[pip3] torchmetrics==1.8.2
[pip3] torchvision==0.24.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell