diff --git a/tools/ais-check/ais-check b/tools/ais-check/ais-check index 628f9f64..62d561d8 100755 --- a/tools/ais-check/ais-check +++ b/tools/ais-check/ais-check @@ -12,10 +12,14 @@ If components are missing the program exits with a non-zero exit code. import argparse import ctypes import ctypes.util +import glob import gzip import os import sys +# Global mapping of HIP runtime library paths to AIS support flags +hip_libraries = {} + def kernel_supports_p2pdma(): """ @@ -46,67 +50,134 @@ def kernel_supports_p2pdma(): return False +def find_hip_runtimes(): + """ + Populate a global list of HIP runtime libraries by looking in the + usual places. + """ + + # NOTE: CodeQL will be unhappy if you are not careful about paths + # in this function + + candidates = [] + + # 1. Respect runtime linker paths + for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"): + if p: + # Clean up the path by removing `..`, etc. and getting + # an absolute path. + safe_p = os.path.abspath(os.path.normpath(p)) + + candidates.append(os.path.join(safe_p, "libamdhip64.so")) + + # 2. Environment variables commonly set by ROCm or modules + for var in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH"): + base = os.environ.get(var) + if base: + # Also clean up this path + safe_base = os.path.abspath(os.path.normpath(base)) + candidates += [ + os.path.join(safe_base, "lib", "libamdhip64.so"), + os.path.join(safe_base, "lib64", "libamdhip64.so"), + ] + + # 3. Standard ROCm install paths + candidates += [ + "/opt/rocm/lib/libamdhip64.so", + "/opt/rocm/lib64/libamdhip64.so", + ] + + # 4. Versioned installs (/opt/rocm-5.x, etc.) + for d in glob.glob("/opt/rocm*/lib*/libamdhip64.so"): + candidates.append(d) + + # Drop any paths that don't exist + existing_paths = [ + os.path.abspath(os.path.normpath(path)) + for path in candidates + if os.path.exists(os.path.abspath(os.path.normpath(path))) + ] + + # Populate the global dictionary of paths + # + # Tell pylint to be quiet since returning the dictionary + # would result in uglier downstream code + global hip_libraries # pylint: disable=W0603 + hip_libraries = dict.fromkeys(existing_paths, False) + + def hip_runtime_supports_ais(): """ Check if hipAmdFileRead and hipAmdFileWrite are available in HIP """ - hip_path = ctypes.util.find_library("amdhip64") - if hip_path is None: - return False + find_hip_runtimes() - hip = ctypes.CDLL(hip_path) + # Check for AIS functions in the list of found HIP libraries + for hip_path in hip_libraries: - hipError_t = ctypes.c_int - hipDriverProcAddressQueryResult = ctypes.c_int + try: + hip = ctypes.CDLL(hip_path) + except OSError: + continue - hip.hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - hip.hipRuntimeGetVersion.restype = hipError_t + hipError_t = ctypes.c_int + hipDriverProcAddressQueryResult = ctypes.c_int - hip.hipGetProcAddress.argtypes = [ - ctypes.c_char_p, - ctypes.POINTER(ctypes.c_void_p), - ctypes.c_int, - ctypes.c_uint64, - ctypes.POINTER(hipDriverProcAddressQueryResult), - ] - hip.hipGetProcAddress.restype = hipError_t + hip.hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] + hip.hipRuntimeGetVersion.restype = hipError_t - hip.hipGetErrorString.argtypes = [hipError_t] - hip.hipGetErrorString.restype = ctypes.c_char_p + hip.hipGetProcAddress.argtypes = [ + ctypes.c_char_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_int, + ctypes.c_uint64, + ctypes.POINTER(hipDriverProcAddressQueryResult), + ] + hip.hipGetProcAddress.restype = hipError_t - version = ctypes.c_int() - err = hip.hipRuntimeGetVersion(ctypes.byref(version)) - if err != 0: - err_str = hip.hipGetErrorString(err).decode("utf-8") - print( - f"hipRuntimeGetVersion failed with err code {err} ({err_str})", - file=sys.stderr, - ) - return False - - for symbol in [b"hipAmdFileWrite", b"hipAmdFileRead"]: - func_ptr = ctypes.c_void_p() - symbol_status = hipDriverProcAddressQueryResult() - err = hip.hipGetProcAddress( - symbol, - ctypes.byref(func_ptr), - version.value, - 0, - ctypes.byref(symbol_status), - ) + hip.hipGetErrorString.argtypes = [hipError_t] + hip.hipGetErrorString.restype = ctypes.c_char_p + + version = ctypes.c_int() + err = hip.hipRuntimeGetVersion(ctypes.byref(version)) if err != 0: - if symbol_status.value != 1: - symbol = symbol.decode("utf-8") - err_str = hip.hipGetErrorString(err).decode("utf-8") - print( - f"hipGetProcAddress({symbol}) failed with err code" - f" {err} ({err_str}) and symbolStatus" - f" {symbol_status.value}", - file=sys.stderr, - ) - return False + err_str = hip.hipGetErrorString(err).decode("utf-8") + print( + f"hipRuntimeGetVersion failed with err code {err} ({err_str})", + file=sys.stderr, + ) + continue - return True + # Track whether all required AIS symbols are available in this library + supported = True + + for symbol in [b"hipAmdFileWrite", b"hipAmdFileRead"]: + func_ptr = ctypes.c_void_p() + symbol_status = hipDriverProcAddressQueryResult() + err = hip.hipGetProcAddress( + symbol, + ctypes.byref(func_ptr), + version.value, + 0, + ctypes.byref(symbol_status), + ) + if err != 0: + if symbol_status.value != 1: + symbol = symbol.decode("utf-8") + err_str = hip.hipGetErrorString(err).decode("utf-8") + print( + f"hipGetProcAddress({symbol}) failed with err code" + f" {err} ({err_str}) and symbolStatus" + f" {symbol_status.value}", + file=sys.stderr, + ) + supported = False + break + + if supported: + hip_libraries[hip_path] = True + + return any(hip_libraries.values()) def amdgpu_supports_ais(): @@ -154,8 +225,17 @@ def main(): u = os.uname() print() print(u.sysname, u.nodename, u.release, u.version, u.machine) + print() + print("Found these HIP libraries (some may be symlinks):") + for lib, support in hip_libraries.items(): + if support: + pretty_supported = "supported" + else: + pretty_supported = "NOT supported" + print(f"\t{lib} (AIS {pretty_supported})") + print() print("AIS support in:") for name, supported in component_support: print(f"\t{name:<24}: {supported}")