Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 129 additions & 49 deletions tools/ais-check/ais-check
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ If components are missing the program exits with a non-zero exit code.
import argparse
import ctypes
import ctypes.util
import glob
import gzip
import os
import sys

# Global mapping of HIP runtime library paths to AIS support flags
hip_libraries = {}


def kernel_supports_p2pdma():
"""
Expand Down Expand Up @@ -46,67 +50,134 @@ def kernel_supports_p2pdma():
return False


def find_hip_runtimes():
"""
Populate a global list of HIP runtime libraries by looking in the
usual places.
"""

# NOTE: CodeQL will be unhappy if you are not careful about paths
# in this function

candidates = []

# 1. Respect runtime linker paths
for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"):
if p:
# Clean up the path by removing `..`, etc. and getting
# an absolute path.
safe_p = os.path.abspath(os.path.normpath(p))

candidates.append(os.path.join(safe_p, "libamdhip64.so"))

# 2. Environment variables commonly set by ROCm or modules
for var in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH"):
base = os.environ.get(var)
Comment on lines +64 to +75
if base:
# Also clean up this path
safe_base = os.path.abspath(os.path.normpath(base))
candidates += [
os.path.join(safe_base, "lib", "libamdhip64.so"),
os.path.join(safe_base, "lib64", "libamdhip64.so"),
]

# 3. Standard ROCm install paths
candidates += [
"/opt/rocm/lib/libamdhip64.so",
"/opt/rocm/lib64/libamdhip64.so",
]

# 4. Versioned installs (/opt/rocm-5.x, etc.)
for d in glob.glob("/opt/rocm*/lib*/libamdhip64.so"):
candidates.append(d)

# Drop any paths that don't exist
existing_paths = [
os.path.abspath(os.path.normpath(path))
for path in candidates
if os.path.exists(os.path.abspath(os.path.normpath(path)))
]

# Populate the global dictionary of paths
#
# Tell pylint to be quiet since returning the dictionary
# would result in uglier downstream code
global hip_libraries # pylint: disable=W0603
hip_libraries = dict.fromkeys(existing_paths, False)


def hip_runtime_supports_ais():
"""
Check if hipAmdFileRead and hipAmdFileWrite are available in HIP
"""
hip_path = ctypes.util.find_library("amdhip64")
if hip_path is None:
return False
find_hip_runtimes()

hip = ctypes.CDLL(hip_path)
# Check for AIS functions in the list of found HIP libraries
for hip_path in hip_libraries:

hipError_t = ctypes.c_int
hipDriverProcAddressQueryResult = ctypes.c_int
try:
hip = ctypes.CDLL(hip_path)
except OSError:
continue

hip.hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
hip.hipRuntimeGetVersion.restype = hipError_t
hipError_t = ctypes.c_int
hipDriverProcAddressQueryResult = ctypes.c_int

hip.hipGetProcAddress.argtypes = [
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_int,
ctypes.c_uint64,
ctypes.POINTER(hipDriverProcAddressQueryResult),
]
hip.hipGetProcAddress.restype = hipError_t
hip.hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
hip.hipRuntimeGetVersion.restype = hipError_t

hip.hipGetErrorString.argtypes = [hipError_t]
hip.hipGetErrorString.restype = ctypes.c_char_p
hip.hipGetProcAddress.argtypes = [
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_int,
ctypes.c_uint64,
ctypes.POINTER(hipDriverProcAddressQueryResult),
]
hip.hipGetProcAddress.restype = hipError_t

version = ctypes.c_int()
err = hip.hipRuntimeGetVersion(ctypes.byref(version))
if err != 0:
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipRuntimeGetVersion failed with err code {err} ({err_str})",
file=sys.stderr,
)
return False

for symbol in [b"hipAmdFileWrite", b"hipAmdFileRead"]:
func_ptr = ctypes.c_void_p()
symbol_status = hipDriverProcAddressQueryResult()
err = hip.hipGetProcAddress(
symbol,
ctypes.byref(func_ptr),
version.value,
0,
ctypes.byref(symbol_status),
)
hip.hipGetErrorString.argtypes = [hipError_t]
hip.hipGetErrorString.restype = ctypes.c_char_p

version = ctypes.c_int()
err = hip.hipRuntimeGetVersion(ctypes.byref(version))
if err != 0:
if symbol_status.value != 1:
symbol = symbol.decode("utf-8")
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipGetProcAddress({symbol}) failed with err code"
f" {err} ({err_str}) and symbolStatus"
f" {symbol_status.value}",
file=sys.stderr,
)
return False
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipRuntimeGetVersion failed with err code {err} ({err_str})",
file=sys.stderr,
)
continue

return True
# Track whether all required AIS symbols are available in this library
supported = True

for symbol in [b"hipAmdFileWrite", b"hipAmdFileRead"]:
func_ptr = ctypes.c_void_p()
symbol_status = hipDriverProcAddressQueryResult()
err = hip.hipGetProcAddress(
symbol,
ctypes.byref(func_ptr),
version.value,
0,
ctypes.byref(symbol_status),
)
if err != 0:
if symbol_status.value != 1:
symbol = symbol.decode("utf-8")
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipGetProcAddress({symbol}) failed with err code"
Comment on lines +164 to +169
f" {err} ({err_str}) and symbolStatus"
f" {symbol_status.value}",
file=sys.stderr,
)
supported = False
break

if supported:
hip_libraries[hip_path] = True

return any(hip_libraries.values())


def amdgpu_supports_ais():
Expand Down Expand Up @@ -154,8 +225,17 @@ def main():
u = os.uname()
print()
print(u.sysname, u.nodename, u.release, u.version, u.machine)

print()
print("Found these HIP libraries (some may be symlinks):")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding the word "redundant" (e.g. redundant symlinks) since /opt/rocm will likely be a duplicate entry (tbh I can't think of typical scenarios where it's not). To fix entirely, we could track ROCm installs by the realpath to get rid of symlinks.

for lib, support in hip_libraries.items():
if support:
pretty_supported = "supported"
else:
pretty_supported = "NOT supported"
print(f"\t{lib} (AIS {pretty_supported})")

print()
print("AIS support in:")
for name, supported in component_support:
print(f"\t{name:<24}: {supported}")
Expand Down
Loading