Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlx/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlx/stream.h"
#include "mlx/backend/cpu/device_info.h"
#include "mlx/backend/gpu/device_info.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/scheduler.h"

#include <array>
Expand Down Expand Up @@ -57,6 +58,16 @@ std::vector<Stream> get_streams() {
return streams;
}

void register_stream(Stream s) {
if (s.device == Device::gpu) {
if (!gpu::is_available()) {
throw std::invalid_argument(
"[register_stream] Cannot register gpu stream without gpu backend.");
}
gpu::new_stream(s);
}
}

Stream new_stream(Device d) {
if (!gpu::is_available() && d == Device::gpu) {
throw std::invalid_argument(
Expand Down
16 changes: 16 additions & 0 deletions mlx/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ MLX_API void set_default_stream(Stream s);
/** Make a new stream on the given device. */
MLX_API Stream new_stream(Device d);

/**
* Register an existing stream on the calling thread.
*
* GPU streams use thread-local command encoders. When a stream is created,
* its encoder is registered only on the creating thread. If a different
* thread later calls eval() on arrays from that stream, it will fail
* because the encoder does not exist on the new thread.
*
* Call this function on any thread that needs to eval() arrays from a
* stream that was created on a different thread.
*
* Safe to call multiple times or on the creating thread (no-op if already
* registered).
*/
MLX_API void register_stream(Stream s);

/** Get all available streams. */
MLX_API std::vector<Stream> get_streams();

Expand Down
17 changes: 17 additions & 0 deletions python/src/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ void init_stream(nb::module_& m) {
&mx::new_stream,
"device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc");
m.def(
"register_stream",
&mx::register_stream,
"stream"_a,
R"pbdoc(
Register a stream on the calling thread.

GPU streams use thread-local command encoders. When a stream is
created, its encoder is only registered on the creating thread.
Call this on any new thread that needs to evaluate arrays from
that stream.

Safe to call multiple times (no-op if already registered).

Args:
stream (Stream): The stream to register.
)pbdoc");

nb::class_<PyStreamContext>(m, "StreamContext", R"pbdoc(
A context manager for setting the current device and stream.
Expand Down
33 changes: 33 additions & 0 deletions python/tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,39 @@ def test_success():
t1.join()
t2.join()

@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_register_stream_cross_thread(self):
"""register_stream allows eval on a thread that did not create the stream."""
s = mx.new_stream(mx.gpu)
x = mx.ones((4, 4), stream=s)
y = mx.abs(x, stream=s)

errors = []

def eval_on_thread():
try:
mx.register_stream(s)
mx.eval(y)
except Exception as e:
errors.append(e)

t = threading.Thread(target=eval_on_thread)
t.start()
t.join()

self.assertEqual(len(errors), 0, f"eval failed on new thread: {errors}")
self.assertTrue(mx.array_equal(y, mx.ones((4, 4))))

@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_register_stream_idempotent(self):
"""Calling register_stream multiple times does not error."""
s = mx.new_stream(mx.gpu)
mx.register_stream(s)
mx.register_stream(s)
x = mx.ones((3,), stream=s)
mx.eval(x)
self.assertEqual(x.tolist(), [1.0, 1.0, 1.0])


if __name__ == "__main__":
mlx_tests.MLXTestRunner()