Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions exir/emit/test/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ fbcode_target(_kind = runtime.python_test,
"//executorch/exir:schema",
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:const_prop_pass",
"//executorch/exir/passes:constant_prop_pass",
"//executorch/exir/passes:init_mutable_pass",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/tests:lib",
"//executorch/exir/tests:models",
"//executorch/extension/pybindings:portable_lib",
Expand Down
125 changes: 125 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,3 +2518,128 @@ def forward(self):
for j in range(2):
expected_storage.append(j * 16 + i)
self.assertEqual([int(v) for v in storage_values], expected_storage)

def test_emit_device_info_propagated_to_serialized_tensor(self) -> None:
"""Verify that device info from PropagateDevicePass flows through
the emitter into ExtraTensorInfo.device_type on serialized tensors."""
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.passes.propagate_device_pass import (
TARGET_DEVICE_COMPILE_SPEC_KEY,
)
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

class AddSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
]

class DevicePartitioner(Partitioner):
def __init__(self):
super().__init__()
self.delegation_spec = DelegationSpec(
BackendWithCompilerDemo.__name__,
[
CompileSpec("max_value", bytes([4])),
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
],
)

def partition(self, exported_program) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
exported_program.graph_module,
op_support=any_chain(AddSupport()),
)
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
return PartitionResult(
tagged_exported_program=exported_program,
partition_tags=partition_tags,
)

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
lowered = edge.to_backend(DevicePartitioner())
et_prog = lowered.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
self.assertGreater(len(plan.delegates), 0)

tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
cuda_tensors = [
t
for t in tensor_values
if t.extra_tensor_info is not None
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
]
# add(a, b) has 2 delegate inputs + 1 delegate output = 3 CUDA tensors
self.assertEqual(
len(cuda_tensors),
3,
f"Expected exactly 3 CUDA tensors (2 inputs + 1 output for delegated add), got {len(cuda_tensors)}",
)
# Verify device_index is also correctly serialized (cuda:0 → index 0)
for t in cuda_tensors:
self.assertEqual(
t.extra_tensor_info.device_index,
0,
"CUDA tensor device_index should be 0 for cuda:0",
)

def test_emit_cpu_tensors_no_extra_device_info(self) -> None:
"""When all tensors are on CPU (default), ExtraTensorInfo should NOT be
created solely for device info — it should remain None for activation tensors.
"""

class Model(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2))

edge = to_edge(
export(model, inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
et_prog = edge.to_executorch()
program = et_prog._emitter_output.program

plan = program.execution_plan[0]
tensor_values = [v.val for v in plan.values if isinstance(v.val, Tensor)]
cuda_tensors = [
t
for t in tensor_values
if t.extra_tensor_info is not None
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
]
self.assertEqual(
len(cuda_tensors),
0,
"No tensor should have CUDA device when model runs entirely on CPU",
)
14 changes: 13 additions & 1 deletion exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,18 @@ def to_list(
tensor_size = to_list(spec.shape)
tensor_dim_order = to_list(spec.dim_order)

extra_tensor_info = spec.extra_tensor_info
# Propagate device from TensorSpec into ExtraTensorInfo for serialization.
if spec.device != schema.DeviceType.CPU:
if extra_tensor_info is None:
extra_tensor_info = schema.ExtraTensorInfo(
device_type=spec.device,
device_index=spec.device_index,
)
else:
extra_tensor_info.device_type = spec.device
extra_tensor_info.device_index = spec.device_index

flatbuffer_tensor = schema.Tensor(
scalar_type=scalar_type_enum(spec.scalar_type),
# The runtime currently only supports tensors with offsets of zero.
Expand All @@ -377,7 +389,7 @@ def to_list(
allocation_info=allocation_info,
layout=layout_enum(spec.layout),
shape_dynamism=spec.shape_dynamism,
extra_tensor_info=spec.extra_tensor_info,
extra_tensor_info=extra_tensor_info,
)
return flatbuffer_tensor

Expand Down
Loading