[JAX] HLO FFI tests#2593
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
Greptile SummaryAdds FFI compatibility tests to prevent breaking changes in JAX FFI interfaces. The tests load and execute pre-generated StableHLO files to validate FFI bindings remain compatible with older HLO code. Major Changes:
Critical Issues Found:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_ffi_compatibility
participant Parser as _make_args_based_on_input
participant File as HLO File System
participant JAX as JAX Backend
participant FFI as TE FFI Bindings
Test->>File: Read HLO text file
File-->>Test: StableHLO text content
Test->>Parser: Parse function signature
Parser->>Parser: Extract @main signature with regex
Parser->>Parser: Parse tensor shapes and dtypes
Parser->>Parser: Create dummy JAX arrays
Parser-->>Test: Return args list
Test->>JAX: compile_and_load(stablehlo_text)
JAX->>FFI: Resolve custom_call references
FFI-->>JAX: Return registered functions
JAX-->>Test: Compiled executable
Test->>JAX: executable.execute(args)
JAX->>FFI: Execute TE operations
FFI-->>JAX: Results
JAX-->>Test: Execution results
|
|
|
||
| @pytest.fixture(name="ffi_hlo_name") | ||
| def hlo_fixture(shape): | ||
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | ||
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) |
There was a problem hiding this comment.
syntax: The fixture parameter shape is undefined and not used in the function body. This will cause an error when pytest tries to parametrize this fixture.
| @pytest.fixture(name="ffi_hlo_name") | |
| def hlo_fixture(shape): | |
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | |
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) | |
| @pytest.fixture(name="ffi_hlo_name") | |
| def hlo_fixture(self): | |
| for file in os.listdir(TestFFICompatibility.HLO_DIR): | |
| file_path = os.path.join(TestFFICompatibility.HLO_DIR, file) | |
| if os.path.isfile(file_path): | |
| yield file.split(".")[0] |
| """Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly.""" | ||
| # Parse function signature to extract argument information |
There was a problem hiding this comment.
logic: The regex pattern uses non-greedy matching (.*?) which will not match multiline content by default. Since the @main function signature spans multiple lines in the HLO file, this pattern will fail to capture the arguments.
| """Parses the StableHLO text to extract input tensor shapes and dtypes, and creates dummy JAX arrays accordingly.""" | |
| # Parse function signature to extract argument information | |
| pattern = r"@main\((.*?)\{" | |
| match = re.search(pattern, stablehlo_text, re.DOTALL) |
| for arg_num, shape_and_dtype_str in arg_matches: | ||
| print(f"Parsing argument {arg_num} with shape and dtype: {shape_and_dtype_str}") | ||
| # Parse shape: "32x32xbf16" -> [32, 32] | ||
| dtype_str = shape_and_dtype_str.split("x")[-1] |
There was a problem hiding this comment.
logic: Missing handling for None return from dtype_map.get(). If an unsupported dtype is encountered, this will pass None to jnp.ones() causing an error.
| dtype_str = shape_and_dtype_str.split("x")[-1] | |
| dtype = dtype_map.get(dtype_str) | |
| if dtype is None: | |
| raise ValueError(f"Unsupported dtype in HLO: {dtype_str}") |
| args_str = match.group(1) | ||
|
|
||
| # Parse individual arguments |
There was a problem hiding this comment.
logic: The parsing logic assumes shape dimensions are separated by 'x' and the last element is always the dtype. This will fail for scalar tensors (e.g., tensor<bf16>) where there are no dimensions, causing int() conversion to fail on the dtype string.
| args_str = match.group(1) | |
| # Parse individual arguments | |
| # Parse shape: "32x32xbf16" -> [32, 32], handle scalars like "bf16" | |
| parts = shape_and_dtype_str.split("x") | |
| dtype_str = parts[-1] | |
| shape = [int(dim) for dim in parts[:-1]] if len(parts) > 1 else [] |
|
/te-ci L0 jax |
Description
Prevents TE/JAX from changing FFI interfaces and breaking backwards compatibility with older HLO on accident.
Type of change
Changes
test_custom_call.pyand associated HLO text file to ensure both of the followingChecklist: