From 9d1bd7cab009cdbd5f57a7c0fbb8fb5a4ad1ebd8 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:35:04 -0300 Subject: [PATCH 1/2] Basic gall_gather implementation --- exla/lib/exla/defn.ex | 18 ++++++++++++++++++ exla/lib/exla/mlir/value.ex | 27 +++++++++++++++++++++++++++ nx/lib/nx/defn/evaluator.ex | 5 +++++ nx/lib/nx/defn/expr.ex | 27 +++++++++++++++++++++++++++ nx/lib/nx/defn/kernel.ex | 26 ++++++++++++++++++++++++++ 5 files changed, 103 insertions(+) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 592a8279b4..4cd674ab97 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1471,6 +1471,24 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end +## to_operator collective ops + + defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do + all_gather_dim = Keyword.fetch!(opts, :all_gather_dim) + replica_groups = Keyword.fetch!(opts, :replica_groups) + use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false) + + Value.all_gather( + [tensor], + expr_to_typespec(ans), + all_gather_dim, + replica_groups, + use_global_device_ids, + Keyword.take(opts, [:channel_id]) + ) + |> hd() + end + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 393b6d57a8..9b6822c6dd 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,6 +64,33 @@ defmodule EXLA.MLIR.Value do end end + def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do + result_types = typespecs_to_mlir_types([typespec]) + + opts = + Keyword.validate!(opts, [ + channel_id: nil, + ]) + + num_groups = length(replica_groups) + group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 + flat_groups = List.flatten(replica_groups) + + attributes = [ + all_gather_dim: attr_i64(all_gather_dim), + replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), + use_global_device_ids: attr_boolean(use_global_device_ids) + ] + + attributes = + if opts[:channel_id] do + attributes ++ [channel_id: attr_i64(opts[:channel_id])] + else + attributes end + + op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes) + end + defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do %{type: lhs_type} = get_typespec(lhs) %{type: rhs_type} = get_typespec(rhs) diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 601f750942..5e33925625 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -478,6 +478,11 @@ defmodule Nx.Defn.Evaluator do {Nx.Shared.list_impl!(args), [ans | args]} end + if op == :all_gather and not function_exported?(mod, :all_gather, 3) do + raise ArgumentError, + "all_gather/3 is not supported by backend #{inspect(mod)}." + end + {apply(mod, op, args), caches} end diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 899b430da4..ab11c46af1 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1166,6 +1166,33 @@ defmodule Nx.Defn.Expr do expr(out, context, :gather, [tensor, indices, opts]) end + def all_gather(tensor, opts) do + {[tensor], context} = to_exprs([tensor]) + + _all_gather_dim = opts[:all_gather_dim] + replica_groups = opts[:replica_groups] + + # Calculate group size (number of replicas per group) + _group_size = + case replica_groups do + [first_group | _] -> length(first_group) + [] -> 1 + end + + # Calculate output shape by multiplying the gather dimension by group_size + input_shape = tensor.shape + output_shape = + input_shape +# |> Tuple.to_list() +# |> List.update_at(all_gather_dim, &(&1 * group_size)) +# |> List.to_tuple() + + # Create output tensor with the new shape + out = %{tensor | shape: output_shape} + + expr(out, context, :all_gather, [tensor, opts]) + end + @impl true def reverse(out, tensor, axes) do tensor = to_expr(tensor) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index ab913ab61f..a0cf4f4493 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1669,6 +1669,32 @@ defmodule Nx.Defn.Kernel do end end + @doc """ + Gathers tensors from all replicas along a specified dimension. + + This operation concatenates tensors from multiple replicas/devices along + the specified dimension. Requires a backend that supports multi-device operations. + + ## Parameters + + * `tensor` - The input tensor to gather + * `all_gather_dim` - The dimension along which to gather + * `replica_groups` - 2D list defining how replicas are grouped (required) + * `opts` - Optional keyword list: + * `:use_global_device_ids` - Whether to use global device IDs (default: false) + * `:channel_id` - Channel ID for communication (optional) + + ## Examples + + all_gather(tensor, 0, [[0, 1, 2, 3]]) + all_gather(tensor, 1, [[0, 1], [2, 3]], use_global_device_ids: true) + """ + def all_gather(tensor, all_gather_dim, replica_groups, opts \\ []) do + opts = Keyword.put(opts, :all_gather_dim, all_gather_dim) + opts = Keyword.put(opts, :replica_groups, replica_groups) + Nx.Defn.Expr.all_gather(tensor, opts) + end + @definitions (Module.definitions_in(__MODULE__, :def) ++ Module.definitions_in(__MODULE__, :defmacro)) -- [ From cc8761d4fa10441bfbe658dd30610a9d1d1c74c1 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 2 Feb 2026 19:33:10 -0300 Subject: [PATCH 2/2] changes due to code review by @polvalente --- exla/lib/exla.ex | 15 ++++++ exla/lib/exla/defn.ex | 21 ++++---- exla/lib/exla/mlir/value.ex | 30 +++++------ exla/test/exla/defn/sharding_test.exs | 73 ++++++++++++++++++++++++++- nx/lib/nx/defn/evaluator.ex | 5 -- nx/lib/nx/defn/kernel.ex | 14 ++--- nx/test/nx/defn_test.exs | 13 +++++ 7 files changed, 128 insertions(+), 43 deletions(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 78c9016361..403c6fbe76 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -215,6 +215,21 @@ defmodule EXLA do The metadata is: * `:key` - the compilation key for debugging + + ## Sharding + + EXLA supports sharding, which is a way to partition a computation across multiple devices. + There are a number of collective operations that are supported by sharding. + + ### [`all_gather`](https://openxla.org/stablehlo/spec#all_gather) + + #### Options + + * `:all_gather_dim` - the dimension along which to gather + * `:replica_groups` - 2D list defining how replicas are grouped + * `:use_global_device_ids` - Whether to use global device IDs (default: `false`) + * `:channel_id` - Channel ID for communication (optional) + """ @behaviour Nx.Defn.Compiler diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 4cd674ab97..65239373ba 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1478,15 +1478,18 @@ defmodule EXLA.Defn do replica_groups = Keyword.fetch!(opts, :replica_groups) use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false) - Value.all_gather( - [tensor], - expr_to_typespec(ans), - all_gather_dim, - replica_groups, - use_global_device_ids, - Keyword.take(opts, [:channel_id]) - ) - |> hd() + # We might want to surface all_gather as an operation that takes a container of operands instead of a single one. + [result] = + Value.all_gather( + [tensor], + expr_to_typespec(ans), + all_gather_dim, + replica_groups, + use_global_device_ids, + opts[:channel_id] + ) + + result end defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 9b6822c6dd..e548693497 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,29 +64,25 @@ defmodule EXLA.MLIR.Value do end end - def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do + def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, channel_id \\ nil) do result_types = typespecs_to_mlir_types([typespec]) - opts = - Keyword.validate!(opts, [ - channel_id: nil, - ]) + num_groups = length(replica_groups) + group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 + flat_groups = List.flatten(replica_groups) - num_groups = length(replica_groups) - group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0 - flat_groups = List.flatten(replica_groups) - - attributes = [ - all_gather_dim: attr_i64(all_gather_dim), - replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), - use_global_device_ids: attr_boolean(use_global_device_ids) - ] + attributes = [ + all_gather_dim: attr_i64(all_gather_dim), + replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}), + use_global_device_ids: attr_boolean(use_global_device_ids) + ] attributes = - if opts[:channel_id] do - attributes ++ [channel_id: attr_i64(opts[:channel_id])] + if channel_id do + Keyword.put(attributes, :channel_id, attr_i64(channel_id)) else - attributes end + attributes + end op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes) end diff --git a/exla/test/exla/defn/sharding_test.exs b/exla/test/exla/defn/sharding_test.exs index ed46f76b6a..058e2683b6 100644 --- a/exla/test/exla/defn/sharding_test.exs +++ b/exla/test/exla/defn/sharding_test.exs @@ -6,7 +6,8 @@ defmodule EXLA.Defn.ShardingTest do describe "MLIR module generation with sharding" do @moduletag :multi_device test "generates correct MLIR with simple 2D mesh and sharding" do - fun = fn x, y -> Nx.add(x, y) end + fun = fn x, y -> Nx.add(x, y) + end mesh = %Mesh{name: "mesh", shape: {2, 2}} # First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 @@ -737,5 +738,75 @@ defmodule EXLA.Defn.ShardingTest do assert result.mlir_module =~ ~r/"axis_0"/ assert result.mlir_module =~ ~r/"axis_1"/ end + + @moduletag :multi_device + test "generates correct MLIR with all_gather" do + fun = fn x, y -> Nx.add(x, y) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]]) + |> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]]) + end + + mesh = %Mesh{name: "mesh", shape: {2, 2}} + # First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1 + # Second arg: shard dim 0 on mesh axis 0, dim 1 not sharded + input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0]}] + + # For mesh {2, 2}, we have 4 partitions + # Each partition gets a shard of the inputs + # First input: shape {8, 2} sharded as [[0], [1]] -> each partition gets {4, 1} + # Second input: shape {8, 1} sharded as [[0], []] -> each partition gets {4, 1} + args = [ + # partition 0 + [Nx.iota({4, 1}), Nx.iota({4, 1})], + # partition 1 + [Nx.iota({4, 1}), Nx.iota({4, 1})], + # partition 2 + [Nx.iota({4, 1}), Nx.iota({4, 1})], + # partition 3 + [Nx.iota({4, 1}), Nx.iota({4, 1})] + ] + + result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) + + expected_mlir = """ + module { + sdy.mesh @mesh = <["axis_0"=2, "axis_1"=2]> + func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x1xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {?}p0]>}) -> tensor<8x2xi32> { + %0 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<8x1xi32>) -> tensor<8x2xi32> + %1 = stablehlo.add %arg0, %0 : tensor<8x2xi32> + %2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + %3 = "stablehlo.all_gather"(%2) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %3 : tensor<8x2xi32> + } + } + """ + + assert expected_mlir == result.mlir_module + + results = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args) + + assert length(results) == 4 + + # After all_gather on both dims, each partition has the full tensor: add(iota, iota) -> 2*iota + # Each shard had iota({4,1}) = [[0],[1],[2],[3]], so add gives [[0],[2],[4],[6]] + # After gathering: replicated 8x2 with pattern [[0,0],[2,2],[4,4],[6,6],[0,0],[2,2],[4,4],[6,6]] + expected_result = + Nx.tensor([ + [0, 0], + [2, 2], + [4, 4], + [6, 6], + [0, 0], + [2, 2], + [4, 4], + [6, 6] + ]) + + for r <- results do + assert_equal(r, expected_result) + end + end + + end end diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 5e33925625..601f750942 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -478,11 +478,6 @@ defmodule Nx.Defn.Evaluator do {Nx.Shared.list_impl!(args), [ans | args]} end - if op == :all_gather and not function_exported?(mod, :all_gather, 3) do - raise ArgumentError, - "all_gather/3 is not supported by backend #{inspect(mod)}." - end - {apply(mod, op, args), caches} end diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index a0cf4f4493..809a66b480 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1678,20 +1678,12 @@ defmodule Nx.Defn.Kernel do ## Parameters * `tensor` - The input tensor to gather - * `all_gather_dim` - The dimension along which to gather - * `replica_groups` - 2D list defining how replicas are grouped (required) - * `opts` - Optional keyword list: - * `:use_global_device_ids` - Whether to use global device IDs (default: false) - * `:channel_id` - Channel ID for communication (optional) - ## Examples + * `opts` - Optional keyword list. These are backend- and compiler-specific; + see your backend or compiler docs for supported options. - all_gather(tensor, 0, [[0, 1, 2, 3]]) - all_gather(tensor, 1, [[0, 1], [2, 3]], use_global_device_ids: true) """ - def all_gather(tensor, all_gather_dim, replica_groups, opts \\ []) do - opts = Keyword.put(opts, :all_gather_dim, all_gather_dim) - opts = Keyword.put(opts, :replica_groups, replica_groups) + def all_gather(tensor, opts \\ []) do Nx.Defn.Expr.all_gather(tensor, opts) end diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index 62993b07a3..621b4f4e77 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -2952,4 +2952,17 @@ defmodule Nx.DefnTest do assert vectorized_metadata_tuple(x, z) == vec_nonvec_result end end + + describe "sharding" do + defn all_gather_test(tensor) do + Nx.Defn.Kernel.all_gather(tensor, all_gather_dim: 0, replica_groups: [[0]]) + end + + @tag compiler: Evaluator + test "all_gather works" do + assert_raise UndefinedFunctionError, fn -> + all_gather_test(Nx.tensor([1, 2, 3, 4])) + end + end + end end