Skip to content

TRT 10.15.1 Myelin failure on RTX PRO 6000 (sm_120) with Fast-FoundationStereo post_runner #4715

@baker-git

Description

@baker-git

TRT 10.15.1 Myelin failure on RTX PRO 6000 (sm_120) with Fast-FoundationStereo post_runner

Description

TensorRT 10.15.1 fails to build a working engine for the post_runner network from Fast-FoundationStereo on the RTX PRO 6000 (Blackwell, sm_120). The Myelin compiler finds zero valid tactics for a fused node containing 3D ConvTranspose + Cast operations.

This is NVIDIA's own stereo depth model failing on NVIDIA's own GPU with the latest TensorRT release.

The feature_runner network from the same model builds and runs correctly - only the post_runner is affected.

Environment

  • TensorRT: 10.15.1.29 (pip, cu12)
  • GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition (sm_120, 96GB)
  • Driver: 570.211.01
  • CUDA: 12.8
  • OS: Ubuntu 22.04 (GCP Deep Learning VM, pytorch-2-7-cu128-ubuntu-2204-nvidia-570)
  • PyTorch: 2.7.1+cu128
  • ONNX opset: 17

Steps to Reproduce

  1. Clone Fast-FoundationStereo and download the 23-36-37 checkpoint
  2. Patch ChannelAttentionEnhancement.forward() in core/submodule.py to replace nn.AdaptiveAvgPool2d(1) / nn.AdaptiveMaxPool2d(1) with x.mean(dim=[2,3], keepdim=True) / x.amax(dim=[2,3], keepdim=True) (required because adaptive pooling at 1920x1088 creates a 480x272 kernel that exceeds TRT's max kernel size - separate issue)
  3. Export ONNX at 1920x1088:
    python scripts/make_onnx.py --model_dir weights/23-36-37/model_best_bp2_serialize.pth --save_path output/ --height 1088 --width 1920 --valid_iters 8
    
  4. Build TRT engine:
    import tensorrt as trt
    
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, logger)
    
    with open("output/post_runner.onnx", "rb") as f:
        parser.parse(f.read())
    
    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
    config.set_flag(trt.BuilderFlag.FP16)
    
    engine = builder.build_serialized_network(network, config)  # Returns None

Error Output

[Autotuner]: No valid tactics to print (all tactics failed)
Internal Error: MyelinCheckException: autotuner.cpp:2318: CHECK(sorted_ids.size() > 0) failed. Must have costs

[TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception
  [myelin_graph.h:1208: attachExceptionMsgToGraph] MyelinCheckException: autotuner.cpp:2318:
  CHECK(sorted_ids.size() > 0) failed. Must have costs
  In compileGraph at /_src/optimizer/myelin/codeGenerator.cpp:1762

[TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error
  (Could not find any implementation for node
  {ForeignNode[stem_2x_cast + /Cast_202 + /Cast_202_output_0_cast.../Cast_205 + disp_castOut]}.
  In computeCosts at /_src/optimizer/common/tactic/optimizer.cpp:4234)

What I've Tried

Attempt Result
FP16 (default opt level 3) Build fails - zero tactics found
FP32 (no FP16 flag) Build fails - same error
BF16 Build fails - No matching rules found for input operand types on 3D ConvTranspose
builder_optimization_level=0 Engine builds, but crashes at runtime: IExecutionContext::enqueueV3: Error Code 1: Myelin ([cask.cpp:2974: exec] Platform (Cuda) error)
builder_optimization_level=1 Build fails
builder_optimization_level=2 Build fails
PREFER_PRECISION_CONSTRAINTS flag Build fails
TRT 10.14.1, 10.13.3 Cannot initialize on sm_120 - CUDA initialization failure with error: 35

Analysis

The failing fused node spans the entire post-processing network (from stem_2x_cast to disp_castOut). The network contains 3D ConvTranspose operations (in the cost aggregation upsampling path) fused with mixed-precision Cast nodes. The Myelin kernel library for sm_120 appears to be missing implementations for this fused op pattern.

The feature_runner network (same model, backbone only, no 3D ConvTranspose) builds and executes correctly at 5.6ms on the same GPU, confirming the issue is specific to the post_runner's op mix.

Impact

Without TRT acceleration for the post_runner, the full inference pipeline runs at 7-10 fps (PyTorch) instead of the expected ~30 fps (full TRT) at 1920x1088 on RTX PRO 6000. The post_runner accounts for 85-135ms of the total 93-146ms latency.

Attachments

The post_runner.onnx file can be reproduced using the steps above, or I can attach it directly if needed.

Additional Note

The ONNX export required patching ChannelAttentionEnhancement to replace nn.AdaptiveAvgPool2d(1) / nn.AdaptiveMaxPool2d(1) with torch.mean() / torch.amax() reduce ops. At 1920x1088, the 1/4 resolution feature maps are 480x272, creating pooling kernels that exceed TRT's maxKernelDimsProduct limit. This is a separate issue from the Myelin failure but affects the same model at this resolution.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions