Skip to content

Commit 9303b07

Browse files
committed
[ET Device Support] TensorImpl carries device info
Pull Request resolved: #17534 This diff extends `TensorImpl` to carry device information, enabling the runtime tensor to track which device its data resides on (CPU, CUDA, etc.). This is a prerequisite for parsing device info from the schema and allocating device memory. ghstack-source-id: 354946057 @exported-using-ghexport Differential Revision: [D93635655](https://our.internmc.facebook.com/intern/diff/D93635655/)
1 parent ba67f10 commit 9303b07

5 files changed

Lines changed: 190 additions & 10 deletions

File tree

runtime/core/portable_type/device.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ enum class DeviceType : int8_t {
2626
constexpr size_t kNumDeviceTypes = 2;
2727

2828
/// An index representing a specific device; e.g. GPU 0 vs GPU 1.
29-
/// -1 means the default/unspecified device for that type.
3029
using DeviceIndex = int8_t;
3130

3231
/**
@@ -41,7 +40,7 @@ struct Device final {
4140

4241
/// Constructs a new `Device` from a `DeviceType` and an optional device
4342
/// index.
44-
/* implicit */ Device(DeviceType type, DeviceIndex index = -1)
43+
/* implicit */ Device(DeviceType type, DeviceIndex index = 0)
4544
: type_(type), index_(index) {}
4645

4746
/// Returns the type of device the tensor data resides on.
@@ -54,7 +53,7 @@ struct Device final {
5453
return type_ == DeviceType::CPU;
5554
}
5655

57-
/// Returns the device index, or -1 if default/unspecified.
56+
/// Returns the device index.
5857
DeviceIndex index() const noexcept {
5958
return index_;
6059
}
@@ -69,7 +68,7 @@ struct Device final {
6968

7069
private:
7170
DeviceType type_;
72-
DeviceIndex index_ = -1;
71+
DeviceIndex index_ = 0;
7372
};
7473

7574
} // namespace etensor

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ TensorImpl::TensorImpl(
5050
void* data,
5151
DimOrderType* dim_order,
5252
StridesType* strides,
53-
TensorShapeDynamism dynamism)
53+
TensorShapeDynamism dynamism,
54+
DeviceType device_type,
55+
DeviceIndex device_index)
5456
: sizes_(sizes),
5557
dim_order_(dim_order),
5658
strides_(strides),
@@ -59,7 +61,8 @@ TensorImpl::TensorImpl(
5961
numel_(compute_numel(sizes, dim)),
6062
numel_bound_(numel_),
6163
type_(type),
62-
shape_dynamism_(dynamism) {
64+
shape_dynamism_(dynamism),
65+
device_(device_type, device_index) {
6366
ET_CHECK_MSG(
6467
isValid(type_), "Invalid type %" PRId8, static_cast<int8_t>(type_));
6568
ET_CHECK_MSG(dim_ >= 0, "Dimension must be non-negative, got %zd", dim_);

runtime/core/portable_type/tensor_impl.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/runtime/core/array_ref.h>
1212
#include <executorch/runtime/core/error.h>
13+
#include <executorch/runtime/core/portable_type/device.h>
1314
#include <executorch/runtime/core/portable_type/scalar_type.h>
1415
#include <executorch/runtime/core/tensor_shape_dynamism.h>
1516

@@ -99,6 +100,8 @@ class TensorImpl {
99100
* @param strides Strides of the tensor at each dimension. Must contain `dim`
100101
* entries.
101102
* @param dynamism The mutability of the shape of the tensor.
103+
* @param device_type The type of device where tensor data resides.
104+
* @param device_index The device index for multi-device scenarios.
102105
*/
103106
TensorImpl(
104107
ScalarType type,
@@ -107,7 +110,9 @@ class TensorImpl {
107110
void* data = nullptr,
108111
DimOrderType* dim_order = nullptr,
109112
StridesType* strides = nullptr,
110-
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC);
113+
TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC,
114+
DeviceType device_type = DeviceType::CPU,
115+
DeviceIndex device_index = 0);
111116

112117
/**
113118
* Returns the size of the tensor in bytes.
@@ -176,6 +181,21 @@ class TensorImpl {
176181
return shape_dynamism_;
177182
}
178183

184+
/// Returns the device where tensor data resides.
185+
Device device() const {
186+
return device_;
187+
}
188+
189+
/// Returns the type of device where tensor data resides.
190+
DeviceType device_type() const {
191+
return device_.type();
192+
}
193+
194+
/// Returns the device index, or 0 if default/unspecified.
195+
DeviceIndex device_index() const {
196+
return device_.index();
197+
}
198+
179199
/// Returns a pointer of type T to the constant underlying data blob.
180200
template <typename T>
181201
inline const T* data() const {
@@ -261,6 +281,9 @@ class TensorImpl {
261281

262282
/// Specifies the mutability of the shape of the tensor.
263283
const TensorShapeDynamism shape_dynamism_;
284+
285+
/// Device where tensor data resides (CPU, CUDA, etc.)
286+
Device device_;
264287
};
265288

266289
/**

runtime/core/portable_type/test/device_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TEST(DeviceTest, CpuDefaultIndex) {
3434
Device d(DeviceType::CPU);
3535
EXPECT_TRUE(d.is_cpu());
3636
EXPECT_EQ(d.type(), DeviceType::CPU);
37-
EXPECT_EQ(d.index(), -1);
37+
EXPECT_EQ(d.index(), 0);
3838
}
3939

4040
TEST(DeviceTest, CpuExplicitIndex) {
@@ -49,7 +49,7 @@ TEST(DeviceTest, CudaDefaultIndex) {
4949
Device d(DeviceType::CUDA);
5050
EXPECT_FALSE(d.is_cpu());
5151
EXPECT_EQ(d.type(), DeviceType::CUDA);
52-
EXPECT_EQ(d.index(), -1);
52+
EXPECT_EQ(d.index(), 0);
5353
}
5454

5555
TEST(DeviceTest, CudaExplicitIndex) {
@@ -83,7 +83,7 @@ TEST(DeviceTest, EqualityDefaultIndices) {
8383
TEST(DeviceTest, ImplicitConstructionFromDeviceType) {
8484
// Device constructor is implicit, allowing DeviceType → Device conversion.
8585
Device d = DeviceType::CUDA;
86-
EXPECT_EQ(d.index(), -1);
86+
EXPECT_EQ(d.index(), 0);
8787
}
8888

8989
// --- Deprecated namespace aliases ---

runtime/core/portable_type/test/tensor_impl_test.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using namespace ::testing;
2121
using executorch::runtime::ArrayRef;
2222
using executorch::runtime::Error;
2323
using executorch::runtime::TensorShapeDynamism;
24+
using executorch::runtime::etensor::Device;
25+
using executorch::runtime::etensor::DeviceIndex;
26+
using executorch::runtime::etensor::DeviceType;
2427
using executorch::runtime::etensor::ScalarType;
2528
using executorch::runtime::etensor::TensorImpl;
2629
using SizesType = TensorImpl::SizesType;
@@ -449,3 +452,155 @@ TEST_F(TensorImplTest, TestResizingTensorToZeroAndBack) {
449452
EXPECT_GT(t.numel(), 0);
450453
EXPECT_EQ(t.data(), data);
451454
}
455+
456+
// ============== Size Tests ==============
457+
458+
TEST_F(TensorImplTest, TestTensorImplSize) {
459+
// Verify TensorImpl size hasn't regressed after adding Device member.
460+
// Device (2 bytes) fits within existing padding after type_ and
461+
// shape_dynamism_, so sizeof(TensorImpl) should remain unchanged.
462+
//
463+
// Memory layout (64-bit):
464+
// sizes_ : 8 bytes (pointer)
465+
// dim_order_ : 8 bytes (pointer)
466+
// strides_ : 8 bytes (pointer)
467+
// data_ : 8 bytes (pointer)
468+
// dim_ : 8 bytes (ssize_t)
469+
// numel_ : 8 bytes (ssize_t)
470+
// numel_bound_ : 8 bytes (size_t)
471+
// type_ : 1 byte (ScalarType : int8_t)
472+
// shape_dynamism_ : 1 byte (TensorShapeDynamism : uint8_t)
473+
// device_ : 2 bytes (Device: DeviceType + DeviceIndex)
474+
// padding : 4 bytes (to align struct to 8 bytes)
475+
// Total : 64 bytes
476+
//
477+
// Memory layout (32-bit):
478+
// sizes_ : 4 bytes (pointer)
479+
// dim_order_ : 4 bytes (pointer)
480+
// strides_ : 4 bytes (pointer)
481+
// data_ : 4 bytes (pointer)
482+
// dim_ : 4 bytes (ssize_t)
483+
// numel_ : 4 bytes (ssize_t)
484+
// numel_bound_ : 4 bytes (size_t)
485+
// type_ : 1 byte (ScalarType : int8_t)
486+
// shape_dynamism_ : 1 byte (TensorShapeDynamism : uint8_t)
487+
// device_ : 2 bytes (Device: DeviceType + DeviceIndex)
488+
// Total : 32 bytes (no additional padding needed)
489+
490+
#if INTPTR_MAX == INT64_MAX
491+
// 64-bit architecture
492+
EXPECT_EQ(sizeof(TensorImpl), 64);
493+
#else
494+
// 32-bit architecture
495+
EXPECT_EQ(sizeof(TensorImpl), 32);
496+
#endif
497+
}
498+
499+
// ============== Device Tests ==============
500+
501+
TEST_F(TensorImplTest, TestDefaultDeviceIsCPU) {
502+
// TensorImpl constructed without device parameters should default to CPU
503+
SizesType sizes[2] = {3, 2};
504+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
505+
TensorImpl t(ScalarType::Float, 2, sizes, data);
506+
507+
EXPECT_EQ(t.device_type(), DeviceType::CPU);
508+
EXPECT_EQ(t.device_index(), 0);
509+
EXPECT_EQ(t.device(), Device(DeviceType::CPU, 0));
510+
}
511+
512+
TEST_F(TensorImplTest, TestExplicitCPUDevice) {
513+
// TensorImpl constructed with explicit CPU device
514+
SizesType sizes[2] = {3, 2};
515+
DimOrderType dim_order[2] = {0, 1};
516+
StridesType strides[2] = {2, 1};
517+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
518+
TensorImpl t(
519+
ScalarType::Float,
520+
2,
521+
sizes,
522+
data,
523+
dim_order,
524+
strides,
525+
TensorShapeDynamism::STATIC,
526+
DeviceType::CPU,
527+
0);
528+
529+
EXPECT_EQ(t.device_type(), DeviceType::CPU);
530+
EXPECT_EQ(t.device_index(), 0);
531+
EXPECT_EQ(t.device(), Device(DeviceType::CPU, 0));
532+
}
533+
534+
TEST_F(TensorImplTest, TestCUDADevice) {
535+
// TensorImpl constructed with CUDA device
536+
SizesType sizes[2] = {3, 2};
537+
DimOrderType dim_order[2] = {0, 1};
538+
StridesType strides[2] = {2, 1};
539+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
540+
TensorImpl t(
541+
ScalarType::Float,
542+
2,
543+
sizes,
544+
data,
545+
dim_order,
546+
strides,
547+
TensorShapeDynamism::STATIC,
548+
DeviceType::CUDA,
549+
0);
550+
551+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
552+
EXPECT_EQ(t.device_index(), 0);
553+
EXPECT_EQ(t.device(), Device(DeviceType::CUDA, 0));
554+
}
555+
556+
TEST_F(TensorImplTest, TestCUDADeviceMultiGPU) {
557+
// TensorImpl with CUDA device index 1 (second GPU)
558+
SizesType sizes[2] = {3, 2};
559+
DimOrderType dim_order[2] = {0, 1};
560+
StridesType strides[2] = {2, 1};
561+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
562+
TensorImpl t(
563+
ScalarType::Float,
564+
2,
565+
sizes,
566+
data,
567+
dim_order,
568+
strides,
569+
TensorShapeDynamism::STATIC,
570+
DeviceType::CUDA,
571+
1);
572+
573+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
574+
EXPECT_EQ(t.device_index(), 1);
575+
EXPECT_EQ(t.device(), Device(DeviceType::CUDA, 1));
576+
}
577+
578+
TEST_F(TensorImplTest, TestDeviceWithDynamicTensor) {
579+
// Device info should work correctly with dynamic tensors
580+
SizesType sizes[2] = {3, 2};
581+
DimOrderType dim_order[2] = {0, 1};
582+
StridesType strides[2] = {2, 1};
583+
float data[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
584+
TensorImpl t(
585+
ScalarType::Float,
586+
2,
587+
sizes,
588+
data,
589+
dim_order,
590+
strides,
591+
TensorShapeDynamism::DYNAMIC_BOUND,
592+
DeviceType::CUDA,
593+
0);
594+
595+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
596+
EXPECT_EQ(t.device_index(), 0);
597+
598+
// Resize should not affect device
599+
SizesType new_sizes[2] = {2, 2};
600+
Error err = resize_tensor_impl(&t, {new_sizes, 2});
601+
EXPECT_EQ(err, Error::Ok);
602+
603+
// Device should remain unchanged after resize
604+
EXPECT_EQ(t.device_type(), DeviceType::CUDA);
605+
EXPECT_EQ(t.device_index(), 0);
606+
}

0 commit comments

Comments
 (0)