diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java index 9bb49d8b6688..6ad06f9f4887 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java @@ -18,8 +18,11 @@ package org.apache.beam.sdk.io.gcp.pubsub; import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; import com.google.protobuf.InvalidProtocolBufferException; +import java.io.IOException; import java.util.Map; +import javax.annotation.Nullable; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -28,26 +31,58 @@ public final class PubsubMessages { private PubsubMessages() {} public static com.google.pubsub.v1.PubsubMessage toProto(PubsubMessage input) { - Map attributes = input.getAttributeMap(); + @Nullable Map attributes = input.getAttributeMap(); com.google.pubsub.v1.PubsubMessage.Builder message = com.google.pubsub.v1.PubsubMessage.newBuilder() .setData(ByteString.copyFrom(input.getPayload())); // TODO(https://github.com/apache/beam/issues/19787) this should not be null - if (attributes != null) { + if (attributes != null && !attributes.isEmpty()) { message.putAllAttributes(attributes); } - String messageId = input.getMessageId(); - if (messageId != null) { + @Nullable String messageId = input.getMessageId(); + if (messageId != null && !messageId.isEmpty()) { message.setMessageId(messageId); } - String orderingKey = input.getOrderingKey(); - if (orderingKey != null) { + @Nullable String orderingKey = input.getOrderingKey(); + if (orderingKey != null && !orderingKey.isEmpty()) { message.setOrderingKey(orderingKey); } return message.build(); } + // Optimization of toProto(input).toByteArray() + private static byte[] toSerializedPubsubMessageProto(PubsubMessage input) { + @Nullable Map attributes = input.getAttributeMap(); + @Nullable String messageId = input.getMessageId(); + @Nullable String orderingKey = input.getOrderingKey(); + if ((attributes == null || attributes.isEmpty()) + && (messageId == null || messageId.isEmpty()) + && (orderingKey == null || orderingKey.isEmpty())) { + // Optimize the case where we are just sending a payload. + byte[] payload = input.getPayload(); + if (payload == null || payload.length == 0) { + return new byte[0]; + } + int size = + CodedOutputStream.computeByteArraySize( + com.google.pubsub.v1.PubsubMessage.DATA_FIELD_NUMBER, payload); + byte[] serialized = new byte[size]; + try { + CodedOutputStream output = CodedOutputStream.newInstance(serialized); + output.writeByteArray(com.google.pubsub.v1.PubsubMessage.DATA_FIELD_NUMBER, payload); + output.checkNoSpaceLeft(); + } catch (IOException e) { + // Should not happen since we are writing to a byte array of the exact size. + throw new RuntimeException( + "Unexpected error while serializing PubsubMessage to a byte array.", e); + } + return serialized; + } + // Fallback to general case by building up a protobuf and serializing it. + return toProto(input).toByteArray(); + } + public static PubsubMessage fromProto(com.google.pubsub.v1.PubsubMessage input) { return new PubsubMessage( input.getData().toByteArray(), @@ -56,12 +91,13 @@ public static PubsubMessage fromProto(com.google.pubsub.v1.PubsubMessage input) input.getOrderingKey()); } - // Convert the PubsubMessage to a PubsubMessage proto, then return its serialized representation. + // Convert the beam PubsubMessage to a serialized com.google.pubsub.v1.PubsubMessage proto + // representation. public static class ParsePayloadAsPubsubMessageProto implements SerializableFunction { @Override public byte[] apply(PubsubMessage input) { - return toProto(input).toByteArray(); + return toSerializedPubsubMessageProto(input); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessagesTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessagesTest.java new file mode 100644 index 000000000000..e7c3a08a3d59 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessagesTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.pubsub; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PubsubMessages}. */ +@RunWith(JUnit4.class) +public class PubsubMessagesTest { + + @Test + public void testRoundTripToProto() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + Map attributes = ImmutableMap.of("key1", "value1", "key2", "value2"); + String messageId = "test-message-id"; + String orderingKey = "test-ordering-key"; + + PubsubMessage originalMessage = new PubsubMessage(payload, attributes, messageId, orderingKey); + PubsubMessage roundTrippedMessage = + PubsubMessages.fromProto(PubsubMessages.toProto(originalMessage)); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(originalMessage.getAttributeMap(), roundTrippedMessage.getAttributeMap()); + assertEquals(originalMessage.getMessageId(), roundTrippedMessage.getMessageId()); + assertEquals(originalMessage.getOrderingKey(), roundTrippedMessage.getOrderingKey()); + } + + @Test + public void testRoundTripToProto_emptyAttributes() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + Map attributes = Collections.emptyMap(); + String messageId = "test-message-id"; + String orderingKey = "test-ordering-key"; + + PubsubMessage originalMessage = new PubsubMessage(payload, attributes, messageId, orderingKey); + PubsubMessage roundTrippedMessage = + PubsubMessages.fromProto(PubsubMessages.toProto(originalMessage)); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(originalMessage.getAttributeMap(), roundTrippedMessage.getAttributeMap()); + assertEquals(originalMessage.getMessageId(), roundTrippedMessage.getMessageId()); + assertEquals(originalMessage.getOrderingKey(), roundTrippedMessage.getOrderingKey()); + } + + @Test + public void testRoundTripToProto_nullAttributes() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + String messageId = "test-message-id"; + String orderingKey = "test-ordering-key"; + + PubsubMessage originalMessage = new PubsubMessage(payload, null, messageId, orderingKey); + PubsubMessage roundTrippedMessage = + PubsubMessages.fromProto(PubsubMessages.toProto(originalMessage)); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + // PubsubMessage.fromProto returns an empty map when proto attributes map is empty + assertEquals(Collections.emptyMap(), roundTrippedMessage.getAttributeMap()); + assertEquals(originalMessage.getMessageId(), roundTrippedMessage.getMessageId()); + assertEquals(originalMessage.getOrderingKey(), roundTrippedMessage.getOrderingKey()); + } + + @Test + public void testRoundTripToProto_nullMessageIdAndOrderingKey() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + Map attributes = ImmutableMap.of("key", "value"); + + PubsubMessage originalMessage = new PubsubMessage(payload, attributes, null, null); + PubsubMessage roundTrippedMessage = + PubsubMessages.fromProto(PubsubMessages.toProto(originalMessage)); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(originalMessage.getAttributeMap(), roundTrippedMessage.getAttributeMap()); + // protobuf translates null string into empty string. + assertTrue(roundTrippedMessage.getMessageId().isEmpty()); + assertTrue(roundTrippedMessage.getOrderingKey().isEmpty()); + } + + @Test + public void testRoundTripToProto_messageIdAndOrderingKey() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + Map attributes = ImmutableMap.of("key", "value"); + + PubsubMessage originalMessage = + new PubsubMessage(payload, attributes, "messageId", "orderingKey"); + PubsubMessage roundTrippedMessage = + PubsubMessages.fromProto(PubsubMessages.toProto(originalMessage)); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(originalMessage.getAttributeMap(), roundTrippedMessage.getAttributeMap()); + assertEquals(originalMessage.getOrderingKey(), roundTrippedMessage.getOrderingKey()); + assertEquals(originalMessage.getMessageId(), roundTrippedMessage.getMessageId()); + } + + @Test + public void testParsePayloadAsPubsubMessageProto() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + PubsubMessage originalMessage = new PubsubMessage(payload, null, null, null); + + byte[] serialized = + new PubsubMessages.ParsePayloadAsPubsubMessageProto().apply(originalMessage); + PubsubMessage roundTrippedMessage = + new PubsubMessages.ParsePubsubMessageProtoAsPayload().apply(serialized); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(Collections.emptyMap(), roundTrippedMessage.getAttributeMap()); + assertTrue( + roundTrippedMessage.getMessageId() == null || roundTrippedMessage.getMessageId().isEmpty()); + assertTrue( + roundTrippedMessage.getOrderingKey() == null + || roundTrippedMessage.getOrderingKey().isEmpty()); + } + + @Test + public void testParsePayloadAsPubsubMessageProto_emptyPayload() { + byte[] payload = new byte[0]; + PubsubMessage originalMessage = new PubsubMessage(payload, null, null, null); + + byte[] serialized = + new PubsubMessages.ParsePayloadAsPubsubMessageProto().apply(originalMessage); + PubsubMessage roundTrippedMessage = + new PubsubMessages.ParsePubsubMessageProtoAsPayload().apply(serialized); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(Collections.emptyMap(), roundTrippedMessage.getAttributeMap()); + assertTrue( + roundTrippedMessage.getMessageId() == null || roundTrippedMessage.getMessageId().isEmpty()); + assertTrue( + roundTrippedMessage.getOrderingKey() == null + || roundTrippedMessage.getOrderingKey().isEmpty()); + } + + @Test + public void testParsePayloadAsPubsubMessageProto_withAttributes() { + byte[] payload = "test-payload".getBytes(StandardCharsets.UTF_8); + Map attributes = ImmutableMap.of("key1", "value1", "key2", "value2"); + PubsubMessage originalMessage = new PubsubMessage(payload, attributes, null, null); + + byte[] serialized = + new PubsubMessages.ParsePayloadAsPubsubMessageProto().apply(originalMessage); + PubsubMessage roundTrippedMessage = + new PubsubMessages.ParsePubsubMessageProtoAsPayload().apply(serialized); + + assertArrayEquals(originalMessage.getPayload(), roundTrippedMessage.getPayload()); + assertEquals(originalMessage.getAttributeMap(), roundTrippedMessage.getAttributeMap()); + assertTrue( + roundTrippedMessage.getMessageId() == null || roundTrippedMessage.getMessageId().isEmpty()); + assertTrue( + roundTrippedMessage.getOrderingKey() == null + || roundTrippedMessage.getOrderingKey().isEmpty()); + } +}