Skip to content

[JAX API] Updating TransferToMemoryKind and jax.experimental.pallas.triton#1339

Closed
Steboss wants to merge 6 commits intoapple:mainfrom
Steboss:main
Closed

[JAX API] Updating TransferToMemoryKind and jax.experimental.pallas.triton#1339
Steboss wants to merge 6 commits intoapple:mainfrom
Steboss:main

Conversation

@Steboss
Copy link
Copy Markdown
Contributor

@Steboss Steboss commented Sep 10, 2025

@matthew-e-hopkins
Hey people, this is a huge update, to allow us to use JAX > 0.5.3 (we're currently testing AXLearn with JAX 0.7.2).
I've implemented the following changes:

  • I've created a back compatibility option _JAX_MEMORY_SPACE_SUPPORT so that all these changes can work with different versions of JAX
  • In utils.py JAX' from jax._src.sharding_impls import TransferToMemoryKind has been substituted with its correspondent version for JAX 0.7 (jax.memory.Space.*). I am preserving the previous option by checking the jax version:
if _JAX_MEMORY_SPACE_SUPPORT:
    MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host]
    DEVICE_MEMORY = jax.memory.Space.Device
    HOST_MEMORY = jax.memory.Space.Host

    def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
        return jax.device_put(tensor, memory_kind)

else:
    from jax._src.sharding_impls import TransferToMemoryKind  # pylint: disable=ungrouped-imports

    MemoryKind = Literal["device", "pinned_host"]
    DEVICE_MEMORY = "device"
    HOST_MEMORY = "pinned_host" 
  • These changes have been propagated to optimizers_test.py and optimizers.py
  • jax.experimental.pallas.triton.TritonCompilerParams has now changed in .CompilerParams, so gpu_attention.py, gpu_decoding.py, gpu_paged_attention.py and paged_kv_cache_gpu_kernel.py have been changed accordingly. Again, as before, I'm importing _JAX_MEMORY_SPACE_SUPPORT to check the JAX version and preserving the previous code.

I've tested the changes with Fuji models, it would be great to find an optimal solution for this, as we'd like to support AXLearn in JAX-Toolbox again newer JAX versions.
Please, let me know if you want some changes. Thank you

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 4, 2026

This pull request has been automatically marked as stale because it has been inactive for 60 days. It will be closed in 7 days if no further activity occurs. If you would like to continue working on this, please remove the stale label or leave a comment.

@github-actions github-actions bot added the stale label Apr 4, 2026
@github-actions
Copy link
Copy Markdown

This pull request was closed because it has been inactive for more than 7 days since being marked as stale. Please feel free to reopen it if you would like to continue.

@github-actions github-actions bot closed this Apr 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant