Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlx/backend/cpu/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ struct StridedIterator {
return *this;
}

StridedIterator operator+(difference_type diff) {
StridedIterator operator+(difference_type diff) const {
return StridedIterator(ptr_, stride_, diff);
}

StridedIterator operator-(difference_type diff) {
StridedIterator operator-(difference_type diff) const {
return StridedIterator(ptr_, stride_, -diff);
}

Expand Down
8 changes: 5 additions & 3 deletions mlx/backend/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,19 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
worker_(std::make_shared<Worker>(d)),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
worker_->start();
}

CommandEncoder::~CommandEncoder() {
synchronize();
worker_->stop();
}

void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
worker_->add_task(std::move(task));
}

void CommandEncoder::set_input_array(const array& arr) {
Expand Down Expand Up @@ -528,7 +530,7 @@ void CommandEncoder::commit() {
}

// Put completion handlers in a batch.
worker_.commit(stream_);
worker_->commit(stream_);
node_count_ = 0;
bytes_in_graph_ = 0;
}
Expand Down
7 changes: 5 additions & 2 deletions mlx/backend/cuda/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h"

#include <memory>
#include <unordered_map>

namespace mlx::core::cu {

// Compute a key and updatability flag for a CUDA graph by walking its nodes.
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph);

class Worker;

class CommandEncoder {
public:
struct CaptureContext {
Expand Down Expand Up @@ -136,7 +139,7 @@ class CommandEncoder {
Device& device_;
CudaStream stream_;
CudaGraph graph_;
Worker worker_;
std::shared_ptr<Worker> worker_;
int node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/cuda/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/cublas_utils.h"
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"

Expand All @@ -16,7 +16,7 @@ void init() {
// Force initalization of CUDA, so CUDA runtime get destroyed last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
mlx::core::cu::CudaEvent::init_pool();
}

void new_stream(Stream s) {
Expand Down
7 changes: 4 additions & 3 deletions mlx/backend/cuda/quantized/qmm/cute_dequant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cute/numeric/numeric_types.hpp>
#include <cute/tensor.hpp>
#include <cutlass/numeric_conversion.h>
#include <cuda/std/array>

namespace cutlass {

Expand Down Expand Up @@ -109,13 +110,13 @@ namespace cute {

// Required by tiled copy for 3/5/6-bit weights.
struct uint24_t {
std::array<std::uint8_t, 3> bytes;
cuda::std::array<std::uint8_t, 3> bytes;
};
struct uint40_t {
std::array<std::uint8_t, 5> bytes;
cuda::std::array<std::uint8_t, 5> bytes;
};
struct uint48_t {
std::array<std::uint8_t, 6> bytes;
cuda::std::array<std::uint8_t, 6> bytes;
};

template <>
Expand Down
17 changes: 13 additions & 4 deletions mlx/backend/cuda/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@ namespace mlx::core::cu {

Worker::Worker(Device& d)
: signal_stream_(d),
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {}
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync) {}

Worker::~Worker() {
Worker::~Worker() = default;

void Worker::start() {
// Note that |shared_from_this| can not be called in constructor.
worker_ = std::thread(&Worker::thread_fn, shared_from_this());
// Detach the thread and let it free itself after finishing tasks.
// This is to avoid deadlock when joining threads on exit on Windows:
// https://developercommunity.visualstudio.com/t/1654756
worker_.detach();
}

void Worker::stop() {
{
std::lock_guard lock(mtx_);
stop_ = true;
}
cond_.notify_one();
worker_.join();
}

void Worker::add_task(std::function<void()> task) {
Expand Down
6 changes: 5 additions & 1 deletion mlx/backend/cuda/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@
#include <condition_variable>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <thread>

namespace mlx::core::cu {

// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
class Worker : public std::enable_shared_from_this<Worker> {
public:
explicit Worker(Device& d);
~Worker();

Worker(const Worker&) = delete;
Worker& operator=(const Worker&) = delete;

void start();
void stop();

// Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task);

Expand Down
1 change: 1 addition & 0 deletions mlx/distributed/nccl/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <mutex>
#include <stdexcept>
#include <string>
#include <thread>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgotten?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file implicitly included the thread via device.h => worker.h, which this PR removes from device.h's includes.

#include <type_traits>

#include "mlx/backend/cuda/device.h"
Expand Down
Loading