Skip to content

TensorRT on qwen image edit #4719

@ohnlily

Description

@ohnlily

We are currently using Alibaba's Qwen image edit model, and we have compiled it using TensorRT to speed up inference. However, the current AOT speedup is only 16% (this evaluation specifically targets the Transformer at fp16 accuracy), with two 2K ​​input images. This speedup is far from our expectations.

We saw on the official blog that Flux achieved a speedup of nearly 1.5x after compilation using TensorRT.

Therefore, we would like to seek your suggestions to improve our speedup performance. Below are some details of our compilation process.
Hardware and Drivers:
Graphics card is 5090
driver version 590.48.01
CUDA version 13.1
Toolkit 12.8

Model: Model parameters from qwen-image-edit2509 were used.
LoRa was used: Qwen-Image-Edit-Lightning-4steps-V1.0-bf16.safetensors:1.0. The above LoRa models all used the fuse_lora function provided by diffuses to merge the parameters into the original model parameters.

Input images: Two images, each 2496 * 1664 pixels wide and long.
Output image: One image, each 2496 * 1664 pixels wide and long long.

Torch to ONNX core code:

torch.onnx.export(
transformer,

(
hidden_states, # hidden_states

encoder_hidden_states, # encoder_hidden_states

timestep, # timestep

img_rope_real, # img_rope_real
img_rope_imag, # img_rope_imag
txt_rope_real, # txt_rope_real
txt_rope_imag, # txt_rope_imag
),
temp_path,
export_params=True,
opset_version=19,
dynamo=False,
optimize=False,
input_names=[
'hidden_states',
'encoder_hidden_states',
'timestep',
'img_rope_real',
'img_rope_imag',
'txt_rope_real',
'txt_rope_imag',
],
output_names=['out_hidden_states'],
dynamic_axes={
'hidden_states': {1: 'img_seq_len'}, 'encoder_hidden_states': {1: 'txt_seq_len'},
'img_rope_real': {0: 'img_seq_len'},
'img_rope_imag': {0: 'img_seq_len'},
'txt_rope_real': {0: 'txt_seq_len'},
'txt_rope_imag': {0: 'txt_seq_len'},
'out_hidden_states': {1: 'img_seq_len'},
'joint_hidden_states': {1: 'dim0', 2:'dim1'}
}
)

Onnx to engine command:
CMD="polygraphy convert \
"$ONNX_PATH" \
--convert-to trt\
--output "$ENGINE_PATH" \
--tf32 --strongly-typed \
--onnx-flags native_instancenorm \
--builder-optimization-level 5 \
--tiling-optimization-level full \
--precision-constraints none \
--trt-opt-shapes hidden_states:[1,48672,64] encoder_hidden_states:[1,5522,3584] timestep:[1] img_rope_real:[48672,64] img_rope_imag:[48672,64] txt_rope_real:[5522,64] txt_rope_imag:[5522,64] \
--trt-min-shapes hidden_states:[1,48672,64] encoder_hidden_states:[1,5456,3584] timestep:[1] img_rope_real:[48672,64] img_rope_imag:[48672,64] txt_rope_real:[5456,64] txt_rope_imag:[5456,64] \
--trt-max-shapes hidden_states:[1,48672,64] encoder_hidden_states:[1,6126,3584] timestep:[1] img_rope_real:[48672,64] img_rope_imag:[48672,64] txt_rope_real:[6126,64] txt_rope_imag:[6126,64] \
$VERBOSITY"
echo $CMD
eval $CMD

Main dependencies:
accelerate=1.12.0
peft=0.18.1
diffusers=0.36.0
onnx=1.20.1
polygraphy=0.49.26
nvidia-cublas-cu12=12.6.4.1
nvidia-cuda-cupti-cu12=12.6.80
nvidia-cuda-nvrtc-cu12=12.6.77
nvidia-cuda-ru ntime-cu12=12.6.77
nvidia-cudnn-cu12=9.10.2.21
nvidia-cufft-cu12=11.3.0.4
nvidia-cufile-cu12=1.11.1.6
nvidia-curand-cu12=10.3.7.77
nvidia-cusolver-cu12=11.7.1.2
nvidia-cusparse-cu12=12.5.4.2
nvidia -cusparselt-cu12=0.7.1
nvidia-ml-py=12.535.133
nvidia-modelopt=0.11.2
nvidia-nccl-cu12=2.27.5
nvidia-nvjitlink-cu12=12.6.85
nvidia-nvshmem-cu12=3.3.20
nvidia-nvtx-cu12=12.6.77
tensorrt=10.0.1
tens orrt-cu12=10.14.1.48.post1
tensorrt-cu12-bindings=10.14.1.48.post1
tensorrt-cu12-libs=10.14.1.48.post1
triton=3.5.1
transformers=4.57.6
tokenizers=0.22.2
torch=2.10.0+cu128
torchvision=0.25.0+cu128

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions