diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 3fbc8ae7ef3..f8c3be6e7c8 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -115,5 +115,11 @@ class ExecutorchBackendConfig: # If set to true, we run quant fusion and constant propagation passes do_quant_fusion_and_const_prop: bool = False - # Experimental: If set to true, we run a pass to reinplace ops in the graph. + # If set to true, we run a pass to reinplace ops in the graph. run_reinplace_pass: bool = False + + # When True, memory planning partitions specs by device and runs the + # algorithm independently per device, producing separate buffers for CPU + # vs. accelerator memory. Default False preserves the legacy behavior + # where all tensors are planned into CPU memory regardless of device. + enable_non_cpu_memory_planning: bool = False diff --git a/exir/memory_planning.py b/exir/memory_planning.py index c5d3441bcde..f6e3234fce5 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -28,6 +28,7 @@ import torch from executorch.exir import memory from executorch.exir.control_flow import while_loop as exir_while +from executorch.exir.schema import DeviceType, NonConstBufferDevice from executorch.exir.delegate import executorch_call_delegate from executorch.exir.error import internal_assert, InternalError from executorch.exir.operator.convert import is_inplace_variant, is_out_variant @@ -1211,10 +1212,19 @@ def apply_algo( alloc_graph_input: bool = True, alloc_graph_output: bool = True, alloc_mutable_buffers: bool = True, + enable_non_cpu_memory_planning: bool = False, ) -> list[int]: """ Recursively apply algo to graph_module and its submodules for control flow. + Partitions specs by device type and device idx, and runs the memory planning + algorithm independently per device, then merges results into separate buffers. + This ensures device memory and CPU memory are never mixed. + + When enable_non_cpu_memory_planning is False (default), all specs are planned + into a single CPU memory pool regardless of their device attribute. This + preserves the legacy behavior. Set to True to enable per-device partitioning. + Algo implementation should handle one of two meta entries for submodules: 1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by `algo` should start at the offset specified by this list; @@ -1229,18 +1239,19 @@ def apply_algo( `operand` arg. The memory for operands is unused. """ # Extract the nodes and their lifespans from the graph_module - # Difficult to just filter the list of specs returned by this due to - # how we flag trainable weights. _ = update_all_tensors_lifetime(graph_module, graph_signature) - # Filter specs based on alloc_graph_input and alloc_graph_output - specs = collect_specs_from_nodes( - graph_module.graph.nodes, - graph_signature, - do_assertion=False, - ignore_graph_input=not alloc_graph_input, - ignore_graph_output=not alloc_graph_output, - ignore_mutable_buffers=not alloc_mutable_buffers, + # Collect and materialize specs into a set so we can iterate multiple + # times and partition by device. + all_specs: set[TensorSpec] = set( + collect_specs_from_nodes( + graph_module.graph.nodes, + graph_signature, + do_assertion=False, + ignore_graph_input=not alloc_graph_input, + ignore_graph_output=not alloc_graph_output, + ignore_mutable_buffers=not alloc_mutable_buffers, + ) ) # Get temporary specs for submodules to set aside space during execution @@ -1249,29 +1260,78 @@ def apply_algo( algo, graph_module, alignment, graph_signature ) - # Update `input_mem_buffer_sizes` in graph_module. This will allow existing - # algos to work using `input_mem_buffer_sizes` or use - # `non_const_buffer_sizes` directly. - # pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`. - graph_module.input_mem_buffer_sizes = submodule_bufsizes - # Get extra padding for XNNPACK if needed extra_padding = 0 if _contains_xnnpack_delegate(graph_module): extra_padding = 64 - # Pass the filtered specs to the algorithm - bufsizes: list[int] = algo( - alignment, - specs, - graph_module, - graph_signature, - extra_padding, + # 1. Partition specs by device + specs_by_device: dict[DeviceType, set[TensorSpec]] = defaultdict(set) + if enable_non_cpu_memory_planning: + for spec in all_specs: + specs_by_device[spec.device].add(spec) + else: + # Legacy behavior: all specs planned into CPU memory regardless of device + specs_by_device[DeviceType.CPU] = all_specs + + # 2. Plan each device independently + global_bufsizes: list[int] = [0] # index 0 reserved for constants + buffer_device_types: list[DeviceType] = [DeviceType.CPU] + + # Process CPU first (if present), then other devices sorted by enum value + device_order = sorted( + specs_by_device.keys(), + key=lambda d: (d != DeviceType.CPU, d.value), ) - # pyre-ignore[6]: Incompatible parameter type [6] - # In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]` - insert_calls_to_free(graph_module, specs) + for device_type in device_order: + device_specs = specs_by_device[device_type] - graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) - return bufsizes + # Only apply submodule pre-allocation for CPU specs; device buffers + # do not share memory space with CPU submodule arenas. + # pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`. + graph_module.input_mem_buffer_sizes = ( + submodule_bufsizes if device_type == DeviceType.CPU else [] + ) + + # Run algorithm independently on this device's specs + device_bufsizes = algo( + alignment, device_specs, graph_module, graph_signature, extra_padding + ) + + # Calculate base mem_id in global space + base_mem_id = len(global_bufsizes) + + # Append buffer sizes (skip index 0 which is constants placeholder) + global_bufsizes.extend(device_bufsizes[1:]) + + # Track device type for each new buffer slot + for _ in device_bufsizes[1:]: + buffer_device_types.append(device_type) + + # Remap spec mem_ids from algo-local to global. + # The algorithm assigns mem_id starting from 1; remap to global position. + for spec in device_specs: + if spec.mem_id is not None: + spec.mem_id = (spec.mem_id - 1) + base_mem_id + + # Ensure backward compatibility: at least [0, 0] when no specs exist + if len(global_bufsizes) < 2: + global_bufsizes.append(0) + buffer_device_types.append(DeviceType.CPU) + + # 3. Insert free calls and build device buffer mapping + insert_calls_to_free(graph_module, all_specs) + + has_device_buffers = any(dt != DeviceType.CPU for dt in buffer_device_types) + non_const_buffer_device: Optional[list[NonConstBufferDevice]] = None + if has_device_buffers: + non_const_buffer_device = [ + NonConstBufferDevice(buffer_idx=i, device_type=dt, device_index=0) + for i, dt in enumerate(buffer_device_types) + ] + + graph_module.meta["non_const_buffer_sizes"] = global_bufsizes + if non_const_buffer_device is not None: + graph_module.meta["non_const_buffer_device"] = non_const_buffer_device + return global_bufsizes diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index f3970f13b56..32c343a4607 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -153,6 +153,7 @@ def __init__( alloc_mutable_buffers: bool = True, share_mutable_buffers: bool = False, alignment: int = ALIGNMENT, + enable_non_cpu_memory_planning: bool = False, ) -> None: r""" alloc_graph_input/alloc_graph_output will have 4 different combinations @@ -173,6 +174,7 @@ def __init__( self.alloc_mutable_buffers = alloc_mutable_buffers self.share_mutable_buffers = share_mutable_buffers self.alignment = alignment + self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning self.state = _MemoryPlanningState() def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: @@ -250,6 +252,7 @@ def run( # If mutable buffers are shared, then do not allocate them in the # main memory planning algo; they are allocated in run_multimethod. self.alloc_mutable_buffers and not self.share_mutable_buffers, + self.enable_non_cpu_memory_planning, ) if self.share_mutable_buffers and graph_signature is not None: diff --git a/exir/program/_program.py b/exir/program/_program.py index 9813b12d594..f1a22773b69 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1792,6 +1792,12 @@ def to_executorch( # noqa (FLAKE8) C901 else: memory_planning_pass = config.memory_planning_pass # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work + # Propagate enable_non_cpu_memory_planning from the top-level config + # to the pass instance so that device-aware partitioning is applied. + if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"): + memory_planning_pass.enable_non_cpu_memory_planning = ( + config.enable_non_cpu_memory_planning + ) if hasattr(memory_planning_pass, "run"): new_gm_res = memory_planning_pass.run(new_gm, new_signature) else: diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index f364541d900..27ecbdfe633 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -29,6 +29,8 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( _do_user_inputs_exist, + apply_algo, + collect_specs_from_nodes, filter_nodes, get_node_tensor_specs, greedy, @@ -45,6 +47,7 @@ ToOutVarPass, ) from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.exir.schema import DeviceType from executorch.exir.tensor import TensorSpec from functorch.experimental.control_flow import map as torch_map from parameterized import parameterized @@ -1259,3 +1262,169 @@ def reset(self, k_zeros: torch.Tensor, v_zeros: torch.Tensor) -> None: self.assertEqual(v_cache[0].val.allocation_info.memory_id, 2) self.assertEqual(v_cache[0].val.allocation_info.memory_offset_low, 256) self.assertEqual(v_cache[0].val.allocation_info.memory_offset_high, 0) + + +class TestDeviceAwareMemoryPlanning(unittest.TestCase): + """Tests for per-device memory planning (separate buffers per device type).""" + + def _prepare_model( + self, + ) -> Tuple[GraphModule, ExportGraphSignature]: + """Prepare ToyModelForMemPlanning through SpecPropPass + ToOutVarPass.""" + model = ToyModelForMemPlanning() + inputs = model.get_random_inputs() + edge = to_edge(export(model, inputs, strict=True)) + gm = edge.exported_program().graph_module + gs = edge.exported_program().graph_signature + gm = PassManager(passes=[SpecPropPass(), ToOutVarPass()])(gm).graph_module + return gm, gs + + def _get_planned_specs( + self, + gm: GraphModule, + gs: ExportGraphSignature, + ) -> list[TensorSpec]: + """Get the unique set of specs that apply_algo would plan.""" + return list( + collect_specs_from_nodes( + gm.graph.nodes, + gs, + do_assertion=False, + ignore_graph_input=False, + ignore_graph_output=False, + ignore_mutable_buffers=False, + ) + ) + + def test_cpu_only_unchanged(self) -> None: + """CPU-only specs produce bufsizes = [0, X] with no device metadata.""" + gm, gs = self._prepare_model() + + algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy]) + bufsizes = apply_algo( + algo, gm, 16, gs, enable_non_cpu_memory_planning=True + ) + + # The CUDA spec is the only tensor in its buffer + self.assertEqual(bufsizes[0], 0) # constants + self.assertGreater(bufsizes[1], 0) # CPU activations + self.assertNotIn("non_const_buffer_device", gm.meta) + + def test_all_cuda_no_wasted_slots(self) -> None: + """CUDA-only specs produce [0, X] with CUDA at buffer index 1.""" + gm, gs = self._prepare_model() + specs = self._get_planned_specs(gm, gs) + for spec in specs: + spec.device = DeviceType.CUDA + + algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy]) + bufsizes = apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True) + + # [0, cuda_size] — no wasted CPU buffer slot + self.assertEqual(len(bufsizes), 2) + self.assertEqual(bufsizes[0], 0) + self.assertGreater(bufsizes[1], 0) + # Device mapping should be present + self.assertIn("non_const_buffer_device", gm.meta) + device_map = gm.meta["non_const_buffer_device"] + self.assertEqual(len(device_map), 2) + self.assertEqual(device_map[0].device_type, DeviceType.CPU) # constants + self.assertEqual(device_map[1].device_type, DeviceType.CUDA) + + def test_mixed_cpu_cuda_separate_buffers(self) -> None: + """CPU specs at mem_id=1, CUDA specs at mem_id=2, separate sizes.""" + gm, gs = self._prepare_model() + specs = self._get_planned_specs(gm, gs) + + # Set second half of specs to CUDA + mid = len(specs) // 2 + self.assertGreater(mid, 0) + cpu_specs = specs[:mid] + cuda_specs = specs[mid:] + for spec in cuda_specs: + spec.device = DeviceType.CUDA + + algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy]) + bufsizes = apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True) + + # [constants, cpu_activations, cuda_activations] + self.assertEqual(len(bufsizes), 3) + self.assertEqual(bufsizes[0], 0) + self.assertGreater(bufsizes[1], 0) + self.assertGreater(bufsizes[2], 0) + + # CPU specs should have mem_id=1, CUDA specs should have mem_id=2 + for spec in cpu_specs: + self.assertEqual(spec.mem_id, 1, f"CPU spec has wrong mem_id: {spec.mem_id}") + for spec in cuda_specs: + self.assertEqual(spec.mem_id, 2, f"CUDA spec has wrong mem_id: {spec.mem_id}") + + def test_mem_offset_correct_after_remap(self) -> None: + """After remapping, mem_offset is relative to its own buffer.""" + gm, gs = self._prepare_model() + specs = self._get_planned_specs(gm, gs) + + # Set the last spec to CUDA (sole CUDA tensor) + cuda_spec = specs[-1] + cuda_spec.device = DeviceType.CUDA + + algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy]) + bufsizes = apply_algo( + algo, gm, 16, gs, enable_non_cpu_memory_planning=True + ) + + # The CUDA spec is the only tensor in its buffer, so offset should be 0 + self.assertEqual(cuda_spec.mem_offset, 0) + # The CUDA buffer should fit exactly this tensor + cuda_mem_id = cuda_spec.mem_id + self.assertIsNotNone(cuda_mem_id) + assert cuda_mem_id is not None + self.assertGreaterEqual(bufsizes[cuda_mem_id], cuda_spec.allocated_memory) + + def test_no_cross_device_memory_sharing(self) -> None: + """Specs on different devices never share buffers, regardless of lifetime.""" + gm, gs = self._prepare_model() + specs = self._get_planned_specs(gm, gs) + self.assertGreaterEqual(len(specs), 2) + + # Assign alternating specs to CUDA to ensure some pairs have + # non-overlapping lifetimes (which greedy would normally share). + for i, spec in enumerate(specs): + if i % 2 == 0: + spec.device = DeviceType.CUDA + + algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy]) + apply_algo(algo, gm, 16, gs, enable_non_cpu_memory_planning=True) + + # Verify CPU and CUDA specs have disjoint mem_ids + cpu_mem_ids: set[int] = set() + cuda_mem_ids: set[int] = set() + for i, spec in enumerate(specs): + if spec.mem_id is not None: + if i % 2 == 0: + cuda_mem_ids.add(spec.mem_id) + else: + cpu_mem_ids.add(spec.mem_id) + + self.assertTrue( + cpu_mem_ids.isdisjoint(cuda_mem_ids), + f"CPU {cpu_mem_ids} and CUDA {cuda_mem_ids} should not share buffers", + ) + + def test_disabled_falls_back_to_cpu(self) -> None: + """With enable_non_cpu_memory_planning=False (default), CUDA specs are + planned into CPU memory — no device-specific buffers are created.""" + gm, gs = self._prepare_model() + specs = self._get_planned_specs(gm, gs) + for spec in specs: + spec.device = DeviceType.CUDA + + algo = MemoryPlanningAlgorithmSuite(algo_list=[greedy]) + # Default: enable_non_cpu_memory_planning=False + bufsizes = apply_algo(algo, gm, 16, gs) + + # All specs planned into a single CPU pool — same as CPU-only + self.assertEqual(len(bufsizes), 2) + self.assertEqual(bufsizes[0], 0) + self.assertGreater(bufsizes[1], 0) + self.assertNotIn("non_const_buffer_device", gm.meta)