Skip to content

Add CuTe DSL JAX demo#3103

Open
katjasrz wants to merge 1 commit intoNVIDIA:mainfrom
katjasrz:cute-dsl-jax-demo
Open

Add CuTe DSL JAX demo#3103
katjasrz wants to merge 1 commit intoNVIDIA:mainfrom
katjasrz:cute-dsl-jax-demo

Conversation

@katjasrz
Copy link

Summary

This PR adds a minimal CuTe DSL + JAX demo under examples/python/CuTeDSL/jax.

The demo shows how to define and invoke custom CuTe DSL kernels from JAX, providing a self-contained example that can serve as a starting point for experimentation and integration.

What This Adds

  • cute_dsl_jax.ipynb
    A walkthrough-style notebook demonstrating:

    • Defining CuTe DSL kernels
    • Calling them from JAX
    • Basic usage patterns and expected behavior
  • cute_dsl_jax_kernels.py
    Supporting kernel definitions used by the notebook.

Purpose

The goal is to provide:

  • A simple reference example for users exploring CuTe DSL + JAX
  • A starting template for extending CuTe DSL kernels within a JAX workflow
  • A lightweight demo that complements the existing Python examples

Notes

  • The example is self-contained and lives under examples/python/CuTeDSL/jax.
  • No changes to core functionality.
  • Intended primarily as an illustrative example.

@fengxie fengxie requested a review from Junkai-Wu March 23, 2026 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant