Skip to content

Commit 3ba8f65

Browse files
committed
[ET Device Support] Add NonConstBufferDevice schema for per-buffer device mapping
Adds the NonConstBufferDevice table to the FlatBuffer schema (program.fbs) and the corresponding Python dataclass to schema.py. This enables mapping each non-constant planned memory buffer to a specific device type (CPU, CUDA, etc.). The field is optional and absent for CPU-only programs, ensuring zero binary size regression. Differential Revision: [D97335597](https://our.internmc.facebook.com/intern/diff/D97335597/) ghstack-source-id: 354722549 Pull Request resolved: #18330
1 parent 5ef6700 commit 3ba8f65

6 files changed

Lines changed: 256 additions & 2 deletions

File tree

exir/_serialize/generated/executorch_flatbuffer/ExecutionPlan.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.exir._serialize.generated.executorch_flatbuffer.Chain import Chain
1111
from executorch.exir._serialize.generated.executorch_flatbuffer.ContainerMetadata import ContainerMetadata
1212
from executorch.exir._serialize.generated.executorch_flatbuffer.EValue import EValue
13+
from executorch.exir._serialize.generated.executorch_flatbuffer.NonConstBufferDevice import NonConstBufferDevice
1314
from executorch.exir._serialize.generated.executorch_flatbuffer.Operator import Operator
1415
from typing import Optional
1516
np = import_numpy()
@@ -230,8 +231,32 @@ def NonConstBufferSizesIsNone(self) -> bool:
230231
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
231232
return o == 0
232233

234+
# ExecutionPlan
235+
def NonConstBufferDevice(self, j: int) -> Optional[NonConstBufferDevice]:
236+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
237+
if o != 0:
238+
x = self._tab.Vector(o)
239+
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
240+
x = self._tab.Indirect(x)
241+
obj = NonConstBufferDevice()
242+
obj.Init(self._tab.Bytes, x)
243+
return obj
244+
return None
245+
246+
# ExecutionPlan
247+
def NonConstBufferDeviceLength(self) -> int:
248+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
249+
if o != 0:
250+
return self._tab.VectorLen(o)
251+
return 0
252+
253+
# ExecutionPlan
254+
def NonConstBufferDeviceIsNone(self) -> bool:
255+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
256+
return o == 0
257+
233258
def ExecutionPlanStart(builder: flatbuffers.Builder):
234-
builder.StartObject(9)
259+
builder.StartObject(10)
235260

236261
def Start(builder: flatbuffers.Builder):
237262
ExecutionPlanStart(builder)
@@ -332,6 +357,18 @@ def ExecutionPlanStartNonConstBufferSizesVector(builder, numElems: int) -> int:
332357
def StartNonConstBufferSizesVector(builder, numElems: int) -> int:
333358
return ExecutionPlanStartNonConstBufferSizesVector(builder, numElems)
334359

360+
def ExecutionPlanAddNonConstBufferDevice(builder: flatbuffers.Builder, nonConstBufferDevice: int):
361+
builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(nonConstBufferDevice), 0)
362+
363+
def AddNonConstBufferDevice(builder: flatbuffers.Builder, nonConstBufferDevice: int):
364+
ExecutionPlanAddNonConstBufferDevice(builder, nonConstBufferDevice)
365+
366+
def ExecutionPlanStartNonConstBufferDeviceVector(builder, numElems: int) -> int:
367+
return builder.StartVector(4, numElems, 4)
368+
369+
def StartNonConstBufferDeviceVector(builder, numElems: int) -> int:
370+
return ExecutionPlanStartNonConstBufferDeviceVector(builder, numElems)
371+
335372
def ExecutionPlanEnd(builder: flatbuffers.Builder) -> int:
336373
return builder.EndObject()
337374

@@ -342,6 +379,7 @@ def End(builder: flatbuffers.Builder) -> int:
342379
from executorch.exir._serialize.generated.executorch_flatbuffer import Chain
343380
from executorch.exir._serialize.generated.executorch_flatbuffer import ContainerMetadata
344381
from executorch.exir._serialize.generated.executorch_flatbuffer import EValue
382+
from executorch.exir._serialize.generated.executorch_flatbuffer import NonConstBufferDevice
345383
from executorch.exir._serialize.generated.executorch_flatbuffer import Operator
346384
try:
347385
from typing import List, Optional
@@ -361,6 +399,7 @@ def __init__(self):
361399
self.operators = None # type: List[executorch_flatbuffer.Operator.OperatorT]
362400
self.delegates = None # type: List[executorch_flatbuffer.BackendDelegate.BackendDelegateT]
363401
self.nonConstBufferSizes = None # type: List[int]
402+
self.nonConstBufferDevice = None # type: List[executorch_flatbuffer.NonConstBufferDevice.NonConstBufferDeviceT]
364403

365404
@classmethod
366405
def InitFromBuf(cls, buf, pos):
@@ -389,7 +428,8 @@ def __eq__(self, other):
389428
self.chains == other.chains and \
390429
self.operators == other.operators and \
391430
self.delegates == other.delegates and \
392-
self.nonConstBufferSizes == other.nonConstBufferSizes
431+
self.nonConstBufferSizes == other.nonConstBufferSizes and \
432+
self.nonConstBufferDevice == other.nonConstBufferDevice
393433

394434
# ExecutionPlanT
395435
def _UnPack(self, executionPlan):
@@ -451,6 +491,14 @@ def _UnPack(self, executionPlan):
451491
self.nonConstBufferSizes.append(executionPlan.NonConstBufferSizes(i))
452492
else:
453493
self.nonConstBufferSizes = executionPlan.NonConstBufferSizesAsNumpy()
494+
if not executionPlan.NonConstBufferDeviceIsNone():
495+
self.nonConstBufferDevice = []
496+
for i in range(executionPlan.NonConstBufferDeviceLength()):
497+
if executionPlan.NonConstBufferDevice(i) is None:
498+
self.nonConstBufferDevice.append(None)
499+
else:
500+
nonConstBufferDevice_ = executorch_flatbuffer.NonConstBufferDevice.NonConstBufferDeviceT.InitFromObj(executionPlan.NonConstBufferDevice(i))
501+
self.nonConstBufferDevice.append(nonConstBufferDevice_)
454502

455503
# ExecutionPlanT
456504
def Pack(self, builder):
@@ -514,6 +562,14 @@ def Pack(self, builder):
514562
for i in reversed(range(len(self.nonConstBufferSizes))):
515563
builder.PrependInt64(self.nonConstBufferSizes[i])
516564
nonConstBufferSizes = builder.EndVector()
565+
if self.nonConstBufferDevice is not None:
566+
nonConstBufferDevicelist = []
567+
for i in range(len(self.nonConstBufferDevice)):
568+
nonConstBufferDevicelist.append(self.nonConstBufferDevice[i].Pack(builder))
569+
ExecutionPlanStartNonConstBufferDeviceVector(builder, len(self.nonConstBufferDevice))
570+
for i in reversed(range(len(self.nonConstBufferDevice))):
571+
builder.PrependUOffsetTRelative(nonConstBufferDevicelist[i])
572+
nonConstBufferDevice = builder.EndVector()
517573
ExecutionPlanStart(builder)
518574
if self.name is not None:
519575
ExecutionPlanAddName(builder, name)
@@ -533,5 +589,7 @@ def Pack(self, builder):
533589
ExecutionPlanAddDelegates(builder, delegates)
534590
if self.nonConstBufferSizes is not None:
535591
ExecutionPlanAddNonConstBufferSizes(builder, nonConstBufferSizes)
592+
if self.nonConstBufferDevice is not None:
593+
ExecutionPlanAddNonConstBufferDevice(builder, nonConstBufferDevice)
536594
executionPlan = ExecutionPlanEnd(builder)
537595
return executionPlan
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace: executorch_flatbuffer
4+
5+
import flatbuffers
6+
from flatbuffers.compat import import_numpy
7+
from typing import Any
8+
np = import_numpy()
9+
10+
class NonConstBufferDevice(object):
11+
__slots__ = ['_tab']
12+
13+
@classmethod
14+
def GetRootAs(cls, buf, offset: int = 0):
15+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
16+
x = NonConstBufferDevice()
17+
x.Init(buf, n + offset)
18+
return x
19+
20+
@classmethod
21+
def GetRootAsNonConstBufferDevice(cls, buf, offset=0):
22+
"""This method is deprecated. Please switch to GetRootAs."""
23+
return cls.GetRootAs(buf, offset)
24+
@classmethod
25+
def NonConstBufferDeviceBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
26+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x45\x54\x31\x32", size_prefixed=size_prefixed)
27+
28+
# NonConstBufferDevice
29+
def Init(self, buf: bytes, pos: int):
30+
self._tab = flatbuffers.table.Table(buf, pos)
31+
32+
# NonConstBufferDevice
33+
def BufferIdx(self):
34+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
35+
if o != 0:
36+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
37+
return 0
38+
39+
# NonConstBufferDevice
40+
def DeviceType(self):
41+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
42+
if o != 0:
43+
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
44+
return 0
45+
46+
# NonConstBufferDevice
47+
def DeviceIndex(self):
48+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
49+
if o != 0:
50+
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
51+
return 0
52+
53+
def NonConstBufferDeviceStart(builder: flatbuffers.Builder):
54+
builder.StartObject(3)
55+
56+
def Start(builder: flatbuffers.Builder):
57+
NonConstBufferDeviceStart(builder)
58+
59+
def NonConstBufferDeviceAddBufferIdx(builder: flatbuffers.Builder, bufferIdx: int):
60+
builder.PrependInt32Slot(0, bufferIdx, 0)
61+
62+
def AddBufferIdx(builder: flatbuffers.Builder, bufferIdx: int):
63+
NonConstBufferDeviceAddBufferIdx(builder, bufferIdx)
64+
65+
def NonConstBufferDeviceAddDeviceType(builder: flatbuffers.Builder, deviceType: int):
66+
builder.PrependInt8Slot(1, deviceType, 0)
67+
68+
def AddDeviceType(builder: flatbuffers.Builder, deviceType: int):
69+
NonConstBufferDeviceAddDeviceType(builder, deviceType)
70+
71+
def NonConstBufferDeviceAddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
72+
builder.PrependInt8Slot(2, deviceIndex, 0)
73+
74+
def AddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
75+
NonConstBufferDeviceAddDeviceIndex(builder, deviceIndex)
76+
77+
def NonConstBufferDeviceEnd(builder: flatbuffers.Builder) -> int:
78+
return builder.EndObject()
79+
80+
def End(builder: flatbuffers.Builder) -> int:
81+
return NonConstBufferDeviceEnd(builder)
82+
83+
84+
class NonConstBufferDeviceT(object):
85+
86+
# NonConstBufferDeviceT
87+
def __init__(self):
88+
self.bufferIdx = 0 # type: int
89+
self.deviceType = 0 # type: int
90+
self.deviceIndex = 0 # type: int
91+
92+
@classmethod
93+
def InitFromBuf(cls, buf, pos):
94+
nonConstBufferDevice = NonConstBufferDevice()
95+
nonConstBufferDevice.Init(buf, pos)
96+
return cls.InitFromObj(nonConstBufferDevice)
97+
98+
@classmethod
99+
def InitFromPackedBuf(cls, buf, pos=0):
100+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
101+
return cls.InitFromBuf(buf, pos+n)
102+
103+
@classmethod
104+
def InitFromObj(cls, nonConstBufferDevice):
105+
x = NonConstBufferDeviceT()
106+
x._UnPack(nonConstBufferDevice)
107+
return x
108+
109+
def __eq__(self, other):
110+
return type(self) == type(other) and \
111+
self.bufferIdx == other.bufferIdx and \
112+
self.deviceType == other.deviceType and \
113+
self.deviceIndex == other.deviceIndex
114+
115+
# NonConstBufferDeviceT
116+
def _UnPack(self, nonConstBufferDevice):
117+
if nonConstBufferDevice is None:
118+
return
119+
self.bufferIdx = nonConstBufferDevice.BufferIdx()
120+
self.deviceType = nonConstBufferDevice.DeviceType()
121+
self.deviceIndex = nonConstBufferDevice.DeviceIndex()
122+
123+
# NonConstBufferDeviceT
124+
def Pack(self, builder):
125+
NonConstBufferDeviceStart(builder)
126+
NonConstBufferDeviceAddBufferIdx(builder, self.bufferIdx)
127+
NonConstBufferDeviceAddDeviceType(builder, self.deviceType)
128+
NonConstBufferDeviceAddDeviceIndex(builder, self.deviceIndex)
129+
nonConstBufferDevice = NonConstBufferDeviceEnd(builder)
130+
return nonConstBufferDevice

exir/_serialize/generated/executorch_flatbuffer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from . import KernelTypes
3232
from . import MoveCall
3333
from . import NamedData
34+
from . import NonConstBufferDevice
3435
from . import Null
3536
from . import Operator
3637
from . import OptionalTensorList
@@ -75,6 +76,7 @@
7576
"KernelTypes",
7677
"MoveCall",
7778
"NamedData",
79+
"NonConstBufferDevice",
7880
"Null",
7981
"Operator",
8082
"OptionalTensorList",

exir/_serialize/test/test_program.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
ContainerMetadata,
3939
DataLocation,
4040
DataSegment,
41+
DeviceType,
4142
ExecutionPlan,
43+
NonConstBufferDevice,
4244
Program,
4345
SubsegmentOffsets,
4446
)
@@ -477,6 +479,32 @@ def test_round_trip_large_buffer_sizes(self) -> None:
477479
program, deserialize_pte_binary(flatbuffer_from_py).program
478480
)
479481

482+
def test_round_trip_with_non_const_buffer_device(self) -> None:
483+
"""Tests that non_const_buffer_device survives round-trip
484+
serialization/deserialization. This verifies the schema extension
485+
for per-buffer device mapping works correctly.
486+
"""
487+
program = get_test_program()
488+
program.execution_plan[0].non_const_buffer_device = [
489+
NonConstBufferDevice(buffer_idx=0, device_type=DeviceType.CPU, device_index=0),
490+
NonConstBufferDevice(buffer_idx=1, device_type=DeviceType.CUDA, device_index=0),
491+
]
492+
flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program)))
493+
self.assert_programs_equal(
494+
program, deserialize_pte_binary(flatbuffer_from_py).program
495+
)
496+
497+
def test_round_trip_without_non_const_buffer_device(self) -> None:
498+
"""Tests backward compatibility: a program without non_const_buffer_device
499+
(the default) round-trips correctly and the field remains None.
500+
"""
501+
program = get_test_program()
502+
self.assertIsNone(program.execution_plan[0].non_const_buffer_device)
503+
flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program)))
504+
deserialized = deserialize_pte_binary(flatbuffer_from_py).program
505+
self.assert_programs_equal(program, deserialized)
506+
self.assertIsNone(deserialized.execution_plan[0].non_const_buffer_device)
507+
480508
def test_round_trip_no_segments_and_no_header(self) -> None:
481509
"""Tests that a Program serialized with extract_delegate_segments=True
482510
when there are no segments does not contain an extended header,

exir/schema.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,18 @@ class Operator:
268268
overload: str
269269

270270

271+
@dataclass
272+
class NonConstBufferDevice:
273+
"""Maps a non-constant buffer to the device where it should be allocated."""
274+
275+
# Index into the non_const_buffer_sizes list.
276+
buffer_idx: int = 0
277+
# The device type for this buffer (CPU, CUDA, etc.).
278+
device_type: DeviceType = DeviceType.CPU
279+
# The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
280+
device_index: int = 0
281+
282+
271283
@dataclass
272284
class ExecutionPlan:
273285
name: str
@@ -283,6 +295,9 @@ class ExecutionPlan:
283295
# Runtime should use the len(constant_buffer) as the ground truch of
284296
# constant memory buffer size, and ignore non_const_buffer_sizes[0].
285297
non_const_buffer_sizes: List[int]
298+
# Per-buffer device mapping. Each entry maps a non-constant buffer to the
299+
# device where it should be allocated. For CPU-only programs, this is empty.
300+
non_const_buffer_device: Optional[List[NonConstBufferDevice]] = None
286301

287302

288303
@dataclass

schema/program.fbs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,27 @@ table ExecutionPlan {
401401
// constants memory buffer size, and ignore non_const_buffer_sizes[0].
402402
non_const_buffer_sizes: [int64];
403403

404+
// [Optional] Per-buffer device mapping, parallel to non_const_buffer_sizes.
405+
// Each entry maps a non-constant buffer to the device where it should be
406+
// allocated. For CPU-only programs, this field is absent and all buffers
407+
// default to CPU, ensuring zero regression.
408+
non_const_buffer_device: [NonConstBufferDevice];
409+
410+
}
411+
412+
// Maps a non-constant buffer to the device where it should be allocated.
413+
// When present as part of ExecutionPlan.non_const_buffer_device, each entry
414+
// describes the device placement for the corresponding planned memory buffer.
415+
// For CPU-only programs, this table is absent (all buffers default to CPU).
416+
table NonConstBufferDevice {
417+
// Index into the non_const_buffer_sizes list.
418+
buffer_idx: int;
419+
420+
// The device type for this buffer (CPU, CUDA, etc.).
421+
device_type: DeviceType = CPU;
422+
423+
// The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
424+
device_index: byte = 0;
404425
}
405426

406427
// Constant tensor data stored directly in the flatbuffer.

0 commit comments

Comments
 (0)