Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions cpp/src/arrow/flight/flight_internals_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/buffer.h"
#include "arrow/flight/client_cookie_middleware.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/cookie_internal.h"
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/server.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/transport.h"
#include "arrow/flight/transport/grpc/util_internal.h"
#include "arrow/flight/types.h"
#include "arrow/ipc/reader.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/string.h"
Expand Down Expand Up @@ -730,6 +735,74 @@ TEST(GrpcTransport, FlightDataDeserialize) {
#endif
}

// ----------------------------------------------------------------------
// Transport-agnostic serialization roundtrip tests

TEST(FlightSerialization, RoundtripPayloadWithBody) {
// Use RecordBatchStream to generate FlightPayloads
auto schema = arrow::schema({arrow::field("a", arrow::int32())});
auto arr = ArrayFromJSON(arrow::int32(), "[1, 2, 3]");
auto batch = RecordBatch::Make(schema, 3, {arr});
auto reader = RecordBatchReader::Make({batch}).ValueOrDie();
RecordBatchStream stream(std::move(reader));

// Get a FlightPayload from the stream
ASSERT_OK_AND_ASSIGN(auto schema_payload, stream.GetSchemaPayload());
ASSERT_OK_AND_ASSIGN(auto flight_payload, stream.Next());

// Add app_metadata to the flight payload
flight_payload.app_metadata = Buffer::FromString("test-metadata");

// Serialize FlightPayload to BufferVector
ASSERT_OK_AND_ASSIGN(auto buffers, internal::SerializePayloadToBuffers(flight_payload));
ASSERT_GT(buffers.size(), 0);

// Concatenate to a single buffer for deserialization and deserialize.
ASSERT_OK_AND_ASSIGN(auto concat, ConcatenateBuffers(buffers));
ASSERT_OK_AND_ASSIGN(auto data, internal::DeserializeFlightData(concat));

// Verify IPC metadata (data_header) is present
ASSERT_NE(data.metadata, nullptr);
ASSERT_GT(data.metadata->size(), 0);

// Verify app_metadata
ASSERT_NE(data.app_metadata, nullptr);
ASSERT_EQ(data.app_metadata->ToString(), "test-metadata");

// Verify body and message are present
ASSERT_NE(data.body, nullptr);
ASSERT_GT(data.body->size(), 0);
ASSERT_OK_AND_ASSIGN(auto message, data.OpenMessage());
ASSERT_NE(message, nullptr);
// Also verify the RecordBatch roundtrips correctly
ipc::DictionaryMemo dict_memo;
ASSERT_OK_AND_ASSIGN(auto result_batch,
ipc::ReadRecordBatch(*message, schema, &dict_memo,
ipc::IpcReadOptions::Defaults()));
ASSERT_TRUE(result_batch->Equals(*batch));
}

TEST(FlightSerialization, RoundtripMetadataOnly) {
// A metadata-only payload (no IPC body, no descriptor)
auto app_meta = Buffer::FromString("metadata-only-message");

FlightPayload payload;
payload.app_metadata = std::move(app_meta);

// Serialize
ASSERT_OK_AND_ASSIGN(auto buffers, internal::SerializePayloadToBuffers(payload));
ASSERT_OK_AND_ASSIGN(auto concat, ConcatenateBuffers(buffers));

// Deserialize
ASSERT_OK_AND_ASSIGN(auto data, internal::DeserializeFlightData(concat));

// Verify: no descriptor, no IPC metadata, just app_metadata
ASSERT_EQ(data.descriptor, nullptr);
ASSERT_EQ(data.metadata, nullptr);
ASSERT_NE(data.app_metadata, nullptr);
ASSERT_EQ(data.app_metadata->ToString(), "metadata-only-message");
}

// ----------------------------------------------------------------------
// Transport abstraction tests

Expand Down
232 changes: 232 additions & 0 deletions cpp/src/arrow/flight/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@

#include "arrow/flight/serialization_internal.h"

#include <limits>
#include <memory>
#include <string>

#include <google/protobuf/any.pb.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/wire_format_lite.h>

#include "arrow/buffer.h"
#include "arrow/flight/protocol_internal.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/message.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/logging_internal.h"

// Lambda helper & CTAD
template <class... Ts>
Expand Down Expand Up @@ -612,6 +618,232 @@ Status ToProto(const CloseSessionResult& result, pb::CloseSessionResult* pb_resu
return Status::OK();
}

namespace {
using google::protobuf::internal::WireFormatLite;
using google::protobuf::io::ArrayOutputStream;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::CodedOutputStream;
static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0};

// Update the sizes of our Protobuf fields based on the given IPC payload.
arrow::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body,
size_t* header_size, int32_t* metadata_size) {
DCHECK_LE(ipc_msg.metadata->size(), kInt32Max);
*metadata_size = static_cast<int32_t>(ipc_msg.metadata->size());

// 1 byte for metadata tag
*header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size);

// 2 bytes for body tag
if (has_body) {
// We write the body tag in the header but not the actual body data
*header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) -
ipc_msg.body_length;
}

return arrow::Status::OK();
}

bool ReadBytesZeroCopy(const std::shared_ptr<Buffer>& source_data,
CodedInputStream* input, std::shared_ptr<Buffer>* out) {
uint32_t length;
if (!input->ReadVarint32(&length)) {
return false;
}
auto buf =
SliceBuffer(source_data, input->CurrentPosition(), static_cast<int64_t>(length));
*out = buf;
return input->Skip(static_cast<int>(length));
}

} // namespace

arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const FlightPayload& msg) {
// Size of the IPC body (protobuf: data_body)
size_t body_size = 0;
// Size of the Protobuf "header" (everything except for the body)
size_t header_size = 0;
// Size of IPC header metadata (protobuf: data_header)
int32_t metadata_size = 0;

// Write the descriptor if present
int32_t descriptor_size = 0;
if (msg.descriptor != nullptr) {
DCHECK_LE(msg.descriptor->size(), kInt32Max);
descriptor_size = static_cast<int32_t>(msg.descriptor->size());
header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size);
}

// App metadata tag if appropriate
int32_t app_metadata_size = 0;
if (msg.app_metadata && msg.app_metadata->size() > 0) {
DCHECK_LE(msg.app_metadata->size(), kInt32Max);
app_metadata_size = static_cast<int32_t>(msg.app_metadata->size());
header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size);
}

const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message;
// No data in this payload (metadata-only).
bool has_ipc = ipc_msg.type != ipc::MessageType::NONE;
bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false;

if (has_ipc) {
DCHECK(has_body || ipc_msg.body_length == 0);
ARROW_RETURN_NOT_OK(
IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size));
body_size = static_cast<size_t>(ipc_msg.body_length);
}

// TODO(wesm): messages over 2GB unlikely to be yet supported
// Validated in WritePayload since returning error here causes gRPC to fail an assertion
DCHECK_LE(body_size, kInt32Max);

// Allocate and initialize buffers
arrow::BufferVector buffers;
ARROW_ASSIGN_OR_RAISE(auto header_buf, arrow::AllocateBuffer(header_size));

// Force the header_stream to be destructed, which actually flushes
// the data into the slice.
{
ArrayOutputStream header_writer(const_cast<uint8_t*>(header_buf->mutable_data()),
static_cast<int>(header_size));
CodedOutputStream header_stream(&header_writer);

// Write descriptor
if (msg.descriptor != nullptr) {
WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
header_stream.WriteVarint32(descriptor_size);
header_stream.WriteRawMaybeAliased(msg.descriptor->data(),
static_cast<int>(msg.descriptor->size()));
}

// Write header
if (has_ipc) {
WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
header_stream.WriteVarint32(metadata_size);
header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
static_cast<int>(ipc_msg.metadata->size()));
}

// Write app metadata
if (app_metadata_size > 0) {
WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
header_stream.WriteVarint32(app_metadata_size);
header_stream.WriteRawMaybeAliased(msg.app_metadata->data(),
static_cast<int>(msg.app_metadata->size()));
}

if (has_body) {
// Write body tag
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
header_stream.WriteVarint32(static_cast<uint32_t>(body_size));

// Enqueue body buffers for writing, without copying
for (const auto& buffer : ipc_msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
if (!buffer || buffer->size() == 0) continue;
buffers.push_back(buffer);

// Write padding if not multiple of 8
const auto remainder = static_cast<int>(
bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
if (remainder) {
buffers.push_back(std::make_shared<arrow::Buffer>(kPaddingBytes, remainder));
}
}
}

DCHECK_EQ(static_cast<int>(header_size), header_stream.ByteCount());
}
// Once header is written we add it as the first buffer in the output vector.
buffers.insert(buffers.begin(), std::move(header_buf));

return buffers;
}

// Read internal::FlightData from arrow::Buffer containing FlightData
// protobuf without copying
arrow::Result<arrow::flight::internal::FlightData> DeserializeFlightData(
const std::shared_ptr<arrow::Buffer>& buffer) {
if (!buffer) {
return Status::Invalid("No payload");
}

arrow::flight::internal::FlightData out;

auto buffer_length = static_cast<int>(buffer->size());
CodedInputStream pb_stream(buffer->data(), buffer_length);

pb_stream.SetTotalBytesLimit(buffer_length);

// This is the bytes remaining when using CodedInputStream like this
while (pb_stream.BytesUntilTotalBytesLimit()) {
const uint32_t tag = pb_stream.ReadTag();
const int field_number = WireFormatLite::GetTagFieldNumber(tag);
switch (field_number) {
case pb::FlightData::kFlightDescriptorFieldNumber: {
pb::FlightDescriptor pb_descriptor;
uint32_t length;
if (!pb_stream.ReadVarint32(&length)) {
return Status::Invalid("Unable to parse length of FlightDescriptor");
}
// Can't use ParseFromCodedStream as this reads the entire
// rest of the stream into the descriptor command field.
std::string buffer;
if (!pb_stream.ReadString(&buffer, length)) {
return Status::Invalid("Unable to read FlightDescriptor from protobuf");
}
if (!pb_descriptor.ParseFromString(buffer)) {
return Status::Invalid("Unable to parse FlightDescriptor");
}
arrow::flight::FlightDescriptor descriptor;
ARROW_RETURN_NOT_OK(
arrow::flight::internal::FromProto(pb_descriptor, &descriptor));
out.descriptor = std::make_unique<arrow::flight::FlightDescriptor>(descriptor);
} break;
case pb::FlightData::kDataHeaderFieldNumber: {
if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.metadata)) {
return Status::Invalid("Unable to read FlightData metadata");
}
} break;
case pb::FlightData::kAppMetadataFieldNumber: {
if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.app_metadata)) {
return Status::Invalid("Unable to read FlightData application metadata");
}
} break;
case pb::FlightData::kDataBodyFieldNumber: {
if (!ReadBytesZeroCopy(buffer, &pb_stream, &out.body)) {
return Status::Invalid("Unable to read FlightData body");
}
} break;
default: {
// Unknown field. We should skip it for compatibility.
if (!WireFormatLite::SkipField(&pb_stream, tag)) {
return Status::Invalid("Could not skip unknown field tag in FlightData");
}
break;
}
}
}

// TODO(wesm): Where and when should we verify that the FlightData is not
// malformed?

// Set the default value for an unspecified FlightData body. The other
// fields can be null if they're unspecified.
if (out.body == nullptr) {
out.body = std::make_shared<Buffer>(nullptr, 0);
}

return out;
}

} // namespace internal
} // namespace flight
} // namespace arrow
9 changes: 9 additions & 0 deletions cpp/src/arrow/flight/serialization_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ ARROW_FLIGHT_EXPORT Status ToProto(const CloseSessionResult& result,

Status ToPayload(const FlightDescriptor& descr, std::shared_ptr<Buffer>* out);

/// \brief Serialize a FlightPayload to a vector of buffers.
ARROW_FLIGHT_EXPORT
arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const FlightPayload& msg);

/// \brief Deserialize FlightData from a contiguous buffer.
ARROW_FLIGHT_EXPORT
arrow::Result<internal::FlightData> DeserializeFlightData(
const std::shared_ptr<arrow::Buffer>& buffer);

// We want to reuse RecordBatchStreamReader's implementation while
// (1) Adapting it to the Flight message format
// (2) Allowing pure-metadata messages before data is sent
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class FlightStatusDetail;
namespace internal {

/// Internal, not user-visible type used for memory-efficient reads
struct FlightData {
struct ARROW_FLIGHT_EXPORT FlightData {
/// Used only for puts, may be null
std::unique_ptr<FlightDescriptor> descriptor;

Expand Down
Loading
Loading