From daf095a43a99c8a27bba4a6e19b147ab9a4a1eb2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 25 Mar 2026 11:10:44 -0700 Subject: [PATCH 1/2] Remove BatchMatmul. --- .../include/kernels/batch_matmul_kernels.h | 32 -- .../kernels/batch_matmul_kernels_cpu.h | 24 -- .../kernels/batch_matmul_kernels_gpu.h | 38 --- .../src/cuda/ops/batch_matmul_kernels.cu | 145 --------- .../src/hip/ops/batch_matmul_kernels.cpp | 159 --------- .../src/kernels/batch_matmul_kernels.cc | 120 ------- .../src/kernels/batch_matmul_kernels_cpu.cc | 23 -- .../test/src/test_batch_matmul_kernel.cc | 81 ----- .../local-execution/local_task_registry.cc | 4 - .../computation_graph_op_attrs.dtg.toml | 13 +- lib/op-attrs/include/op-attrs/get_op_type.h | 2 - lib/op-attrs/include/op-attrs/is_valid.h | 3 - .../include/op-attrs/ops/batch_matmul.h | 26 -- .../op-attrs/ops/batch_matmul_attrs.dtg.toml | 31 -- .../op-attrs/pcg_operator_attrs.dtg.toml | 17 +- .../src/op-attrs/get_incoming_tensor_roles.cc | 6 - lib/op-attrs/src/op-attrs/get_op_type.cc | 3 - lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 181 ----------- lib/op-attrs/src/op-attrs/shape_inference.cc | 57 +--- .../test/src/op-attrs/ops/batch_matmul.cc | 304 ------------------ .../parallel_computation_graph_builder.h | 5 - .../parallel_computation_graph_builder.cc | 29 -- .../parallel_computation_graph_builder.cc | 72 ----- .../realm-execution/tasks/task_id_t.dtg.toml | 6 - .../tasks/realm_task_registry.cc | 2 - .../src/realm-execution/tasks/task_id_t.cc | 7 - .../operator_pattern/get_attribute.h | 2 - .../operator_pattern/get_attribute.cc | 10 - .../include/task-spec/ops/impl/batch_matmul.h | 14 - .../src/task-spec/ops/impl/batch_matmul.cc | 95 ------ 30 files changed, 18 insertions(+), 1493 deletions(-) delete mode 100644 lib/kernels/include/kernels/batch_matmul_kernels.h delete mode 100644 lib/kernels/include/kernels/batch_matmul_kernels_cpu.h delete mode 100644 lib/kernels/include/kernels/batch_matmul_kernels_gpu.h delete mode 100644 lib/kernels/src/cuda/ops/batch_matmul_kernels.cu delete mode 100644 lib/kernels/src/hip/ops/batch_matmul_kernels.cpp delete mode 100644 lib/kernels/src/kernels/batch_matmul_kernels.cc delete mode 100644 lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc delete mode 100644 lib/kernels/test/src/test_batch_matmul_kernel.cc delete mode 100644 lib/op-attrs/include/op-attrs/ops/batch_matmul.h delete mode 100644 lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml delete mode 100644 lib/op-attrs/src/op-attrs/ops/batch_matmul.cc delete mode 100644 lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc delete mode 100644 lib/task-spec/include/task-spec/ops/impl/batch_matmul.h delete mode 100644 lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h deleted file mode 100644 index d54663f110..0000000000 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H -#define _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H - -#include "kernels/accessor.h" -#include "kernels/device_handle_t.dtg.h" -#include "kernels/device_stream_t.dtg.h" -#include "kernels/ff_handle.h" -#include "utils/nonnegative_int/nonnegative_int.h" - -namespace FlexFlow::Kernels::BatchMatmul { - -void forward_kernel(device_stream_t const &stream, - device_handle_t const &handle, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorR const &input_b, - positive_int seq_length, - std::optional a_seq_length_dim, - std::optional b_seq_length_dim); - -void backward_kernel(device_stream_t const &stream, - device_handle_t const &handle, - GenericTensorAccessorR const &output, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorW const &input_a_grad, - GenericTensorAccessorR const &input_b, - GenericTensorAccessorW const &input_b_grad); - -} // namespace FlexFlow::Kernels::BatchMatmul - -#endif diff --git a/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h b/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h deleted file mode 100644 index 6d9c804be2..0000000000 --- a/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_BATCH_MATMUL_KERNELS_CPU_H -#define _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_BATCH_MATMUL_KERNELS_CPU_H - -#include "kernels/allocation.h" - -namespace FlexFlow::Kernels::BatchMatmul { - -void cpu_forward_kernel(GenericTensorAccessorW const &output, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorR const &input_b, - positive_int seq_length, - std::optional a_seq_length_dim, - std::optional b_seq_length_dim); - -void cpu_backward_kernel(GenericTensorAccessorR const &output, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorW const &input_a_grad, - GenericTensorAccessorR const &input_b, - GenericTensorAccessorW const &input_b_grad); - -} // namespace FlexFlow::Kernels::BatchMatmul - -#endif diff --git a/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h b/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h deleted file mode 100644 index 1e13755b81..0000000000 --- a/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_BATCH_MATMUL_KERNELS_GPU_H -#define _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_BATCH_MATMUL_KERNELS_GPU_H - -#include "kernels/allocation.h" -#include "kernels/device.h" -#include "kernels/ff_handle.h" - -namespace FlexFlow::Kernels::BatchMatmul { - -void gpu_forward_kernel(ffStream_t stream, - PerDeviceFFHandle const &handle, - float *output_ptr, - float const *input_a_ptr, - float const *input_b_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim); - -void gpu_backward_kernel(ffStream_t stream, - PerDeviceFFHandle const &handle, - float const *output_ptr, - float const *output_grad_ptr, - float const *input_a_ptr, - float *input_a_grad_ptr, - float const *input_b_ptr, - float *input_b_grad_ptr, - int m, - int n, - int k, - int batch); - -} // namespace FlexFlow::Kernels::BatchMatmul - -#endif diff --git a/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu b/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu deleted file mode 100644 index 39f5beea21..0000000000 --- a/lib/kernels/src/cuda/ops/batch_matmul_kernels.cu +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "internal/device.h" -#include "kernels/batch_matmul_kernels_gpu.h" - -namespace FlexFlow { -namespace Kernels { -namespace BatchMatmul { - -void gpu_forward_kernel(cudaStream_t stream, - PerDeviceFFHandle const &handle, - float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int a_seq_length_dim, - int b_seq_length_dim, - int seq_length) { - checkCUBLAS(cublasSetStream(handle.blas, stream)); - checkCUDNN(cudnnSetStream(handle.dnn, stream)); - int lda = k; - int ldb = m; - int ldo = m; - long long int strideA = (long long int)n * k; - long long int strideB = (long long int)k * m; - long long int strideO = (long long int)n * m; - if ((a_seq_length_dim == 0) && (seq_length >= 0)) { - assert(seq_length <= k); - k = seq_length; - assert(b_seq_length_dim == 1); - } else if ((a_seq_length_dim == 1) && (seq_length >= 0)) { - assert(seq_length <= n); - n = seq_length; - } else { - // currently only support a_seq_length_dim = 0 or 1 - assert((a_seq_length_dim < 0) || (seq_length < 0)); - } - if ((b_seq_length_dim == 0) && (seq_length >= 0)) { - assert(seq_length <= m); - m = seq_length; - } else if ((b_seq_length_dim == 1) && (seq_length >= 0)) { - assert(a_seq_length_dim == 0); - assert(k == seq_length); - } else { - // currently only support a_seq_length_dim = 0 or 1 - assert((b_seq_length_dim < 0) || (seq_length < 0)); - } - - float alpha = 1.0f, beta = 0.0f; - checkCUBLAS(cublasSgemmStridedBatched(handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_N, - m, - n, - k, - &alpha, - b_input_ptr, - ldb, - strideB, - a_input_ptr, - lda, - strideA, - &beta, - output_ptr, - ldo, - strideO, - batch)); -} - -void gpu_backward_kernel(cudaStream_t stream, - PerDeviceFFHandle const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch) { - checkCUBLAS(cublasSetStream(handle.blas, stream)); - checkCUDNN(cudnnSetStream(handle.dnn, stream)); - - int a_stride = n * k; - int b_stride = m * k; - int o_stride = n * m; - float alpha = 1.0f; - checkCUBLAS(cublasSgemmStridedBatched(handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - k, - n, - m, - &alpha, - b_ptr, - m, - b_stride, - o_grad_ptr, - m, - o_stride, - &alpha, - a_grad_ptr, - k, - a_stride, - batch)); - checkCUBLAS(cublasSgemmStridedBatched(handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m, - k, - n, - &alpha, - o_grad_ptr, - m, - o_stride, - a_ptr, - k, - a_stride, - &alpha, - b_grad_ptr, - m, - b_stride, - batch)); -} - -} // namespace BatchMatmul -} // namespace Kernels -} // namespace FlexFlow diff --git a/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp b/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp deleted file mode 100644 index 6d9ae8a268..0000000000 --- a/lib/kernels/src/hip/ops/batch_matmul_kernels.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernels/batch_matmul_kernels.h" -#include "internal/device.h" -#include - -namespace FlexFlow { -namespace Kernels { -namespace BatchMatmul { - -/* -A: (batch, n, k) -B: (batch, k, m) -O: (batch, n, m) -O = A * B -*/ -void forward_kernel(hipStream_t stream, - PerDeviceFFHandle const &handle, - float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int a_seq_length_dim, - int b_seq_length_dim, - int seq_length) { - checkCUDA(hipblasSetStream(handle.blas, stream)); - checkCUDNN(miopenSetStream(handle.dnn, stream)); - int lda = k; - int ldb = m; - int ldo = m; - long long int strideA = (long long int)n * k; - long long int strideB = (long long int)k * m; - long long int strideO = (long long int)n * m; - if ((a_seq_length_dim == 0) && (seq_length >= 0)) { - assert(seq_length <= k); - k = seq_length; - assert(b_seq_length_dim == 1); - } else if ((a_seq_length_dim == 1) && (seq_length >= 0)) { - assert(seq_length <= n); - n = seq_length; - } else { - // currently only support a_seq_length_dim = 0 or 1 - assert((a_seq_length_dim < 0) || (seq_length < 0)); - } - if ((b_seq_length_dim == 0) && (seq_length >= 0)) { - assert(seq_length <= m); - m = seq_length; - } else if ((b_seq_length_dim == 1) && (seq_length >= 0)) { - assert(a_seq_length_dim == 0); - assert(k == seq_length); - } else { - // currently only support a_seq_length_dim = 0 or 1 - assert((b_seq_length_dim < 0) || (seq_length < 0)); - } - - float alpha = 1.0f, beta = 0.0f; - checkCUDA(hipblasSgemmStridedBatched(handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_N, - m, - n, - k, - &alpha, - b_ptr, - ldb, - strideB, - a_ptr, - lda, - strideA, - &beta, - o_ptr, - ldo, - strideO, - batch)); -} - -/* -A, AGrad: (batch, n, k) -B, BGrad: (batch, k, m) -O, OGrad: (batch, n, m) -AGrad = OGrad * B^T -BGrad = A^T * OGrad -*/ -void backward_kernel(hipStream_t stream, - PerDeviceFFHandle const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch) { - checkCUDA(hipblasSetStream(handle.blas, stream)); - checkCUDNN(miopenSetStream(handle.dnn, stream)); - - int a_stride = n * k; - int b_stride = m * k; - int o_stride = n * m; - float alpha = 1.0f; - checkCUDA(hipblasSgemmStridedBatched(handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - k, - n, - m, - &alpha, - b_ptr, - m, - b_stride, - o_grad_ptr, - m, - o_stride, - &alpha, - a_grad_ptr, - k, - a_stride, - batch)); - checkCUDA(hipblasSgemmStridedBatched(handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m, - k, - n, - &alpha, - o_grad_ptr, - m, - o_stride, - a_ptr, - k, - a_stride, - &alpha, - b_grad_ptr, - m, - b_stride, - batch)); -} - -} // namespace BatchMatmul -} // namespace Kernels -} // namespace FlexFlow diff --git a/lib/kernels/src/kernels/batch_matmul_kernels.cc b/lib/kernels/src/kernels/batch_matmul_kernels.cc deleted file mode 100644 index a6ac364900..0000000000 --- a/lib/kernels/src/kernels/batch_matmul_kernels.cc +++ /dev/null @@ -1,120 +0,0 @@ -#include "kernels/batch_matmul_kernels.h" -#include "kernels/batch_matmul_kernels_cpu.h" -#include "kernels/batch_matmul_kernels_gpu.h" -#include "utils/containers/require_same.h" - -namespace FlexFlow::Kernels::BatchMatmul { - -static std::tuple - get_params(TensorDims const &input_a_dims, - TensorDims const &input_b_dims, - TensorDims const &output_dims) { - positive_int m = require_same(dim_at_idx(input_b_dims, relative_ff_dim_t{-1}), - dim_at_idx(output_dims, relative_ff_dim_t{-1})); - - positive_int n = require_same(dim_at_idx(input_a_dims, relative_ff_dim_t{-2}), - dim_at_idx(output_dims, relative_ff_dim_t{-2})); - - positive_int k = - require_same(dim_at_idx(input_a_dims, relative_ff_dim_t{-1}), - dim_at_idx(input_b_dims, relative_ff_dim_t{-2})); - - TensorDims leading_dims = require_same( - slice_tensor_dims( - input_a_dims, relative_ff_dim_t{0}, relative_ff_dim_t{-2}), - slice_tensor_dims( - input_b_dims, relative_ff_dim_t{0}, relative_ff_dim_t{-2})); - - positive_int batch = get_num_elements(leading_dims); - - return {m, n, k, batch}; -} - -void forward_kernel(device_stream_t const &stream, - device_handle_t const &handle, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorR const &input_b, - positive_int seq_length, - std::optional a_seq_length_dim, - std::optional b_seq_length_dim) { - - auto [m, n, k, batch] = - get_params(input_a.shape.dims, input_b.shape.dims, output.shape.dims); - - auto get_raw_seq_len = [](std::optional seq_len) -> int { - return transform(seq_len, - [](positive_int x) { return x.int_from_positive_int(); }) - .value_or(-1); - }; - - if (stream.is_gpu()) { - gpu_forward_kernel( - /*stream=*/stream.require_gpu(), - /*handle=*/handle.require_for_gpu(), - /*output_ptr=*/output.get_float_ptr(), - /*a_input_ptr=*/input_a.get_float_ptr(), - /*b_input_ptr=*/input_b.get_float_ptr(), - /*m=*/m.int_from_positive_int(), - /*n=*/n.int_from_positive_int(), - /*k=*/k.int_from_positive_int(), - /*batch=*/batch.int_from_positive_int(), - /*seq_length=*/seq_length.int_from_positive_int(), - /*a_seq_length_dim=*/get_raw_seq_len(a_seq_length_dim), - /*b_seq_length_dim=*/get_raw_seq_len(b_seq_length_dim)); - } else { - ASSERT(stream.is_cpu()); - ASSERT(handle.is_for_cpu()); - cpu_forward_kernel( - /*output=*/output, - /*input_a=*/input_a, - /*input_b=*/input_b, - /*seq_length=*/seq_length, - /*a_seq_length_dim=*/a_seq_length_dim, - /*b_seq_length_dim=*/b_seq_length_dim); - } -} - -void backward_kernel(device_stream_t const &stream, - device_handle_t const &handle, - GenericTensorAccessorR const &output, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorW const &input_a_grad, - GenericTensorAccessorR const &input_b, - GenericTensorAccessorW const &input_b_grad) { - TensorShape input_a_shape = require_same(input_a.shape, input_a_grad.shape); - TensorShape input_b_shape = require_same(input_b.shape, input_b_grad.shape); - TensorShape output_shape = require_same(output.shape, output_grad.shape); - - auto [m, n, k, batch] = - get_params(input_a_shape.dims, input_b_shape.dims, output_shape.dims); - - if (stream.is_gpu()) { - gpu_backward_kernel( - /*stream=*/stream.require_gpu(), - /*handle=*/handle.require_for_gpu(), - /*output_ptr=*/output.get_float_ptr(), - /*output_grad_ptr=*/output_grad.get_float_ptr(), - /*input_a_ptr=*/input_a.get_float_ptr(), - /*input_a_grad_ptr=*/input_a_grad.get_float_ptr(), - /*input_b_ptr=*/input_b.get_float_ptr(), - /*input_b_grad_ptr=*/input_b_grad.get_float_ptr(), - /*m=*/m.int_from_positive_int(), - /*n=*/n.int_from_positive_int(), - /*k=*/k.int_from_positive_int(), - /*batch=*/batch.int_from_positive_int()); - } else { - ASSERT(stream.is_cpu()); - ASSERT(handle.is_for_cpu()); - cpu_backward_kernel( - /*output=*/output, - /*output_grad=*/output_grad, - /*input_a=*/input_a, - /*input_a_grad=*/input_a_grad, - /*input_b=*/input_b, - /*input_b_grad=*/input_b_grad); - } -} - -} // namespace FlexFlow::Kernels::BatchMatmul diff --git a/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc b/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc deleted file mode 100644 index 292841d19f..0000000000 --- a/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "kernels/batch_matmul_kernels_cpu.h" - -namespace FlexFlow::Kernels::BatchMatmul { - -void cpu_forward_kernel(GenericTensorAccessorW const &output, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorR const &input_b, - positive_int seq_length, - std::optional a_seq_length_dim, - std::optional b_seq_length_dim) { - NOT_IMPLEMENTED(); -} - -void cpu_backward_kernel(GenericTensorAccessorR const &output, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &input_a, - GenericTensorAccessorW const &input_a_grad, - GenericTensorAccessorR const &input_b, - GenericTensorAccessorW const &input_b_grad) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow::Kernels::BatchMatmul diff --git a/lib/kernels/test/src/test_batch_matmul_kernel.cc b/lib/kernels/test/src/test_batch_matmul_kernel.cc deleted file mode 100644 index 8a904b7a0d..0000000000 --- a/lib/kernels/test/src/test_batch_matmul_kernel.cc +++ /dev/null @@ -1,81 +0,0 @@ -#include "internal/test_utils.h" -#include "kernels/batch_matmul_kernels_gpu.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_CUDA_TEST_SUITE) { - TEST_CASE("Test BatchMatmul Kernel") { - positive_int m = 10_p; - positive_int n = 10_p; - positive_int k = 10_p; - positive_int batch = 5_p; - int a_seq_length_dim = -1; - int b_seq_length_dim = -1; - int seq_length = -1; - - ManagedFFStream managed_stream{}; - ManagedPerDeviceFFHandle managed_handle = initialize_single_gpu_handle( - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); - - Allocator allocator = create_local_cuda_memory_allocator(); - - TensorShape input_shape_a = TensorShape{ - TensorDims{FFOrdered{batch, k, m}}, - DataType::FLOAT, - }; - TensorShape input_shape_b = TensorShape{ - TensorDims{FFOrdered{batch, n, k}}, - DataType::FLOAT, - }; - TensorShape output_shape = TensorShape{ - TensorDims{FFOrdered{batch, n, m}}, - DataType::FLOAT, - }; - - GenericTensorAccessorW a_accessor = - create_random_filled_accessor_w(input_shape_a, allocator); - GenericTensorAccessorW b_accessor = - create_random_filled_accessor_w(input_shape_b, allocator); - GenericTensorAccessorW output_accessor = - create_random_filled_accessor_w(output_shape, allocator); - - SUBCASE("gpu_forward_kernel") { - Kernels::BatchMatmul::gpu_forward_kernel(managed_stream.raw_stream(), - managed_handle.raw_handle(), - output_accessor.get_float_ptr(), - a_accessor.get_float_ptr(), - b_accessor.get_float_ptr(), - m.int_from_positive_int(), - n.int_from_positive_int(), - k.int_from_positive_int(), - batch.int_from_positive_int(), - a_seq_length_dim, - b_seq_length_dim, - seq_length); - } - - SUBCASE("gpu_backward_kernel") { - GenericTensorAccessorW o_grad_accessor = - create_random_filled_accessor_w(output_shape, allocator); - GenericTensorAccessorW a_grad_accessor = - allocator.allocate_tensor(input_shape_a); - GenericTensorAccessorW b_grad_accessor = - allocator.allocate_tensor(input_shape_b); - - Kernels::BatchMatmul::gpu_backward_kernel(managed_stream.raw_stream(), - managed_handle.raw_handle(), - output_accessor.get_float_ptr(), - o_grad_accessor.get_float_ptr(), - a_accessor.get_float_ptr(), - a_grad_accessor.get_float_ptr(), - b_accessor.get_float_ptr(), - b_grad_accessor.get_float_ptr(), - m.int_from_positive_int(), - n.int_from_positive_int(), - k.int_from_positive_int(), - batch.int_from_positive_int()); - } - } -} diff --git a/lib/local-execution/src/local-execution/local_task_registry.cc b/lib/local-execution/src/local-execution/local_task_registry.cc index abf6595cf4..4c351b9f02 100644 --- a/lib/local-execution/src/local-execution/local_task_registry.cc +++ b/lib/local-execution/src/local-execution/local_task_registry.cc @@ -2,7 +2,6 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "task-spec/loss_functions.h" #include "task-spec/ops/impl/attention.h" -#include "task-spec/ops/impl/batch_matmul.h" #include "task-spec/ops/impl/batch_norm.h" #include "task-spec/ops/impl/broadcast.h" #include "task-spec/ops/impl/cast.h" @@ -37,7 +36,6 @@ std::optional get_init_task_impl_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { return op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { return std::nullopt; }, [](BatchNormAttrs const &) { return get_batch_norm_init_task_impl(); }, [](BroadcastAttrs const &) { return std::nullopt; }, [](CastAttrs const &) { return std::nullopt; }, @@ -76,7 +74,6 @@ std::optional get_fwd_task_impl_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { return op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { return get_batch_matmul_fwd_task_impl(); }, [](BatchNormAttrs const &) { return get_batch_norm_fwd_task_impl(); }, [](BroadcastAttrs const &) { return get_broadcast_fwd_task_impl(); }, [](CastAttrs const &) { return get_cast_fwd_task_impl(); }, @@ -115,7 +112,6 @@ std::optional get_bwd_task_impl_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { return op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { return get_batch_matmul_bwd_task_impl(); }, [](BatchNormAttrs const &) { return get_batch_norm_bwd_task_impl(); }, [](BroadcastAttrs const &) { return get_broadcast_bwd_task_impl(); }, [](CastAttrs const &) { return get_cast_bwd_task_impl(); }, diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml index c8c646bd19..066c5ce394 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml @@ -11,12 +11,11 @@ features = [ ] includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", "op-attrs/ops/element_binary_attrs.dtg.h", @@ -39,10 +38,6 @@ includes = [ "op-attrs/ops/weight_attrs.dtg.h", ] -[[values]] -type = "::FlexFlow::BatchMatmulAttrs" -key = "batch_matmul" - [[values]] type = "::FlexFlow::BatchNormAttrs" key = "batch_norm" diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 7799900709..5598a666d0 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H #include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/ops/batch_matmul_attrs.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/broadcast_attrs.dtg.h" #include "op-attrs/ops/cast_attrs.dtg.h" @@ -34,7 +33,6 @@ namespace FlexFlow { -OperatorType get_op_type(BatchMatmulAttrs const &); OperatorType get_op_type(BatchNormAttrs const &); OperatorType get_op_type(BroadcastAttrs const &); OperatorType get_op_type(CastAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/is_valid.h b/lib/op-attrs/include/op-attrs/is_valid.h index 2d91307e19..9c3d2d27f5 100644 --- a/lib/op-attrs/include/op-attrs/is_valid.h +++ b/lib/op-attrs/include/op-attrs/is_valid.h @@ -24,9 +24,6 @@ bool is_valid(T const &t, std::vector const &shapes) { bool is_valid_internal(MultiHeadAttentionAttrs const &, std::vector const &); -bool is_valid_internal(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &); bool is_valid_internal(ConcatAttrs const &, std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h deleted file mode 100644 index f17757ac85..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H - -#include "op-attrs/ops/batch_matmul_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/tensor_shape.dtg.h" -#include - -namespace FlexFlow { - -bool is_valid(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); - -tl::expected - get_output_shape(BatchMatmulAttrs const &attrs, - TensorShape const &input_lhs, - TensorShape const &input_rhs); - -tl::expected - get_output_shape(BatchMatmulAttrs const &attrs, - ParallelTensorShape const &input_lhs, - ParallelTensorShape const &input_rhs); -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml deleted file mode 100644 index 7a82d89e8d..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "BatchMatmulAttrs" -type = "struct" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "a_seq_length_dim" -type = "std::optional<::FlexFlow::positive_int>" - -[[fields]] -name = "b_seq_length_dim" -type = "std::optional<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml index 88a65f75c5..3de04b7308 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml @@ -11,13 +11,12 @@ features = [ ] includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", "op-attrs/ops/element_binary_attrs.dtg.h", @@ -43,10 +42,6 @@ includes = [ "op-attrs/ops/weight_attrs.dtg.h", ] -[[values]] -type = "::FlexFlow::BatchMatmulAttrs" -key = "batch_matmul" - [[values]] type = "::FlexFlow::BatchNormAttrs" key = "batch_norm" diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index eec9ae869c..7a21bc33b6 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -21,12 +21,6 @@ std::unordered_map get_incoming_tensor_roles(PCGOperatorAttrs const &pcg_op_attrs) { return pcg_op_attrs .visit>(overload{ - [](BatchMatmulAttrs const &) { - return std::unordered_map{ - {TensorSlotName::LHS_INPUT, IncomingTensorRole::INPUT}, - {TensorSlotName::RHS_INPUT, IncomingTensorRole::INPUT}, - }; - }, [](BatchNormAttrs const &attrs) { return get_batch_norm_incoming_tensor_roles(attrs); }, diff --git a/lib/op-attrs/src/op-attrs/get_op_type.cc b/lib/op-attrs/src/op-attrs/get_op_type.cc index c941098cb8..a93e08877e 100644 --- a/lib/op-attrs/src/op-attrs/get_op_type.cc +++ b/lib/op-attrs/src/op-attrs/get_op_type.cc @@ -2,9 +2,6 @@ namespace FlexFlow { -OperatorType get_op_type(BatchMatmulAttrs const &) { - return OperatorType::BATCHMATMUL; -} OperatorType get_op_type(BatchNormAttrs const &) { return OperatorType::BATCHNORM; } diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc deleted file mode 100644 index 8fb34dc191..0000000000 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ /dev/null @@ -1,181 +0,0 @@ -#include "op-attrs/ops/batch_matmul.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/tensor_dims.h" -#include "utils/exception.h" - -namespace FlexFlow { - -// bool BatchMatmulAttrs::is_valid( -// ParallelTensorShape const &lhs, -// ParallelTensorShape const &rhs) const { -// if (!lhs.is_valid() || !rhs.is_valid()) { -// return false; -// } -// if (lhs.num_dims() != rhs.num_dims()) { -// return false; -// } -// for (int i = lhs.num_dims() - 1; i >= 2; i--) { -// if (lhs.at(i) != rhs.at(i)) { -// return false; -// } -// } -// if (lhs.at(0) != rhs.at(1)) { -// return false; -// } -// -// return true; -// } - -bool is_valid(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} - -tl::expected - get_output_shape(BatchMatmulAttrs const &attrs, - TensorShape const &input_lhs, - TensorShape const &input_rhs) { - /** - * If input_lhs is a (b×n×m) tensor, - * input_rhs is a (b×m×p) tensor, - * out will be a (b×n×p) tensor. - * https://pytorch.org/docs/stable/generated/torch.bmm.html - */ - - if (get_num_dims(input_lhs.dims) != 3) { - return tl::unexpected( - fmt::format("LHS input has incorrect number of shard dims: {} != {}", - get_num_dims(input_lhs.dims), - 3)); - } - if (get_num_dims(input_rhs.dims) != 3) { - return tl::unexpected( - fmt::format("RHS input has incorrect number of shard dims: {} != {}", - get_num_dims(input_rhs.dims), - 3)); - } - if (input_lhs.data_type != input_rhs.data_type) { - return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", - input_lhs.data_type, - input_rhs.data_type)); - } - - positive_int lhs_b = dim_at_idx(input_lhs.dims, relative_ff_dim_t{0}); - positive_int n = dim_at_idx(input_lhs.dims, relative_ff_dim_t{1}); - positive_int lhs_m = dim_at_idx(input_lhs.dims, relative_ff_dim_t{2}); - - positive_int rhs_b = dim_at_idx(input_rhs.dims, relative_ff_dim_t{0}); - positive_int rhs_m = dim_at_idx(input_rhs.dims, relative_ff_dim_t{1}); - positive_int p = dim_at_idx(input_rhs.dims, relative_ff_dim_t{2}); - - if (lhs_b != rhs_b) { - return tl::unexpected( - fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); - } - if (lhs_m != rhs_m) { - return tl::unexpected( - fmt::format("RHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); - } - - return TensorShape{ - TensorDims{ - FFOrdered{ - lhs_b, - n, - p, - }, - }, - input_lhs.data_type, - }; -} - -tl::expected - get_output_shape(BatchMatmulAttrs const &attrs, - ParallelTensorShape const &input_lhs, - ParallelTensorShape const &input_rhs) { - if (num_shard_dims(input_lhs).value != 3) { - return tl::unexpected( - fmt::format("LHS input has incorrect number of shard dims: {} != {}", - num_shard_dims(input_lhs), - 3)); - } - if (num_shard_dims(input_rhs).value != 3) { - return tl::unexpected( - fmt::format("RHS input has incorrect number of shard dims: {} != {}", - num_shard_dims(input_rhs), - 3)); - } - if (input_lhs.data_type != input_rhs.data_type) { - return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", - input_lhs.data_type, - input_rhs.data_type)); - } - - assert(get_total_parallel_degree(input_lhs) == - get_total_parallel_degree(input_rhs)); - - ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, relative_ff_dim_t{0}); - ShardParallelDim n = shard_dim_at_idx(input_lhs, relative_ff_dim_t{1}); - ShardParallelDim lhs_m = shard_dim_at_idx(input_lhs, relative_ff_dim_t{2}); - - ShardParallelDim rhs_b = shard_dim_at_idx(input_rhs, relative_ff_dim_t{0}); - ShardParallelDim rhs_m = shard_dim_at_idx(input_rhs, relative_ff_dim_t{1}); - ShardParallelDim p = shard_dim_at_idx(input_rhs, relative_ff_dim_t{2}); - - if (lhs_b != rhs_b) { - return tl::unexpected( - fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); - } - - if (lhs_m != rhs_m) { - return tl::unexpected( - fmt::format("LHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); - } - - if (get_discard_copy_degree(input_lhs) != - get_sum_degree(input_rhs) * p.degree) { - return tl::unexpected(fmt::format("Unexpected number of replicas in LHS: " - "lhs.= ({}) != rhs.+ ({}) * rhs.p ({})", - get_discard_copy_degree(input_lhs), - get_sum_degree(input_rhs), - p.degree)); - } - - if (get_discard_copy_degree(input_rhs) != - get_sum_degree(input_lhs) * n.degree) { - return tl::unexpected(fmt::format("Unexpected number of replicas in RHS: " - "rhs.= ({}) != lhs.+ ({}) * lhs.n ({})", - get_discard_copy_degree(input_rhs), - get_sum_degree(input_lhs), - n.degree)); - } - - ShardParallelDim output_b = lhs_b; - ShardParallelDim output_n = n; - ShardParallelDim output_p = p; - - positive_int output_discard_copy_degree = 1_p; - positive_int output_sum_degree = - positive_int{get_total_parallel_degree(input_lhs) / - (output_b.degree * output_n.degree * output_p.degree)}; - - ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - output_b, - output_n, - output_p, - }, - ReplicaParallelDimSet{ - SumDegree{output_sum_degree}, - DiscardCopyDegree{output_discard_copy_degree}, - }, - }, - input_lhs.data_type, - }; - - return result; -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/shape_inference.cc b/lib/op-attrs/src/op-attrs/shape_inference.cc index a3f8066dee..84b990127e 100644 --- a/lib/op-attrs/src/op-attrs/shape_inference.cc +++ b/lib/op-attrs/src/op-attrs/shape_inference.cc @@ -1,6 +1,5 @@ #include "op-attrs/shape_inference.h" #include "op-attrs/ops/attention.h" -#include "op-attrs/ops/batch_matmul.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/cast.h" #include "op-attrs/ops/combine.h" @@ -64,21 +63,6 @@ std::unordered_map get_output_shapes( std::unordered_map const &input_shapes) { return op_attrs.visit>( overload{ - [&](BatchMatmulAttrs const &attrs) - -> std::unordered_map { - auto [lhs, rhs] = require_two_keys(input_shapes, - TensorSlotName::LHS_INPUT, - TensorSlotName::RHS_INPUT); - - return { - { - TensorSlotName::OUTPUT, - TensorShape{ - throw_if_unexpected(get_output_shape(attrs, lhs, rhs)), - }, - }, - }; - }, [&](BatchNormAttrs const &attrs) -> std::unordered_map { TensorShape input = @@ -294,7 +278,8 @@ std::unordered_map get_output_shapes( [&](auto const &attrs) -> std::unordered_map { NOT_IMPLEMENTED(); - }}); + }, + }); } std::unordered_map get_weight_shapes( @@ -302,14 +287,6 @@ std::unordered_map get_weight_shapes( std::unordered_map const &input_shapes) { return op_attrs.visit>( overload{ - [&](BatchMatmulAttrs const &attrs) - -> std::unordered_map { - require_two_keys(input_shapes, - TensorSlotName::LHS_INPUT, - TensorSlotName::RHS_INPUT); - - return {}; - }, [&](BatchNormAttrs const &attrs) -> std::unordered_map { TensorShape input = @@ -427,7 +404,8 @@ std::unordered_map get_weight_shapes( [&](auto const &attrs) -> std::unordered_map { NOT_IMPLEMENTED(); - }}); + }, + }); } std::unordered_map get_output_shapes( @@ -436,19 +414,6 @@ std::unordered_map get_output_shapes( &input_shapes) { return pcg_op_attrs .visit>(overload{ - [&](BatchMatmulAttrs const &attrs) - -> std::unordered_map { - auto [lhs, rhs] = require_two_keys(input_shapes, - TensorSlotName::LHS_INPUT, - TensorSlotName::RHS_INPUT); - - return { - { - TensorSlotName::OUTPUT, - throw_if_unexpected(get_output_shape(attrs, lhs, rhs)), - }, - }; - }, [&](BatchNormAttrs const &attrs) -> std::unordered_map { ParallelTensorShape input = @@ -707,7 +672,8 @@ std::unordered_map get_output_shapes( [&](auto const &attrs) -> std::unordered_map { NOT_IMPLEMENTED(); - }}); + }, + }); } std::unordered_map get_weight_shapes( @@ -716,14 +682,6 @@ std::unordered_map get_weight_shapes( &input_shapes) { return pcg_op_attrs .visit>(overload{ - [&](BatchMatmulAttrs const &attrs) - -> std::unordered_map { - require_two_keys(input_shapes, - TensorSlotName::LHS_INPUT, - TensorSlotName::RHS_INPUT); - - return {}; - }, [&](BatchNormAttrs const &attrs) -> std::unordered_map { ParallelTensorShape input = @@ -876,7 +834,8 @@ std::unordered_map get_weight_shapes( [&](auto const &attrs) -> std::unordered_map { NOT_IMPLEMENTED(); - }}); + }, + }); } } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc deleted file mode 100644 index 1044c379f0..0000000000 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc +++ /dev/null @@ -1,304 +0,0 @@ -#include "op-attrs/ops/batch_matmul.h" -#include "test/utils/doctest/fmt/expected.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { - positive_int b = 4_p; - positive_int m = 6_p; - positive_int n = 8_p; - positive_int p = 10_p; - - BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/1_p, // TODO figure out if these arguments are - // still relevant - /*b_seq_length_dim=*/1_p, - }; - - TensorShape input_lhs_shape = TensorShape{ - TensorDims{ - FFOrdered{ - b, - n, - m, - }, - }, - DataType::FLOAT, - }; - - SUBCASE("valid") { - TensorShape input_rhs_shape = TensorShape{ - TensorDims{ - FFOrdered{ - b, - m, - p, - }, - }, - DataType::FLOAT, - }; - - tl::expected result = - get_output_shape(attrs, input_lhs_shape, input_rhs_shape); - - tl::expected correct_output_shape = TensorShape{ - TensorDims{ - FFOrdered{ - b, - n, - p, - }, - }, - DataType::FLOAT, - }; - - CHECK(result == correct_output_shape); - } - - SUBCASE("mismatched b") { - TensorShape input_rhs_shape = TensorShape{ - TensorDims{ - FFOrdered{ - b + 1_p, - m, - p, - }, - }, - DataType::FLOAT, - }; - - tl::expected result = - get_output_shape(attrs, input_lhs_shape, input_rhs_shape); - - CHECK(!result.has_value()); - } - - SUBCASE("mismatched m") { - TensorShape input_rhs_shape = TensorShape{ - TensorDims{ - FFOrdered{ - b, - m + 1_p, - p, - }, - }, - DataType::FLOAT, - }; - - tl::expected result = - get_output_shape(attrs, input_lhs_shape, input_rhs_shape); - - CHECK(!result.has_value()); - } - } - - TEST_CASE("get_output_shape(BatchMatmulAttrs, ParallelTensorShape)") { - positive_int b = 2_p * 2_p; - positive_int o_b = 2_p; - positive_int m = 3_p * 3_p; - positive_int o_m = 3_p; - positive_int n = 5_p * 5_p; - positive_int o_n = 5_p; - positive_int p = 7_p * 7_p; - positive_int o_p = 7_p; - positive_int o_sum = 11_p; - - BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/0_p, // TODO figure out if these arguments are - // still relevant - /*b_seq_length_dim=*/0_p, - }; - - auto make_lhs = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - positive_int o_b, - positive_int o_n, - positive_int o_m) { - return ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{b, o_b}, - ShardParallelDim{n, o_n}, - ShardParallelDim{m, o_m}, - }, - ReplicaParallelDimSet{ - o_sum, - o_eq, - }, - }, - DataType::FLOAT, - }; - }; - - auto make_rhs = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - positive_int o_b, - positive_int o_m, - positive_int o_p) { - return ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{b, o_b}, - ShardParallelDim{m, o_m}, - ShardParallelDim{p, o_p}, - }, - ReplicaParallelDimSet{ - o_sum, - o_eq, - }, - }, - DataType::FLOAT, - }; - }; - - auto make_output = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - positive_int o_b, - positive_int o_n, - positive_int o_p) { - return ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{b, o_b}, - ShardParallelDim{n, o_n}, - ShardParallelDim{p, o_p}, - }, - ReplicaParallelDimSet{ - o_sum, - o_eq, - }, - }, - DataType::FLOAT, - }; - }; - - SUBCASE("data parallel") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, o_b, 1_p, 1_p), - make_rhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, o_b, 1_p, 1_p)); - tl::expected correct = - make_output(SumDegree{1_p}, DiscardCopyDegree{1_p}, o_b, 1_p, 1_p); - - CHECK(result == correct); - } - - SUBCASE("n parallel") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, o_n, 1_p), - make_rhs(SumDegree{1_p}, DiscardCopyDegree{o_n}, 1_p, 1_p, 1_p)); - tl::expected correct = - make_output(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, o_n, 1_p); - - CHECK(result == correct); - } - - SUBCASE("p parallel") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{1_p}, DiscardCopyDegree{o_p}, 1_p, 1_p, 1_p), - make_rhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, 1_p, o_p)); - tl::expected correct = - make_output(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, 1_p, o_p); - - CHECK(result == correct); - } - - SUBCASE("reduction parallel") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, 1_p, o_m), - make_rhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, o_m, 1_p)); - tl::expected correct = - make_output(SumDegree{o_m}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p); - - CHECK(result == correct); - } - - SUBCASE("propagate reduction lhs") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p), - make_rhs(SumDegree{1_p}, DiscardCopyDegree{o_sum}, 1_p, 1_p, 1_p)); - tl::expected correct = - make_output(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p); - - CHECK(result == correct); - } - - SUBCASE("propagate reduction rhs") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{1_p}, DiscardCopyDegree{o_sum}, 1_p, 1_p, 1_p), - make_rhs(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)); - tl::expected correct = - make_output(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p); - - CHECK(result == correct); - } - - SUBCASE("reduction lhs & reduction rhs") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1_p, 1_p, 1_p), - make_rhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1_p, 1_p, 1_p)); - tl::expected correct = make_output( - SumDegree{o_sum * o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p); - - CHECK(result == correct); - } - - SUBCASE("reduction lhs & rhs (invalid)") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p), - make_rhs(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)); - - CHECK_MESSAGE( - !result.has_value(), "Unexpected successful value: ", result); - } - - SUBCASE("reduction lhs & n") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, o_n, 1_p), - make_rhs( - SumDegree{1_p}, DiscardCopyDegree{o_sum * o_n}, 1_p, 1_p, 1_p)); - tl::expected correct = - make_output(SumDegree{o_sum}, DiscardCopyDegree{1_p}, 1_p, o_n, 1_p); - - CHECK(result == correct); - } - - SUBCASE("reduction lhs & reduction rhs & n") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1_p, o_n, 1_p), - make_rhs( - SumDegree{o_sum}, DiscardCopyDegree{o_sum * o_n}, 1_p, 1_p, 1_p)); - tl::expected correct = make_output( - SumDegree{o_sum * o_sum}, DiscardCopyDegree{1_p}, 1_p, o_n, 1_p); - - CHECK(result == correct); - } - - SUBCASE("reduction lhs & reduction rhs & n & m") { - tl::expected result = get_output_shape( - attrs, - make_lhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1_p, o_n, o_m), - make_rhs( - SumDegree{o_sum}, DiscardCopyDegree{o_sum * o_n}, 1_p, o_m, 1_p)); - tl::expected correct = - make_output(SumDegree{o_sum * o_sum * o_m}, - DiscardCopyDegree{1_p}, - 1_p, - o_n, - 1_p); - - CHECK(result == correct); - } - } -} diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index b0adec3ab1..88df8128e7 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -20,11 +20,6 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t const &rhs, std::optional const &name = std::nullopt); - parallel_tensor_guid_t - batch_matmul(parallel_tensor_guid_t const &a, - parallel_tensor_guid_t const &b, - std::optional const &name = std::nullopt); - parallel_tensor_guid_t cast(parallel_tensor_guid_t const &input, DataType result_type, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index d18fc17621..a8c09dde9a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -2,7 +2,6 @@ #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/attention.h" #include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/ops/batch_matmul_attrs.dtg.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/cast_attrs.dtg.h" @@ -117,34 +116,6 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add( TensorSlotName::OUTPUT); } -parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_matmul( - parallel_tensor_guid_t const &a, - parallel_tensor_guid_t const &b, - std::optional const &maybe_name) { - - BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/std::nullopt, - /*b_seq_length_dim=*/std::nullopt, - }; - - std::string name = - maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); - - ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - - return require_only_key(this->add_layer(layer, - {{ - TensorSlotName::LHS_INPUT, - a, - }, - { - TensorSlotName::RHS_INPUT, - b, - }}, - {}), - TensorSlotName::OUTPUT); -} - parallel_tensor_guid_t ParallelComputationGraphBuilder::cast( parallel_tensor_guid_t const &input, DataType result_type, diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index fd314ebaea..b9721da7f7 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -95,78 +95,6 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("ParallelComputationGraphBuilder::batch_matmul") { - ParallelComputationGraphBuilder b; - - TensorShape a_shape = TensorShape{ - TensorDims{ - FFOrdered{ - 4_p, - 10_p, - 15_p, - }, - }, - DataType::FLOAT, - }; - - TensorShape b_shape = TensorShape{ - TensorDims{ - FFOrdered{ - 4_p, - 15_p, - 10_p, - }, - }, - DataType::FLOAT, - }; - - parallel_tensor_guid_t a_tensor = b.create_input_tensor(a_shape); - parallel_tensor_guid_t b_tensor = b.create_input_tensor(b_shape); - - parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); - parallel_layer_guid_t layer = get_source_layer(out); - - SUBCASE("incoming") { - std::unordered_map result = - get_incoming_tensors(b.pcg, layer); - std::unordered_map correct = { - { - TensorSlotName::LHS_INPUT, - a_tensor, - }, - { - TensorSlotName::RHS_INPUT, - b_tensor, - }, - }; - - CHECK(result == correct); - } - - SUBCASE("outputs") { - std::unordered_map result = - get_outgoing_tensors(b.pcg, layer); - std::unordered_map correct = { - { - TensorSlotName::OUTPUT, - out, - }, - }; - - CHECK(result == correct); - } - - SUBCASE("op attrs") { - PCGOperatorAttrs result = get_parallel_layer_attrs(b.pcg, layer).op_attrs; - PCGOperatorAttrs correct = PCGOperatorAttrs{ - BatchMatmulAttrs{/*a_seq_length_dim=*/std::nullopt, - /*b_seq_length_dim=*/std::nullopt}, - }; - - CHECK(result == correct); - } - } - TEST_CASE("ParallelComputationGraphBuilder::cast") { ParallelComputationGraphBuilder b; diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index b1e5e07e28..9ba9bdf579 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -120,12 +120,6 @@ name = "BATCHNORM_FWD_TASK_ID" [[values]] name = "BATCHNORM_BWD_TASK_ID" -[[values]] -name = "BATCHMATMUL_FWD_TASK_ID" - -[[values]] -name = "BATCHMATMUL_BWD_TASK_ID" - [[values]] name = "LAYERNORM_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..809f743856 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -66,7 +66,6 @@ Realm::Event register_all_tasks() { std::vector task_ids = { // Forward tasks - task_id_t::BATCHMATMUL_FWD_TASK_ID, task_id_t::BATCHNORM_FWD_TASK_ID, task_id_t::BROADCAST_FWD_TASK_ID, task_id_t::CAST_FWD_TASK_ID, @@ -95,7 +94,6 @@ Realm::Event register_all_tasks() { task_id_t::TRANSPOSE_FWD_TASK_ID, // Backward tasks - task_id_t::BATCHMATMUL_BWD_TASK_ID, task_id_t::BATCHNORM_BWD_TASK_ID, task_id_t::BROADCAST_BWD_TASK_ID, task_id_t::CAST_BWD_TASK_ID, diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index dd4b0a66ca..a1ad841a2a 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -32,7 +32,6 @@ std::optional get_init_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { return op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { return std::nullopt; }, [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_INIT_TASK_ID; }, [](BroadcastAttrs const &) { return std::nullopt; }, [](CastAttrs const &) { return std::nullopt; }, @@ -81,9 +80,6 @@ std::optional get_fwd_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { return op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { - return task_id_t::BATCHMATMUL_FWD_TASK_ID; - }, [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_FWD_TASK_ID; }, [](BroadcastAttrs const &) { return task_id_t::BROADCAST_FWD_TASK_ID; }, [](CastAttrs const &) { return task_id_t::CAST_FWD_TASK_ID; }, @@ -132,9 +128,6 @@ std::optional get_bwd_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { return op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { - return task_id_t::BATCHMATMUL_BWD_TASK_ID; - }, [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_BWD_TASK_ID; }, [](BroadcastAttrs const &) { return task_id_t::BROADCAST_BWD_TASK_ID; }, [](CastAttrs const &) { return task_id_t::CAST_BWD_TASK_ID; }, diff --git a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index a5f0cc6fdc..76a5b2f168 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -10,8 +10,6 @@ namespace FlexFlow { std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchMatmulAttrs const &, - OperatorAttributeKey); std::optional get_attribute(BatchNormAttrs const &, OperatorAttributeKey); std::optional get_attribute(BroadcastAttrs const &, diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 9183278fe1..6f15682192 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -4,16 +4,6 @@ namespace FlexFlow { -std::optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey key) { - switch (key) { - case OperatorAttributeKey::OP_TYPE: - return OperatorAttributeValue{get_op_type(p)}; - default: - return std::nullopt; - } -} - std::optional get_attribute(BatchNormAttrs const &p, OperatorAttributeKey key) { switch (key) { diff --git a/lib/task-spec/include/task-spec/ops/impl/batch_matmul.h b/lib/task-spec/include/task-spec/ops/impl/batch_matmul.h deleted file mode 100644 index 6184e194f2..0000000000 --- a/lib/task-spec/include/task-spec/ops/impl/batch_matmul.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BATCH_MATMUL_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BATCH_MATMUL_H - -#include "op-attrs/ops/batch_matmul_attrs.dtg.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -TaskImplFunction get_batch_matmul_fwd_task_impl(); -TaskImplFunction get_batch_matmul_bwd_task_impl(); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc b/lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc deleted file mode 100644 index 43bc185b0d..0000000000 --- a/lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/impl/batch_matmul.h" -#include "kernels/batch_matmul_kernels.h" -#include "op-attrs/ops/batch_matmul.h" -#include "task-spec/profiling.h" -#include "utils/containers/transform.h" -#include "utils/nonnegative_int/nonnegative_range.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::BatchMatmul; - -static std::optional - forward_task_impl(TaskArgumentAccessor const &acc) { - auto a_input = acc.get_tensor(TensorSlotName::LHS_INPUT); - auto b_input = acc.get_tensor(TensorSlotName::RHS_INPUT); - auto output = acc.get_tensor(TensorSlotName::OUTPUT); - BatchMatmulAttrs attrs = acc.get_op_attrs().require_batch_matmul(); - device_handle_t handle = acc.get_ff_handle(); - - ProfilingSettings profiling = acc.get_profiling_settings(); - FFIterationConfig iter_config = acc.get_iteration_config(); - DeviceType kernel_device_type = acc.get_kernel_device_type(); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[BatchMatmul] forward_time = {:.2lf}ms\n", - handle, - output, - a_input, - b_input, - iter_config.seq_length, - attrs.a_seq_length_dim, - attrs.b_seq_length_dim); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - FFIterationConfig iter_config = acc.get_iteration_config(); - ProfilingSettings profiling = acc.get_profiling_settings(); - device_handle_t handle = acc.get_ff_handle(); - DeviceType kernel_device_type = acc.get_kernel_device_type(); - - auto output = acc.get_tensor(TensorSlotName::OUTPUT); - auto output_grad = - acc.get_tensor_grad(TensorSlotName::OUTPUT); - ASSERT(output.shape == output_grad.shape); - - auto a_input = acc.get_tensor(TensorSlotName::LHS_INPUT); - auto a_input_grad = - acc.get_tensor_grad(TensorSlotName::LHS_INPUT); - ASSERT(a_input.shape == a_input_grad.shape); - - auto b_input = acc.get_tensor(TensorSlotName::RHS_INPUT); - auto b_input_grad = - acc.get_tensor_grad(TensorSlotName::RHS_INPUT); - ASSERT(b_input.shape == b_input_grad.shape); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[BatchMatmul] backward_time = {:.2lf}ms\n", - handle, - output, - output_grad, - a_input, - a_input_grad, - b_input, - b_input_grad); -} - -TaskImplFunction get_batch_matmul_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_batch_matmul_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -}; // namespace FlexFlow From 1f4dafa57a9fb08ff6908d7230599c407997c657 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 25 Mar 2026 10:59:46 -0700 Subject: [PATCH 2/2] Remove FFIterationConfig. --- .../computation_graph_instance.h | 6 --- .../cost_estimator/local_cost_estimator.h | 4 -- .../local_task_argument_accessor.h | 3 -- .../per_device_op_state_initialization.h | 3 -- .../include/local-execution/task_execution.h | 2 - .../computation_graph_instance.cc | 12 ------ .../cost_estimator/local_cost_estimator.cc | 9 +---- .../local_task_argument_accessor.cc | 8 +--- .../per_device_op_state_initialization.cc | 4 -- .../src/local-execution/task_execution.cc | 4 -- .../local-execution/local_cost_estimator.cc | 2 - .../local_task_argument_accessor.cc | 1 - .../src/local-execution/loss_functions.cc | 2 - .../test/src/local-execution/test_e2e.cc | 4 -- ...buted_per_device_op_state_initialization.h | 2 - .../include/realm-execution/pcg_instance.h | 16 +++----- .../realm-execution/tasks/impl/op_task.h | 2 - .../tasks/impl/op_task_args.dtg.toml | 5 --- .../impl/per_device_op_state_init_task.h | 2 - ...er_device_op_state_init_task_args.dtg.toml | 5 --- .../impl/serializable_op_task_args.dtg.toml | 5 --- ...er_device_op_state_init_task_args.dtg.toml | 5 --- ...uted_per_device_op_state_initialization.cc | 2 - .../src/realm-execution/pcg_instance.cc | 38 ++++++------------- .../src/realm-execution/tasks/impl/op_task.cc | 3 -- .../impl/per_device_op_state_init_task.cc | 3 -- .../tasks/impl/serializable_op_task_args.cc | 2 - ...able_per_device_op_state_init_task_args.cc | 2 - .../test/src/realm-execution/test_e2e.cc | 12 ++---- .../task-spec/ff_iteration_config.dtg.toml | 19 ---------- .../task_argument_accessor/index.dox | 2 +- .../itask_argument_accessor.h | 2 - .../task_argument_accessor.h | 2 - .../task_argument_accessor.cc | 4 -- 34 files changed, 24 insertions(+), 173 deletions(-) delete mode 100644 lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml diff --git a/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h b/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h index c43001397b..a4ded5edaf 100644 --- a/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h +++ b/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h @@ -13,7 +13,6 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" #include "utils/units/milliseconds_t.h" #include #include @@ -50,7 +49,6 @@ ComputationGraphInstance create_computation_graph_instance( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, device_id_t device_idx); std::unordered_map> @@ -58,27 +56,23 @@ std::unordered_map> ComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx); std::unordered_map> perform_forward_pass_for_computation_graph_instance( ComputationGraphInstance const &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx); std::unordered_map> perform_backward_pass_for_computation_graph_instance( ComputationGraphInstance const &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx); void perform_update_pass_for_computation_graph_instance( ComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h b/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h index 653067da8a..fdff1153ff 100644 --- a/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h +++ b/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h @@ -7,7 +7,6 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_interconnect_specification.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { @@ -17,7 +16,6 @@ struct LocalCostEstimator : public ICostEstimator { Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, device_id_t device_idx); LocalCostEstimator(LocalCostEstimator const &) = delete; @@ -33,7 +31,6 @@ struct LocalCostEstimator : public ICostEstimator { Allocator allocator; ProfilingSettings profiling_settings; device_handle_t device_handle; - FFIterationConfig iteration_config; device_id_t device_idx; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalCostEstimator); @@ -43,7 +40,6 @@ CostEstimator get_local_cost_estimator( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, device_id_t device_idx); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/local_task_argument_accessor.h b/lib/local-execution/include/local-execution/local_task_argument_accessor.h index 638bea247e..12eab4a76d 100644 --- a/lib/local-execution/include/local-execution/local_task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/local_task_argument_accessor.h @@ -20,7 +20,6 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { std::optional const &op_attrs, std::optional const &loss_attrs, std::optional const &per_device_op_state, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, device_id_t device_idx); @@ -38,7 +37,6 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { PCGOperatorAttrs get_op_attrs() const override; LossAttrs get_loss_attrs() const override; PerDeviceOpState get_per_device_op_state() const override; - FFIterationConfig get_iteration_config() const override; OptimizerAttrs get_optimizer_attrs() const override; Allocator get_allocator() const override; @@ -56,7 +54,6 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { std::optional op_attrs; std::optional loss_attrs; std::optional per_device_op_state; - FFIterationConfig iteration_config; std::optional optimizer_attrs; device_id_t device_idx; diff --git a/lib/local-execution/include/local-execution/per_device_op_state_initialization.h b/lib/local-execution/include/local-execution/per_device_op_state_initialization.h index abf24cdfd1..ff0ffd73f8 100644 --- a/lib/local-execution/include/local-execution/per_device_op_state_initialization.h +++ b/lib/local-execution/include/local-execution/per_device_op_state_initialization.h @@ -7,7 +7,6 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { @@ -18,7 +17,6 @@ DynamicNodeInvocation Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, device_id_t device_idx); @@ -30,7 +28,6 @@ DynamicOpenDataflowGraph perform_per_device_op_state_initialization( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, device_id_t device_idx); diff --git a/lib/local-execution/include/local-execution/task_execution.h b/lib/local-execution/include/local-execution/task_execution.h index 61a57dbfa0..3bb0c6b92b 100644 --- a/lib/local-execution/include/local-execution/task_execution.h +++ b/lib/local-execution/include/local-execution/task_execution.h @@ -15,7 +15,6 @@ TaskArgumentAccessor make_task_argument_accessor_for_invocation( ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, std::optional const &per_device_op_state, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, device_id_t device_idx); @@ -25,7 +24,6 @@ std::optional execute_dynamic_node_invocation( ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, std::optional const &per_device_op_state, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, device_id_t device_idx); diff --git a/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc b/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc index 961dfae3f1..998920589c 100644 --- a/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc +++ b/lib/local-execution/src/local-execution/computation_graph_instance/computation_graph_instance.cc @@ -67,7 +67,6 @@ ComputationGraphInstance create_computation_graph_instance( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, device_id_t device_idx) { DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_cg(cg); dg = perform_pass_expansion(dg); @@ -96,7 +95,6 @@ ComputationGraphInstance create_computation_graph_instance( allocator, profiling_settings, device_handle, - iteration_config, optimizer_attrs, device_idx); @@ -118,7 +116,6 @@ static std::unordered_map> OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx) { return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { @@ -133,7 +130,6 @@ static std::unordered_map> return get_per_device_op_state_from_device_specific( op_state, device_idx); }), - /*iteration_config=*/iteration_config, /*optimizer_attrs=*/optimizer_attrs, /*device_idx=*/device_idx); return std::pair{invocation.node_attrs.layer_guid, timing}; @@ -145,7 +141,6 @@ std::unordered_map> ComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx) { std::vector execution_order = instance.get_execution_order(); @@ -156,7 +151,6 @@ std::unordered_map> /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*ff_handle=*/ff_handle, - /*iteration_config=*/iteration_config, /*device_idx=*/device_idx); instance.update_optimizer_attrs_for_next_iter(); return result; @@ -167,7 +161,6 @@ std::unordered_map> ComputationGraphInstance const &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx) { std::vector execution_order = filter(instance.get_execution_order(), @@ -183,7 +176,6 @@ std::unordered_map> /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*ff_handle=*/ff_handle, - /*iteration_config=*/iteration_config, /*device_idx=*/device_idx); } @@ -192,7 +184,6 @@ std::unordered_map> ComputationGraphInstance const &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx) { std::vector execution_order = filter(instance.get_execution_order(), @@ -208,7 +199,6 @@ std::unordered_map> /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*ff_handle=*/ff_handle, - /*iteration_config=*/iteration_config, /*device_idx=*/device_idx); } @@ -216,7 +206,6 @@ void perform_update_pass_for_computation_graph_instance( ComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, - FFIterationConfig iteration_config, device_id_t device_idx) { std::vector execution_order = filter(instance.get_execution_order(), @@ -232,7 +221,6 @@ void perform_update_pass_for_computation_graph_instance( /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, /*ff_handle=*/ff_handle, - /*iteration_config=*/iteration_config, /*device_idx=*/device_idx); instance.update_optimizer_attrs_for_next_iter(); } diff --git a/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc b/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc index 89010c543e..b966ff9423 100644 --- a/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc +++ b/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc @@ -32,12 +32,10 @@ LocalCostEstimator::LocalCostEstimator( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, device_id_t device_idx) : interconnect_specification(interconnect_specification), allocator(allocator), profiling_settings(profiling_settings), - device_handle(device_handle), iteration_config(iteration_config), - device_idx(device_idx) {} + device_handle(device_handle), device_idx(device_idx) {} static ComputationGraph computation_graph_for_local_cost_estimation( ComputationGraphOpAttrs const &op, @@ -127,7 +125,6 @@ OpCostMetrics LocalCostEstimator::estimate_cost( /*allocator=*/allocator, /*profiling_settings=*/this->profiling_settings, /*device_handle=*/this->device_handle, - /*iteration_config=*/this->iteration_config, /*device_idx=*/this->device_idx); // execute layer @@ -138,7 +135,6 @@ OpCostMetrics LocalCostEstimator::estimate_cost( instance, this->profiling_settings, this->device_handle, - this->iteration_config, this->device_idx); milliseconds_t fwd = fwd_timing.at(operator_layer_guid).value(); std::unordered_map> @@ -146,7 +142,6 @@ OpCostMetrics LocalCostEstimator::estimate_cost( instance, this->profiling_settings, this->device_handle, - this->iteration_config, this->device_idx); milliseconds_t bwd = bwd_timing.at(operator_layer_guid).value(); @@ -187,13 +182,11 @@ CostEstimator get_local_cost_estimator( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, device_id_t device_idx) { return CostEstimator::create(interconnect_specification, allocator, profiling_settings, device_handle, - iteration_config, device_idx); } diff --git a/lib/local-execution/src/local-execution/local_task_argument_accessor.cc b/lib/local-execution/src/local-execution/local_task_argument_accessor.cc index 796d122a23..b8feca720e 100644 --- a/lib/local-execution/src/local-execution/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local-execution/local_task_argument_accessor.cc @@ -17,15 +17,13 @@ LocalTaskArgumentAccessor::LocalTaskArgumentAccessor( std::optional const &op_attrs, std::optional const &loss_attrs, std::optional const &per_device_op_state, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, device_id_t device_idx) : allocator(allocator), tensor_slots_backing(tensor_slots_backing), profiling_settings(profiling_settings), ff_handle(ff_handle), op_attrs(op_attrs), loss_attrs(loss_attrs), per_device_op_state(per_device_op_state), - iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), - device_idx(device_idx) {} + optimizer_attrs(optimizer_attrs), device_idx(device_idx) {} TensorShape LocalTaskArgumentAccessor::get_tensor_shape(TensorSlotName slot) const { @@ -101,10 +99,6 @@ PerDeviceOpState LocalTaskArgumentAccessor::get_per_device_op_state() const { return assert_unwrap(this->per_device_op_state); } -FFIterationConfig LocalTaskArgumentAccessor::get_iteration_config() const { - return this->iteration_config; -} - OptimizerAttrs LocalTaskArgumentAccessor::get_optimizer_attrs() const { return assert_unwrap(this->optimizer_attrs); } diff --git a/lib/local-execution/src/local-execution/per_device_op_state_initialization.cc b/lib/local-execution/src/local-execution/per_device_op_state_initialization.cc index 2cd53b428b..bf72843daf 100644 --- a/lib/local-execution/src/local-execution/per_device_op_state_initialization.cc +++ b/lib/local-execution/src/local-execution/per_device_op_state_initialization.cc @@ -23,7 +23,6 @@ DynamicNodeInvocation Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, device_id_t device_idx) { if (!i.node_attrs.op_attrs.has_value() || @@ -44,7 +43,6 @@ DynamicNodeInvocation /*profiling_settings=*/profiling_settings, /*ff_handle=*/device_handle, /*per_device_op_state=*/std::nullopt, - /*iteration_config=*/iteration_config, /*optimizer_attrs=*/optimizer_attrs, /*device_idx=*/device_idx); @@ -62,7 +60,6 @@ DynamicOpenDataflowGraph perform_per_device_op_state_initialization( Allocator &allocator, ProfilingSettings const &profiling_settings, device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, device_id_t device_idx) { @@ -73,7 +70,6 @@ DynamicOpenDataflowGraph perform_per_device_op_state_initialization( allocator, profiling_settings, device_handle, - iteration_config, optimizer_attrs, device_idx); }); diff --git a/lib/local-execution/src/local-execution/task_execution.cc b/lib/local-execution/src/local-execution/task_execution.cc index c96c834d4a..0b1bc3513a 100644 --- a/lib/local-execution/src/local-execution/task_execution.cc +++ b/lib/local-execution/src/local-execution/task_execution.cc @@ -48,7 +48,6 @@ TaskArgumentAccessor make_task_argument_accessor_for_invocation( ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, std::optional const &per_device_op_state, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, device_id_t device_idx) { auto make_param = [&](DynamicTensorSlot const &slot) { @@ -78,7 +77,6 @@ TaskArgumentAccessor make_task_argument_accessor_for_invocation( return op_attrs.try_require_loss(); }), /*per_device_op_state=*/per_device_op_state, - /*iteration_config=*/iteration_config, /*optimizer_attrs=*/optimizer_attrs, /*device_idx=*/device_idx); } @@ -89,7 +87,6 @@ std::optional execute_dynamic_node_invocation( ProfilingSettings const &profiling_settings, device_handle_t const &ff_handle, std::optional const &per_device_op_state, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, device_id_t device_idx) { TaskArgumentAccessor arg_accessor = @@ -99,7 +96,6 @@ std::optional execute_dynamic_node_invocation( /*profiling_settings=*/profiling_settings, /*ff_handle=*/ff_handle, /*per_device_op_state=*/per_device_op_state, - /*iteration_config=*/iteration_config, /*optimizer_attrs=*/optimizer_attrs, /*device_idx=*/device_idx); diff --git a/lib/local-execution/test/src/local-execution/local_cost_estimator.cc b/lib/local-execution/test/src/local-execution/local_cost_estimator.cc index f3dcab7f82..ac20c1ad75 100644 --- a/lib/local-execution/test/src/local-execution/local_cost_estimator.cc +++ b/lib/local-execution/test/src/local-execution/local_cost_estimator.cc @@ -43,7 +43,6 @@ TEST_SUITE(FF_TEST_SUITE) { ProfilingSettings{/*warmup_iters=*/0, /*measure_iters=*/1}, /*device_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); SUBCASE("estimate operator cost") { @@ -116,7 +115,6 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { ProfilingSettings{/*warmup_iters=*/0, /*measure_iters=*/1}, /*device_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); SUBCASE("estimate operator cost") { diff --git a/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc b/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc index 2f2dbbd503..07bb869d5f 100644 --- a/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc +++ b/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc @@ -63,7 +63,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*op_attrs=*/PCGOperatorAttrs{InputAttrs{input_tensor_shape}}, /*loss_attrs=*/std::nullopt, /*per_device_op_state=*/std::nullopt, - /*iteration_config=*/FFIterationConfig{1_p}, /*optimizer_attrs=*/std::nullopt, /*device_idx=*/device_idx, }; diff --git a/lib/local-execution/test/src/local-execution/loss_functions.cc b/lib/local-execution/test/src/local-execution/loss_functions.cc index 39aa5f138a..ace0f0ad9d 100644 --- a/lib/local-execution/test/src/local-execution/loss_functions.cc +++ b/lib/local-execution/test/src/local-execution/loss_functions.cc @@ -107,14 +107,12 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { /*allocator=*/allocator, /*profiling_settings=*/ProfilingSettings{0, 1}, /*device_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); perform_all_passes_for_computation_graph_instance( /*instance=*/computation_graph_instance, /*profiling_settings=*/ProfilingSettings{0, 0}, /*ff_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); assert_unwrap(computation_graph_instance.get_loss_tensor_accessor()); }; diff --git a/lib/local-execution/test/src/local-execution/test_e2e.cc b/lib/local-execution/test/src/local-execution/test_e2e.cc index da62d22071..8156bf1ae9 100644 --- a/lib/local-execution/test/src/local-execution/test_e2e.cc +++ b/lib/local-execution/test/src/local-execution/test_e2e.cc @@ -159,7 +159,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*allocator=*/allocator, /*profiling_settings=*/ProfilingSettings{0, 0}, /*device_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); // begin training loop @@ -171,7 +170,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*instance=*/computation_graph_instance, /*profiling_settings=*/ProfilingSettings{0, 0}, /*ff_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); loss_values.push_back(copy_tensor_accessor_r( computation_graph_instance.get_loss_tensor_accessor().value(), @@ -330,7 +328,6 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { /*allocator=*/allocator, /*profiling_settings=*/ProfilingSettings{0, 0}, /*device_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); // begin training loop @@ -344,7 +341,6 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { /*instance=*/computation_graph_instance, /*profiling_settings=*/ProfilingSettings{0, 0}, /*ff_handle=*/ff_handle, - /*iteration_config=*/FFIterationConfig{1_p}, /*device_idx=*/device_idx); loss_values.push_back(copy_tensor_accessor_r( computation_graph_instance.get_loss_tensor_accessor().value(), diff --git a/lib/realm-execution/include/realm-execution/distributed_per_device_op_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_per_device_op_state_initialization.h index 0da97089ce..5d52f8caaf 100644 --- a/lib/realm-execution/include/realm-execution/distributed_per_device_op_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_per_device_op_state_initialization.h @@ -8,7 +8,6 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { @@ -25,7 +24,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( TensorInstanceBacking const &tensor_instance_backing, ProfilingSettings const &profiling_settings, DistributedFfHandle const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, Realm::Event precondition); diff --git a/lib/realm-execution/include/realm-execution/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance.h index 2443e4e66a..7b86d6d383 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance.h @@ -15,7 +15,6 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" #include "utils/units/milliseconds_t.h" #include @@ -87,8 +86,7 @@ PCGInstance create_pcg_instance( std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &ff_handle, - FFIterationConfig const &iteration_config); + DistributedFfHandle const &ff_handle); /** * \brief Dispatch a training iteration for a \ref PCGInstance. @@ -105,29 +103,25 @@ std::unordered_map perform_all_passes_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &ff_handle, - FFIterationConfig iteration_config); + DistributedFfHandle const &ff_handle); std::unordered_map perform_forward_pass_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &ff_handle, - FFIterationConfig iteration_config); + DistributedFfHandle const &ff_handle); std::unordered_map perform_backward_pass_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &ff_handle, - FFIterationConfig iteration_config); + DistributedFfHandle const &ff_handle); std::unordered_map perform_update_pass_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &ff_handle, - FFIterationConfig iteration_config); + DistributedFfHandle const &ff_handle); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 9ad8a6ed38..b5acd8e582 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -10,7 +10,6 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" #include "task-spec/per_device_op_state.dtg.h" #include @@ -60,7 +59,6 @@ Realm::Event spawn_op_task( std::optional> const &device_state, ProfilingSettings const &profiling_settings, DeviceSpecificPtr const &device_handle, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, Realm::Event precondition); diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml index 90202bcbf3..bfc705ba04 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -10,7 +10,6 @@ includes = [ "realm-execution/device_specific_ptr.h", "realm-execution/tensor_instance_backing.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", - "task-spec/ff_iteration_config.dtg.h", "task-spec/per_device_op_state.dtg.h", ] @@ -34,10 +33,6 @@ type = "::FlexFlow::ProfilingSettings" name = "device_handle" type = "::FlexFlow::DeviceSpecificPtr<::FlexFlow::ManagedPerDeviceFFHandle>" -[[fields]] -name = "iteration_config" -type = "::FlexFlow::FFIterationConfig" - [[fields]] name = "optimizer_attrs" type = "std::optional<::FlexFlow::OptimizerAttrs>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task.h index 11437d5df8..8fd7aafbfa 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task.h @@ -9,7 +9,6 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" #include "task-spec/per_device_op_state.dtg.h" namespace FlexFlow { @@ -39,7 +38,6 @@ std::optional spawn_per_device_op_state_init_task( TensorInstanceBacking const &tensor_backing, ProfilingSettings const &profiling_settings, DeviceSpecificPtr const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, DeviceSpecificPtr *result_ptr, Realm::Event precondition); diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task_args.dtg.toml index 98bbdb6a7b..a84c5d60b0 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/per_device_op_state_init_task_args.dtg.toml @@ -18,7 +18,6 @@ includes = [ "realm-execution/realm.h", "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", - "task-spec/ff_iteration_config.dtg.h", "task-spec/per_device_op_state.dtg.h", ] @@ -38,10 +37,6 @@ type = "::FlexFlow::ProfilingSettings" name = "device_handle" type = "::FlexFlow::DeviceSpecificPtr<::FlexFlow::ManagedPerDeviceFFHandle>" -[[fields]] -name = "iteration_config" -type = "::FlexFlow::FFIterationConfig" - [[fields]] name = "optimizer_attrs" type = "::FlexFlow::OptimizerAttrs" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml index adac6631ee..d189323d48 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -14,7 +14,6 @@ includes = [ "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", "realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.h", "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", - "task-spec/ff_iteration_config.dtg.h", ] src_includes = [ @@ -42,10 +41,6 @@ type = "::FlexFlow::ProfilingSettings" name = "device_handle" type = "::FlexFlow::SerializableDeviceSpecificPtr" -[[fields]] -name = "iteration_config" -type = "::FlexFlow::FFIterationConfig" - [[fields]] name = "optimizer_attrs" type = "std::optional<::FlexFlow::OptimizerAttrs>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.dtg.toml index 0e53767862..7a6a48d0b6 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.dtg.toml @@ -16,7 +16,6 @@ includes = [ "realm-execution/tasks/serializer/serializable_tensor_instance_backing.dtg.h", "task-spec/device_specific_per_device_op_state.dtg.h", "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", - "task-spec/ff_iteration_config.dtg.h", ] [[fields]] @@ -35,10 +34,6 @@ type = "::FlexFlow::ProfilingSettings" name = "device_handle" type = "::FlexFlow::SerializableDeviceSpecificPtr" -[[fields]] -name = "iteration_config" -type = "::FlexFlow::FFIterationConfig" - [[fields]] name = "optimizer_attrs" type = "::FlexFlow::OptimizerAttrs" diff --git a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc index 1d517a8fe4..8cf4c21b25 100644 --- a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc @@ -21,7 +21,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( TensorInstanceBacking const &tensor_instance_backing, ProfilingSettings const &profiling_settings, DistributedFfHandle const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, Realm::Event precondition) { @@ -50,7 +49,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( tensor_backing, profiling_settings, device_handle.at(target_proc), - iteration_config, optimizer_attrs, device_state_ptr, precondition); diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 0ecd02143e..1ef0c4270b 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -85,8 +85,7 @@ PCGInstance create_pcg_instance( std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig const &iteration_config) { + DistributedFfHandle const &device_handle) { DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_mapped_pcg(mpcg); @@ -141,7 +140,6 @@ PCGInstance create_pcg_instance( tensor_instance_backing, profiling_settings, device_handle, - iteration_config, optimizer_attrs, ctx.get_outstanding_events()); @@ -175,8 +173,7 @@ static Realm::Event spawn_dynamic_node_invocation( PerDeviceOpStateBacking const &device_state_backing, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig iteration_config) { + DistributedFfHandle const &device_handle) { Realm::Event precondition = Realm::Event::merge_events( Realm::Event::merge_events(input_dependencies), Realm::Event::merge_events(output_dependencies)); @@ -195,7 +192,6 @@ static Realm::Event spawn_dynamic_node_invocation( try_at(device_state_backing.backing, invocation), profiling_settings, device_handle.at(target_proc), - iteration_config, optimizer_attrs, precondition); }; @@ -238,8 +234,7 @@ static std::unordered_map PerDeviceOpStateBacking const &device_state_backing, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig iteration_config) { + DistributedFfHandle const &device_handle) { // For simplicity we'll track a dependency on all outstanding operations up to // this point. This will create an effective barrier between phases. DependencySet dependency_set{ctx.get_outstanding_events()}; @@ -265,8 +260,7 @@ static std::unordered_map device_state_backing, optimizer_attrs, profiling_settings, - device_handle, - iteration_config); + device_handle); for (DynamicValueAttrs const &value : values(invocation.inputs)) { dependency_set.add_reader(value, result); @@ -282,8 +276,7 @@ std::unordered_map perform_all_passes_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig iteration_config) { + DistributedFfHandle const &device_handle) { std::vector execution_order = pcg_instance.get_execution_order(); std::unordered_map result = @@ -295,8 +288,7 @@ std::unordered_map /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*device_handle=*/device_handle, - /*iteration_config=*/iteration_config); + /*device_handle=*/device_handle); pcg_instance.update_optimizer_attrs_for_next_iter(); return result; } @@ -305,8 +297,7 @@ std::unordered_map perform_forward_pass_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig iteration_config) { + DistributedFfHandle const &device_handle) { std::vector execution_order = filter(pcg_instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { @@ -322,16 +313,14 @@ std::unordered_map /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*device_handle=*/device_handle, - /*iteration_config=*/iteration_config); + /*device_handle=*/device_handle); } std::unordered_map perform_backward_pass_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig iteration_config) { + DistributedFfHandle const &device_handle) { std::vector execution_order = filter(pcg_instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { @@ -347,16 +336,14 @@ std::unordered_map /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*device_handle=*/device_handle, - /*iteration_config=*/iteration_config); + /*device_handle=*/device_handle); } std::unordered_map perform_update_pass_for_pcg_instance( PCGInstance &pcg_instance, ProfilingSettings const &profiling_settings, - DistributedFfHandle const &device_handle, - FFIterationConfig iteration_config) { + DistributedFfHandle const &device_handle) { std::vector execution_order = filter(pcg_instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { @@ -374,8 +361,7 @@ std::unordered_map /*device_state_backing=*/pcg_instance.get_device_state_backing(), /*optimizer_attrs=*/pcg_instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*device_handle=*/device_handle, - /*iteration_config=*/iteration_config); + /*device_handle=*/device_handle); pcg_instance.update_optimizer_attrs_for_next_iter(); return result; } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 0d20baa0a3..02b0dd7f54 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -56,7 +56,6 @@ void op_task_body(void const *args, return d.get(ctx.get_current_device_idx()); }), [](PerDeviceOpState *ptr) { return *ptr; }), - /*iteration_config=*/task_args.iteration_config, /*optimizer_attrs=*/task_args.optimizer_attrs, /*device_idx=*/ctx.get_current_device_idx()); } @@ -69,7 +68,6 @@ Realm::Event spawn_op_task( std::optional> const &device_state, ProfilingSettings const &profiling_settings, DeviceSpecificPtr const &device_handle, - FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, Realm::Event precondition) { @@ -79,7 +77,6 @@ Realm::Event spawn_op_task( device_state, profiling_settings, device_handle, - iteration_config, optimizer_attrs, }; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc index 753fccf74b..c05932ff61 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc @@ -54,7 +54,6 @@ void per_device_op_state_init_task_body(void const *args, ctx.get_current_device_allocator(), task_args.profiling_settings, device_handle, - task_args.iteration_config, task_args.optimizer_attrs, ctx.get_current_device_idx()); DeviceSpecificPerDeviceOpState result_state = @@ -80,7 +79,6 @@ std::optional spawn_per_device_op_state_init_task( TensorInstanceBacking const &tensor_backing, ProfilingSettings const &profiling_settings, DeviceSpecificPtr const &device_handle, - FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, DeviceSpecificPtr *result_ptr, Realm::Event precondition) { @@ -89,7 +87,6 @@ std::optional spawn_per_device_op_state_init_task( tensor_backing, profiling_settings, device_handle, - iteration_config, optimizer_attrs, ctx.get_current_processor(), result_ptr, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc index 32d54adc37..a4adff1261 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -16,7 +16,6 @@ SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { device_specific_ptr_to_serializable), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/device_specific_ptr_to_serializable(args.device_handle), - /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, }; } @@ -33,7 +32,6 @@ OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { /*device_handle=*/ device_specific_ptr_from_serializable( args.device_handle), - /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, }; } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.cc index 7b52d9c03d..9e719eeb6e 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_per_device_op_state_init_task_args.cc @@ -15,7 +15,6 @@ SerializablePerDeviceOpStateInitTaskArgs tensor_instance_backing_to_serializable(args.tensor_backing), /*profiling_settings=*/args.profiling_settings, /*device_handle=*/device_specific_ptr_to_serializable(args.device_handle), - /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, /*origin_proc=*/realm_processor_to_serializable(args.origin_proc), /*origin_result_ptr=*/reinterpret_cast(args.origin_result_ptr), @@ -33,7 +32,6 @@ PerDeviceOpStateInitTaskArgs /*device_handle=*/ device_specific_ptr_from_serializable( args.device_handle), - /*iteration_config=*/args.iteration_config, /*optimizer_attrs=*/args.optimizer_attrs, /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), /*origin_result_ptr=*/ diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 4a8edb3b6c..afdb1b2343 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -223,8 +223,7 @@ TEST_SUITE(FF_TEST_SUITE) { }, /*input_tensors=*/input_tensors, /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); + /*device_handle=*/device_handle); // begin training loop int num_epochs = 5; @@ -234,8 +233,7 @@ TEST_SUITE(FF_TEST_SUITE) { perform_all_passes_for_pcg_instance( /*instance=*/pcg_instance, /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); + /*device_handle=*/device_handle); loss_values.push_back(copy_tensor_accessor_r( dynamic_tensor_accessor_from_instance( pcg_instance.get_loss_tensor_instance().value(), @@ -453,8 +451,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { }, /*input_tensors=*/input_tensors, /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); + /*device_handle=*/device_handle); // begin training loop int num_epochs = 5; @@ -464,8 +461,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { perform_all_passes_for_pcg_instance( /*instance=*/pcg_instance, /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); + /*device_handle=*/device_handle); loss_values.push_back(copy_tensor_accessor_r( dynamic_tensor_accessor_from_instance( pcg_instance.get_loss_tensor_instance().value(), diff --git a/lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml b/lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml deleted file mode 100644 index c8a8bb61b0..0000000000 --- a/lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "FFIterationConfig" -type = "struct" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", - "json", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "seq_length" -type = "::FlexFlow::positive_int" diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/index.dox b/lib/task-spec/include/task-spec/task_argument_accessor/index.dox index 9c42a19188..4dae416725 100644 --- a/lib/task-spec/include/task-spec/task_argument_accessor/index.dox +++ b/lib/task-spec/include/task-spec/task_argument_accessor/index.dox @@ -76,7 +76,7 @@ Instances of ITaskArgumentAccessor provide access to the following arguments: - One of \ref PCGOperatorAttrs, \ref LossAttrs, or \ref OptimizerAttrs depending on whether this task is for an operator, an optimizer, or a loss function. - Two pieces of device-specific state: \ref device_handle_t (aka \ref PerDeviceFFHandle) and \ref PerDeviceOpState. As both of these contain pointers and hold device-specific initialization state, in distributed execution their addresses (rather than their contents) are passed around, and they are only valid on the device they originated on. One \ref PerDeviceFFHandle should be created per device, while one \ref PerDeviceOpState should be create for every operator for every device it runs on. -- A few simple value types communicating runtime-wide settings: \ref ProfilingSettings, \ref DeviceType, and \ref FFIterationConfig. +- A few simple value types communicating runtime-wide settings: \ref ProfilingSettings and \ref DeviceType. */ } diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h b/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h index 3d08101915..b4c8dcdf36 100644 --- a/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h +++ b/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h @@ -10,7 +10,6 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "task-spec/concrete_arg_spec.h" -#include "task-spec/ff_iteration_config.dtg.h" #include "task-spec/ops/arg_slot_id_t.dtg.h" #include "task-spec/per_device_op_state.dtg.h" #include "task-spec/privilege_tensor_accessor.h" @@ -35,7 +34,6 @@ struct ITaskArgumentAccessor { virtual PCGOperatorAttrs get_op_attrs() const = 0; virtual LossAttrs get_loss_attrs() const = 0; virtual PerDeviceOpState get_per_device_op_state() const = 0; - virtual FFIterationConfig get_iteration_config() const = 0; virtual OptimizerAttrs get_optimizer_attrs() const = 0; virtual Allocator get_allocator() const = 0; diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h b/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h index 29f3f625f6..0c63643e6e 100644 --- a/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h @@ -9,7 +9,6 @@ #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizer_slot_name.dtg.h" #include "task-spec/device_specific.h" -#include "task-spec/ff_iteration_config.dtg.h" #include "task-spec/per_device_op_state.dtg.h" #include "task-spec/task_argument_accessor/itask_argument_accessor.h" #include "task-spec/task_argument_accessor/task_tensor_parameter.h" @@ -23,7 +22,6 @@ struct TaskArgumentAccessor { PCGOperatorAttrs get_op_attrs() const; LossAttrs get_loss_attrs() const; PerDeviceOpState get_per_device_op_state() const; - FFIterationConfig get_iteration_config() const; OptimizerAttrs get_optimizer_attrs() const; TensorShape get_tensor_shape(TensorSlotName slot) const { diff --git a/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc b/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc index 97f6069d68..e1055696d4 100644 --- a/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc +++ b/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc @@ -25,10 +25,6 @@ PerDeviceOpState TaskArgumentAccessor::get_per_device_op_state() const { return this->ptr->get_per_device_op_state(); } -FFIterationConfig TaskArgumentAccessor::get_iteration_config() const { - return this->ptr->get_iteration_config(); -} - OptimizerAttrs TaskArgumentAccessor::get_optimizer_attrs() const { return this->ptr->get_optimizer_attrs(); }