Skip to content

Commit ba67f10

Browse files
committed
[ET Device Support] Schema changes: device info on Tensor
Pull Request resolved: #17533 This diff adds device placement information to the ExecuTorch schema to support representing tensor-level device type information, which will be the basic requirement for the following tensor_parser updates. This is part of the Phase 1 implementation to make ET device type work E2E without user-specified device placement. Design doc: https://docs.google.com/document/d/1lwd9BlohmwkN5EEvRulO_b-XnZBwv1nMb5l2K3jfuwA/edit?tab=t.0#heading=h.o6anuvkix4bu ghstack-source-id: 354940210 @exported-using-ghexport Differential Revision: [D93635657](https://our.internmc.facebook.com/intern/diff/D93635657/)
1 parent 9076110 commit ba67f10

5 files changed

Lines changed: 67 additions & 2 deletions

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# automatically generated by the FlatBuffers compiler, do not modify
2+
3+
# namespace: executorch_flatbuffer
4+
5+
class DeviceType(object):
6+
CPU = 0
7+
CUDA = 1

exir/_serialize/generated/executorch_flatbuffer/ExtraTensorInfo.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,22 @@ def Location(self):
5151
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
5252
return 0
5353

54+
# ExtraTensorInfo
55+
def DeviceType(self):
56+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
57+
if o != 0:
58+
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
59+
return 0
60+
61+
# ExtraTensorInfo
62+
def DeviceIndex(self):
63+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
64+
if o != 0:
65+
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
66+
return 0
67+
5468
def ExtraTensorInfoStart(builder: flatbuffers.Builder):
55-
builder.StartObject(3)
69+
builder.StartObject(5)
5670

5771
def Start(builder: flatbuffers.Builder):
5872
ExtraTensorInfoStart(builder)
@@ -75,6 +89,18 @@ def ExtraTensorInfoAddLocation(builder: flatbuffers.Builder, location: int):
7589
def AddLocation(builder: flatbuffers.Builder, location: int):
7690
ExtraTensorInfoAddLocation(builder, location)
7791

92+
def ExtraTensorInfoAddDeviceType(builder: flatbuffers.Builder, deviceType: int):
93+
builder.PrependInt8Slot(3, deviceType, 0)
94+
95+
def AddDeviceType(builder: flatbuffers.Builder, deviceType: int):
96+
ExtraTensorInfoAddDeviceType(builder, deviceType)
97+
98+
def ExtraTensorInfoAddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
99+
builder.PrependInt8Slot(4, deviceIndex, 0)
100+
101+
def AddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int):
102+
ExtraTensorInfoAddDeviceIndex(builder, deviceIndex)
103+
78104
def ExtraTensorInfoEnd(builder: flatbuffers.Builder) -> int:
79105
return builder.EndObject()
80106

@@ -89,6 +115,8 @@ def __init__(self):
89115
self.mutableDataSegmentsIdx = 0 # type: int
90116
self.fullyQualifiedName = None # type: str
91117
self.location = 0 # type: int
118+
self.deviceType = 0 # type: int
119+
self.deviceIndex = 0 # type: int
92120

93121
@classmethod
94122
def InitFromBuf(cls, buf, pos):
@@ -111,7 +139,9 @@ def __eq__(self, other):
111139
return type(self) == type(other) and \
112140
self.mutableDataSegmentsIdx == other.mutableDataSegmentsIdx and \
113141
self.fullyQualifiedName == other.fullyQualifiedName and \
114-
self.location == other.location
142+
self.location == other.location and \
143+
self.deviceType == other.deviceType and \
144+
self.deviceIndex == other.deviceIndex
115145

116146
# ExtraTensorInfoT
117147
def _UnPack(self, extraTensorInfo):
@@ -120,6 +150,8 @@ def _UnPack(self, extraTensorInfo):
120150
self.mutableDataSegmentsIdx = extraTensorInfo.MutableDataSegmentsIdx()
121151
self.fullyQualifiedName = extraTensorInfo.FullyQualifiedName()
122152
self.location = extraTensorInfo.Location()
153+
self.deviceType = extraTensorInfo.DeviceType()
154+
self.deviceIndex = extraTensorInfo.DeviceIndex()
123155

124156
# ExtraTensorInfoT
125157
def Pack(self, builder):
@@ -130,5 +162,7 @@ def Pack(self, builder):
130162
if self.fullyQualifiedName is not None:
131163
ExtraTensorInfoAddFullyQualifiedName(builder, fullyQualifiedName)
132164
ExtraTensorInfoAddLocation(builder, self.location)
165+
ExtraTensorInfoAddDeviceType(builder, self.deviceType)
166+
ExtraTensorInfoAddDeviceIndex(builder, self.deviceIndex)
133167
extraTensorInfo = ExtraTensorInfoEnd(builder)
134168
return extraTensorInfo

exir/_serialize/generated/executorch_flatbuffer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from . import DataLocation
1414
from . import DataSegment
1515
from . import DelegateCall
16+
from . import DeviceType
1617
from . import Double
1718
from . import DoubleList
1819
from . import EValue
@@ -56,6 +57,7 @@
5657
"DataLocation",
5758
"DataSegment",
5859
"DelegateCall",
60+
"DeviceType",
5961
"Double",
6062
"DoubleList",
6163
"EValue",

exir/schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ class TensorDataLocation(IntEnum):
4848
EXTERNAL = 1
4949

5050

51+
class DeviceType(IntEnum):
52+
CPU = 0
53+
CUDA = 1
54+
55+
5156
@dataclass
5257
class ExtraTensorInfo:
5358
"""
@@ -57,6 +62,8 @@ class ExtraTensorInfo:
5762
mutable_data_segments_idx: int = 0
5863
fully_qualified_name: Optional[str] = None
5964
location: TensorDataLocation = TensorDataLocation.SEGMENT
65+
device_type: DeviceType = DeviceType.CPU
66+
device_index: int = 0
6067

6168

6269
@dataclass

schema/program.fbs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ enum TensorDataLocation : byte {
6262
EXTERNAL = 1,
6363
}
6464

65+
// Device type enum indicating where a tensor resides or should be allocated.
66+
// Please keep this in sync with executorch/runtime/core/portable_type/device.h
67+
enum DeviceType : byte {
68+
CPU = 0,
69+
CUDA = 1,
70+
}
71+
6572
// Table to put additional information about tensors in that is not applicable
6673
// to the vast majority of tensors in the vast majority of programs.
6774
table ExtraTensorInfo {
@@ -80,6 +87,14 @@ table ExtraTensorInfo {
8087
// must be non-empty, and is used as a key to find the tensor's external
8188
// data. Tensor.data_buffer_idx is ignored.
8289
location: TensorDataLocation;
90+
91+
// [Optional] The device type where this tensor resides or should be allocated.
92+
// Defaults to CPU for backward compatibility with existing PTE files.
93+
device_type: DeviceType = CPU;
94+
95+
// [Optional] The device index for multi-device scenarios (e.g., cuda:0, cuda:1).
96+
// Defaults to 0 (the first device of the given type).
97+
device_index: byte = 0;
8398
}
8499

85100
table Tensor {

0 commit comments

Comments
 (0)