diff --git a/exir/_serialize/generated/executorch_flatbuffer/DeviceType.py b/exir/_serialize/generated/executorch_flatbuffer/DeviceType.py new file mode 100644 index 00000000000..1d17205a47c --- /dev/null +++ b/exir/_serialize/generated/executorch_flatbuffer/DeviceType.py @@ -0,0 +1,7 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: executorch_flatbuffer + +class DeviceType(object): + CPU = 0 + CUDA = 1 diff --git a/exir/_serialize/generated/executorch_flatbuffer/ExtraTensorInfo.py b/exir/_serialize/generated/executorch_flatbuffer/ExtraTensorInfo.py index 7622f49821b..d2c00382067 100644 --- a/exir/_serialize/generated/executorch_flatbuffer/ExtraTensorInfo.py +++ b/exir/_serialize/generated/executorch_flatbuffer/ExtraTensorInfo.py @@ -51,8 +51,22 @@ def Location(self): return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) return 0 + # ExtraTensorInfo + def DeviceType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + + # ExtraTensorInfo + def DeviceIndex(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) + return 0 + def ExtraTensorInfoStart(builder: flatbuffers.Builder): - builder.StartObject(3) + builder.StartObject(5) def Start(builder: flatbuffers.Builder): ExtraTensorInfoStart(builder) @@ -75,6 +89,18 @@ def ExtraTensorInfoAddLocation(builder: flatbuffers.Builder, location: int): def AddLocation(builder: flatbuffers.Builder, location: int): ExtraTensorInfoAddLocation(builder, location) +def ExtraTensorInfoAddDeviceType(builder: flatbuffers.Builder, deviceType: int): + builder.PrependInt8Slot(3, deviceType, 0) + +def AddDeviceType(builder: flatbuffers.Builder, deviceType: int): + ExtraTensorInfoAddDeviceType(builder, deviceType) + +def ExtraTensorInfoAddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int): + builder.PrependInt8Slot(4, deviceIndex, 0) + +def AddDeviceIndex(builder: flatbuffers.Builder, deviceIndex: int): + ExtraTensorInfoAddDeviceIndex(builder, deviceIndex) + def ExtraTensorInfoEnd(builder: flatbuffers.Builder) -> int: return builder.EndObject() @@ -89,6 +115,8 @@ def __init__(self): self.mutableDataSegmentsIdx = 0 # type: int self.fullyQualifiedName = None # type: str self.location = 0 # type: int + self.deviceType = 0 # type: int + self.deviceIndex = 0 # type: int @classmethod def InitFromBuf(cls, buf, pos): @@ -111,7 +139,9 @@ def __eq__(self, other): return type(self) == type(other) and \ self.mutableDataSegmentsIdx == other.mutableDataSegmentsIdx and \ self.fullyQualifiedName == other.fullyQualifiedName and \ - self.location == other.location + self.location == other.location and \ + self.deviceType == other.deviceType and \ + self.deviceIndex == other.deviceIndex # ExtraTensorInfoT def _UnPack(self, extraTensorInfo): @@ -120,6 +150,8 @@ def _UnPack(self, extraTensorInfo): self.mutableDataSegmentsIdx = extraTensorInfo.MutableDataSegmentsIdx() self.fullyQualifiedName = extraTensorInfo.FullyQualifiedName() self.location = extraTensorInfo.Location() + self.deviceType = extraTensorInfo.DeviceType() + self.deviceIndex = extraTensorInfo.DeviceIndex() # ExtraTensorInfoT def Pack(self, builder): @@ -130,5 +162,7 @@ def Pack(self, builder): if self.fullyQualifiedName is not None: ExtraTensorInfoAddFullyQualifiedName(builder, fullyQualifiedName) ExtraTensorInfoAddLocation(builder, self.location) + ExtraTensorInfoAddDeviceType(builder, self.deviceType) + ExtraTensorInfoAddDeviceIndex(builder, self.deviceIndex) extraTensorInfo = ExtraTensorInfoEnd(builder) return extraTensorInfo diff --git a/exir/_serialize/generated/executorch_flatbuffer/__init__.py b/exir/_serialize/generated/executorch_flatbuffer/__init__.py index ee27d60361c..df59751e724 100644 --- a/exir/_serialize/generated/executorch_flatbuffer/__init__.py +++ b/exir/_serialize/generated/executorch_flatbuffer/__init__.py @@ -13,6 +13,7 @@ from . import DataLocation from . import DataSegment from . import DelegateCall +from . import DeviceType from . import Double from . import DoubleList from . import EValue @@ -56,6 +57,7 @@ "DataLocation", "DataSegment", "DelegateCall", + "DeviceType", "Double", "DoubleList", "EValue", diff --git a/exir/schema.py b/exir/schema.py index 7dba623aebf..993a473dabb 100644 --- a/exir/schema.py +++ b/exir/schema.py @@ -48,6 +48,11 @@ class TensorDataLocation(IntEnum): EXTERNAL = 1 +class DeviceType(IntEnum): + CPU = 0 + CUDA = 1 + + @dataclass class ExtraTensorInfo: """ @@ -57,6 +62,8 @@ class ExtraTensorInfo: mutable_data_segments_idx: int = 0 fully_qualified_name: Optional[str] = None location: TensorDataLocation = TensorDataLocation.SEGMENT + device_type: DeviceType = DeviceType.CPU + device_index: int = 0 @dataclass diff --git a/schema/program.fbs b/schema/program.fbs index c177d60fd4c..ae95c56fa96 100644 --- a/schema/program.fbs +++ b/schema/program.fbs @@ -62,6 +62,13 @@ enum TensorDataLocation : byte { EXTERNAL = 1, } +// Device type enum indicating where a tensor resides or should be allocated. +// Please keep this in sync with executorch/runtime/core/portable_type/device.h +enum DeviceType : byte { + CPU = 0, + CUDA = 1, +} + // Table to put additional information about tensors in that is not applicable // to the vast majority of tensors in the vast majority of programs. table ExtraTensorInfo { @@ -80,6 +87,14 @@ table ExtraTensorInfo { // must be non-empty, and is used as a key to find the tensor's external // data. Tensor.data_buffer_idx is ignored. location: TensorDataLocation; + + // [Optional] The device type where this tensor resides or should be allocated. + // Defaults to CPU for backward compatibility with existing PTE files. + device_type: DeviceType = CPU; + + // [Optional] The device index for multi-device scenarios (e.g., cuda:0, cuda:1). + // Defaults to 0 (the first device of the given type). + device_index: byte = 0; } table Tensor {