From ba5c37f9abb6f5b65f2bae600e6744e56ae46f3c Mon Sep 17 00:00:00 2001 From: Tom Turney Date: Sat, 11 Apr 2026 22:28:42 -0500 Subject: [PATCH] Add register_stream to allow cross-thread GPU stream usage --- mlx/stream.cpp | 11 +++++++++++ mlx/stream.h | 16 ++++++++++++++++ python/src/stream.cpp | 17 +++++++++++++++++ python/tests/test_threads.py | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+) diff --git a/mlx/stream.cpp b/mlx/stream.cpp index ee1db01629..d6d177f69a 100644 --- a/mlx/stream.cpp +++ b/mlx/stream.cpp @@ -3,6 +3,7 @@ #include "mlx/stream.h" #include "mlx/backend/cpu/device_info.h" #include "mlx/backend/gpu/device_info.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/scheduler.h" #include @@ -57,6 +58,16 @@ std::vector get_streams() { return streams; } +void register_stream(Stream s) { + if (s.device == Device::gpu) { + if (!gpu::is_available()) { + throw std::invalid_argument( + "[register_stream] Cannot register gpu stream without gpu backend."); + } + gpu::new_stream(s); + } +} + Stream new_stream(Device d) { if (!gpu::is_available() && d == Device::gpu) { throw std::invalid_argument( diff --git a/mlx/stream.h b/mlx/stream.h index 54f7d82015..7370b0f35b 100644 --- a/mlx/stream.h +++ b/mlx/stream.h @@ -30,6 +30,22 @@ MLX_API void set_default_stream(Stream s); /** Make a new stream on the given device. */ MLX_API Stream new_stream(Device d); +/** + * Register an existing stream on the calling thread. + * + * GPU streams use thread-local command encoders. When a stream is created, + * its encoder is registered only on the creating thread. If a different + * thread later calls eval() on arrays from that stream, it will fail + * because the encoder does not exist on the new thread. + * + * Call this function on any thread that needs to eval() arrays from a + * stream that was created on a different thread. + * + * Safe to call multiple times or on the creating thread (no-op if already + * registered). + */ +MLX_API void register_stream(Stream s); + /** Get all available streams. */ MLX_API std::vector get_streams(); diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 17e16a4145..28a55279cc 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -133,6 +133,23 @@ void init_stream(nb::module_& m) { &mx::new_stream, "device"_a, R"pbdoc(Make a new stream on the given device.)pbdoc"); + m.def( + "register_stream", + &mx::register_stream, + "stream"_a, + R"pbdoc( + Register a stream on the calling thread. + + GPU streams use thread-local command encoders. When a stream is + created, its encoder is only registered on the creating thread. + Call this on any new thread that needs to evaluate arrays from + that stream. + + Safe to call multiple times (no-op if already registered). + + Args: + stream (Stream): The stream to register. + )pbdoc"); nb::class_(m, "StreamContext", R"pbdoc( A context manager for setting the current device and stream. diff --git a/python/tests/test_threads.py b/python/tests/test_threads.py index 5e125469d7..7235998ecb 100644 --- a/python/tests/test_threads.py +++ b/python/tests/test_threads.py @@ -39,6 +39,39 @@ def test_success(): t1.join() t2.join() + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") + def test_register_stream_cross_thread(self): + """register_stream allows eval on a thread that did not create the stream.""" + s = mx.new_stream(mx.gpu) + x = mx.ones((4, 4), stream=s) + y = mx.abs(x, stream=s) + + errors = [] + + def eval_on_thread(): + try: + mx.register_stream(s) + mx.eval(y) + except Exception as e: + errors.append(e) + + t = threading.Thread(target=eval_on_thread) + t.start() + t.join() + + self.assertEqual(len(errors), 0, f"eval failed on new thread: {errors}") + self.assertTrue(mx.array_equal(y, mx.ones((4, 4)))) + + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") + def test_register_stream_idempotent(self): + """Calling register_stream multiple times does not error.""" + s = mx.new_stream(mx.gpu) + mx.register_stream(s) + mx.register_stream(s) + x = mx.ones((3,), stream=s) + mx.eval(x) + self.assertEqual(x.tolist(), [1.0, 1.0, 1.0]) + if __name__ == "__main__": mlx_tests.MLXTestRunner()