Skip to content

Improve docstrings and type annotations in jax_utils.py#2022

Closed
irhyl wants to merge 1 commit intogoogle-deepmind:mainfrom
gsoc-2026:pr/jax-utils-docstrings
Closed

Improve docstrings and type annotations in jax_utils.py#2022
irhyl wants to merge 1 commit intogoogle-deepmind:mainfrom
gsoc-2026:pr/jax-utils-docstrings

Conversation

@irhyl
Copy link
Copy Markdown

@irhyl irhyl commented Mar 1, 2026

Summary

Improves developer experience and static analysis for jax_utils.py by adding missing docstrings and correcting type annotations.

Changes

  • Modified: torax/_src/jax_utils.py
    • Added Google-style docstrings to get_dtype, get_np_dtype, get_int_dtype, _init_pytree
    • Fixed return annotations from type(jnp.float32) (a runtime value) to type[jnp.float32] | type[jnp.float64] (a proper type hint)
    • Added type hints to _init_pytree(t: PyTree) -> PyTree and its inner init_array

Testing

All 15 existing jax_utils_test.py tests continue to pass.

- Add Google-style docstrings to get_dtype, get_np_dtype,
  get_int_dtype, and _init_pytree
- Fix return type annotations: use type[jnp.float32] | type[jnp.float64]
  instead of type(jnp.float32) for dtype-returning functions
- Add parameter and return type hints to _init_pytree and its inner
  init_array function
Comment thread torax/_src/jax_utils.py

def _init_pytree(t):
def _init_pytree(t: PyTree) -> PyTree:
"""Initializes a pytree of `ShapeDtypeStruct` leaves into concrete arrays.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a one line docstring is sufficient for this private function

@gsoc-2026 gsoc-2026 closed this by deleting the head repository Apr 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants