From dd6eeaf62e947ef6be83da46dd392ab87557e5e9 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 19 Mar 2026 13:16:06 +0000 Subject: [PATCH 1/8] supporting various clients thorugh types framework --- .../spark/sql/types/ops/ClientTypeOps.scala | 177 ++++++++++++++++++ .../spark/sql/types/ops/TimeTypeApiOps.scala | 48 ++++- .../apache/spark/sql/util/ArrowUtils.scala | 3 + .../catalyst/DeserializerBuildHelper.scala | 3 + .../sql/catalyst/SerializerBuildHelper.scala | 3 + .../catalyst/types/ops/CatalystTypeOps.scala | 106 +++++++++++ .../sql/catalyst/types/ops/TimeTypeOps.scala | 39 +++- .../sql/execution/arrow/ArrowWriter.scala | 7 +- .../client/arrow/ArrowDeserializer.scala | 4 + .../client/arrow/ArrowSerializer.scala | 3 + .../client/arrow/ArrowVectorReader.scala | 4 + .../client/arrow/TimeTypeConnectOps.scala | 124 ++++++++++++ .../common/DataTypeProtoConverter.scala | 5 + .../common/LiteralValueProtoConverter.scala | 9 + .../types/ops/ConnectArrowTypeOps.scala | 74 ++++++++ .../common/types/ops/ProtoTypeOps.scala | 115 ++++++++++++ .../spark/sql/execution/HiveResult.scala | 3 + .../datasources/jdbc/JdbcUtils.scala | 3 + .../sql/execution/python/EvaluatePython.scala | 4 + .../SparkExecuteStatementOperation.scala | 3 + 20 files changed, 730 insertions(+), 7 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala create mode 100644 sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala create mode 100644 sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala create mode 100644 sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala new file mode 100644 index 000000000000..ca273680d8be --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.ops + +import org.apache.arrow.vector.types.pojo.ArrowType + +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional client-side type operations for the Types Framework. + * + * This trait extends TypeApiOps with operations needed by client-facing infrastructure: + * Arrow conversion (ArrowUtils), JDBC mapping (JdbcUtils), Python interop (EvaluatePython), + * Hive formatting (HiveResult), and Thrift type mapping (SparkExecuteStatementOperation). + * + * Lives in sql/api so it's visible from sql/core and sql/hive-thriftserver. + * + * USAGE - integration points use ClientTypeOps(dt) which returns Option[ClientTypeOps]: + * {{{ + * // Forward lookup (most files): + * ClientTypeOps(dt).map(_.toArrowType(timeZoneId)).getOrElse { ... } + * + * // Reverse lookup (ArrowUtils.fromArrowType): + * ClientTypeOps.fromArrowType(at).getOrElse { ... } + * }}} + * + * @see + * TimeTypeApiOps for a reference implementation + * @since 4.2.0 + */ +trait ClientTypeOps { self: TypeApiOps => + + // ==================== Arrow Conversion ==================== + + /** + * Converts this DataType to its Arrow representation. + * + * Used by ArrowUtils.toArrowType. + * + * @param timeZoneId + * the session timezone (needed by some temporal types) + * @return + * the corresponding ArrowType + */ + def toArrowType(timeZoneId: String): ArrowType + + // ==================== JDBC Mapping ==================== + + /** + * Returns the java.sql.Types constant for this type. + * + * Used by JdbcUtils.getCommonJDBCType for JDBC write path. + * + * @return + * java.sql.Types constant (e.g., java.sql.Types.TIME) + */ + def getJdbcType: Int + + /** + * Returns the DDL type name string for this type. + * + * Used by JdbcUtils for CREATE TABLE DDL generation. + * + * @return + * DDL type string (e.g., "TIME") + */ + def jdbcTypeName: String + + // ==================== Python Interop ==================== + + /** + * Returns true if values of this type need conversion when passed to/from Python. + * + * Used by EvaluatePython.needConversionInPython. + */ + def needConversionInPython: Boolean + + /** + * Creates a converter function for Python/Py4J interop. + * + * Used by EvaluatePython.makeFromJava. The returned function handles null-safe + * conversion of Java/Py4J values to the internal Catalyst representation. + * + * @return + * a function that converts a Java value to the internal representation + */ + def makeFromJava: Any => Any + + // ==================== Hive Formatting ==================== + + /** + * Formats an external-type value for Hive output. + * + * Used by HiveResult.toHiveString. The input is an external-type value + * (e.g., java.time.LocalTime for TimeType), NOT the internal representation. + * + * @param value + * the external-type value to format + * @return + * formatted string representation + */ + def formatExternal(value: Any): String + + // ==================== Thrift Mapping ==================== + + /** + * Returns the Thrift TTypeId name for this type. + * + * Used by SparkExecuteStatementOperation.toTTypeId. Returns a String that maps + * to a TTypeId enum value (e.g., "STRING_TYPE") since TTypeId is only available + * in the hive-thriftserver module. + * + * @return + * TTypeId enum name (e.g., "STRING_TYPE") + */ + def thriftTypeName: String +} + +/** + * Factory object for ClientTypeOps lookup. + * + * Delegates to TypeApiOps and narrows via collect to find implementations that mix in + * ClientTypeOps. + */ +object ClientTypeOps { + + /** + * Returns a ClientTypeOps instance for the given DataType, if available. + * + * @param dt + * the DataType to get operations for + * @return + * Some(ClientTypeOps) if supported, None otherwise + */ + def apply(dt: DataType): Option[ClientTypeOps] = + TypeApiOps(dt).collect { case co: ClientTypeOps => co } + + /** + * Reverse lookup: converts an Arrow type to a Spark DataType, if it belongs to a + * framework-managed type. + * + * Used by ArrowUtils.fromArrowType. Returns None if the Arrow type doesn't correspond + * to any framework-managed type, or the framework is disabled. + * + * @param at + * the ArrowType to convert + * @return + * Some(DataType) if recognized, None otherwise + */ + def fromArrowType(at: ArrowType): Option[DataType] = { + import org.apache.arrow.vector.types.TimeUnit + if (!SqlApiConf.get.typesFrameworkEnabled) return None + at match { + case t: ArrowType.Time + if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => + Some(TimeType(TimeType.MICROS_PRECISION)) + // Add new framework types here + case _ => None + } + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala index 581ffffff2f9..bebbbbf89613 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.types.ops +import java.time.LocalTime + +import org.apache.arrow.vector.types.TimeUnit +import org.apache.arrow.vector.types.pojo.ArrowType + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder -import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, TimeFormatter} +import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, SparkDateTimeUtils, TimeFormatter} import org.apache.spark.sql.types.{DataType, TimeType} /** @@ -29,6 +34,13 @@ import org.apache.spark.sql.types.{DataType, TimeType} * - String formatting: uses FractionTimeFormatter for consistent output * - Row encoding: uses LocalTimeEncoder for java.time.LocalTime * + * Additionally, it implements ClientTypeOps for: + * - Arrow conversion (ArrowUtils) + * - JDBC mapping (JdbcUtils) + * - Python interop (EvaluatePython) + * - Hive formatting (HiveResult) + * - Thrift type mapping (SparkExecuteStatementOperation) + * * RELATIONSHIP TO TimeTypeOps: TimeTypeOps (in catalyst package) extends this class to inherit * client-side operations while adding server-side operations (physical type, literals, etc.). * @@ -36,7 +48,7 @@ import org.apache.spark.sql.types.{DataType, TimeType} * The TimeType with precision information * @since 4.2.0 */ -class TimeTypeApiOps(val t: TimeType) extends TypeApiOps { +class TimeTypeApiOps(val t: TimeType) extends TypeApiOps with ClientTypeOps { override def dataType: DataType = t @@ -56,4 +68,36 @@ class TimeTypeApiOps(val t: TimeType) extends TypeApiOps { // ==================== Row Encoding ==================== override def getEncoder: AgnosticEncoder[_] = LocalTimeEncoder + + // ==================== Client Type Operations (ClientTypeOps) ==================== + + override def toArrowType(timeZoneId: String): ArrowType = { + new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) + } + + override def getJdbcType: Int = java.sql.Types.TIME + + override def jdbcTypeName: String = "TIME" + + override def needConversionInPython: Boolean = true + + override def makeFromJava: Any => Any = (obj: Any) => { + if (obj == null) { + null + } else { + obj match { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + case _ => null + } + } + } + + override def formatExternal(value: Any): String = { + val nanos = SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime]) + timeFormatter.format(nanos) + } + + override def thriftTypeName: String = "STRING_TYPE" } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 92b52d4ae634..8e1f1347a7b0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.SparkException import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.util.ArrayImplicits._ private[sql] object ArrowUtils { @@ -58,6 +59,7 @@ private[sql] object ArrowUtils { case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) case TimestampNTZType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case dt if ClientTypeOps(dt).isDefined => ClientTypeOps(dt).get.toArrowType(timeZoneId) case _: TimeType => new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) case NullType => ArrowType.Null.INSTANCE case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) @@ -89,6 +91,7 @@ private[sql] object ArrowUtils { if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case at if ClientTypeOps.fromArrowType(at).isDefined => ClientTypeOps.fromArrowType(at).get case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => TimeType(TimeType.MICROS_PRECISION) case ArrowType.Null.INSTANCE => NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 080794643fa0..62db0112148a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} +import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -345,6 +346,8 @@ object DeserializerBuildHelper { createDeserializerForInstant(path) case LocalDateTimeEncoder => createDeserializerForLocalDateTime(path) + case enc if CatalystTypeOps(enc.dataType).isDefined => + CatalystTypeOps(enc.dataType).get.createDeserializer(path) case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled => throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError() case LocalTimeEncoder => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index b8b2406a5813..5539a0bb5719 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -367,6 +368,8 @@ object SerializerBuildHelper { case TimestampEncoder(false) => createSerializerForSqlTimestamp(input) case InstantEncoder(false) => createSerializerForJavaInstant(input) case LocalDateTimeEncoder => createSerializerForLocalDateTime(input) + case enc if CatalystTypeOps(enc.dataType).isDefined => + CatalystTypeOps(enc.dataType).get.createSerializer(input) case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled => throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError() case LocalTimeEncoder => createSerializerForLocalTime(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala new file mode 100644 index 000000000000..651a778d7676 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.ops + +import org.apache.arrow.vector.ValueVector + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.arrow.ArrowFieldWriter +import org.apache.spark.sql.types.DataType + +/** + * Optional catalyst-layer type operations for the Types Framework. + * + * This trait extends TypeOps with operations needed by catalyst-level client infrastructure: + * serializer/deserializer expression building (SerializerBuildHelper, DeserializerBuildHelper) + * and Arrow field writer creation (ArrowWriter). + * + * USAGE - integration points use CatalystTypeOps(dt) which returns Option[CatalystTypeOps]: + * {{{ + * // DataType-keyed (ArrowWriter): + * CatalystTypeOps(dt).map(_.createArrowFieldWriter(vector)).getOrElse { ... } + * + * // Encoder-keyed (SerializerBuildHelper): use enc.dataType to get the DataType + * CatalystTypeOps(enc.dataType).map(_.createSerializer(input)).getOrElse { ... } + * }}} + * + * @see + * TimeTypeOps for a reference implementation + * @since 4.2.0 + */ +trait CatalystTypeOps { self: TypeOps => + + /** + * Creates a serializer expression that converts an external value to its internal + * Catalyst representation. + * + * Used by SerializerBuildHelper for Dataset[T] serialization. + * + * @param input + * the input expression representing the external value + * @return + * an Expression that performs the conversion + */ + def createSerializer(input: Expression): Expression + + /** + * Creates a deserializer expression that converts an internal Catalyst value + * to its external representation. + * + * Used by DeserializerBuildHelper for Dataset[T] deserialization. + * + * @param path + * the expression representing the internal value + * @return + * an Expression that performs the conversion + */ + def createDeserializer(path: Expression): Expression + + /** + * Creates an ArrowFieldWriter for writing values of this type to an Arrow vector. + * + * Used by ArrowWriter for Arrow-based data exchange. + * + * @param vector + * the Arrow ValueVector to write to (must be the correct vector type for this data type) + * @return + * an ArrowFieldWriter configured for this type + */ + def createArrowFieldWriter(vector: ValueVector): ArrowFieldWriter +} + +/** + * Factory object for CatalystTypeOps lookup. + * + * Delegates to TypeOps and narrows via collect to find implementations that mix in + * CatalystTypeOps. Returns None if the type is not supported, the framework is disabled, + * or the type's TypeOps does not implement CatalystTypeOps. + */ +object CatalystTypeOps { + + /** + * Returns a CatalystTypeOps instance for the given DataType, if available. + * + * @param dt + * the DataType to get operations for + * @return + * Some(CatalystTypeOps) if supported, None otherwise + */ + def apply(dt: DataType): Option[CatalystTypeOps] = + TypeOps(dt).collect { case co: CatalystTypeOps => co } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala index 74198c956edc..5464ae5b1d2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.catalyst.types.ops import java.time.LocalTime +import org.apache.arrow.vector.{TimeNanoVector, ValueVector} + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, MutableLong, MutableValue} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, MutableLong, MutableValue} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.TimeType +import org.apache.spark.sql.execution.arrow.{ArrowFieldWriter, TimeWriter} +import org.apache.spark.sql.types.{ObjectType, TimeType} import org.apache.spark.sql.types.ops.TimeTypeApiOps /** @@ -38,6 +42,10 @@ import org.apache.spark.sql.types.ops.TimeTypeApiOps * - String formatting (FractionTimeFormatter) * - Row encoding (LocalTimeEncoder) * + * Additionally, it implements CatalystTypeOps for: + * - Serializer/deserializer expression building (SerializerBuildHelper, DeserializerBuildHelper) + * - Arrow field writer creation (ArrowWriter) + * * INTERNAL REPRESENTATION: * - Values stored as Long nanoseconds since midnight * - Range: 0 to 86,399,999,999,999 @@ -48,7 +56,8 @@ import org.apache.spark.sql.types.ops.TimeTypeApiOps * The TimeType with precision information * @since 4.2.0 */ -case class TimeTypeOps(override val t: TimeType) extends TimeTypeApiOps(t) with TypeOps { +case class TimeTypeOps(override val t: TimeType) + extends TimeTypeApiOps(t) with TypeOps with CatalystTypeOps { // ==================== Physical Type Representation ==================== @@ -81,4 +90,28 @@ case class TimeTypeOps(override val t: TimeType) extends TimeTypeApiOps(t) with override def toScalaImpl(row: InternalRow, column: Int): Any = { DateTimeUtils.nanosToLocalTime(row.getLong(column)) } + + // ==================== Catalyst Type Operations (CatalystTypeOps) ==================== + + override def createSerializer(input: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + t, + "localTimeToNanos", + input :: Nil, + returnNullable = false) + } + + override def createDeserializer(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.LocalTime]), + "nanosToLocalTime", + path :: Nil, + returnNullable = false) + } + + override def createArrowFieldWriter(vector: ValueVector): ArrowFieldWriter = { + new TimeWriter(vector.asInstanceOf[TimeNanoVector]) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index b5269da035f3..e69400e93b65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps import org.apache.spark.sql.catalyst.util.STUtils import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ @@ -69,6 +70,8 @@ object ArrowWriter { case (DateType, vector: DateDayVector) => new DateWriter(vector) case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) + case (dt, vector) if CatalystTypeOps(dt).isDefined => + CatalystTypeOps(dt).get.createArrowFieldWriter(vector) case (_: TimeType, vector: TimeNanoVector) => new TimeWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) @@ -146,7 +149,7 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { } } -private[arrow] abstract class ArrowFieldWriter { +private[sql] abstract class ArrowFieldWriter { def valueVector: ValueVector @@ -371,7 +374,7 @@ private[arrow] class TimestampNTZWriter( } } -private[arrow] class TimeWriter( +private[sql] class TimeWriter( val valueVector: TimeNanoVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 82029025a7f0..a8da26c49bd4 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.connect.common.types.ops.ConnectArrowTypeOps import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors} import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.{CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} @@ -200,6 +201,9 @@ object ArrowDeserializers { new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) { override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i) } + case (enc, v: FieldVector) if ConnectArrowTypeOps(enc).isDefined => + ConnectArrowTypeOps(enc).get + .createArrowDeserializer(enc, v, timeZoneId).asInstanceOf[Deserializer[Any]] case (LocalTimeEncoder, v: FieldVector) => new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) { override def value(i: Int): LocalTime = reader.getLocalTime(i) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index d547c81afe5a..d0fca318a72d 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.connect.common.types.ops.ConnectArrowTypeOps import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator} @@ -391,6 +392,8 @@ object ArrowSerializer { override def set(index: Int, value: LocalDateTime): Unit = vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value)) } + case (enc, v) if ConnectArrowTypeOps(enc).isDefined => + ConnectArrowTypeOps(enc).get.createArrowSerializer(v).asInstanceOf[Serializer] case (LocalTimeEncoder, v: TimeNanoVector) => new FieldSerializer[LocalTime, TimeNanoVector](v) { override def set(index: Int, value: LocalTime): Unit = diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index ea57e0e1c77f..46f5a939ff7e 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, Ti import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ +import org.apache.spark.sql.connect.common.types.ops.ConnectArrowTypeOps import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, UpCastRule, YearMonthIntervalType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.SparkStringUtils @@ -92,6 +93,9 @@ object ArrowVectorReader { case v: DateDayVector => new DateDayVectorReader(v, timeZoneId) case v: TimeStampMicroTZVector => new TimeStampMicroTZVectorReader(v) case v: TimeStampMicroVector => new TimeStampMicroVectorReader(v, timeZoneId) + case v if ConnectArrowTypeOps(targetDataType).isDefined => + ConnectArrowTypeOps(targetDataType).get + .createArrowVectorReader(v).asInstanceOf[ArrowVectorReader] case v: TimeNanoVector => new TimeVectorReader(v) case _: NullVector => NullVectorReader case _ => throw new RuntimeException("Unsupported Vector Type: " + vector.getClass) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala new file mode 100644 index 000000000000..2ccb37bb2b50 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.arrow + +import java.time.LocalTime + +import org.apache.arrow.vector.{FieldVector, TimeNanoVector} + +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils +import org.apache.spark.sql.connect.common.types.ops.{ConnectArrowTypeOps, ProtoTypeOps} +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Combined Connect operations for TimeType. + * + * Implements both ProtoTypeOps (proto DataType/Literal conversions) and ConnectArrowTypeOps + * (Arrow serialization/deserialization) in a single class. + * + * Lives in the arrow package to access arrow-private types (ArrowVectorReader, TimeVectorReader, + * ArrowSerializer.Serializer, ArrowDeserializers.LeafFieldDeserializer). + * + * @param t + * The TimeType with precision information + * @since 4.2.0 + */ +private[connect] class TimeTypeConnectOps(val t: TimeType) + extends ProtoTypeOps with ConnectArrowTypeOps { + + override def dataType: DataType = t + + override def encoder: AgnosticEncoder[_] = LocalTimeEncoder + + // ==================== ProtoTypeOps ==================== + + override def toConnectProtoType: proto.DataType = { + proto.DataType + .newBuilder() + .setTime(proto.DataType.Time.newBuilder().setPrecision(t.precision).build()) + .build() + } + + override def toLiteralProto( + value: Any, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder = { + val v = value.asInstanceOf[LocalTime] + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION)) + } + + override def toLiteralProtoWithType( + value: Any, + dt: DataType, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder = { + val v = value.asInstanceOf[LocalTime] + val timeType = dt.asInstanceOf[TimeType] + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(timeType.precision)) + } + + override def getScalaConverter: proto.Expression.Literal => Any = { + v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) + } + + override def buildProtoDataType( + literal: proto.Expression.Literal, + builder: proto.DataType.Builder): Unit = { + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + builder.setTime(timeBuilder.build()) + } + + // ==================== ConnectArrowTypeOps ==================== + + override def createArrowSerializer(vector: Any): Any = { + val v = vector.asInstanceOf[TimeNanoVector] + new ArrowSerializer.Serializer { + override def write(index: Int, value: Any): Unit = { + if (value != null) { + v.setSafe(index, SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime])) + } else { + v.setNull(index) + } + } + } + } + + override def createArrowDeserializer( + enc: AgnosticEncoder[_], + vector: Any, + timeZoneId: String): Any = { + val v = vector.asInstanceOf[FieldVector] + new ArrowDeserializers.LeafFieldDeserializer[LocalTime](enc, v, timeZoneId) { + override def value(i: Int): LocalTime = reader.getLocalTime(i) + } + } + + override def createArrowVectorReader(vector: Any): Any = { + new TimeVectorReader(vector.asInstanceOf[TimeNanoVector]) + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index ceccf780f586..a8b6cf721dd1 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.common.types.ops.ProtoTypeOps import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils @@ -53,6 +54,8 @@ object DataTypeProtoConverter { case proto.DataType.KindCase.DATE => DateType case proto.DataType.KindCase.TIMESTAMP => TimestampType case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType + case kindCase if ProtoTypeOps.toCatalystType(t).isDefined => + ProtoTypeOps.toCatalystType(t).get case proto.DataType.KindCase.TIME => if (t.getTime.hasPrecision) { TimeType(t.getTime.getPrecision) @@ -232,6 +235,8 @@ object DataTypeProtoConverter { case TimestampNTZType => ProtoDataTypes.TimestampNTZType + case dt if ProtoTypeOps(dt).isDefined => + ProtoTypeOps(dt).get.toConnectProtoType case TimeType(precision) => proto.DataType .newBuilder() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 026f5441c6ca..19c54e610738 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ +import org.apache.spark.sql.connect.common.types.ops.ProtoTypeOps import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -109,6 +110,8 @@ object LiteralValueProtoConverter { case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) + case v: LocalTime if ProtoTypeOps(TimeType()).isDefined => + ProtoTypeOps(TimeType()).get.toLiteralProto(v, builder) case v: LocalTime => builder.setTime( builder.getTimeBuilder @@ -216,6 +219,8 @@ object LiteralValueProtoConverter { builder.setMap(mapBuilder(v, keyType, valueType)) case (v, structType: StructType) => builder.setStruct(structBuilder(v, structType)) + case (v, dt) if ProtoTypeOps(dt).isDefined => + ProtoTypeOps(dt).get.toLiteralProtoWithType(v, dt, builder) case (v: LocalTime, timeType: TimeType) => builder.setTime( builder.getTimeBuilder @@ -410,6 +415,8 @@ object LiteralValueProtoConverter { v => SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => v => SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) + case kindCase if ProtoTypeOps.getScalaConverterForKind(kindCase).isDefined => + ProtoTypeOps.getScalaConverterForKind(kindCase).get case proto.DataType.KindCase.TIME => v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) @@ -548,6 +555,8 @@ object LiteralValueProtoConverter { builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) + case litCase if ProtoTypeOps.buildProtoDataTypeForLiteral(literal, builder) => + // Framework handled case proto.Expression.Literal.LiteralTypeCase.TIME => val timeBuilder = proto.DataType.Time.newBuilder() if (literal.getTime.hasPrecision) { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala new file mode 100644 index 000000000000..1c1aff44251e --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.common.types.ops + +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder +import org.apache.spark.sql.connect.client.arrow.TimeTypeConnectOps +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional type operations for Spark Connect Arrow serialization/deserialization. + * + * Handles Arrow-based data exchange on the Connect client side for framework-managed types. + * Used by ArrowSerializer, ArrowDeserializer, and ArrowVectorReader. + * + * NOTE: No feature flag check -the Connect client must handle whatever types the server sends. + * The feature flag controls server-side engine behavior; the client always needs to handle + * types that exist in the encoder. + * + * Methods return Any to avoid referencing arrow-private types from the types.ops package. + * Call sites in the arrow package cast to the expected types. + * + * @since 4.2.0 + */ +trait ConnectArrowTypeOps extends Serializable { + + def encoder: AgnosticEncoder[_] + + /** Creates an Arrow serializer for writing values to a vector. Returns a Serializer. */ + def createArrowSerializer(vector: Any): Any + + /** Creates an Arrow deserializer for reading values from a vector. Returns a Deserializer. */ + def createArrowDeserializer(enc: AgnosticEncoder[_], vector: Any, timeZoneId: String): Any + + /** Creates an ArrowVectorReader for this type's vector. Returns an ArrowVectorReader. */ + def createArrowVectorReader(vector: Any): Any +} + +/** + * Factory object for ConnectArrowTypeOps lookup. + * + * No feature flag check -the Connect client always handles registered types. + */ +object ConnectArrowTypeOps { + + /** Encoder-keyed dispatch (for ArrowSerializer, ArrowDeserializer). */ + def apply(enc: AgnosticEncoder[_]): Option[ConnectArrowTypeOps] = + enc match { + case LocalTimeEncoder => Some(new TimeTypeConnectOps(TimeType())) + case _ => None + } + + /** DataType-keyed dispatch (for ArrowVectorReader which doesn't have an encoder). */ + def apply(dt: DataType): Option[ConnectArrowTypeOps] = + dt match { + case tt: TimeType => Some(new TimeTypeConnectOps(tt)) + case _ => None + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala new file mode 100644 index 000000000000..47bfed7da182 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.common.types.ops + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.arrow.TimeTypeConnectOps +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional type operations for Spark Connect protobuf conversions. + * + * Handles bidirectional DataType <-> proto and Literal <-> proto conversions + * for framework-managed types in DataTypeProtoConverter and LiteralValueProtoConverter. + * + * @since 4.2.0 + */ +trait ProtoTypeOps extends Serializable { + + def dataType: DataType + + /** Converts this DataType to its Connect proto representation. */ + def toConnectProtoType: proto.DataType + + /** Converts a value to a proto literal builder (generic, no DataType context). */ + def toLiteralProto( + value: Any, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder + + /** Converts a value to a proto literal builder (with DataType context). */ + def toLiteralProtoWithType( + value: Any, + dt: DataType, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder + + /** Returns a converter from proto literal to Scala value. */ + def getScalaConverter: proto.Expression.Literal => Any + + /** Builds a proto DataType from a proto literal (for type inference). */ + def buildProtoDataType( + literal: proto.Expression.Literal, + builder: proto.DataType.Builder): Unit +} + +/** + * Factory object for ProtoTypeOps lookup. + */ +object ProtoTypeOps { + + def apply(dt: DataType): Option[ProtoTypeOps] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + dt match { + case tt: TimeType => Some(new TimeTypeConnectOps(tt)) + case _ => None + } + } + + /** + * Reverse lookup: converts a proto DataType to a Spark DataType, if it belongs + * to a framework-managed type. + */ + def toCatalystType(t: proto.DataType): Option[DataType] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + t.getKindCase match { + case proto.DataType.KindCase.TIME => + val time = t.getTime + if (time.hasPrecision) Some(TimeType(time.getPrecision)) + else Some(TimeType()) + case _ => None + } + } + + /** + * Reverse lookup: returns a Scala converter for a proto literal KindCase. + */ + def getScalaConverterForKind( + kindCase: proto.DataType.KindCase): Option[proto.Expression.Literal => Any] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + kindCase match { + case proto.DataType.KindCase.TIME => + Some(new TimeTypeConnectOps(TimeType()).getScalaConverter) + case _ => None + } + } + + /** + * Reverse lookup: builds a proto DataType from a proto literal's type case. + */ + def buildProtoDataTypeForLiteral( + literal: proto.Expression.Literal, + builder: proto.DataType.Builder): Boolean = { + if (!SqlApiConf.get.typesFrameworkEnabled) return false + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.TIME => + new TimeTypeConnectOps(TimeType()).buildProtoDataType(literal, builder) + true + case _ => false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 8a666bbb9dad..97c61820f2e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTab import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ @@ -115,6 +116,8 @@ object HiveResult extends SQLConfHelper { case (b, BooleanType) => b.toString case (d: Date, DateType) => formatters.date.format(d) case (ld: LocalDate, DateType) => formatters.date.format(ld) + case (value, dt) if ClientTypeOps(dt).isDefined => + ClientTypeOps(dt).get.formatExternal(value) case (lt: LocalTime, _: TimeType) => formatters.time.format(lt) case (t: Timestamp, TimestampType) => formatters.timestamp.format(t) case (i: Instant, TimestampType) => formatters.timestamp.format(i) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 28882a45b7f8..dccdd521930b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -166,6 +167,8 @@ object JdbcUtils extends Logging with SQLConfHelper { case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) case t: DecimalType => Option( JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case dt if ClientTypeOps(dt).isDefined => + Option(JdbcType(ClientTypeOps(dt).get.jdbcTypeName, ClientTypeOps(dt).get.getJdbcType)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 33622ca7349a..c6837ce2db27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal} object EvaluatePython { @@ -42,6 +43,7 @@ object EvaluatePython { private[python] class BytesWrapper(val data: Array[Byte]) def needConversionInPython(dt: DataType): Boolean = dt match { + case dt if ClientTypeOps(dt).exists(_.needConversionInPython) => true case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType | _: TimeType | _: GeometryType | _: GeographyType => true case _: StructType => true @@ -162,6 +164,8 @@ object EvaluatePython { case c: Int => c } + case dt if ClientTypeOps(dt).isDefined => ClientTypeOps(dt).get.makeFromJava + case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => (obj: Any) => nullSafeConvert(obj) { case c: Long => c diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index cd2bdefcc306..13957351aa7a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( @@ -342,6 +343,8 @@ object SparkExecuteStatementOperation { case _: StringType => TTypeId.STRING_TYPE case _: DecimalType => TTypeId.DECIMAL_TYPE case DateType => TTypeId.DATE_TYPE + case dt if ClientTypeOps(dt).isDefined => + TTypeId.valueOf(ClientTypeOps(dt).get.thriftTypeName) case _: TimeType => TTypeId.STRING_TYPE // TODO: Shall use TIMESTAMPLOCALTZ_TYPE, keep AS-IS now for // unnecessary behavior change From 6a1776932cf1d95b7acad9ce9758d044af4f633d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 14:56:55 +0000 Subject: [PATCH 2/8] minor refactors + scalafmt --- .../spark/sql/types/ops/ClientTypeOps.scala | 46 +- .../spark/sql/types/ops/TimeTypeApiOps.scala | 16 +- .../apache/spark/sql/util/ArrowUtils.scala | 66 +-- .../catalyst/DeserializerBuildHelper.scala | 10 +- .../sql/catalyst/SerializerBuildHelper.scala | 9 +- .../catalyst/types/ops/CatalystTypeOps.scala | 2 + .../sql/execution/arrow/ArrowWriter.scala | 116 ++-- .../client/arrow/ArrowDeserializer.scala | 3 +- .../client/arrow/ArrowVectorReader.scala | 3 +- .../{ => types/ops}/TimeTypeConnectOps.scala | 10 +- .../common/DataTypeProtoConverter.scala | 507 +++++++++--------- .../common/LiteralValueProtoConverter.scala | 299 ++++++----- .../types/ops/ConnectArrowTypeOps.scala | 12 +- .../common/types/ops/ProtoTypeOps.scala | 12 +- .../spark/sql/execution/HiveResult.scala | 13 +- .../datasources/jdbc/JdbcUtils.scala | 50 +- .../sql/execution/python/EvaluatePython.scala | 38 +- .../SparkExecuteStatementOperation.scala | 8 +- 18 files changed, 635 insertions(+), 585 deletions(-) rename sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/{ => types/ops}/TimeTypeConnectOps.scala (94%) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala index ca273680d8be..5efe7fd33bdf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.types.{DataType, TimeType} /** * Optional client-side type operations for the Types Framework. * - * This trait extends TypeApiOps with operations needed by client-facing infrastructure: - * Arrow conversion (ArrowUtils), JDBC mapping (JdbcUtils), Python interop (EvaluatePython), - * Hive formatting (HiveResult), and Thrift type mapping (SparkExecuteStatementOperation). + * This trait extends TypeApiOps with operations needed by client-facing infrastructure: Arrow + * conversion (ArrowUtils), JDBC mapping (JdbcUtils), Python interop (EvaluatePython), Hive + * formatting (HiveResult), and Thrift type mapping (SparkExecuteStatementOperation). * * Lives in sql/api so it's visible from sql/core and sql/hive-thriftserver. * @@ -46,6 +46,20 @@ import org.apache.spark.sql.types.{DataType, TimeType} */ trait ClientTypeOps { self: TypeApiOps => + // ==================== Utilities ==================== + + /** + * Null-safe conversion helper. Returns null for null input, applies the partial function for + * non-null input, and returns null for unmatched values. + */ + protected def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, (_: Any) => null) + } + } + // ==================== Arrow Conversion ==================== /** @@ -94,8 +108,8 @@ trait ClientTypeOps { self: TypeApiOps => /** * Creates a converter function for Python/Py4J interop. * - * Used by EvaluatePython.makeFromJava. The returned function handles null-safe - * conversion of Java/Py4J values to the internal Catalyst representation. + * Used by EvaluatePython.makeFromJava. The returned function handles null-safe conversion of + * Java/Py4J values to the internal Catalyst representation. * * @return * a function that converts a Java value to the internal representation @@ -107,8 +121,8 @@ trait ClientTypeOps { self: TypeApiOps => /** * Formats an external-type value for Hive output. * - * Used by HiveResult.toHiveString. The input is an external-type value - * (e.g., java.time.LocalTime for TimeType), NOT the internal representation. + * Used by HiveResult.toHiveString. The input is an external-type value (e.g., + * java.time.LocalTime for TimeType), NOT the internal representation. * * @param value * the external-type value to format @@ -122,9 +136,9 @@ trait ClientTypeOps { self: TypeApiOps => /** * Returns the Thrift TTypeId name for this type. * - * Used by SparkExecuteStatementOperation.toTTypeId. Returns a String that maps - * to a TTypeId enum value (e.g., "STRING_TYPE") since TTypeId is only available - * in the hive-thriftserver module. + * Used by SparkExecuteStatementOperation.toTTypeId. Returns a String that maps to a TTypeId + * enum value (e.g., "STRING_TYPE") since TTypeId is only available in the hive-thriftserver + * module. * * @return * TTypeId enum name (e.g., "STRING_TYPE") @@ -148,6 +162,9 @@ object ClientTypeOps { * @return * Some(ClientTypeOps) if supported, None otherwise */ + // Delegates to TypeApiOps and narrows: a type must implement TypeApiOps AND mix in + // ClientTypeOps to be found here. No separate registration needed — the collect + // filter handles incremental trait adoption automatically. def apply(dt: DataType): Option[ClientTypeOps] = TypeApiOps(dt).collect { case co: ClientTypeOps => co } @@ -155,8 +172,8 @@ object ClientTypeOps { * Reverse lookup: converts an Arrow type to a Spark DataType, if it belongs to a * framework-managed type. * - * Used by ArrowUtils.fromArrowType. Returns None if the Arrow type doesn't correspond - * to any framework-managed type, or the framework is disabled. + * Used by ArrowUtils.fromArrowType. Returns None if the Arrow type doesn't correspond to any + * framework-managed type, or the framework is disabled. * * @param at * the ArrowType to convert @@ -167,9 +184,8 @@ object ClientTypeOps { import org.apache.arrow.vector.types.TimeUnit if (!SqlApiConf.get.typesFrameworkEnabled) return None at match { - case t: ArrowType.Time - if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => - Some(TimeType(TimeType.MICROS_PRECISION)) + case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => + Some(TimeType(TimeType.MICROS_PRECISION)) // Add new framework types here case _ => None } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala index bebbbbf89613..1ddceabd61b2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala @@ -81,18 +81,12 @@ class TimeTypeApiOps(val t: TimeType) extends TypeApiOps with ClientTypeOps { override def needConversionInPython: Boolean = true - override def makeFromJava: Any => Any = (obj: Any) => { - if (obj == null) { - null - } else { - obj match { - case c: Long => c - // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs - case c: Int => c.toLong - case _ => null - } + override def makeFromJava: Any => Any = (obj: Any) => + nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong } - } override def formatExternal(value: Any): String = { val nanos = SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime]) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 8e1f1347a7b0..0d8a78ac6b30 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -40,36 +40,43 @@ private[sql] object ArrowUtils { /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = - dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE - case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE - case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 8 * 16) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw SparkException.internalError("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case dt if ClientTypeOps(dt).isDefined => ClientTypeOps(dt).get.toArrowType(timeZoneId) - case _: TimeType => new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO) - case _ => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } + ClientTypeOps(dt).map(_.toArrowType(timeZoneId)) + .getOrElse(toArrowTypeDefault(dt, timeZoneId, largeVarTypes)) + + private def toArrowTypeDefault( + dt: DataType, timeZoneId: String, largeVarTypes: Boolean): ArrowType = dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + case DecimalType.Fixed(precision, scale) => + new ArrowType.Decimal(precision, scale, 8 * 16) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw SparkException.internalError("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case _: TimeType => new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO) + case _ => + throw ExecutionErrors.unsupportedDataTypeError(dt) + } + + def fromArrowType(dt: ArrowType): DataType = + ClientTypeOps.fromArrowType(dt).getOrElse(fromArrowTypeDefault(dt)) - def fromArrowType(dt: ArrowType): DataType = dt match { + private def fromArrowTypeDefault(dt: ArrowType): DataType = dt match { case ArrowType.Bool.INSTANCE => BooleanType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType @@ -91,7 +98,6 @@ private[sql] object ArrowUtils { if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType - case at if ClientTypeOps.fromArrowType(at).isDefined => ClientTypeOps.fromArrowType(at).get case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => TimeType(TimeType.MICROS_PRECISION) case ArrowType.Null.INSTANCE => NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 62db0112148a..875898952691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -290,6 +290,14 @@ object DeserializerBuildHelper { * @param isTopLevel true if we are creating a deserializer for the top level value. */ private def createDeserializer( + enc: AgnosticEncoder[_], + path: Expression, + walkedTypePath: WalkedTypePath, + isTopLevel: Boolean = false): Expression = + CatalystTypeOps(enc.dataType).map(_.createDeserializer(path)) + .getOrElse(createDeserializerDefault(enc, path, walkedTypePath, isTopLevel)) + + private def createDeserializerDefault( enc: AgnosticEncoder[_], path: Expression, walkedTypePath: WalkedTypePath, @@ -346,8 +354,6 @@ object DeserializerBuildHelper { createDeserializerForInstant(path) case LocalDateTimeEncoder => createDeserializerForLocalDateTime(path) - case enc if CatalystTypeOps(enc.dataType).isDefined => - CatalystTypeOps(enc.dataType).get.createDeserializer(path) case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled => throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError() case LocalTimeEncoder => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 5539a0bb5719..d1e263d98276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -333,7 +333,12 @@ object SerializerBuildHelper { * representation. The mapping between the external and internal representations is described * by encoder `enc`. */ - private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { + private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = + CatalystTypeOps(enc.dataType).map(_.createSerializer(input)) + .getOrElse(createSerializerDefault(enc, input)) + + private def createSerializerDefault( + enc: AgnosticEncoder[_], input: Expression): Expression = enc match { case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input) case _ if isNativeEncoder(enc) => input case BoxedBooleanEncoder => createSerializerForBoolean(input) @@ -368,8 +373,6 @@ object SerializerBuildHelper { case TimestampEncoder(false) => createSerializerForSqlTimestamp(input) case InstantEncoder(false) => createSerializerForJavaInstant(input) case LocalDateTimeEncoder => createSerializerForLocalDateTime(input) - case enc if CatalystTypeOps(enc.dataType).isDefined => - CatalystTypeOps(enc.dataType).get.createSerializer(input) case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled => throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError() case LocalTimeEncoder => createSerializerForLocalTime(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala index 651a778d7676..e366436fea39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala @@ -101,6 +101,8 @@ object CatalystTypeOps { * @return * Some(CatalystTypeOps) if supported, None otherwise */ + // Delegates to TypeOps and narrows: a type must implement TypeOps AND mix in + // CatalystTypeOps to be found here. def apply(dt: DataType): Option[CatalystTypeOps] = TypeOps(dt).collect { case co: CatalystTypeOps => co } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index e69400e93b65..84ac1d311658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,62 +53,66 @@ object ArrowWriter { private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() - (ArrowUtils.fromArrowField(field), vector) match { - case (BooleanType, vector: BitVector) => new BooleanWriter(vector) - case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) - case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) - case (IntegerType, vector: IntVector) => new IntegerWriter(vector) - case (LongType, vector: BigIntVector) => new LongWriter(vector) - case (FloatType, vector: Float4Vector) => new FloatWriter(vector) - case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) - case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => - new DecimalWriter(vector, precision, scale) - case (StringType, vector: VarCharVector) => new StringWriter(vector) - case (StringType, vector: LargeVarCharVector) => new LargeStringWriter(vector) - case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) - case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) - case (DateType, vector: DateDayVector) => new DateWriter(vector) - case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) - case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) - case (dt, vector) if CatalystTypeOps(dt).isDefined => - CatalystTypeOps(dt).get.createArrowFieldWriter(vector) - case (_: TimeType, vector: TimeNanoVector) => new TimeWriter(vector) - case (ArrayType(_, _), vector: ListVector) => - val elementVector = createFieldWriter(vector.getDataVector()) - new ArrayWriter(vector, elementVector) - case (MapType(_, _, _), vector: MapVector) => - val structVector = vector.getDataVector.asInstanceOf[StructVector] - val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) - val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) - new MapWriter(vector, structVector, keyWriter, valueWriter) - case (StructType(_), vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new StructWriter(vector, children.toArray) - case (NullType, vector: NullVector) => new NullWriter(vector) - case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) - case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) - case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => - new IntervalMonthDayNanoWriter(vector) - case (VariantType, vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new StructWriter(vector, children.toArray) - case (dt: GeometryType, vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new GeometryWriter(dt, vector, children.toArray) - case (dt: GeographyType, vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new GeographyWriter(dt, vector, children.toArray) - case (dt, _) => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } + val dt = ArrowUtils.fromArrowField(field) + CatalystTypeOps(dt).map(_.createArrowFieldWriter(vector)) + .getOrElse(createFieldWriterDefault(dt, vector)) + } + + private[sql] def createFieldWriterDefault( + dt: DataType, vector: ValueVector): ArrowFieldWriter = (dt, vector) match { + case (BooleanType, vector: BitVector) => new BooleanWriter(vector) + case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: IntVector) => new IntegerWriter(vector) + case (LongType, vector: BigIntVector) => new LongWriter(vector) + case (FloatType, vector: Float4Vector) => new FloatWriter(vector) + case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) + case (StringType, vector: VarCharVector) => new StringWriter(vector) + case (StringType, vector: LargeVarCharVector) => new LargeStringWriter(vector) + case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) + case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) + case (DateType, vector: DateDayVector) => new DateWriter(vector) + case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) + case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) + case (_: TimeType, vector: TimeNanoVector) => new TimeWriter(vector) + case (ArrayType(_, _), vector: ListVector) => + val elementVector = createFieldWriter(vector.getDataVector()) + new ArrayWriter(vector, elementVector) + case (MapType(_, _, _), vector: MapVector) => + val structVector = vector.getDataVector.asInstanceOf[StructVector] + val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) + val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) + new MapWriter(vector, structVector, keyWriter, valueWriter) + case (StructType(_), vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (NullType, vector: NullVector) => new NullWriter(vector) + case (_: YearMonthIntervalType, vector: IntervalYearVector) => + new IntervalYearWriter(vector) + case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) + case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => + new IntervalMonthDayNanoWriter(vector) + case (VariantType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (dt: GeometryType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new GeometryWriter(dt, vector, children.toArray) + case (dt: GeographyType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new GeographyWriter(dt, vector, children.toArray) + case (dt, _) => + throw ExecutionErrors.unsupportedDataTypeError(dt) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index a8da26c49bd4..93d01c2ce67f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -203,7 +203,8 @@ object ArrowDeserializers { } case (enc, v: FieldVector) if ConnectArrowTypeOps(enc).isDefined => ConnectArrowTypeOps(enc).get - .createArrowDeserializer(enc, v, timeZoneId).asInstanceOf[Deserializer[Any]] + .createArrowDeserializer(enc, v, timeZoneId) + .asInstanceOf[Deserializer[Any]] case (LocalTimeEncoder, v: FieldVector) => new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) { override def value(i: Int): LocalTime = reader.getLocalTime(i) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index 46f5a939ff7e..a1eb39573aad 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -95,7 +95,8 @@ object ArrowVectorReader { case v: TimeStampMicroVector => new TimeStampMicroVectorReader(v, timeZoneId) case v if ConnectArrowTypeOps(targetDataType).isDefined => ConnectArrowTypeOps(targetDataType).get - .createArrowVectorReader(v).asInstanceOf[ArrowVectorReader] + .createArrowVectorReader(v) + .asInstanceOf[ArrowVectorReader] case v: TimeNanoVector => new TimeVectorReader(v) case _: NullVector => NullVectorReader case _ => throw new RuntimeException("Unsupported Vector Type: " + vector.getClass) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala similarity index 94% rename from sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala rename to sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala index 2ccb37bb2b50..7168d2620bbb 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/TimeTypeConnectOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala @@ -35,14 +35,16 @@ import org.apache.spark.sql.types.{DataType, TimeType} * (Arrow serialization/deserialization) in a single class. * * Lives in the arrow package to access arrow-private types (ArrowVectorReader, TimeVectorReader, - * ArrowSerializer.Serializer, ArrowDeserializers.LeafFieldDeserializer). + * ArrowSerializer.Serializer, ArrowDeserializers.LeafFieldDeserializer). Placed under types/ops + * subdirectory to separate ops implementations from core arrow infrastructure. * * @param t * The TimeType with precision information * @since 4.2.0 */ private[connect] class TimeTypeConnectOps(val t: TimeType) - extends ProtoTypeOps with ConnectArrowTypeOps { + extends ProtoTypeOps + with ConnectArrowTypeOps { override def dataType: DataType = t @@ -79,8 +81,8 @@ private[connect] class TimeTypeConnectOps(val t: TimeType) .setPrecision(timeType.precision)) } - override def getScalaConverter: proto.Expression.Literal => Any = { - v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) + override def getScalaConverter: proto.Expression.Literal => Any = { v => + SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) } override def buildProtoDataType( diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index a8b6cf721dd1..24f740e9925d 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -30,72 +30,71 @@ import org.apache.spark.util.SparkClassUtils * Helper class for conversions between [[DataType]] and [[proto.DataType]]. */ object DataTypeProtoConverter { - def toCatalystType(t: proto.DataType): DataType = { - t.getKindCase match { - case proto.DataType.KindCase.NULL => NullType - - case proto.DataType.KindCase.BINARY => BinaryType - - case proto.DataType.KindCase.BOOLEAN => BooleanType - - case proto.DataType.KindCase.BYTE => ByteType - case proto.DataType.KindCase.SHORT => ShortType - case proto.DataType.KindCase.INTEGER => IntegerType - case proto.DataType.KindCase.LONG => LongType - - case proto.DataType.KindCase.FLOAT => FloatType - case proto.DataType.KindCase.DOUBLE => DoubleType - case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal) - - case proto.DataType.KindCase.STRING => toCatalystStringType(t.getString) - case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength) - case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength) - - case proto.DataType.KindCase.DATE => DateType - case proto.DataType.KindCase.TIMESTAMP => TimestampType - case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType - case kindCase if ProtoTypeOps.toCatalystType(t).isDefined => - ProtoTypeOps.toCatalystType(t).get - case proto.DataType.KindCase.TIME => - if (t.getTime.hasPrecision) { - TimeType(t.getTime.getPrecision) - } else { - TimeType() - } - - case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType - case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => - toCatalystYearMonthIntervalType(t.getYearMonthInterval) - case proto.DataType.KindCase.DAY_TIME_INTERVAL => - toCatalystDayTimeIntervalType(t.getDayTimeInterval) - - case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray) - case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct) - case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap) - case proto.DataType.KindCase.VARIANT => VariantType - - case proto.DataType.KindCase.GEOMETRY => - val srid = t.getGeometry.getSrid - if (srid == GeometryType.MIXED_SRID) { - GeometryType("ANY") - } else { - GeometryType(srid) - } - case proto.DataType.KindCase.GEOGRAPHY => - val srid = t.getGeography.getSrid - if (srid == GeographyType.MIXED_SRID) { - GeographyType("ANY") - } else { - GeographyType(srid) - } - - case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt) - - case _ => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_PROTO_TO_CATALYST", - Map("kindCase" -> t.getKindCase.toString)) - } + def toCatalystType(t: proto.DataType): DataType = + ProtoTypeOps.toCatalystType(t).getOrElse(toCatalystTypeDefault(t)) + + private def toCatalystTypeDefault(t: proto.DataType): DataType = t.getKindCase match { + case proto.DataType.KindCase.NULL => NullType + + case proto.DataType.KindCase.BINARY => BinaryType + + case proto.DataType.KindCase.BOOLEAN => BooleanType + + case proto.DataType.KindCase.BYTE => ByteType + case proto.DataType.KindCase.SHORT => ShortType + case proto.DataType.KindCase.INTEGER => IntegerType + case proto.DataType.KindCase.LONG => LongType + + case proto.DataType.KindCase.FLOAT => FloatType + case proto.DataType.KindCase.DOUBLE => DoubleType + case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal) + + case proto.DataType.KindCase.STRING => toCatalystStringType(t.getString) + case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength) + case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength) + + case proto.DataType.KindCase.DATE => DateType + case proto.DataType.KindCase.TIMESTAMP => TimestampType + case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType + case proto.DataType.KindCase.TIME => + if (t.getTime.hasPrecision) { + TimeType(t.getTime.getPrecision) + } else { + TimeType() + } + + case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType + case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => + toCatalystYearMonthIntervalType(t.getYearMonthInterval) + case proto.DataType.KindCase.DAY_TIME_INTERVAL => + toCatalystDayTimeIntervalType(t.getDayTimeInterval) + + case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray) + case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct) + case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap) + case proto.DataType.KindCase.VARIANT => VariantType + + case proto.DataType.KindCase.GEOMETRY => + val srid = t.getGeometry.getSrid + if (srid == GeometryType.MIXED_SRID) { + GeometryType("ANY") + } else { + GeometryType(srid) + } + case proto.DataType.KindCase.GEOGRAPHY => + val srid = t.getGeography.getSrid + if (srid == GeographyType.MIXED_SRID) { + GeographyType("ANY") + } else { + GeographyType(srid) + } + + case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt) + + case _ => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_PROTO_TO_CATALYST", + Map("kindCase" -> t.getKindCase.toString)) } private def toCatalystDecimalType(t: proto.DataType.Decimal): DecimalType = { @@ -177,227 +176,229 @@ object DataTypeProtoConverter { toConnectProtoTypeInternal(t, bytesToBinary) } - private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = { - t match { - case NullType => ProtoDataTypes.NullType + private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = + ProtoTypeOps(t).map(_.toConnectProtoType) + .getOrElse(toConnectProtoTypeDefault(t, bytesToBinary)) - case BooleanType => ProtoDataTypes.BooleanType + private def toConnectProtoTypeDefault( + t: DataType, bytesToBinary: Boolean): proto.DataType = t match { + case NullType => ProtoDataTypes.NullType - case BinaryType => ProtoDataTypes.BinaryType + case BooleanType => ProtoDataTypes.BooleanType - case ByteType => ProtoDataTypes.ByteType + case BinaryType => ProtoDataTypes.BinaryType - case ShortType => ProtoDataTypes.ShortType + case ByteType => ProtoDataTypes.ByteType - case IntegerType => ProtoDataTypes.IntegerType + case ShortType => ProtoDataTypes.ShortType - case LongType => ProtoDataTypes.LongType + case IntegerType => ProtoDataTypes.IntegerType - case FloatType => ProtoDataTypes.FloatType + case LongType => ProtoDataTypes.LongType - case DoubleType => ProtoDataTypes.DoubleType + case FloatType => ProtoDataTypes.FloatType - case DecimalType.Fixed(precision, scale) => - proto.DataType - .newBuilder() - .setDecimal( - proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build()) - .build() + case DoubleType => ProtoDataTypes.DoubleType - case c: CharType => - proto.DataType - .newBuilder() - .setChar(proto.DataType.Char.newBuilder().setLength(c.length).build()) - .build() + case DecimalType.Fixed(precision, scale) => + proto.DataType + .newBuilder() + .setDecimal( + proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build()) + .build() - case v: VarcharType => - proto.DataType - .newBuilder() - .setVarChar(proto.DataType.VarChar.newBuilder().setLength(v.length).build()) - .build() + case c: CharType => + proto.DataType + .newBuilder() + .setChar(proto.DataType.Char.newBuilder().setLength(c.length).build()) + .build() - // StringType must be matched after CharType and VarcharType - case s: StringType => - val stringBuilder = proto.DataType.String.newBuilder() - // Send collation only for explicit collations (including explicit UTF8_BINARY). - // Default STRING (case object) has no explicit collation and should omit it. - if (!s.eq(StringType)) { - stringBuilder.setCollation(CollationFactory.fetchCollation(s.collationId).collationName) - } - proto.DataType - .newBuilder() - .setString(stringBuilder.build()) - .build() + case v: VarcharType => + proto.DataType + .newBuilder() + .setVarChar(proto.DataType.VarChar.newBuilder().setLength(v.length).build()) + .build() - case DateType => ProtoDataTypes.DateType + // StringType must be matched after CharType and VarcharType + case s: StringType => + val stringBuilder = proto.DataType.String.newBuilder() + // Send collation only for explicit collations (including explicit UTF8_BINARY). + // Default STRING (case object) has no explicit collation and should omit it. + if (!s.eq(StringType)) { + stringBuilder.setCollation( + CollationFactory.fetchCollation(s.collationId).collationName) + } + proto.DataType + .newBuilder() + .setString(stringBuilder.build()) + .build() - case TimestampType => ProtoDataTypes.TimestampType + case DateType => ProtoDataTypes.DateType - case TimestampNTZType => ProtoDataTypes.TimestampNTZType + case TimestampType => ProtoDataTypes.TimestampType - case dt if ProtoTypeOps(dt).isDefined => - ProtoTypeOps(dt).get.toConnectProtoType - case TimeType(precision) => - proto.DataType - .newBuilder() - .setTime(proto.DataType.Time.newBuilder().setPrecision(precision).build()) - .build() + case TimestampNTZType => ProtoDataTypes.TimestampNTZType - case CalendarIntervalType => ProtoDataTypes.CalendarIntervalType + case TimeType(precision) => + proto.DataType + .newBuilder() + .setTime(proto.DataType.Time.newBuilder().setPrecision(precision).build()) + .build() - case YearMonthIntervalType(startField, endField) => - proto.DataType - .newBuilder() - .setYearMonthInterval( - proto.DataType.YearMonthInterval - .newBuilder() - .setStartField(startField) - .setEndField(endField) - .build()) - .build() + case CalendarIntervalType => ProtoDataTypes.CalendarIntervalType - case DayTimeIntervalType(startField, endField) => - proto.DataType - .newBuilder() - .setDayTimeInterval( - proto.DataType.DayTimeInterval - .newBuilder() - .setStartField(startField) - .setEndField(endField) - .build()) - .build() - - case ArrayType(elementType: DataType, containsNull: Boolean) => - if (elementType == ByteType && bytesToBinary) { - proto.DataType + case YearMonthIntervalType(startField, endField) => + proto.DataType + .newBuilder() + .setYearMonthInterval( + proto.DataType.YearMonthInterval .newBuilder() - .setBinary(proto.DataType.Binary.newBuilder().build()) - .build() - } else { - proto.DataType + .setStartField(startField) + .setEndField(endField) + .build()) + .build() + + case DayTimeIntervalType(startField, endField) => + proto.DataType + .newBuilder() + .setDayTimeInterval( + proto.DataType.DayTimeInterval .newBuilder() - .setArray( - proto.DataType.Array - .newBuilder() - .setElementType(toConnectProtoTypeInternal(elementType, bytesToBinary)) - .setContainsNull(containsNull) - .build()) - .build() - } - - case StructType(fields: Array[StructField]) => - val protoFields = fields.toImmutableArraySeq.map { - case StructField( - name: String, - dataType: DataType, - nullable: Boolean, - metadata: Metadata) => - if (metadata.equals(Metadata.empty)) { - proto.DataType.StructField - .newBuilder() - .setName(name) - .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) - .setNullable(nullable) - .build() - } else { - proto.DataType.StructField - .newBuilder() - .setName(name) - .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) - .setNullable(nullable) - .setMetadata(metadata.json) - .build() - } - } - proto.DataType - .newBuilder() - .setStruct( - proto.DataType.Struct - .newBuilder() - .addAllFields(protoFields.asJava) - .build()) - .build() + .setStartField(startField) + .setEndField(endField) + .build()) + .build() - case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => + case ArrayType(elementType: DataType, containsNull: Boolean) => + if (elementType == ByteType && bytesToBinary) { proto.DataType .newBuilder() - .setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(toConnectProtoTypeInternal(keyType, bytesToBinary)) - .setValueType(toConnectProtoTypeInternal(valueType, bytesToBinary)) - .setValueContainsNull(valueContainsNull) - .build()) - .build() - - case g: GeographyType => - proto.DataType - .newBuilder() - .setGeography( - proto.DataType.Geography - .newBuilder() - .setSrid(g.srid) - .build()) + .setBinary(proto.DataType.Binary.newBuilder().build()) .build() - - case g: GeometryType => + } else { proto.DataType .newBuilder() - .setGeometry( - proto.DataType.Geometry + .setArray( + proto.DataType.Array .newBuilder() - .setSrid(g.srid) + .setElementType(toConnectProtoTypeInternal(elementType, bytesToBinary)) + .setContainsNull(containsNull) .build()) .build() + } - case VariantType => ProtoDataTypes.VariantType - - case pyudt: PythonUserDefinedType => - // Python UDT - proto.DataType - .newBuilder() - .setUdt( - proto.DataType.UDT + case StructType(fields: Array[StructField]) => + val protoFields = fields.toImmutableArraySeq.map { + case StructField( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata) => + if (metadata.equals(Metadata.empty)) { + proto.DataType.StructField .newBuilder() - .setType("udt") - .setPythonClass(pyudt.pyUDT) - .setSqlType(toConnectProtoTypeInternal(pyudt.sqlType, bytesToBinary)) - .setSerializedPythonClass(pyudt.serializedPyClass) - .build()) - .build() - - case udt: UserDefinedType[_] => - // Scala/Java UDT - udt.getClass.getName match { - // To avoid making connect-common depend on ml, - // we use class name to identify VectorUDT and MatrixUDT. - case "org.apache.spark.ml.linalg.VectorUDT" => - ProtoDataTypes.VectorUDT - - case "org.apache.spark.ml.linalg.MatrixUDT" => - ProtoDataTypes.MatrixUDT - - case className => - val builder = proto.DataType.UDT.newBuilder() - builder - .setType("udt") - .setJvmClass(className) - .setSqlType(toConnectProtoTypeInternal(udt.sqlType, bytesToBinary)) - - if (udt.pyUDT != null) { - builder.setPythonClass(udt.pyUDT) - } - - proto.DataType + .setName(name) + .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) + .setNullable(nullable) + .build() + } else { + proto.DataType.StructField .newBuilder() - .setUdt(builder.build()) + .setName(name) + .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) + .setNullable(nullable) + .setMetadata(metadata.json) .build() - } + } + } + proto.DataType + .newBuilder() + .setStruct( + proto.DataType.Struct + .newBuilder() + .addAllFields(protoFields.asJava) + .build()) + .build() + + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => + proto.DataType + .newBuilder() + .setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(toConnectProtoTypeInternal(keyType, bytesToBinary)) + .setValueType(toConnectProtoTypeInternal(valueType, bytesToBinary)) + .setValueContainsNull(valueContainsNull) + .build()) + .build() + + case g: GeographyType => + proto.DataType + .newBuilder() + .setGeography( + proto.DataType.Geography + .newBuilder() + .setSrid(g.srid) + .build()) + .build() + + case g: GeometryType => + proto.DataType + .newBuilder() + .setGeometry( + proto.DataType.Geometry + .newBuilder() + .setSrid(g.srid) + .build()) + .build() + + case VariantType => ProtoDataTypes.VariantType + + case pyudt: PythonUserDefinedType => + // Python UDT + proto.DataType + .newBuilder() + .setUdt( + proto.DataType.UDT + .newBuilder() + .setType("udt") + .setPythonClass(pyudt.pyUDT) + .setSqlType(toConnectProtoTypeInternal(pyudt.sqlType, bytesToBinary)) + .setSerializedPythonClass(pyudt.serializedPyClass) + .build()) + .build() + + case udt: UserDefinedType[_] => + // Scala/Java UDT + udt.getClass.getName match { + // To avoid making connect-common depend on ml, + // we use class name to identify VectorUDT and MatrixUDT. + case "org.apache.spark.ml.linalg.VectorUDT" => + ProtoDataTypes.VectorUDT + + case "org.apache.spark.ml.linalg.MatrixUDT" => + ProtoDataTypes.MatrixUDT + + case className => + val builder = proto.DataType.UDT.newBuilder() + builder + .setType("udt") + .setJvmClass(className) + .setSqlType(toConnectProtoTypeInternal(udt.sqlType, bytesToBinary)) + + if (udt.pyUDT != null) { + builder.setPythonClass(udt.pyUDT) + } - case _ => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_CATALYST_TO_PROTO", - Map("typeName" -> t.typeName)) - } + proto.DataType + .newBuilder() + .setUdt(builder.build()) + .build() + } + + case _ => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_CATALYST_TO_PROTO", + Map("typeName" -> t.typeName)) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 19c54e610738..8ec2e0897da3 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -110,13 +110,13 @@ object LiteralValueProtoConverter { case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) - case v: LocalTime if ProtoTypeOps(TimeType()).isDefined => - ProtoTypeOps(TimeType()).get.toLiteralProto(v, builder) case v: LocalTime => - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(TimeType.DEFAULT_PRECISION)) + ProtoTypeOps(TimeType()).map(_.toLiteralProto(v, builder)).getOrElse { + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION)) + } case v: Array[_] => builder.setArray(arrayBuilder(v)) case v: CalendarInterval => builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) @@ -219,13 +219,13 @@ object LiteralValueProtoConverter { builder.setMap(mapBuilder(v, keyType, valueType)) case (v, structType: StructType) => builder.setStruct(structBuilder(v, structType)) - case (v, dt) if ProtoTypeOps(dt).isDefined => - ProtoTypeOps(dt).get.toLiteralProtoWithType(v, dt, builder) case (v: LocalTime, timeType: TimeType) => - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(timeType.precision)) + ProtoTypeOps(timeType).map(_.toLiteralProtoWithType(v, timeType, builder)).getOrElse { + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(timeType.precision)) + } case _ => toLiteralProtoBuilderInternal(literal, options) } @@ -390,51 +390,52 @@ object LiteralValueProtoConverter { } private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { - val converter: proto.Expression.Literal => Any = dataType.getKindCase match { - case proto.DataType.KindCase.NULL => - v => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.EXPECTED_NULL_VALUE", - Map("literalTypeCase" -> v.getLiteralTypeCase.toString)) - case proto.DataType.KindCase.SHORT => v => v.getShort.toShort - case proto.DataType.KindCase.INTEGER => v => v.getInteger - case proto.DataType.KindCase.LONG => v => v.getLong - case proto.DataType.KindCase.DOUBLE => v => v.getDouble - case proto.DataType.KindCase.BYTE => v => v.getByte.toByte - case proto.DataType.KindCase.FLOAT => v => v.getFloat - case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean - case proto.DataType.KindCase.STRING => v => v.getString - case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray - case proto.DataType.KindCase.DATE => - v => SparkDateTimeUtils.toJavaDate(v.getDate) - case proto.DataType.KindCase.TIMESTAMP => - v => SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) - case proto.DataType.KindCase.TIMESTAMP_NTZ => - v => SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) - case proto.DataType.KindCase.DAY_TIME_INTERVAL => - v => SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) - case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => - v => SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) - case kindCase if ProtoTypeOps.getScalaConverterForKind(kindCase).isDefined => - ProtoTypeOps.getScalaConverterForKind(kindCase).get - case proto.DataType.KindCase.TIME => - v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) - case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) - case proto.DataType.KindCase.CALENDAR_INTERVAL => - v => - val interval = v.getCalendarInterval - new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) - case proto.DataType.KindCase.ARRAY => - v => toScalaArrayInternal(v, dataType.getArray) - case proto.DataType.KindCase.MAP => - v => toScalaMapInternal(v, dataType.getMap) - case proto.DataType.KindCase.STRUCT => - v => toScalaStructInternal(v, dataType.getStruct) - case _ => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", - Map("typeInfo" -> dataType.getKindCase.toString)) - } + val converter: proto.Expression.Literal => Any = + ProtoTypeOps.getScalaConverterForKind(dataType.getKindCase).getOrElse { + dataType.getKindCase match { + case proto.DataType.KindCase.NULL => + v => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.EXPECTED_NULL_VALUE", + Map("literalTypeCase" -> v.getLiteralTypeCase.toString)) + case proto.DataType.KindCase.SHORT => v => v.getShort.toShort + case proto.DataType.KindCase.INTEGER => v => v.getInteger + case proto.DataType.KindCase.LONG => v => v.getLong + case proto.DataType.KindCase.DOUBLE => v => v.getDouble + case proto.DataType.KindCase.BYTE => v => v.getByte.toByte + case proto.DataType.KindCase.FLOAT => v => v.getFloat + case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean + case proto.DataType.KindCase.STRING => v => v.getString + case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray + case proto.DataType.KindCase.DATE => + v => SparkDateTimeUtils.toJavaDate(v.getDate) + case proto.DataType.KindCase.TIMESTAMP => + v => SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) + case proto.DataType.KindCase.TIMESTAMP_NTZ => + v => SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) + case proto.DataType.KindCase.DAY_TIME_INTERVAL => + v => SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) + case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => + v => SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) + case proto.DataType.KindCase.TIME => + v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) + case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) + case proto.DataType.KindCase.CALENDAR_INTERVAL => + v => + val interval = v.getCalendarInterval + new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) + case proto.DataType.KindCase.ARRAY => + v => toScalaArrayInternal(v, dataType.getArray) + case proto.DataType.KindCase.MAP => + v => toScalaMapInternal(v, dataType.getMap) + case proto.DataType.KindCase.STRUCT => + v => toScalaStructInternal(v, dataType.getStruct) + case _ => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", + Map("typeInfo" -> dataType.getKindCase.toString)) + } + } v => if (v.hasNull) null else converter(v) } @@ -508,101 +509,101 @@ object LiteralValueProtoConverter { literal.getNull } else { val builder = proto.DataType.newBuilder() - literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.BINARY => - builder.setBinary(proto.DataType.Binary.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BYTE => - builder.setByte(proto.DataType.Byte.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.SHORT => - builder.setShort(proto.DataType.Short.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - builder.setInteger(proto.DataType.Integer.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.LONG => - builder.setLong(proto.DataType.Long.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - builder.setFloat(proto.DataType.Float.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - builder.setDouble(proto.DataType.Double.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - val decimal = Decimal.apply(literal.getDecimal.getValue) - var precision = decimal.precision - if (literal.getDecimal.hasPrecision) { - precision = math.max(precision, literal.getDecimal.getPrecision) - } - var scale = decimal.scale - if (literal.getDecimal.hasScale) { - scale = math.max(scale, literal.getDecimal.getScale) - } - builder.setDecimal( - proto.DataType.Decimal - .newBuilder() - .setPrecision(math.max(precision, scale)) - .setScale(scale) - .build()) - case proto.Expression.Literal.LiteralTypeCase.STRING => - builder.setString(proto.DataType.String.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DATE => - builder.setDate(proto.DataType.Date.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) - case litCase if ProtoTypeOps.buildProtoDataTypeForLiteral(literal, builder) => - // Framework handled - case proto.Expression.Literal.LiteralTypeCase.TIME => - val timeBuilder = proto.DataType.Time.newBuilder() - if (literal.getTime.hasPrecision) { - timeBuilder.setPrecision(literal.getTime.getPrecision) - } - builder.setTime(timeBuilder.build()) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - if (literal.getArray.hasElementType) { - builder.setArray( - proto.DataType.Array + if (!ProtoTypeOps.buildProtoDataTypeForLiteral(literal, builder)) { + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.BINARY => + builder.setBinary(proto.DataType.Binary.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => + builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BYTE => + builder.setByte(proto.DataType.Byte.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.SHORT => + builder.setShort(proto.DataType.Short.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.INTEGER => + builder.setInteger(proto.DataType.Integer.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.LONG => + builder.setLong(proto.DataType.Long.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.FLOAT => + builder.setFloat(proto.DataType.Float.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DOUBLE => + builder.setDouble(proto.DataType.Double.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DECIMAL => + val decimal = Decimal.apply(literal.getDecimal.getValue) + var precision = decimal.precision + if (literal.getDecimal.hasPrecision) { + precision = math.max(precision, literal.getDecimal.getPrecision) + } + var scale = decimal.scale + if (literal.getDecimal.hasScale) { + scale = math.max(scale, literal.getDecimal.getScale) + } + builder.setDecimal( + proto.DataType.Decimal .newBuilder() - .setElementType(literal.getArray.getElementType) - .setContainsNull(true) + .setPrecision(math.max(precision, scale)) + .setScale(scale) .build()) - } else { + case proto.Expression.Literal.LiteralTypeCase.STRING => + builder.setString(proto.DataType.String.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DATE => + builder.setDate(proto.DataType.Date.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => + builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => + builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => + builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => + builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => + builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIME => + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + builder.setTime(timeBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + if (literal.getArray.hasElementType) { + builder.setArray( + proto.DataType.Array + .newBuilder() + .setElementType(literal.getArray.getElementType) + .setContainsNull(true) + .build()) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.MAP => + if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { + builder.setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(literal.getMap.getKeyType) + .setValueType(literal.getMap.getValueType) + .setValueContainsNull(true) + .build()) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + if (literal.getStruct.hasStructType) { + builder.setStruct(literal.getStruct.getStructType.getStruct) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case _ => + val literalCase = literal.getLiteralTypeCase throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.MAP => - if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { - builder.setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(literal.getMap.getKeyType) - .setValueType(literal.getMap.getValueType) - .setValueContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - if (literal.getStruct.hasStructType) { - builder.setStruct(literal.getStruct.getStructType.getStruct) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case _ => - val literalCase = literal.getLiteralTypeCase - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", - Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) + "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", + Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) + } } builder.build() } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala index 1c1aff44251e..da1357d579c1 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala @@ -25,15 +25,15 @@ import org.apache.spark.sql.types.{DataType, TimeType} /** * Optional type operations for Spark Connect Arrow serialization/deserialization. * - * Handles Arrow-based data exchange on the Connect client side for framework-managed types. - * Used by ArrowSerializer, ArrowDeserializer, and ArrowVectorReader. + * Handles Arrow-based data exchange on the Connect client side for framework-managed types. Used + * by ArrowSerializer, ArrowDeserializer, and ArrowVectorReader. * * NOTE: No feature flag check -the Connect client must handle whatever types the server sends. - * The feature flag controls server-side engine behavior; the client always needs to handle - * types that exist in the encoder. + * The feature flag controls server-side engine behavior; the client always needs to handle types + * that exist in the encoder. * - * Methods return Any to avoid referencing arrow-private types from the types.ops package. - * Call sites in the arrow package cast to the expected types. + * Methods return Any to avoid referencing arrow-private types from the types.ops package. Call + * sites in the arrow package cast to the expected types. * * @since 4.2.0 */ diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala index 47bfed7da182..4f9d6b2b0d4b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.types.{DataType, TimeType} /** * Optional type operations for Spark Connect protobuf conversions. * - * Handles bidirectional DataType <-> proto and Literal <-> proto conversions - * for framework-managed types in DataTypeProtoConverter and LiteralValueProtoConverter. + * Handles bidirectional DataType <-> proto and Literal <-> proto conversions for + * framework-managed types in DataTypeProtoConverter and LiteralValueProtoConverter. * * @since 4.2.0 */ @@ -52,9 +52,7 @@ trait ProtoTypeOps extends Serializable { def getScalaConverter: proto.Expression.Literal => Any /** Builds a proto DataType from a proto literal (for type inference). */ - def buildProtoDataType( - literal: proto.Expression.Literal, - builder: proto.DataType.Builder): Unit + def buildProtoDataType(literal: proto.Expression.Literal, builder: proto.DataType.Builder): Unit } /** @@ -71,8 +69,8 @@ object ProtoTypeOps { } /** - * Reverse lookup: converts a proto DataType to a Spark DataType, if it belongs - * to a framework-managed type. + * Reverse lookup: converts a proto DataType to a Spark DataType, if it belongs to a + * framework-managed type. */ def toCatalystType(t: proto.DataType): Option[DataType] = { if (!SqlApiConf.get.typesFrameworkEnabled) return None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 97c61820f2e3..f748592e9071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -113,11 +113,20 @@ object HiveResult extends SQLConfHelper { formatters: TimeFormatters, binaryFormatter: BinaryFormatter): String = a match { case (null, _) => if (nested) "null" else "NULL" + case (value, dt) => + ClientTypeOps(dt).map(_.formatExternal(value)).getOrElse { + toHiveStringDefault(a, nested, formatters, binaryFormatter) + } + } + + private def toHiveStringDefault( + a: (Any, DataType), + nested: Boolean, + formatters: TimeFormatters, + binaryFormatter: BinaryFormatter): String = a match { case (b, BooleanType) => b.toString case (d: Date, DateType) => formatters.date.format(d) case (ld: LocalDate, DateType) => formatters.date.format(ld) - case (value, dt) if ClientTypeOps(dt).isDefined => - ClientTypeOps(dt).get.formatExternal(value) case (lt: LocalTime, _: TimeType) => formatters.time.format(lt) case (t: Timestamp, TimestampType) => formatters.timestamp.format(t) case (i: Instant, TimestampType) => formatters.timestamp.format(i) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index dccdd521930b..834fdde6d581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -146,31 +146,31 @@ object JdbcUtils extends Logging with SQLConfHelper { * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) * @return The default JdbcType for this DataType */ - def getCommonJDBCType(dt: DataType): Option[JdbcType] = { - dt match { - case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) - case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) - case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) - case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) - case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) - case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) - case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) - case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) - case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) - case c: CharType => Option(JdbcType(s"CHAR(${c.length})", java.sql.Types.CHAR)) - case v: VarcharType => Option(JdbcType(s"VARCHAR(${v.length})", java.sql.Types.VARCHAR)) - case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) - // This is a common case of timestamp without time zone. Most of the databases either only - // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. - // Note that some dialects override this setting, e.g. as SQL Server. - case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) - case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) - case t: DecimalType => Option( - JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) - case dt if ClientTypeOps(dt).isDefined => - Option(JdbcType(ClientTypeOps(dt).get.jdbcTypeName, ClientTypeOps(dt).get.getJdbcType)) - case _ => None - } + def getCommonJDBCType(dt: DataType): Option[JdbcType] = + ClientTypeOps(dt).map(ops => JdbcType(ops.jdbcTypeName, ops.getJdbcType)) + .orElse(getCommonJDBCTypeDefault(dt)) + + private def getCommonJDBCTypeDefault(dt: DataType): Option[JdbcType] = dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case c: CharType => Option(JdbcType(s"CHAR(${c.length})", java.sql.Types.CHAR)) + case v: VarcharType => Option(JdbcType(s"VARCHAR(${v.length})", java.sql.Types.VARCHAR)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + // This is a common case of timestamp without time zone. Most of the databases either only + // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. + // Note that some dialects override this setting, e.g. as SQL Server. + case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None } def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index c6837ce2db27..12034e113e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -42,17 +42,19 @@ object EvaluatePython { */ private[python] class BytesWrapper(val data: Array[Byte]) - def needConversionInPython(dt: DataType): Boolean = dt match { - case dt if ClientTypeOps(dt).exists(_.needConversionInPython) => true - case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType - | _: TimeType | _: GeometryType | _: GeographyType => true - case _: StructType => true - case _: UserDefinedType[_] => true - case ArrayType(elementType, _) => needConversionInPython(elementType) - case MapType(keyType, valueType, _) => - needConversionInPython(keyType) || needConversionInPython(valueType) - case _ => false - } + def needConversionInPython(dt: DataType): Boolean = + ClientTypeOps(dt).map(_.needConversionInPython).getOrElse { + dt match { + case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType + | _: TimeType | _: GeometryType | _: GeographyType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } + } /** * Helper for converting from Catalyst type to java type suitable for Pickle. @@ -113,7 +115,10 @@ object EvaluatePython { * Make a converter that converts `obj` to the type specified by the data type, or returns * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def makeFromJava(dataType: DataType): Any => Any = dataType match { + def makeFromJava(dataType: DataType): Any => Any = + ClientTypeOps(dataType).map(_.makeFromJava).getOrElse(makeFromJavaDefault(dataType)) + + private def makeFromJavaDefault(dataType: DataType): Any => Any = dataType match { case BooleanType => (obj: Any) => nullSafeConvert(obj) { case b: Boolean => b } @@ -164,10 +169,8 @@ object EvaluatePython { case c: Int => c } - case dt if ClientTypeOps(dt).isDefined => ClientTypeOps(dt).get.makeFromJava - - case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => (obj: Any) => - nullSafeConvert(obj) { + case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => + (obj: Any) => nullSafeConvert(obj) { case c: Long => c // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs case c: Int => c.toLong @@ -232,7 +235,8 @@ object EvaluatePython { case VariantType => (obj: Any) => nullSafeConvert(obj) { case s: java.util.HashMap[_, _] => new VariantVal( - s.get("value").asInstanceOf[Array[Byte]], s.get("metadata").asInstanceOf[Array[Byte]] + s.get("value").asInstanceOf[Array[Byte]], + s.get("metadata").asInstanceOf[Array[Byte]] ) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 13957351aa7a..12d16467101d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -327,7 +327,11 @@ private[hive] class SparkExecuteStatementOperation( object SparkExecuteStatementOperation { - def toTTypeId(typ: DataType): TTypeId = typ match { + def toTTypeId(typ: DataType): TTypeId = + ClientTypeOps(typ).map(ops => TTypeId.valueOf(ops.thriftTypeName)) + .getOrElse(toTTypeIdDefault(typ)) + + private def toTTypeIdDefault(typ: DataType): TTypeId = typ match { case NullType => TTypeId.NULL_TYPE case BooleanType => TTypeId.BOOLEAN_TYPE case ByteType => TTypeId.TINYINT_TYPE @@ -343,8 +347,6 @@ object SparkExecuteStatementOperation { case _: StringType => TTypeId.STRING_TYPE case _: DecimalType => TTypeId.DECIMAL_TYPE case DateType => TTypeId.DATE_TYPE - case dt if ClientTypeOps(dt).isDefined => - TTypeId.valueOf(ClientTypeOps(dt).get.thriftTypeName) case _: TimeType => TTypeId.STRING_TYPE // TODO: Shall use TIMESTAMPLOCALTZ_TYPE, keep AS-IS now for // unnecessary behavior change From 453e22f3cf28806c1f98bbf0c09d8b7de61a7ce0 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 20:07:36 +0000 Subject: [PATCH 3/8] scala style --- .../apache/spark/sql/util/ArrowUtils.scala | 7 +- .../client/arrow/ArrowDeserializer.scala | 592 +++++++++--------- .../client/arrow/ArrowSerializer.scala | 10 +- .../client/arrow/ArrowVectorReader.scala | 12 +- .../common/DataTypeProtoConverter.scala | 376 +++++------ .../common/LiteralValueProtoConverter.scala | 143 +++-- .../common/types/ops/ProtoTypeOps.scala | 12 + .../sql/execution/python/EvaluatePython.scala | 25 +- 8 files changed, 607 insertions(+), 570 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 0d8a78ac6b30..dbbb7ab396f0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -40,11 +40,14 @@ private[sql] object ArrowUtils { /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = - ClientTypeOps(dt).map(_.toArrowType(timeZoneId)) + ClientTypeOps(dt) + .map(_.toArrowType(timeZoneId)) .getOrElse(toArrowTypeDefault(dt, timeZoneId, largeVarTypes)) private def toArrowTypeDefault( - dt: DataType, timeZoneId: String, largeVarTypes: Boolean): ArrowType = dt match { + dt: DataType, + timeZoneId: String, + largeVarTypes: Boolean): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 93d01c2ce67f..f42131e725ad 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -91,334 +91,336 @@ object ArrowDeserializers { private[arrow] def deserializerFor( encoder: AgnosticEncoder[_], data: AnyRef, - timeZoneId: String): Deserializer[Any] = { - (encoder, data) match { - case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) => - new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) { - override def value(i: Int): Boolean = reader.getBoolean(i) - } - case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) => - new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) { - override def value(i: Int): Byte = reader.getByte(i) - } - case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) => - new LeafFieldDeserializer[Short](encoder, v, timeZoneId) { - override def value(i: Int): Short = reader.getShort(i) - } - case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) => - new LeafFieldDeserializer[Int](encoder, v, timeZoneId) { - override def value(i: Int): Int = reader.getInt(i) - } - case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) => - new LeafFieldDeserializer[Long](encoder, v, timeZoneId) { - override def value(i: Int): Long = reader.getLong(i) - } - case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) => - new LeafFieldDeserializer[Float](encoder, v, timeZoneId) { - override def value(i: Int): Float = reader.getFloat(i) - } - case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) => - new LeafFieldDeserializer[Double](encoder, v, timeZoneId) { - override def value(i: Int): Double = reader.getDouble(i) - } - case (NullEncoder, _: FieldVector) => - new Deserializer[Any] { - def get(i: Int): Any = null - } - case (StringEncoder, v: FieldVector) => - new LeafFieldDeserializer[String](encoder, v, timeZoneId) { - override def value(i: Int): String = reader.getString(i) - } - case (JavaEnumEncoder(tag), v: FieldVector) => - // It would be nice if we can get Enum.valueOf working... - val valueOf = methodLookup.findStatic( - tag.runtimeClass, - "valueOf", - MethodType.methodType(tag.runtimeClass, classOf[String])) - new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) { - override def value(i: Int): Enum[_] = { - valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]] - } - } - case (ScalaEnumEncoder(parent, _), v: FieldVector) => - val mirror = scala.reflect.runtime.currentMirror - val module = mirror.classSymbol(parent).module.asModule - val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] - new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) { - override def value(i: Int): Enumeration#Value = { - enumeration.withName(reader.getString(i)) - } - } - case (BinaryEncoder, v: FieldVector) => - new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) { - override def value(i: Int): Array[Byte] = reader.getBytes(i) - } - case (SparkDecimalEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) { - override def value(i: Int): Decimal = reader.getDecimal(i) - } - case (ScalaDecimalEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) { - override def value(i: Int): BigDecimal = reader.getScalaDecimal(i) - } - case (JavaDecimalEncoder(_, _), v: FieldVector) => - new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) { - override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i) - } - case (ScalaBigIntEncoder, v: FieldVector) => - new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) { - override def value(i: Int): BigInt = reader.getScalaBigInt(i) - } - case (JavaBigIntEncoder, v: FieldVector) => - new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) { - override def value(i: Int): JBigInteger = reader.getJavaBigInt(i) - } - case (DayTimeIntervalEncoder, v: FieldVector) => - new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) { - override def value(i: Int): Duration = reader.getDuration(i) - } - case (YearMonthIntervalEncoder, v: FieldVector) => - new LeafFieldDeserializer[Period](encoder, v, timeZoneId) { - override def value(i: Int): Period = reader.getPeriod(i) - } - case (DateEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) { - override def value(i: Int): java.sql.Date = reader.getDate(i) - } - case (LocalDateEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) { - override def value(i: Int): LocalDate = reader.getLocalDate(i) - } - case (TimestampEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) { - override def value(i: Int): java.sql.Timestamp = reader.getTimestamp(i) - } - case (InstantEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) { - override def value(i: Int): Instant = reader.getInstant(i) - } - case (LocalDateTimeEncoder, v: FieldVector) => - new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) { - override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i) + timeZoneId: String): Deserializer[Any] = + ConnectArrowTypeOps(encoder) + .map(_.createArrowDeserializer(encoder, data, timeZoneId).asInstanceOf[Deserializer[Any]]) + .getOrElse(deserializerForDefault(encoder, data, timeZoneId)) + + private def deserializerForDefault( + encoder: AgnosticEncoder[_], + data: AnyRef, + timeZoneId: String): Deserializer[Any] = (encoder, data) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) => + new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) { + override def value(i: Int): Boolean = reader.getBoolean(i) + } + case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) => + new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) { + override def value(i: Int): Byte = reader.getByte(i) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) => + new LeafFieldDeserializer[Short](encoder, v, timeZoneId) { + override def value(i: Int): Short = reader.getShort(i) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) => + new LeafFieldDeserializer[Int](encoder, v, timeZoneId) { + override def value(i: Int): Int = reader.getInt(i) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) => + new LeafFieldDeserializer[Long](encoder, v, timeZoneId) { + override def value(i: Int): Long = reader.getLong(i) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) => + new LeafFieldDeserializer[Float](encoder, v, timeZoneId) { + override def value(i: Int): Float = reader.getFloat(i) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) => + new LeafFieldDeserializer[Double](encoder, v, timeZoneId) { + override def value(i: Int): Double = reader.getDouble(i) + } + case (NullEncoder, _: FieldVector) => + new Deserializer[Any] { + def get(i: Int): Any = null + } + case (StringEncoder, v: FieldVector) => + new LeafFieldDeserializer[String](encoder, v, timeZoneId) { + override def value(i: Int): String = reader.getString(i) + } + case (JavaEnumEncoder(tag), v: FieldVector) => + // It would be nice if we can get Enum.valueOf working... + val valueOf = methodLookup.findStatic( + tag.runtimeClass, + "valueOf", + MethodType.methodType(tag.runtimeClass, classOf[String])) + new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) { + override def value(i: Int): Enum[_] = { + valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]] } - case (enc, v: FieldVector) if ConnectArrowTypeOps(enc).isDefined => - ConnectArrowTypeOps(enc).get - .createArrowDeserializer(enc, v, timeZoneId) - .asInstanceOf[Deserializer[Any]] - case (LocalTimeEncoder, v: FieldVector) => - new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) { - override def value(i: Int): LocalTime = reader.getLocalTime(i) + } + case (ScalaEnumEncoder(parent, _), v: FieldVector) => + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(parent).module.asModule + val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] + new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) { + override def value(i: Int): Enumeration#Value = { + enumeration.withName(reader.getString(i)) } + } + case (BinaryEncoder, v: FieldVector) => + new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) { + override def value(i: Int): Array[Byte] = reader.getBytes(i) + } + case (SparkDecimalEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) { + override def value(i: Int): Decimal = reader.getDecimal(i) + } + case (ScalaDecimalEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) { + override def value(i: Int): BigDecimal = reader.getScalaDecimal(i) + } + case (JavaDecimalEncoder(_, _), v: FieldVector) => + new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) { + override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i) + } + case (ScalaBigIntEncoder, v: FieldVector) => + new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) { + override def value(i: Int): BigInt = reader.getScalaBigInt(i) + } + case (JavaBigIntEncoder, v: FieldVector) => + new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) { + override def value(i: Int): JBigInteger = reader.getJavaBigInt(i) + } + case (DayTimeIntervalEncoder, v: FieldVector) => + new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) { + override def value(i: Int): Duration = reader.getDuration(i) + } + case (YearMonthIntervalEncoder, v: FieldVector) => + new LeafFieldDeserializer[Period](encoder, v, timeZoneId) { + override def value(i: Int): Period = reader.getPeriod(i) + } + case (DateEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) { + override def value(i: Int): java.sql.Date = reader.getDate(i) + } + case (LocalDateEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) { + override def value(i: Int): LocalDate = reader.getLocalDate(i) + } + case (TimestampEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) { + override def value(i: Int): java.sql.Timestamp = reader.getTimestamp(i) + } + case (InstantEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) { + override def value(i: Int): Instant = reader.getInstant(i) + } + case (LocalDateTimeEncoder, v: FieldVector) => + new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) { + override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i) + } + case (LocalTimeEncoder, v: FieldVector) => + new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) { + override def value(i: Int): LocalTime = reader.getLocalTime(i) + } - case (OptionEncoder(value), v) => - val deserializer = deserializerFor(value, v, timeZoneId) - new Deserializer[Any] { - override def get(i: Int): Any = Option(deserializer.get(i)) - } + case (OptionEncoder(value), v) => + val deserializer = deserializerFor(value, v, timeZoneId) + new Deserializer[Any] { + override def get(i: Int): Any = Option(deserializer.get(i)) + } - case (ArrayEncoder(element, _), v: ListVector) => - val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) - new VectorFieldDeserializer[AnyRef, ListVector](v) { - def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag) - } + case (ArrayEncoder(element, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) + new VectorFieldDeserializer[AnyRef, ListVector](v) { + def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag) + } - case (IterableEncoder(tag, element, _, _), v: ListVector) => - val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) - if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) { - // mutable ArraySeq is a bit special because we need to use an array of the element type. - // Some parts of our codebase (unfortunately) rely on this for type inference on results. - new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) { - def value(i: Int): mutable.ArraySeq[Any] = { - val array = getArray(vector, i, deserializer)(element.clsTag) - ScalaCollectionUtils.wrap(array) - } + case (IterableEncoder(tag, element, _, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) + if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) { + // mutable ArraySeq is a bit special because we need to use an array of the element type. + // Some parts of our codebase (unfortunately) rely on this for type inference on results. + new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) { + def value(i: Int): mutable.ArraySeq[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + ScalaCollectionUtils.wrap(array) } - } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) { - new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { - def value(i: Int): immutable.ArraySeq[Any] = { - val array = getArray(vector, i, deserializer)(element.clsTag) - array.asInstanceOf[Array[_]].toImmutableArraySeq - } + } + } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) { + new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { + def value(i: Int): immutable.ArraySeq[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + array.asInstanceOf[Array[_]].toImmutableArraySeq } - } else if (isSubClass(Classes.ITERABLE, tag)) { - val companion = ScalaCollectionUtils.getIterableCompanion(tag) - new VectorFieldDeserializer[Iterable[Any], ListVector](v) { - def value(i: Int): Iterable[Any] = { - val builder = companion.newBuilder[Any] - loadListIntoBuilder(vector, i, deserializer, builder) - builder.result() - } + } + } else if (isSubClass(Classes.ITERABLE, tag)) { + val companion = ScalaCollectionUtils.getIterableCompanion(tag) + new VectorFieldDeserializer[Iterable[Any], ListVector](v) { + def value(i: Int): Iterable[Any] = { + val builder = companion.newBuilder[Any] + loadListIntoBuilder(vector, i, deserializer, builder) + builder.result() } - } else if (isSubClass(Classes.JLIST, tag)) { - val newInstance = resolveJavaListCreator(tag) - new VectorFieldDeserializer[JList[Any], ListVector](v) { - def value(i: Int): JList[Any] = { - var index = v.getElementStartIndex(i) - val end = v.getElementEndIndex(i) - val list = newInstance(end - index) - while (index < end) { - list.add(deserializer.get(index)) - index += 1 - } - list + } + } else if (isSubClass(Classes.JLIST, tag)) { + val newInstance = resolveJavaListCreator(tag) + new VectorFieldDeserializer[JList[Any], ListVector](v) { + def value(i: Int): JList[Any] = { + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + val list = newInstance(end - index) + while (index < end) { + list.add(deserializer.get(index)) + index += 1 } + list } - } else { - throw unsupportedCollectionType(tag.runtimeClass) } + } else { + throw unsupportedCollectionType(tag.runtimeClass) + } - case (MapEncoder(tag, key, value, _), v: MapVector) => - val structVector = v.getDataVector.asInstanceOf[StructVector] - val keyDeserializer = - deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), timeZoneId) - val valueDeserializer = - deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), timeZoneId) - if (isSubClass(Classes.MAP, tag)) { - val companion = ScalaCollectionUtils.getMapCompanion(tag) - new VectorFieldDeserializer[Map[Any, Any], MapVector](v) { - def value(i: Int): Map[Any, Any] = { - val builder = companion.newBuilder[Any, Any] - var index = v.getElementStartIndex(i) - val end = v.getElementEndIndex(i) - builder.sizeHint(end - index) - while (index < end) { - builder += (keyDeserializer.get(index) -> valueDeserializer.get(index)) - index += 1 - } - builder.result() + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val keyDeserializer = + deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), timeZoneId) + val valueDeserializer = + deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), timeZoneId) + if (isSubClass(Classes.MAP, tag)) { + val companion = ScalaCollectionUtils.getMapCompanion(tag) + new VectorFieldDeserializer[Map[Any, Any], MapVector](v) { + def value(i: Int): Map[Any, Any] = { + val builder = companion.newBuilder[Any, Any] + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + builder.sizeHint(end - index) + while (index < end) { + builder += (keyDeserializer.get(index) -> valueDeserializer.get(index)) + index += 1 } + builder.result() } - } else if (isSubClass(Classes.JMAP, tag)) { - val newInstance = resolveJavaMapCreator(tag) - new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) { - def value(i: Int): JMap[Any, Any] = { - val map = newInstance() - var index = v.getElementStartIndex(i) - val end = v.getElementEndIndex(i) - while (index < end) { - map.put(keyDeserializer.get(index), valueDeserializer.get(index)) - index += 1 - } - map + } + } else if (isSubClass(Classes.JMAP, tag)) { + val newInstance = resolveJavaMapCreator(tag) + new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) { + def value(i: Int): JMap[Any, Any] = { + val map = newInstance() + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + while (index < end) { + map.put(keyDeserializer.get(index), valueDeserializer.get(index)) + index += 1 } + map } - } else { - throw unsupportedCollectionType(tag.runtimeClass) } + } else { + throw unsupportedCollectionType(tag.runtimeClass) + } - case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) => - val outer = outerPointerGetter.map(_()).toSeq - // We should try to make this work with MethodHandles. - val Some(constructor) = - ScalaReflection.findConstructor( - tag.runtimeClass, - outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass)) - val deserializers = if (isTuple(tag.runtimeClass)) { - fields.zip(vectors).map { case (field, vector) => - deserializerFor(field.enc, vector, timeZoneId) - } - } else { - val outerDeserializer = outer.map { value => - new Deserializer[Any] { - override def get(i: Int): Any = value - } - } - val lookup = createFieldLookup(vectors) - outerDeserializer ++ fields.map { field => - deserializerFor(field.enc, lookup(field.name), timeZoneId) - } + case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) => + val outer = outerPointerGetter.map(_()).toSeq + // We should try to make this work with MethodHandles. + val Some(constructor) = + ScalaReflection.findConstructor( + tag.runtimeClass, + outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass)) + val deserializers = if (isTuple(tag.runtimeClass)) { + fields.zip(vectors).map { case (field, vector) => + deserializerFor(field.enc, vector, timeZoneId) } - new StructFieldSerializer[Any](struct) { - def value(i: Int): Any = { - constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef])) + } else { + val outerDeserializer = outer.map { value => + new Deserializer[Any] { + override def get(i: Int): Any = value } } - - case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => val lookup = createFieldLookup(vectors) - val deserializers = fields.toArray.map { field => + outerDeserializer ++ fields.map { field => deserializerFor(field.enc, lookup(field.name), timeZoneId) } - new StructFieldSerializer[Any](struct) { - def value(i: Int): Any = { - val values = deserializers.map(_.get(i)) - new GenericRowWithSchema(values, r.schema) - } + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef])) } + } - case (_: GeometryEncoder, StructVectors(struct, vectors)) => - val gdser = new GeometryArrowSerDe - gdser.createDeserializer(struct, vectors, timeZoneId) - - case (_: GeographyEncoder, StructVectors(struct, vectors)) => - val gdser = new GeographyArrowSerDe - gdser.createDeserializer(struct, vectors, timeZoneId) - - case (VariantEncoder, StructVectors(struct, vectors)) => - assert(vectors.exists(_.getName == "value")) - assert( - vectors.exists(field => - field.getName == "metadata" && field.getField.getMetadata - .containsKey("variant") && field.getField.getMetadata.get("variant") == "true")) - val valueDecoder = - deserializerFor( - BinaryEncoder, - vectors - .find(_.getName == "value") - .getOrElse(throw CompilationErrors.columnNotFoundError("value")), - timeZoneId) - val metadataDecoder = - deserializerFor( - BinaryEncoder, - vectors - .find(_.getName == "metadata") - .getOrElse(throw CompilationErrors.columnNotFoundError("metadata")), - timeZoneId) - new StructFieldSerializer[VariantVal](struct) { - def value(i: Int): VariantVal = { - new VariantVal( - valueDecoder.get(i).asInstanceOf[Array[Byte]], - metadataDecoder.get(i).asInstanceOf[Array[Byte]]) - } + case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => + val lookup = createFieldLookup(vectors) + val deserializers = fields.toArray.map { field => + deserializerFor(field.enc, lookup(field.name), timeZoneId) + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val values = deserializers.map(_.get(i)) + new GenericRowWithSchema(values, r.schema) } + } - case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => - val constructor = - methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) - val lookup = createFieldLookup(vectors) - val setters = fields - .filter(_.writeMethod.isDefined) - .map { field => - val vector = lookup(field.name) - val deserializer = deserializerFor(field.enc, vector, timeZoneId) - val setter = methodLookup.findVirtual( - tag.runtimeClass, - field.writeMethod.get, - MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) - (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) - } - new StructFieldSerializer[Any](struct) { - def value(i: Int): Any = { - val instance = constructor.invoke() - setters.foreach(_(instance, i)) - instance - } + case (_: GeometryEncoder, StructVectors(struct, vectors)) => + val gdser = new GeometryArrowSerDe + gdser.createDeserializer(struct, vectors, timeZoneId) + + case (_: GeographyEncoder, StructVectors(struct, vectors)) => + val gdser = new GeographyArrowSerDe + gdser.createDeserializer(struct, vectors, timeZoneId) + + case (VariantEncoder, StructVectors(struct, vectors)) => + assert(vectors.exists(_.getName == "value")) + assert( + vectors.exists(field => + field.getName == "metadata" && field.getField.getMetadata + .containsKey("variant") && field.getField.getMetadata.get("variant") == "true")) + val valueDecoder = + deserializerFor( + BinaryEncoder, + vectors + .find(_.getName == "value") + .getOrElse(throw CompilationErrors.columnNotFoundError("value")), + timeZoneId) + val metadataDecoder = + deserializerFor( + BinaryEncoder, + vectors + .find(_.getName == "metadata") + .getOrElse(throw CompilationErrors.columnNotFoundError("metadata")), + timeZoneId) + new StructFieldSerializer[VariantVal](struct) { + def value(i: Int): VariantVal = { + new VariantVal( + valueDecoder.get(i).asInstanceOf[Array[Byte]], + metadataDecoder.get(i).asInstanceOf[Array[Byte]]) } + } - case (TransformingEncoder(_, encoder, provider, _), v) => - new Deserializer[Any] { - private[this] val codec = provider() - private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) - override def get(i: Int): Any = codec.decode(deserializer.get(i)) + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + val constructor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + val lookup = createFieldLookup(vectors) + val setters = fields + .filter(_.writeMethod.isDefined) + .map { field => + val vector = lookup(field.name) + val deserializer = deserializerFor(field.enc, vector, timeZoneId) + val setter = methodLookup.findVirtual( + tag.runtimeClass, + field.writeMethod.get, + MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) + (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val instance = constructor.invoke() + setters.foreach(_(instance, i)) + instance + } + } - case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => - throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) + case (TransformingEncoder(_, encoder, provider, _), v) => + new Deserializer[Any] { + private[this] val codec = provider() + private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) + override def get(i: Int): Any = codec.decode(deserializer.get(i)) + } - case _ => - throw new RuntimeException( - s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") - } + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) + + case _ => + throw new RuntimeException( + s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") } private val methodLookup = MethodHandles.lookup() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index d0fca318a72d..5dff3413223d 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -240,7 +240,12 @@ object ArrowSerializer { } // TODO throw better errors on class cast exceptions. - private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { + private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = + ConnectArrowTypeOps(encoder) + .map(_.createArrowSerializer(v).asInstanceOf[Serializer]) + .getOrElse(serializerForDefault(encoder, v)) + + private def serializerForDefault[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = (encoder, v) match { case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => new FieldSerializer[Boolean, BitVector](v) { @@ -392,8 +397,6 @@ object ArrowSerializer { override def set(index: Int, value: LocalDateTime): Unit = vector.setSafe(index, SparkDateTimeUtils.localDateTimeToMicros(value)) } - case (enc, v) if ConnectArrowTypeOps(enc).isDefined => - ConnectArrowTypeOps(enc).get.createArrowSerializer(v).asInstanceOf[Serializer] case (LocalTimeEncoder, v: TimeNanoVector) => new FieldSerializer[LocalTime, TimeNanoVector](v) { override def set(index: Int, value: LocalTime): Unit = @@ -520,7 +523,6 @@ object ArrowSerializer { case _ => throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.") } - } private val methodLookup = MethodHandles.lookup() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index a1eb39573aad..d483d1bb5de7 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -67,6 +67,14 @@ private[arrow] abstract class ArrowVectorReader { object ArrowVectorReader { def apply( + targetDataType: DataType, + vector: FieldVector, + timeZoneId: String): ArrowVectorReader = + ConnectArrowTypeOps(targetDataType) + .map(_.createArrowVectorReader(vector).asInstanceOf[ArrowVectorReader]) + .getOrElse(applyDefault(targetDataType, vector, timeZoneId)) + + private def applyDefault( targetDataType: DataType, vector: FieldVector, timeZoneId: String): ArrowVectorReader = { @@ -93,10 +101,6 @@ object ArrowVectorReader { case v: DateDayVector => new DateDayVectorReader(v, timeZoneId) case v: TimeStampMicroTZVector => new TimeStampMicroTZVectorReader(v) case v: TimeStampMicroVector => new TimeStampMicroVectorReader(v, timeZoneId) - case v if ConnectArrowTypeOps(targetDataType).isDefined => - ConnectArrowTypeOps(targetDataType).get - .createArrowVectorReader(v) - .asInstanceOf[ArrowVectorReader] case v: TimeNanoVector => new TimeVectorReader(v) case _: NullVector => NullVectorReader case _ => throw new RuntimeException("Unsupported Vector Type: " + vector.getClass) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 24f740e9925d..3ccdcb73a869 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -177,228 +177,228 @@ object DataTypeProtoConverter { } private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = - ProtoTypeOps(t).map(_.toConnectProtoType) + ProtoTypeOps(t) + .map(_.toConnectProtoType) .getOrElse(toConnectProtoTypeDefault(t, bytesToBinary)) - private def toConnectProtoTypeDefault( - t: DataType, bytesToBinary: Boolean): proto.DataType = t match { - case NullType => ProtoDataTypes.NullType + private def toConnectProtoTypeDefault(t: DataType, bytesToBinary: Boolean): proto.DataType = + t match { + case NullType => ProtoDataTypes.NullType - case BooleanType => ProtoDataTypes.BooleanType + case BooleanType => ProtoDataTypes.BooleanType - case BinaryType => ProtoDataTypes.BinaryType + case BinaryType => ProtoDataTypes.BinaryType - case ByteType => ProtoDataTypes.ByteType + case ByteType => ProtoDataTypes.ByteType - case ShortType => ProtoDataTypes.ShortType + case ShortType => ProtoDataTypes.ShortType - case IntegerType => ProtoDataTypes.IntegerType + case IntegerType => ProtoDataTypes.IntegerType - case LongType => ProtoDataTypes.LongType + case LongType => ProtoDataTypes.LongType - case FloatType => ProtoDataTypes.FloatType + case FloatType => ProtoDataTypes.FloatType - case DoubleType => ProtoDataTypes.DoubleType + case DoubleType => ProtoDataTypes.DoubleType - case DecimalType.Fixed(precision, scale) => - proto.DataType - .newBuilder() - .setDecimal( - proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build()) - .build() + case DecimalType.Fixed(precision, scale) => + proto.DataType + .newBuilder() + .setDecimal( + proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build()) + .build() - case c: CharType => - proto.DataType - .newBuilder() - .setChar(proto.DataType.Char.newBuilder().setLength(c.length).build()) - .build() + case c: CharType => + proto.DataType + .newBuilder() + .setChar(proto.DataType.Char.newBuilder().setLength(c.length).build()) + .build() - case v: VarcharType => - proto.DataType - .newBuilder() - .setVarChar(proto.DataType.VarChar.newBuilder().setLength(v.length).build()) - .build() + case v: VarcharType => + proto.DataType + .newBuilder() + .setVarChar(proto.DataType.VarChar.newBuilder().setLength(v.length).build()) + .build() - // StringType must be matched after CharType and VarcharType - case s: StringType => - val stringBuilder = proto.DataType.String.newBuilder() - // Send collation only for explicit collations (including explicit UTF8_BINARY). - // Default STRING (case object) has no explicit collation and should omit it. - if (!s.eq(StringType)) { - stringBuilder.setCollation( - CollationFactory.fetchCollation(s.collationId).collationName) - } - proto.DataType - .newBuilder() - .setString(stringBuilder.build()) - .build() + // StringType must be matched after CharType and VarcharType + case s: StringType => + val stringBuilder = proto.DataType.String.newBuilder() + // Send collation only for explicit collations (including explicit UTF8_BINARY). + // Default STRING (case object) has no explicit collation and should omit it. + if (!s.eq(StringType)) { + stringBuilder.setCollation(CollationFactory.fetchCollation(s.collationId).collationName) + } + proto.DataType + .newBuilder() + .setString(stringBuilder.build()) + .build() - case DateType => ProtoDataTypes.DateType + case DateType => ProtoDataTypes.DateType - case TimestampType => ProtoDataTypes.TimestampType + case TimestampType => ProtoDataTypes.TimestampType - case TimestampNTZType => ProtoDataTypes.TimestampNTZType + case TimestampNTZType => ProtoDataTypes.TimestampNTZType - case TimeType(precision) => - proto.DataType - .newBuilder() - .setTime(proto.DataType.Time.newBuilder().setPrecision(precision).build()) - .build() + case TimeType(precision) => + proto.DataType + .newBuilder() + .setTime(proto.DataType.Time.newBuilder().setPrecision(precision).build()) + .build() - case CalendarIntervalType => ProtoDataTypes.CalendarIntervalType + case CalendarIntervalType => ProtoDataTypes.CalendarIntervalType - case YearMonthIntervalType(startField, endField) => - proto.DataType - .newBuilder() - .setYearMonthInterval( - proto.DataType.YearMonthInterval + case YearMonthIntervalType(startField, endField) => + proto.DataType + .newBuilder() + .setYearMonthInterval( + proto.DataType.YearMonthInterval + .newBuilder() + .setStartField(startField) + .setEndField(endField) + .build()) + .build() + + case DayTimeIntervalType(startField, endField) => + proto.DataType + .newBuilder() + .setDayTimeInterval( + proto.DataType.DayTimeInterval + .newBuilder() + .setStartField(startField) + .setEndField(endField) + .build()) + .build() + + case ArrayType(elementType: DataType, containsNull: Boolean) => + if (elementType == ByteType && bytesToBinary) { + proto.DataType .newBuilder() - .setStartField(startField) - .setEndField(endField) - .build()) - .build() - - case DayTimeIntervalType(startField, endField) => - proto.DataType - .newBuilder() - .setDayTimeInterval( - proto.DataType.DayTimeInterval + .setBinary(proto.DataType.Binary.newBuilder().build()) + .build() + } else { + proto.DataType .newBuilder() - .setStartField(startField) - .setEndField(endField) - .build()) - .build() + .setArray( + proto.DataType.Array + .newBuilder() + .setElementType(toConnectProtoTypeInternal(elementType, bytesToBinary)) + .setContainsNull(containsNull) + .build()) + .build() + } + + case StructType(fields: Array[StructField]) => + val protoFields = fields.toImmutableArraySeq.map { + case StructField( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata) => + if (metadata.equals(Metadata.empty)) { + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) + .setNullable(nullable) + .build() + } else { + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) + .setNullable(nullable) + .setMetadata(metadata.json) + .build() + } + } + proto.DataType + .newBuilder() + .setStruct( + proto.DataType.Struct + .newBuilder() + .addAllFields(protoFields.asJava) + .build()) + .build() - case ArrayType(elementType: DataType, containsNull: Boolean) => - if (elementType == ByteType && bytesToBinary) { + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => proto.DataType .newBuilder() - .setBinary(proto.DataType.Binary.newBuilder().build()) + .setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(toConnectProtoTypeInternal(keyType, bytesToBinary)) + .setValueType(toConnectProtoTypeInternal(valueType, bytesToBinary)) + .setValueContainsNull(valueContainsNull) + .build()) .build() - } else { + + case g: GeographyType => proto.DataType .newBuilder() - .setArray( - proto.DataType.Array + .setGeography( + proto.DataType.Geography .newBuilder() - .setElementType(toConnectProtoTypeInternal(elementType, bytesToBinary)) - .setContainsNull(containsNull) + .setSrid(g.srid) .build()) .build() - } - case StructType(fields: Array[StructField]) => - val protoFields = fields.toImmutableArraySeq.map { - case StructField( - name: String, - dataType: DataType, - nullable: Boolean, - metadata: Metadata) => - if (metadata.equals(Metadata.empty)) { - proto.DataType.StructField + case g: GeometryType => + proto.DataType + .newBuilder() + .setGeometry( + proto.DataType.Geometry .newBuilder() - .setName(name) - .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) - .setNullable(nullable) - .build() - } else { - proto.DataType.StructField + .setSrid(g.srid) + .build()) + .build() + + case VariantType => ProtoDataTypes.VariantType + + case pyudt: PythonUserDefinedType => + // Python UDT + proto.DataType + .newBuilder() + .setUdt( + proto.DataType.UDT .newBuilder() - .setName(name) - .setDataType(toConnectProtoTypeInternal(dataType, bytesToBinary)) - .setNullable(nullable) - .setMetadata(metadata.json) - .build() - } - } - proto.DataType - .newBuilder() - .setStruct( - proto.DataType.Struct - .newBuilder() - .addAllFields(protoFields.asJava) - .build()) - .build() - - case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => - proto.DataType - .newBuilder() - .setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(toConnectProtoTypeInternal(keyType, bytesToBinary)) - .setValueType(toConnectProtoTypeInternal(valueType, bytesToBinary)) - .setValueContainsNull(valueContainsNull) - .build()) - .build() - - case g: GeographyType => - proto.DataType - .newBuilder() - .setGeography( - proto.DataType.Geography - .newBuilder() - .setSrid(g.srid) - .build()) - .build() - - case g: GeometryType => - proto.DataType - .newBuilder() - .setGeometry( - proto.DataType.Geometry - .newBuilder() - .setSrid(g.srid) - .build()) - .build() - - case VariantType => ProtoDataTypes.VariantType - - case pyudt: PythonUserDefinedType => - // Python UDT - proto.DataType - .newBuilder() - .setUdt( - proto.DataType.UDT - .newBuilder() - .setType("udt") - .setPythonClass(pyudt.pyUDT) - .setSqlType(toConnectProtoTypeInternal(pyudt.sqlType, bytesToBinary)) - .setSerializedPythonClass(pyudt.serializedPyClass) - .build()) - .build() - - case udt: UserDefinedType[_] => - // Scala/Java UDT - udt.getClass.getName match { - // To avoid making connect-common depend on ml, - // we use class name to identify VectorUDT and MatrixUDT. - case "org.apache.spark.ml.linalg.VectorUDT" => - ProtoDataTypes.VectorUDT - - case "org.apache.spark.ml.linalg.MatrixUDT" => - ProtoDataTypes.MatrixUDT - - case className => - val builder = proto.DataType.UDT.newBuilder() - builder - .setType("udt") - .setJvmClass(className) - .setSqlType(toConnectProtoTypeInternal(udt.sqlType, bytesToBinary)) - - if (udt.pyUDT != null) { - builder.setPythonClass(udt.pyUDT) - } + .setType("udt") + .setPythonClass(pyudt.pyUDT) + .setSqlType(toConnectProtoTypeInternal(pyudt.sqlType, bytesToBinary)) + .setSerializedPythonClass(pyudt.serializedPyClass) + .build()) + .build() - proto.DataType - .newBuilder() - .setUdt(builder.build()) - .build() - } + case udt: UserDefinedType[_] => + // Scala/Java UDT + udt.getClass.getName match { + // To avoid making connect-common depend on ml, + // we use class name to identify VectorUDT and MatrixUDT. + case "org.apache.spark.ml.linalg.VectorUDT" => + ProtoDataTypes.VectorUDT + + case "org.apache.spark.ml.linalg.MatrixUDT" => + ProtoDataTypes.MatrixUDT + + case className => + val builder = proto.DataType.UDT.newBuilder() + builder + .setType("udt") + .setJvmClass(className) + .setSqlType(toConnectProtoTypeInternal(udt.sqlType, bytesToBinary)) + + if (udt.pyUDT != null) { + builder.setPythonClass(udt.pyUDT) + } + + proto.DataType + .newBuilder() + .setUdt(builder.build()) + .build() + } - case _ => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_CATALYST_TO_PROTO", - Map("typeName" -> t.typeName)) - } + case _ => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_CATALYST_TO_PROTO", + Map("typeName" -> t.typeName)) + } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 8ec2e0897da3..169fc2896480 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -58,71 +58,77 @@ object LiteralValueProtoConverter { literal: Any, options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() - - def decimalBuilder(precision: Int, scale: Int, value: String) = { - builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value) - } - - def calendarIntervalBuilder(months: Int, days: Int, microseconds: Long) = { - builder.getCalendarIntervalBuilder - .setMonths(months) - .setDays(days) - .setMicroseconds(microseconds) + ProtoTypeOps.toLiteralProtoForValue(literal, builder).getOrElse { + toLiteralProtoBuilderDefault(literal, builder, options) } + } - def arrayBuilder(array: Array[_]) = { + private def toLiteralProtoBuilderDefault( + literal: Any, + builder: proto.Expression.Literal.Builder, + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = literal match { + case v: Boolean => builder.setBoolean(v) + case v: Byte => builder.setByte(v) + case v: Short => builder.setShort(v) + case v: Int => builder.setInteger(v) + case v: Long => builder.setLong(v) + case v: Float => builder.setFloat(v) + case v: Double => builder.setDouble(v) + case v: BigDecimal => + builder.setDecimal( + builder.getDecimalBuilder + .setPrecision(v.precision) + .setScale(v.scale) + .setValue(v.toString)) + case v: JBigDecimal => + builder.setDecimal( + builder.getDecimalBuilder + .setPrecision(v.precision) + .setScale(v.scale) + .setValue(v.toString)) + case v: String => builder.setString(v) + case v: Char => builder.setString(v.toString) + case v: Array[Char] => builder.setString(String.valueOf(v)) + case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) + case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) + case v: immutable.ArraySeq[_] => + toLiteralProtoBuilderInternal(v.unsafeArray, options) + case v: LocalDate => builder.setDate(v.toEpochDay.toInt) + case v: Decimal => + builder.setDecimal( + builder.getDecimalBuilder + .setPrecision(Math.max(v.precision, v.scale)) + .setScale(v.scale) + .setValue(v.toString)) + case v: Instant => builder.setTimestamp(SparkDateTimeUtils.instantToMicros(v)) + case v: Timestamp => builder.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)) + case v: LocalDateTime => + builder.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)) + case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) + case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) + case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) + case v: LocalTime => + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION)) + case v: Array[_] => val ab = builder.getArrayBuilder - array.foreach { x => + v.foreach { x => ab.addElements(toLiteralProtoBuilderInternal(x, options).build()) } if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) + ab.setElementType(toConnectProtoType(toDataType(v.getClass.getComponentType))) } - ab - } - - literal match { - case v: Boolean => builder.setBoolean(v) - case v: Byte => builder.setByte(v) - case v: Short => builder.setShort(v) - case v: Int => builder.setInteger(v) - case v: Long => builder.setLong(v) - case v: Float => builder.setFloat(v) - case v: Double => builder.setDouble(v) - case v: BigDecimal => - builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) - case v: JBigDecimal => - builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) - case v: String => builder.setString(v) - case v: Char => builder.setString(v.toString) - case v: Array[Char] => builder.setString(String.valueOf(v)) - case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) - case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) - case v: immutable.ArraySeq[_] => - toLiteralProtoBuilderInternal(v.unsafeArray, options) - case v: LocalDate => builder.setDate(v.toEpochDay.toInt) - case v: Decimal => - builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString)) - case v: Instant => builder.setTimestamp(SparkDateTimeUtils.instantToMicros(v)) - case v: Timestamp => builder.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)) - case v: LocalDateTime => - builder.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)) - case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) - case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) - case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) - case v: LocalTime => - ProtoTypeOps(TimeType()).map(_.toLiteralProto(v, builder)).getOrElse { - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(TimeType.DEFAULT_PRECISION)) - } - case v: Array[_] => builder.setArray(arrayBuilder(v)) - case v: CalendarInterval => - builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) - case null => builder.setNull(ProtoDataTypes.NullType) - case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") - } + builder.setArray(ab) + case v: CalendarInterval => + builder.setCalendarInterval( + builder.getCalendarIntervalBuilder + .setMonths(v.months) + .setDays(v.days) + .setMicroseconds(v.microseconds)) + case null => builder.setNull(ProtoDataTypes.NullType) + case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") } private def toLiteralProtoBuilderInternal( @@ -130,6 +136,16 @@ object LiteralValueProtoConverter { dataType: DataType, options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() + ProtoTypeOps(dataType) + .map(_.toLiteralProtoWithType(literal, dataType, builder)) + .getOrElse(toLiteralProtoWithTypeDefault(literal, dataType, builder, options)) + } + + private def toLiteralProtoWithTypeDefault( + literal: Any, + dataType: DataType, + builder: proto.Expression.Literal.Builder, + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { def arrayBuilder(scalaValue: Any, elementType: DataType) = { val ab = builder.getArrayBuilder @@ -220,15 +236,12 @@ object LiteralValueProtoConverter { case (v, structType: StructType) => builder.setStruct(structBuilder(v, structType)) case (v: LocalTime, timeType: TimeType) => - ProtoTypeOps(timeType).map(_.toLiteralProtoWithType(v, timeType, builder)).getOrElse { - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(timeType.precision)) - } + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(timeType.precision)) case _ => toLiteralProtoBuilderInternal(literal, options) } - } /** diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala index 4f9d6b2b0d4b..eaf0ec9e94b9 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala @@ -68,6 +68,18 @@ object ProtoTypeOps { } } + /** Reverse lookup by value class for the generic literal builder. */ + def toLiteralProtoForValue( + value: Any, + builder: proto.Expression.Literal.Builder): Option[proto.Expression.Literal.Builder] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + value match { + case v: java.time.LocalTime => + Some(new TimeTypeConnectOps(TimeType()).toLiteralProto(v, builder)) + case _ => None + } + } + /** * Reverse lookup: converts a proto DataType to a Spark DataType, if it belongs to a * framework-managed type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 12034e113e2f..6a0a2353c6aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -43,18 +43,19 @@ object EvaluatePython { private[python] class BytesWrapper(val data: Array[Byte]) def needConversionInPython(dt: DataType): Boolean = - ClientTypeOps(dt).map(_.needConversionInPython).getOrElse { - dt match { - case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType - | _: TimeType | _: GeometryType | _: GeographyType => true - case _: StructType => true - case _: UserDefinedType[_] => true - case ArrayType(elementType, _) => needConversionInPython(elementType) - case MapType(keyType, valueType, _) => - needConversionInPython(keyType) || needConversionInPython(valueType) - case _ => false - } - } + ClientTypeOps(dt).map(_.needConversionInPython) + .getOrElse(needConversionInPythonDefault(dt)) + + private def needConversionInPythonDefault(dt: DataType): Boolean = dt match { + case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType + | _: TimeType | _: GeometryType | _: GeographyType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } /** * Helper for converting from Catalyst type to java type suitable for Pickle. From 3e6546f080fecbd43b576d26136cb6571ac8a3d4 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 21:00:47 +0000 Subject: [PATCH 4/8] styling --- .../apache/spark/sql/util/ArrowUtils.scala | 57 +- .../client/arrow/ArrowDeserializer.scala | 581 +++++++++--------- .../arrow/types/ops/TimeTypeConnectOps.scala | 7 +- .../common/DataTypeProtoConverter.scala | 109 ++-- .../common/LiteralValueProtoConverter.scala | 420 +++++++------ .../common/types/ops/ProtoTypeOps.scala | 19 +- .../sql/execution/python/EvaluatePython.scala | 7 +- 7 files changed, 605 insertions(+), 595 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index dbbb7ab396f0..ebe7ba6808ed 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -47,34 +47,35 @@ private[sql] object ArrowUtils { private def toArrowTypeDefault( dt: DataType, timeZoneId: String, - largeVarTypes: Boolean): ArrowType = dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE - case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE - case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => - new ArrowType.Decimal(precision, scale, 8 * 16) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw SparkException.internalError("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case _: TimeType => new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO) - case _ => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } + largeVarTypes: Boolean): ArrowType = + dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + case DecimalType.Fixed(precision, scale) => + new ArrowType.Decimal(precision, scale, 8 * 16) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw SparkException.internalError("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case _: TimeType => new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO) + case _ => + throw ExecutionErrors.unsupportedDataTypeError(dt) + } def fromArrowType(dt: ArrowType): DataType = ClientTypeOps.fromArrowType(dt).getOrElse(fromArrowTypeDefault(dt)) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index f42131e725ad..ba523660757e 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -99,329 +99,330 @@ object ArrowDeserializers { private def deserializerForDefault( encoder: AgnosticEncoder[_], data: AnyRef, - timeZoneId: String): Deserializer[Any] = (encoder, data) match { - case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) => - new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) { - override def value(i: Int): Boolean = reader.getBoolean(i) - } - case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) => - new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) { - override def value(i: Int): Byte = reader.getByte(i) - } - case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) => - new LeafFieldDeserializer[Short](encoder, v, timeZoneId) { - override def value(i: Int): Short = reader.getShort(i) - } - case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) => - new LeafFieldDeserializer[Int](encoder, v, timeZoneId) { - override def value(i: Int): Int = reader.getInt(i) - } - case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) => - new LeafFieldDeserializer[Long](encoder, v, timeZoneId) { - override def value(i: Int): Long = reader.getLong(i) - } - case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) => - new LeafFieldDeserializer[Float](encoder, v, timeZoneId) { - override def value(i: Int): Float = reader.getFloat(i) - } - case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) => - new LeafFieldDeserializer[Double](encoder, v, timeZoneId) { - override def value(i: Int): Double = reader.getDouble(i) - } - case (NullEncoder, _: FieldVector) => - new Deserializer[Any] { - def get(i: Int): Any = null - } - case (StringEncoder, v: FieldVector) => - new LeafFieldDeserializer[String](encoder, v, timeZoneId) { - override def value(i: Int): String = reader.getString(i) - } - case (JavaEnumEncoder(tag), v: FieldVector) => - // It would be nice if we can get Enum.valueOf working... - val valueOf = methodLookup.findStatic( - tag.runtimeClass, - "valueOf", - MethodType.methodType(tag.runtimeClass, classOf[String])) - new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) { - override def value(i: Int): Enum[_] = { - valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]] + timeZoneId: String): Deserializer[Any] = + (encoder, data) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) => + new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) { + override def value(i: Int): Boolean = reader.getBoolean(i) } - } - case (ScalaEnumEncoder(parent, _), v: FieldVector) => - val mirror = scala.reflect.runtime.currentMirror - val module = mirror.classSymbol(parent).module.asModule - val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] - new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) { - override def value(i: Int): Enumeration#Value = { - enumeration.withName(reader.getString(i)) + case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) => + new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) { + override def value(i: Int): Byte = reader.getByte(i) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) => + new LeafFieldDeserializer[Short](encoder, v, timeZoneId) { + override def value(i: Int): Short = reader.getShort(i) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) => + new LeafFieldDeserializer[Int](encoder, v, timeZoneId) { + override def value(i: Int): Int = reader.getInt(i) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) => + new LeafFieldDeserializer[Long](encoder, v, timeZoneId) { + override def value(i: Int): Long = reader.getLong(i) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) => + new LeafFieldDeserializer[Float](encoder, v, timeZoneId) { + override def value(i: Int): Float = reader.getFloat(i) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) => + new LeafFieldDeserializer[Double](encoder, v, timeZoneId) { + override def value(i: Int): Double = reader.getDouble(i) + } + case (NullEncoder, _: FieldVector) => + new Deserializer[Any] { + def get(i: Int): Any = null + } + case (StringEncoder, v: FieldVector) => + new LeafFieldDeserializer[String](encoder, v, timeZoneId) { + override def value(i: Int): String = reader.getString(i) + } + case (JavaEnumEncoder(tag), v: FieldVector) => + // It would be nice if we can get Enum.valueOf working... + val valueOf = methodLookup.findStatic( + tag.runtimeClass, + "valueOf", + MethodType.methodType(tag.runtimeClass, classOf[String])) + new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) { + override def value(i: Int): Enum[_] = { + valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]] + } + } + case (ScalaEnumEncoder(parent, _), v: FieldVector) => + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(parent).module.asModule + val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] + new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) { + override def value(i: Int): Enumeration#Value = { + enumeration.withName(reader.getString(i)) + } + } + case (BinaryEncoder, v: FieldVector) => + new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) { + override def value(i: Int): Array[Byte] = reader.getBytes(i) + } + case (SparkDecimalEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) { + override def value(i: Int): Decimal = reader.getDecimal(i) + } + case (ScalaDecimalEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) { + override def value(i: Int): BigDecimal = reader.getScalaDecimal(i) + } + case (JavaDecimalEncoder(_, _), v: FieldVector) => + new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) { + override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i) + } + case (ScalaBigIntEncoder, v: FieldVector) => + new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) { + override def value(i: Int): BigInt = reader.getScalaBigInt(i) + } + case (JavaBigIntEncoder, v: FieldVector) => + new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) { + override def value(i: Int): JBigInteger = reader.getJavaBigInt(i) + } + case (DayTimeIntervalEncoder, v: FieldVector) => + new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) { + override def value(i: Int): Duration = reader.getDuration(i) + } + case (YearMonthIntervalEncoder, v: FieldVector) => + new LeafFieldDeserializer[Period](encoder, v, timeZoneId) { + override def value(i: Int): Period = reader.getPeriod(i) + } + case (DateEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) { + override def value(i: Int): java.sql.Date = reader.getDate(i) + } + case (LocalDateEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) { + override def value(i: Int): LocalDate = reader.getLocalDate(i) + } + case (TimestampEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) { + override def value(i: Int): java.sql.Timestamp = reader.getTimestamp(i) + } + case (InstantEncoder(_), v: FieldVector) => + new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) { + override def value(i: Int): Instant = reader.getInstant(i) + } + case (LocalDateTimeEncoder, v: FieldVector) => + new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) { + override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i) + } + case (LocalTimeEncoder, v: FieldVector) => + new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) { + override def value(i: Int): LocalTime = reader.getLocalTime(i) } - } - case (BinaryEncoder, v: FieldVector) => - new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) { - override def value(i: Int): Array[Byte] = reader.getBytes(i) - } - case (SparkDecimalEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) { - override def value(i: Int): Decimal = reader.getDecimal(i) - } - case (ScalaDecimalEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) { - override def value(i: Int): BigDecimal = reader.getScalaDecimal(i) - } - case (JavaDecimalEncoder(_, _), v: FieldVector) => - new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) { - override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i) - } - case (ScalaBigIntEncoder, v: FieldVector) => - new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) { - override def value(i: Int): BigInt = reader.getScalaBigInt(i) - } - case (JavaBigIntEncoder, v: FieldVector) => - new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) { - override def value(i: Int): JBigInteger = reader.getJavaBigInt(i) - } - case (DayTimeIntervalEncoder, v: FieldVector) => - new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) { - override def value(i: Int): Duration = reader.getDuration(i) - } - case (YearMonthIntervalEncoder, v: FieldVector) => - new LeafFieldDeserializer[Period](encoder, v, timeZoneId) { - override def value(i: Int): Period = reader.getPeriod(i) - } - case (DateEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) { - override def value(i: Int): java.sql.Date = reader.getDate(i) - } - case (LocalDateEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) { - override def value(i: Int): LocalDate = reader.getLocalDate(i) - } - case (TimestampEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) { - override def value(i: Int): java.sql.Timestamp = reader.getTimestamp(i) - } - case (InstantEncoder(_), v: FieldVector) => - new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) { - override def value(i: Int): Instant = reader.getInstant(i) - } - case (LocalDateTimeEncoder, v: FieldVector) => - new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) { - override def value(i: Int): LocalDateTime = reader.getLocalDateTime(i) - } - case (LocalTimeEncoder, v: FieldVector) => - new LeafFieldDeserializer[LocalTime](encoder, v, timeZoneId) { - override def value(i: Int): LocalTime = reader.getLocalTime(i) - } - case (OptionEncoder(value), v) => - val deserializer = deserializerFor(value, v, timeZoneId) - new Deserializer[Any] { - override def get(i: Int): Any = Option(deserializer.get(i)) - } + case (OptionEncoder(value), v) => + val deserializer = deserializerFor(value, v, timeZoneId) + new Deserializer[Any] { + override def get(i: Int): Any = Option(deserializer.get(i)) + } - case (ArrayEncoder(element, _), v: ListVector) => - val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) - new VectorFieldDeserializer[AnyRef, ListVector](v) { - def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag) - } + case (ArrayEncoder(element, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) + new VectorFieldDeserializer[AnyRef, ListVector](v) { + def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag) + } - case (IterableEncoder(tag, element, _, _), v: ListVector) => - val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) - if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) { - // mutable ArraySeq is a bit special because we need to use an array of the element type. - // Some parts of our codebase (unfortunately) rely on this for type inference on results. - new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) { - def value(i: Int): mutable.ArraySeq[Any] = { - val array = getArray(vector, i, deserializer)(element.clsTag) - ScalaCollectionUtils.wrap(array) + case (IterableEncoder(tag, element, _, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector, timeZoneId) + if (isSubClass(Classes.MUTABLE_ARRAY_SEQ, tag)) { + // mutable ArraySeq is a bit special because we need to use an array of the element type. + // Some parts of our codebase (unfortunately) rely on this for type inference on results. + new VectorFieldDeserializer[mutable.ArraySeq[Any], ListVector](v) { + def value(i: Int): mutable.ArraySeq[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + ScalaCollectionUtils.wrap(array) + } } - } - } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) { - new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { - def value(i: Int): immutable.ArraySeq[Any] = { - val array = getArray(vector, i, deserializer)(element.clsTag) - array.asInstanceOf[Array[_]].toImmutableArraySeq + } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) { + new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { + def value(i: Int): immutable.ArraySeq[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + array.asInstanceOf[Array[_]].toImmutableArraySeq + } } - } - } else if (isSubClass(Classes.ITERABLE, tag)) { - val companion = ScalaCollectionUtils.getIterableCompanion(tag) - new VectorFieldDeserializer[Iterable[Any], ListVector](v) { - def value(i: Int): Iterable[Any] = { - val builder = companion.newBuilder[Any] - loadListIntoBuilder(vector, i, deserializer, builder) - builder.result() + } else if (isSubClass(Classes.ITERABLE, tag)) { + val companion = ScalaCollectionUtils.getIterableCompanion(tag) + new VectorFieldDeserializer[Iterable[Any], ListVector](v) { + def value(i: Int): Iterable[Any] = { + val builder = companion.newBuilder[Any] + loadListIntoBuilder(vector, i, deserializer, builder) + builder.result() + } } - } - } else if (isSubClass(Classes.JLIST, tag)) { - val newInstance = resolveJavaListCreator(tag) - new VectorFieldDeserializer[JList[Any], ListVector](v) { - def value(i: Int): JList[Any] = { - var index = v.getElementStartIndex(i) - val end = v.getElementEndIndex(i) - val list = newInstance(end - index) - while (index < end) { - list.add(deserializer.get(index)) - index += 1 + } else if (isSubClass(Classes.JLIST, tag)) { + val newInstance = resolveJavaListCreator(tag) + new VectorFieldDeserializer[JList[Any], ListVector](v) { + def value(i: Int): JList[Any] = { + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + val list = newInstance(end - index) + while (index < end) { + list.add(deserializer.get(index)) + index += 1 + } + list } - list } + } else { + throw unsupportedCollectionType(tag.runtimeClass) } - } else { - throw unsupportedCollectionType(tag.runtimeClass) - } - case (MapEncoder(tag, key, value, _), v: MapVector) => - val structVector = v.getDataVector.asInstanceOf[StructVector] - val keyDeserializer = - deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), timeZoneId) - val valueDeserializer = - deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), timeZoneId) - if (isSubClass(Classes.MAP, tag)) { - val companion = ScalaCollectionUtils.getMapCompanion(tag) - new VectorFieldDeserializer[Map[Any, Any], MapVector](v) { - def value(i: Int): Map[Any, Any] = { - val builder = companion.newBuilder[Any, Any] - var index = v.getElementStartIndex(i) - val end = v.getElementEndIndex(i) - builder.sizeHint(end - index) - while (index < end) { - builder += (keyDeserializer.get(index) -> valueDeserializer.get(index)) - index += 1 + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val keyDeserializer = + deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), timeZoneId) + val valueDeserializer = + deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), timeZoneId) + if (isSubClass(Classes.MAP, tag)) { + val companion = ScalaCollectionUtils.getMapCompanion(tag) + new VectorFieldDeserializer[Map[Any, Any], MapVector](v) { + def value(i: Int): Map[Any, Any] = { + val builder = companion.newBuilder[Any, Any] + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + builder.sizeHint(end - index) + while (index < end) { + builder += (keyDeserializer.get(index) -> valueDeserializer.get(index)) + index += 1 + } + builder.result() } - builder.result() } - } - } else if (isSubClass(Classes.JMAP, tag)) { - val newInstance = resolveJavaMapCreator(tag) - new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) { - def value(i: Int): JMap[Any, Any] = { - val map = newInstance() - var index = v.getElementStartIndex(i) - val end = v.getElementEndIndex(i) - while (index < end) { - map.put(keyDeserializer.get(index), valueDeserializer.get(index)) - index += 1 + } else if (isSubClass(Classes.JMAP, tag)) { + val newInstance = resolveJavaMapCreator(tag) + new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) { + def value(i: Int): JMap[Any, Any] = { + val map = newInstance() + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + while (index < end) { + map.put(keyDeserializer.get(index), valueDeserializer.get(index)) + index += 1 + } + map } - map } + } else { + throw unsupportedCollectionType(tag.runtimeClass) } - } else { - throw unsupportedCollectionType(tag.runtimeClass) - } - case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) => - val outer = outerPointerGetter.map(_()).toSeq - // We should try to make this work with MethodHandles. - val Some(constructor) = - ScalaReflection.findConstructor( - tag.runtimeClass, - outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass)) - val deserializers = if (isTuple(tag.runtimeClass)) { - fields.zip(vectors).map { case (field, vector) => - deserializerFor(field.enc, vector, timeZoneId) + case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) => + val outer = outerPointerGetter.map(_()).toSeq + // We should try to make this work with MethodHandles. + val Some(constructor) = + ScalaReflection.findConstructor( + tag.runtimeClass, + outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass)) + val deserializers = if (isTuple(tag.runtimeClass)) { + fields.zip(vectors).map { case (field, vector) => + deserializerFor(field.enc, vector, timeZoneId) + } + } else { + val outerDeserializer = outer.map { value => + new Deserializer[Any] { + override def get(i: Int): Any = value + } + } + val lookup = createFieldLookup(vectors) + outerDeserializer ++ fields.map { field => + deserializerFor(field.enc, lookup(field.name), timeZoneId) + } } - } else { - val outerDeserializer = outer.map { value => - new Deserializer[Any] { - override def get(i: Int): Any = value + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef])) } } + + case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => val lookup = createFieldLookup(vectors) - outerDeserializer ++ fields.map { field => + val deserializers = fields.toArray.map { field => deserializerFor(field.enc, lookup(field.name), timeZoneId) } - } - new StructFieldSerializer[Any](struct) { - def value(i: Int): Any = { - constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef])) + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val values = deserializers.map(_.get(i)) + new GenericRowWithSchema(values, r.schema) + } } - } - case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => - val lookup = createFieldLookup(vectors) - val deserializers = fields.toArray.map { field => - deserializerFor(field.enc, lookup(field.name), timeZoneId) - } - new StructFieldSerializer[Any](struct) { - def value(i: Int): Any = { - val values = deserializers.map(_.get(i)) - new GenericRowWithSchema(values, r.schema) + case (_: GeometryEncoder, StructVectors(struct, vectors)) => + val gdser = new GeometryArrowSerDe + gdser.createDeserializer(struct, vectors, timeZoneId) + + case (_: GeographyEncoder, StructVectors(struct, vectors)) => + val gdser = new GeographyArrowSerDe + gdser.createDeserializer(struct, vectors, timeZoneId) + + case (VariantEncoder, StructVectors(struct, vectors)) => + assert(vectors.exists(_.getName == "value")) + assert( + vectors.exists(field => + field.getName == "metadata" && field.getField.getMetadata + .containsKey("variant") && field.getField.getMetadata.get("variant") == "true")) + val valueDecoder = + deserializerFor( + BinaryEncoder, + vectors + .find(_.getName == "value") + .getOrElse(throw CompilationErrors.columnNotFoundError("value")), + timeZoneId) + val metadataDecoder = + deserializerFor( + BinaryEncoder, + vectors + .find(_.getName == "metadata") + .getOrElse(throw CompilationErrors.columnNotFoundError("metadata")), + timeZoneId) + new StructFieldSerializer[VariantVal](struct) { + def value(i: Int): VariantVal = { + new VariantVal( + valueDecoder.get(i).asInstanceOf[Array[Byte]], + metadataDecoder.get(i).asInstanceOf[Array[Byte]]) + } } - } - case (_: GeometryEncoder, StructVectors(struct, vectors)) => - val gdser = new GeometryArrowSerDe - gdser.createDeserializer(struct, vectors, timeZoneId) - - case (_: GeographyEncoder, StructVectors(struct, vectors)) => - val gdser = new GeographyArrowSerDe - gdser.createDeserializer(struct, vectors, timeZoneId) - - case (VariantEncoder, StructVectors(struct, vectors)) => - assert(vectors.exists(_.getName == "value")) - assert( - vectors.exists(field => - field.getName == "metadata" && field.getField.getMetadata - .containsKey("variant") && field.getField.getMetadata.get("variant") == "true")) - val valueDecoder = - deserializerFor( - BinaryEncoder, - vectors - .find(_.getName == "value") - .getOrElse(throw CompilationErrors.columnNotFoundError("value")), - timeZoneId) - val metadataDecoder = - deserializerFor( - BinaryEncoder, - vectors - .find(_.getName == "metadata") - .getOrElse(throw CompilationErrors.columnNotFoundError("metadata")), - timeZoneId) - new StructFieldSerializer[VariantVal](struct) { - def value(i: Int): VariantVal = { - new VariantVal( - valueDecoder.get(i).asInstanceOf[Array[Byte]], - metadataDecoder.get(i).asInstanceOf[Array[Byte]]) + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + val constructor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + val lookup = createFieldLookup(vectors) + val setters = fields + .filter(_.writeMethod.isDefined) + .map { field => + val vector = lookup(field.name) + val deserializer = deserializerFor(field.enc, vector, timeZoneId) + val setter = methodLookup.findVirtual( + tag.runtimeClass, + field.writeMethod.get, + MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) + (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val instance = constructor.invoke() + setters.foreach(_(instance, i)) + instance + } } - } - case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => - val constructor = - methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) - val lookup = createFieldLookup(vectors) - val setters = fields - .filter(_.writeMethod.isDefined) - .map { field => - val vector = lookup(field.name) - val deserializer = deserializerFor(field.enc, vector, timeZoneId) - val setter = methodLookup.findVirtual( - tag.runtimeClass, - field.writeMethod.get, - MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) - (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) - } - new StructFieldSerializer[Any](struct) { - def value(i: Int): Any = { - val instance = constructor.invoke() - setters.foreach(_(instance, i)) - instance + case (TransformingEncoder(_, encoder, provider, _), v) => + new Deserializer[Any] { + private[this] val codec = provider() + private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) + override def get(i: Int): Any = codec.decode(deserializer.get(i)) } - } - - case (TransformingEncoder(_, encoder, provider, _), v) => - new Deserializer[Any] { - private[this] val codec = provider() - private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) - override def get(i: Int): Any = codec.decode(deserializer.get(i)) - } - case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => - throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) - case _ => - throw new RuntimeException( - s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") - } + case _ => + throw new RuntimeException( + s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") + } private val methodLookup = MethodHandles.lookup() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala index 7168d2620bbb..d451c88ec490 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala @@ -85,14 +85,13 @@ private[connect] class TimeTypeConnectOps(val t: TimeType) SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) } - override def buildProtoDataType( - literal: proto.Expression.Literal, - builder: proto.DataType.Builder): Unit = { + override def getProtoDataTypeFromLiteral( + literal: proto.Expression.Literal): proto.DataType = { val timeBuilder = proto.DataType.Time.newBuilder() if (literal.getTime.hasPrecision) { timeBuilder.setPrecision(literal.getTime.getPrecision) } - builder.setTime(timeBuilder.build()) + proto.DataType.newBuilder().setTime(timeBuilder.build()).build() } // ==================== ConnectArrowTypeOps ==================== diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 3ccdcb73a869..44fe0044f195 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -33,69 +33,70 @@ object DataTypeProtoConverter { def toCatalystType(t: proto.DataType): DataType = ProtoTypeOps.toCatalystType(t).getOrElse(toCatalystTypeDefault(t)) - private def toCatalystTypeDefault(t: proto.DataType): DataType = t.getKindCase match { - case proto.DataType.KindCase.NULL => NullType + private def toCatalystTypeDefault(t: proto.DataType): DataType = + t.getKindCase match { + case proto.DataType.KindCase.NULL => NullType - case proto.DataType.KindCase.BINARY => BinaryType + case proto.DataType.KindCase.BINARY => BinaryType - case proto.DataType.KindCase.BOOLEAN => BooleanType + case proto.DataType.KindCase.BOOLEAN => BooleanType - case proto.DataType.KindCase.BYTE => ByteType - case proto.DataType.KindCase.SHORT => ShortType - case proto.DataType.KindCase.INTEGER => IntegerType - case proto.DataType.KindCase.LONG => LongType + case proto.DataType.KindCase.BYTE => ByteType + case proto.DataType.KindCase.SHORT => ShortType + case proto.DataType.KindCase.INTEGER => IntegerType + case proto.DataType.KindCase.LONG => LongType - case proto.DataType.KindCase.FLOAT => FloatType - case proto.DataType.KindCase.DOUBLE => DoubleType - case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal) + case proto.DataType.KindCase.FLOAT => FloatType + case proto.DataType.KindCase.DOUBLE => DoubleType + case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal) - case proto.DataType.KindCase.STRING => toCatalystStringType(t.getString) - case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength) - case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength) + case proto.DataType.KindCase.STRING => toCatalystStringType(t.getString) + case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength) + case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength) - case proto.DataType.KindCase.DATE => DateType - case proto.DataType.KindCase.TIMESTAMP => TimestampType - case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType - case proto.DataType.KindCase.TIME => - if (t.getTime.hasPrecision) { - TimeType(t.getTime.getPrecision) - } else { - TimeType() - } + case proto.DataType.KindCase.DATE => DateType + case proto.DataType.KindCase.TIMESTAMP => TimestampType + case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType + case proto.DataType.KindCase.TIME => + if (t.getTime.hasPrecision) { + TimeType(t.getTime.getPrecision) + } else { + TimeType() + } - case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType - case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => - toCatalystYearMonthIntervalType(t.getYearMonthInterval) - case proto.DataType.KindCase.DAY_TIME_INTERVAL => - toCatalystDayTimeIntervalType(t.getDayTimeInterval) - - case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray) - case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct) - case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap) - case proto.DataType.KindCase.VARIANT => VariantType - - case proto.DataType.KindCase.GEOMETRY => - val srid = t.getGeometry.getSrid - if (srid == GeometryType.MIXED_SRID) { - GeometryType("ANY") - } else { - GeometryType(srid) - } - case proto.DataType.KindCase.GEOGRAPHY => - val srid = t.getGeography.getSrid - if (srid == GeographyType.MIXED_SRID) { - GeographyType("ANY") - } else { - GeographyType(srid) - } + case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType + case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => + toCatalystYearMonthIntervalType(t.getYearMonthInterval) + case proto.DataType.KindCase.DAY_TIME_INTERVAL => + toCatalystDayTimeIntervalType(t.getDayTimeInterval) + + case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray) + case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct) + case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap) + case proto.DataType.KindCase.VARIANT => VariantType + + case proto.DataType.KindCase.GEOMETRY => + val srid = t.getGeometry.getSrid + if (srid == GeometryType.MIXED_SRID) { + GeometryType("ANY") + } else { + GeometryType(srid) + } + case proto.DataType.KindCase.GEOGRAPHY => + val srid = t.getGeography.getSrid + if (srid == GeographyType.MIXED_SRID) { + GeographyType("ANY") + } else { + GeographyType(srid) + } - case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt) + case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt) - case _ => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_PROTO_TO_CATALYST", - Map("kindCase" -> t.getKindCase.toString)) - } + case _ => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_PROTO_TO_CATALYST", + Map("kindCase" -> t.getKindCase.toString)) + } private def toCatalystDecimalType(t: proto.DataType.Decimal): DecimalType = { (t.hasPrecision, t.hasScale) match { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 169fc2896480..c3df9dc4833f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -66,70 +66,71 @@ object LiteralValueProtoConverter { private def toLiteralProtoBuilderDefault( literal: Any, builder: proto.Expression.Literal.Builder, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = literal match { - case v: Boolean => builder.setBoolean(v) - case v: Byte => builder.setByte(v) - case v: Short => builder.setShort(v) - case v: Int => builder.setInteger(v) - case v: Long => builder.setLong(v) - case v: Float => builder.setFloat(v) - case v: Double => builder.setDouble(v) - case v: BigDecimal => - builder.setDecimal( - builder.getDecimalBuilder - .setPrecision(v.precision) - .setScale(v.scale) - .setValue(v.toString)) - case v: JBigDecimal => - builder.setDecimal( - builder.getDecimalBuilder - .setPrecision(v.precision) - .setScale(v.scale) - .setValue(v.toString)) - case v: String => builder.setString(v) - case v: Char => builder.setString(v.toString) - case v: Array[Char] => builder.setString(String.valueOf(v)) - case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) - case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) - case v: immutable.ArraySeq[_] => - toLiteralProtoBuilderInternal(v.unsafeArray, options) - case v: LocalDate => builder.setDate(v.toEpochDay.toInt) - case v: Decimal => - builder.setDecimal( - builder.getDecimalBuilder - .setPrecision(Math.max(v.precision, v.scale)) - .setScale(v.scale) - .setValue(v.toString)) - case v: Instant => builder.setTimestamp(SparkDateTimeUtils.instantToMicros(v)) - case v: Timestamp => builder.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)) - case v: LocalDateTime => - builder.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)) - case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) - case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) - case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) - case v: LocalTime => - builder.setTime( - builder.getTimeBuilder - .setNano(SparkDateTimeUtils.localTimeToNanos(v)) - .setPrecision(TimeType.DEFAULT_PRECISION)) - case v: Array[_] => - val ab = builder.getArrayBuilder - v.foreach { x => - ab.addElements(toLiteralProtoBuilderInternal(x, options).build()) - } - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(toDataType(v.getClass.getComponentType))) - } - builder.setArray(ab) - case v: CalendarInterval => - builder.setCalendarInterval( - builder.getCalendarIntervalBuilder - .setMonths(v.months) - .setDays(v.days) - .setMicroseconds(v.microseconds)) - case null => builder.setNull(ProtoDataTypes.NullType) - case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") - } + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = + literal match { + case v: Boolean => builder.setBoolean(v) + case v: Byte => builder.setByte(v) + case v: Short => builder.setShort(v) + case v: Int => builder.setInteger(v) + case v: Long => builder.setLong(v) + case v: Float => builder.setFloat(v) + case v: Double => builder.setDouble(v) + case v: BigDecimal => + builder.setDecimal( + builder.getDecimalBuilder + .setPrecision(v.precision) + .setScale(v.scale) + .setValue(v.toString)) + case v: JBigDecimal => + builder.setDecimal( + builder.getDecimalBuilder + .setPrecision(v.precision) + .setScale(v.scale) + .setValue(v.toString)) + case v: String => builder.setString(v) + case v: Char => builder.setString(v.toString) + case v: Array[Char] => builder.setString(String.valueOf(v)) + case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) + case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) + case v: immutable.ArraySeq[_] => + toLiteralProtoBuilderInternal(v.unsafeArray, options) + case v: LocalDate => builder.setDate(v.toEpochDay.toInt) + case v: Decimal => + builder.setDecimal( + builder.getDecimalBuilder + .setPrecision(Math.max(v.precision, v.scale)) + .setScale(v.scale) + .setValue(v.toString)) + case v: Instant => builder.setTimestamp(SparkDateTimeUtils.instantToMicros(v)) + case v: Timestamp => builder.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)) + case v: LocalDateTime => + builder.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)) + case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) + case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) + case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) + case v: LocalTime => + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION)) + case v: Array[_] => + val ab = builder.getArrayBuilder + v.foreach { x => + ab.addElements(toLiteralProtoBuilderInternal(x, options).build()) + } + if (options.useDeprecatedDataTypeFields) { + ab.setElementType(toConnectProtoType(toDataType(v.getClass.getComponentType))) + } + builder.setArray(ab) + case v: CalendarInterval => + builder.setCalendarInterval( + builder.getCalendarIntervalBuilder + .setMonths(v.months) + .setDays(v.days) + .setMicroseconds(v.microseconds)) + case null => builder.setNull(ProtoDataTypes.NullType) + case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") + } private def toLiteralProtoBuilderInternal( literal: Any, @@ -402,52 +403,58 @@ object LiteralValueProtoConverter { getScalaConverter(getProtoDataType(literal))(literal) } + private def getScalaConverterDefault( + dataType: proto.DataType): proto.Expression.Literal => Any = { + val converter: proto.Expression.Literal => Any = dataType.getKindCase match { + case proto.DataType.KindCase.NULL => + v => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.EXPECTED_NULL_VALUE", + Map("literalTypeCase" -> v.getLiteralTypeCase.toString)) + case proto.DataType.KindCase.SHORT => v => v.getShort.toShort + case proto.DataType.KindCase.INTEGER => v => v.getInteger + case proto.DataType.KindCase.LONG => v => v.getLong + case proto.DataType.KindCase.DOUBLE => v => v.getDouble + case proto.DataType.KindCase.BYTE => v => v.getByte.toByte + case proto.DataType.KindCase.FLOAT => v => v.getFloat + case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean + case proto.DataType.KindCase.STRING => v => v.getString + case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray + case proto.DataType.KindCase.DATE => + v => SparkDateTimeUtils.toJavaDate(v.getDate) + case proto.DataType.KindCase.TIMESTAMP => + v => SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) + case proto.DataType.KindCase.TIMESTAMP_NTZ => + v => SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) + case proto.DataType.KindCase.DAY_TIME_INTERVAL => + v => SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) + case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => + v => SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) + case proto.DataType.KindCase.TIME => + v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) + case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) + case proto.DataType.KindCase.CALENDAR_INTERVAL => + v => + val interval = v.getCalendarInterval + new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) + case proto.DataType.KindCase.ARRAY => + v => toScalaArrayInternal(v, dataType.getArray) + case proto.DataType.KindCase.MAP => + v => toScalaMapInternal(v, dataType.getMap) + case proto.DataType.KindCase.STRUCT => + v => toScalaStructInternal(v, dataType.getStruct) + case _ => + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", + Map("typeInfo" -> dataType.getKindCase.toString)) + } + converter + } + private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { val converter: proto.Expression.Literal => Any = ProtoTypeOps.getScalaConverterForKind(dataType.getKindCase).getOrElse { - dataType.getKindCase match { - case proto.DataType.KindCase.NULL => - v => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.EXPECTED_NULL_VALUE", - Map("literalTypeCase" -> v.getLiteralTypeCase.toString)) - case proto.DataType.KindCase.SHORT => v => v.getShort.toShort - case proto.DataType.KindCase.INTEGER => v => v.getInteger - case proto.DataType.KindCase.LONG => v => v.getLong - case proto.DataType.KindCase.DOUBLE => v => v.getDouble - case proto.DataType.KindCase.BYTE => v => v.getByte.toByte - case proto.DataType.KindCase.FLOAT => v => v.getFloat - case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean - case proto.DataType.KindCase.STRING => v => v.getString - case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray - case proto.DataType.KindCase.DATE => - v => SparkDateTimeUtils.toJavaDate(v.getDate) - case proto.DataType.KindCase.TIMESTAMP => - v => SparkDateTimeUtils.toJavaTimestamp(v.getTimestamp) - case proto.DataType.KindCase.TIMESTAMP_NTZ => - v => SparkDateTimeUtils.microsToLocalDateTime(v.getTimestampNtz) - case proto.DataType.KindCase.DAY_TIME_INTERVAL => - v => SparkIntervalUtils.microsToDuration(v.getDayTimeInterval) - case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => - v => SparkIntervalUtils.monthsToPeriod(v.getYearMonthInterval) - case proto.DataType.KindCase.TIME => - v => SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) - case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) - case proto.DataType.KindCase.CALENDAR_INTERVAL => - v => - val interval = v.getCalendarInterval - new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) - case proto.DataType.KindCase.ARRAY => - v => toScalaArrayInternal(v, dataType.getArray) - case proto.DataType.KindCase.MAP => - v => toScalaMapInternal(v, dataType.getMap) - case proto.DataType.KindCase.STRUCT => - v => toScalaStructInternal(v, dataType.getStruct) - case _ => - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", - Map("typeInfo" -> dataType.getKindCase.toString)) - } + getScalaConverterDefault(dataType) } v => if (v.hasNull) null else converter(v) } @@ -521,104 +528,8 @@ object LiteralValueProtoConverter { if (literal.getLiteralTypeCase == proto.Expression.Literal.LiteralTypeCase.NULL) { literal.getNull } else { - val builder = proto.DataType.newBuilder() - if (!ProtoTypeOps.buildProtoDataTypeForLiteral(literal, builder)) { - literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.BINARY => - builder.setBinary(proto.DataType.Binary.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BYTE => - builder.setByte(proto.DataType.Byte.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.SHORT => - builder.setShort(proto.DataType.Short.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - builder.setInteger(proto.DataType.Integer.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.LONG => - builder.setLong(proto.DataType.Long.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - builder.setFloat(proto.DataType.Float.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - builder.setDouble(proto.DataType.Double.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - val decimal = Decimal.apply(literal.getDecimal.getValue) - var precision = decimal.precision - if (literal.getDecimal.hasPrecision) { - precision = math.max(precision, literal.getDecimal.getPrecision) - } - var scale = decimal.scale - if (literal.getDecimal.hasScale) { - scale = math.max(scale, literal.getDecimal.getScale) - } - builder.setDecimal( - proto.DataType.Decimal - .newBuilder() - .setPrecision(math.max(precision, scale)) - .setScale(scale) - .build()) - case proto.Expression.Literal.LiteralTypeCase.STRING => - builder.setString(proto.DataType.String.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DATE => - builder.setDate(proto.DataType.Date.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIME => - val timeBuilder = proto.DataType.Time.newBuilder() - if (literal.getTime.hasPrecision) { - timeBuilder.setPrecision(literal.getTime.getPrecision) - } - builder.setTime(timeBuilder.build()) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - if (literal.getArray.hasElementType) { - builder.setArray( - proto.DataType.Array - .newBuilder() - .setElementType(literal.getArray.getElementType) - .setContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.MAP => - if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { - builder.setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(literal.getMap.getKeyType) - .setValueType(literal.getMap.getValueType) - .setValueContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - if (literal.getStruct.hasStructType) { - builder.setStruct(literal.getStruct.getStructType.getStruct) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case _ => - val literalCase = literal.getLiteralTypeCase - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", - Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) - } - } - builder.build() + ProtoTypeOps.getProtoDataTypeFromLiteral(literal) + .getOrElse(getProtoDataTypeDefault(literal)) } } @@ -633,6 +544,105 @@ object LiteralValueProtoConverter { dataType } + private def getProtoDataTypeDefault(literal: proto.Expression.Literal): proto.DataType = { + val builder = proto.DataType.newBuilder() + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.BINARY => + builder.setBinary(proto.DataType.Binary.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => + builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BYTE => + builder.setByte(proto.DataType.Byte.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.SHORT => + builder.setShort(proto.DataType.Short.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.INTEGER => + builder.setInteger(proto.DataType.Integer.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.LONG => + builder.setLong(proto.DataType.Long.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.FLOAT => + builder.setFloat(proto.DataType.Float.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DOUBLE => + builder.setDouble(proto.DataType.Double.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DECIMAL => + val decimal = Decimal.apply(literal.getDecimal.getValue) + var precision = decimal.precision + if (literal.getDecimal.hasPrecision) { + precision = math.max(precision, literal.getDecimal.getPrecision) + } + var scale = decimal.scale + if (literal.getDecimal.hasScale) { + scale = math.max(scale, literal.getDecimal.getScale) + } + builder.setDecimal( + proto.DataType.Decimal + .newBuilder() + .setPrecision(math.max(precision, scale)) + .setScale(scale) + .build()) + case proto.Expression.Literal.LiteralTypeCase.STRING => + builder.setString(proto.DataType.String.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DATE => + builder.setDate(proto.DataType.Date.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => + builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => + builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => + builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => + builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => + builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIME => + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + builder.setTime(timeBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + if (literal.getArray.hasElementType) { + builder.setArray( + proto.DataType.Array + .newBuilder() + .setElementType(literal.getArray.getElementType) + .setContainsNull(true) + .build()) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.MAP => + if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { + builder.setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(literal.getMap.getKeyType) + .setValueType(literal.getMap.getValueType) + .setValueContainsNull(true) + .build()) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + if (literal.getStruct.hasStructType) { + builder.setStruct(literal.getStruct.getStructType.getStruct) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case _ => + val literalCase = literal.getLiteralTypeCase + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", + Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) + } + builder.build() + } + private def toScalaArrayInternal( literal: proto.Expression.Literal, arrayType: proto.DataType.Array): Array[_] = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala index eaf0ec9e94b9..0887c06591a9 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala @@ -51,8 +51,8 @@ trait ProtoTypeOps extends Serializable { /** Returns a converter from proto literal to Scala value. */ def getScalaConverter: proto.Expression.Literal => Any - /** Builds a proto DataType from a proto literal (for type inference). */ - def buildProtoDataType(literal: proto.Expression.Literal, builder: proto.DataType.Builder): Unit + /** Returns a proto DataType inferred from a proto literal (for type inference). */ + def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): proto.DataType } /** @@ -109,17 +109,16 @@ object ProtoTypeOps { } /** - * Reverse lookup: builds a proto DataType from a proto literal's type case. + * Reverse lookup: returns the proto DataType inferred from a proto literal's type case, if the + * literal type belongs to a framework-managed type. */ - def buildProtoDataTypeForLiteral( - literal: proto.Expression.Literal, - builder: proto.DataType.Builder): Boolean = { - if (!SqlApiConf.get.typesFrameworkEnabled) return false + def getProtoDataTypeFromLiteral( + literal: proto.Expression.Literal): Option[proto.DataType] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None literal.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.TIME => - new TimeTypeConnectOps(TimeType()).buildProtoDataType(literal, builder) - true - case _ => false + Some(new TimeTypeConnectOps(TimeType()).getProtoDataTypeFromLiteral(literal)) + case _ => None } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 6a0a2353c6aa..d329ca8494d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -170,8 +170,8 @@ object EvaluatePython { case c: Int => c } - case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => - (obj: Any) => nullSafeConvert(obj) { + case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: TimeType => (obj: Any) => + nullSafeConvert(obj) { case c: Long => c // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs case c: Int => c.toLong @@ -236,8 +236,7 @@ object EvaluatePython { case VariantType => (obj: Any) => nullSafeConvert(obj) { case s: java.util.HashMap[_, _] => new VariantVal( - s.get("value").asInstanceOf[Array[Byte]], - s.get("metadata").asInstanceOf[Array[Byte]] + s.get("value").asInstanceOf[Array[Byte]], s.get("metadata").asInstanceOf[Array[Byte]] ) } From a7ad71daed1419e03835e52d0e4ec0973dc79fa3 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 21:15:24 +0000 Subject: [PATCH 5/8] more style --- .../apache/spark/sql/util/ArrowUtils.scala | 3 +- .../sql/execution/arrow/ArrowWriter.scala | 111 +++++++++--------- .../client/arrow/ArrowDeserializer.scala | 3 +- .../client/arrow/ArrowSerializer.scala | 3 +- .../arrow/types/ops/TimeTypeConnectOps.scala | 3 +- .../common/DataTypeProtoConverter.scala | 6 +- .../common/LiteralValueProtoConverter.scala | 7 +- .../common/types/ops/ProtoTypeOps.scala | 3 +- 8 files changed, 70 insertions(+), 69 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index ebe7ba6808ed..257f209dee9d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -60,8 +60,7 @@ private[sql] object ArrowUtils { case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => - new ArrowType.Decimal(precision, scale, 8 * 16) + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 8 * 16) case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType if timeZoneId == null => throw SparkException.internalError("Missing timezoneId where it is mandatory.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 84ac1d311658..3f3ad0773a3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -59,61 +59,62 @@ object ArrowWriter { } private[sql] def createFieldWriterDefault( - dt: DataType, vector: ValueVector): ArrowFieldWriter = (dt, vector) match { - case (BooleanType, vector: BitVector) => new BooleanWriter(vector) - case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) - case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) - case (IntegerType, vector: IntVector) => new IntegerWriter(vector) - case (LongType, vector: BigIntVector) => new LongWriter(vector) - case (FloatType, vector: Float4Vector) => new FloatWriter(vector) - case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) - case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => - new DecimalWriter(vector, precision, scale) - case (StringType, vector: VarCharVector) => new StringWriter(vector) - case (StringType, vector: LargeVarCharVector) => new LargeStringWriter(vector) - case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) - case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) - case (DateType, vector: DateDayVector) => new DateWriter(vector) - case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) - case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) - case (_: TimeType, vector: TimeNanoVector) => new TimeWriter(vector) - case (ArrayType(_, _), vector: ListVector) => - val elementVector = createFieldWriter(vector.getDataVector()) - new ArrayWriter(vector, elementVector) - case (MapType(_, _, _), vector: MapVector) => - val structVector = vector.getDataVector.asInstanceOf[StructVector] - val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) - val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) - new MapWriter(vector, structVector, keyWriter, valueWriter) - case (StructType(_), vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new StructWriter(vector, children.toArray) - case (NullType, vector: NullVector) => new NullWriter(vector) - case (_: YearMonthIntervalType, vector: IntervalYearVector) => - new IntervalYearWriter(vector) - case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) - case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => - new IntervalMonthDayNanoWriter(vector) - case (VariantType, vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new StructWriter(vector, children.toArray) - case (dt: GeometryType, vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new GeometryWriter(dt, vector, children.toArray) - case (dt: GeographyType, vector: StructVector) => - val children = (0 until vector.size()).map { ordinal => - createFieldWriter(vector.getChildByOrdinal(ordinal)) - } - new GeographyWriter(dt, vector, children.toArray) - case (dt, _) => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } + dt: DataType, vector: ValueVector): ArrowFieldWriter = + (dt, vector) match { + case (BooleanType, vector: BitVector) => new BooleanWriter(vector) + case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: IntVector) => new IntegerWriter(vector) + case (LongType, vector: BigIntVector) => new LongWriter(vector) + case (FloatType, vector: Float4Vector) => new FloatWriter(vector) + case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) + case (StringType, vector: VarCharVector) => new StringWriter(vector) + case (StringType, vector: LargeVarCharVector) => new LargeStringWriter(vector) + case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) + case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) + case (DateType, vector: DateDayVector) => new DateWriter(vector) + case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) + case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) + case (_: TimeType, vector: TimeNanoVector) => new TimeWriter(vector) + case (ArrayType(_, _), vector: ListVector) => + val elementVector = createFieldWriter(vector.getDataVector()) + new ArrayWriter(vector, elementVector) + case (MapType(_, _, _), vector: MapVector) => + val structVector = vector.getDataVector.asInstanceOf[StructVector] + val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) + val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) + new MapWriter(vector, structVector, keyWriter, valueWriter) + case (StructType(_), vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (NullType, vector: NullVector) => new NullWriter(vector) + case (_: YearMonthIntervalType, vector: IntervalYearVector) => + new IntervalYearWriter(vector) + case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) + case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => + new IntervalMonthDayNanoWriter(vector) + case (VariantType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (dt: GeometryType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new GeometryWriter(dt, vector, children.toArray) + case (dt: GeographyType, vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new GeographyWriter(dt, vector, children.toArray) + case (dt, _) => + throw ExecutionErrors.unsupportedDataTypeError(dt) + } } class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index ba523660757e..a05d6f80a9cf 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -99,7 +99,7 @@ object ArrowDeserializers { private def deserializerForDefault( encoder: AgnosticEncoder[_], data: AnyRef, - timeZoneId: String): Deserializer[Any] = + timeZoneId: String): Deserializer[Any] = { (encoder, data) match { case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) => new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) { @@ -423,6 +423,7 @@ object ArrowDeserializers { throw new RuntimeException( s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") } + } private val methodLookup = MethodHandles.lookup() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 5dff3413223d..17a3da981265 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -245,7 +245,7 @@ object ArrowSerializer { .map(_.createArrowSerializer(v).asInstanceOf[Serializer]) .getOrElse(serializerForDefault(encoder, v)) - private def serializerForDefault[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = + private def serializerForDefault[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { (encoder, v) match { case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => new FieldSerializer[Boolean, BitVector](v) { @@ -523,6 +523,7 @@ object ArrowSerializer { case _ => throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.") } + } private val methodLookup = MethodHandles.lookup() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala index d451c88ec490..958423978539 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala @@ -85,8 +85,7 @@ private[connect] class TimeTypeConnectOps(val t: TimeType) SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) } - override def getProtoDataTypeFromLiteral( - literal: proto.Expression.Literal): proto.DataType = { + override def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): proto.DataType = { val timeBuilder = proto.DataType.Time.newBuilder() if (literal.getTime.hasPrecision) { timeBuilder.setPrecision(literal.getTime.getPrecision) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 44fe0044f195..f5948d9b1463 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -33,7 +33,7 @@ object DataTypeProtoConverter { def toCatalystType(t: proto.DataType): DataType = ProtoTypeOps.toCatalystType(t).getOrElse(toCatalystTypeDefault(t)) - private def toCatalystTypeDefault(t: proto.DataType): DataType = + private def toCatalystTypeDefault(t: proto.DataType): DataType = { t.getKindCase match { case proto.DataType.KindCase.NULL => NullType @@ -97,6 +97,7 @@ object DataTypeProtoConverter { "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_PROTO_TO_CATALYST", Map("kindCase" -> t.getKindCase.toString)) } + } private def toCatalystDecimalType(t: proto.DataType.Decimal): DecimalType = { (t.hasPrecision, t.hasScale) match { @@ -182,7 +183,7 @@ object DataTypeProtoConverter { .map(_.toConnectProtoType) .getOrElse(toConnectProtoTypeDefault(t, bytesToBinary)) - private def toConnectProtoTypeDefault(t: DataType, bytesToBinary: Boolean): proto.DataType = + private def toConnectProtoTypeDefault(t: DataType, bytesToBinary: Boolean): proto.DataType = { t match { case NullType => ProtoDataTypes.NullType @@ -402,4 +403,5 @@ object DataTypeProtoConverter { "CONNECT_INVALID_PLAN.DATA_TYPE_UNSUPPORTED_CATALYST_TO_PROTO", Map("typeName" -> t.typeName)) } + } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index c3df9dc4833f..5478e6d35c65 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -528,7 +528,8 @@ object LiteralValueProtoConverter { if (literal.getLiteralTypeCase == proto.Expression.Literal.LiteralTypeCase.NULL) { literal.getNull } else { - ProtoTypeOps.getProtoDataTypeFromLiteral(literal) + ProtoTypeOps + .getProtoDataTypeFromLiteral(literal) .getOrElse(getProtoDataTypeDefault(literal)) } } @@ -622,9 +623,7 @@ object LiteralValueProtoConverter { .setValueContainsNull(true) .build()) } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", - Map.empty) + throw InvalidPlanInput("CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", Map.empty) } case proto.Expression.Literal.LiteralTypeCase.STRUCT => if (literal.getStruct.hasStructType) { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala index 0887c06591a9..c4e24597dbf1 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala @@ -112,8 +112,7 @@ object ProtoTypeOps { * Reverse lookup: returns the proto DataType inferred from a proto literal's type case, if the * literal type belongs to a framework-managed type. */ - def getProtoDataTypeFromLiteral( - literal: proto.Expression.Literal): Option[proto.DataType] = { + def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): Option[proto.DataType] = { if (!SqlApiConf.get.typesFrameworkEnabled) return None literal.getLiteralTypeCase match { case proto.Expression.Literal.LiteralTypeCase.TIME => From 775d5a3bb2f641e5e2b35c05d1adc750dccd932c Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 21:20:39 +0000 Subject: [PATCH 6/8] more style --- .../sql/execution/arrow/ArrowWriter.scala | 6 +-- .../datasources/jdbc/JdbcUtils.scala | 44 ++++++++++--------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 3f3ad0773a3f..0f6edbf60f76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -59,7 +59,7 @@ object ArrowWriter { } private[sql] def createFieldWriterDefault( - dt: DataType, vector: ValueVector): ArrowFieldWriter = + dt: DataType, vector: ValueVector): ArrowFieldWriter = { (dt, vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) @@ -92,8 +92,7 @@ object ArrowWriter { } new StructWriter(vector, children.toArray) case (NullType, vector: NullVector) => new NullWriter(vector) - case (_: YearMonthIntervalType, vector: IntervalYearVector) => - new IntervalYearWriter(vector) + case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => new IntervalMonthDayNanoWriter(vector) @@ -115,6 +114,7 @@ object ArrowWriter { case (dt, _) => throw ExecutionErrors.unsupportedDataTypeError(dt) } + } } class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 834fdde6d581..1c28bc2e105d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -150,27 +150,29 @@ object JdbcUtils extends Logging with SQLConfHelper { ClientTypeOps(dt).map(ops => JdbcType(ops.jdbcTypeName, ops.getJdbcType)) .orElse(getCommonJDBCTypeDefault(dt)) - private def getCommonJDBCTypeDefault(dt: DataType): Option[JdbcType] = dt match { - case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) - case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) - case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) - case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) - case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) - case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) - case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) - case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) - case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) - case c: CharType => Option(JdbcType(s"CHAR(${c.length})", java.sql.Types.CHAR)) - case v: VarcharType => Option(JdbcType(s"VARCHAR(${v.length})", java.sql.Types.VARCHAR)) - case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) - // This is a common case of timestamp without time zone. Most of the databases either only - // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. - // Note that some dialects override this setting, e.g. as SQL Server. - case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) - case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) - case t: DecimalType => Option( - JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) - case _ => None + private def getCommonJDBCTypeDefault(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case c: CharType => Option(JdbcType(s"CHAR(${c.length})", java.sql.Types.CHAR)) + case v: VarcharType => Option(JdbcType(s"VARCHAR(${v.length})", java.sql.Types.VARCHAR)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + // This is a common case of timestamp without time zone. Most of the databases either only + // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. + // Note that some dialects override this setting, e.g. as SQL Server. + case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None + } } def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { From 9fdf5c6ece6765ab0a9ffc1019c7a747bb3f7ce0 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 21:27:04 +0000 Subject: [PATCH 7/8] more style --- .../common/LiteralValueProtoConverter.scala | 59 +++++++++---------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 5478e6d35c65..8c8b9a3429d3 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -66,7 +66,29 @@ object LiteralValueProtoConverter { private def toLiteralProtoBuilderDefault( literal: Any, builder: proto.Expression.Literal.Builder, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { + def decimalBuilder(precision: Int, scale: Int, value: String) = { + builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value) + } + + def calendarIntervalBuilder(months: Int, days: Int, microseconds: Long) = { + builder.getCalendarIntervalBuilder + .setMonths(months) + .setDays(days) + .setMicroseconds(microseconds) + } + + def arrayBuilder(array: Array[_]) = { + val ab = builder.getArrayBuilder + array.foreach { x => + ab.addElements(toLiteralProtoBuilderInternal(x, options).build()) + } + if (options.useDeprecatedDataTypeFields) { + ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) + } + ab + } + literal match { case v: Boolean => builder.setBoolean(v) case v: Byte => builder.setByte(v) @@ -76,17 +98,9 @@ object LiteralValueProtoConverter { case v: Float => builder.setFloat(v) case v: Double => builder.setDouble(v) case v: BigDecimal => - builder.setDecimal( - builder.getDecimalBuilder - .setPrecision(v.precision) - .setScale(v.scale) - .setValue(v.toString)) + builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) case v: JBigDecimal => - builder.setDecimal( - builder.getDecimalBuilder - .setPrecision(v.precision) - .setScale(v.scale) - .setValue(v.toString)) + builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString)) case v: String => builder.setString(v) case v: Char => builder.setString(v.toString) case v: Array[Char] => builder.setString(String.valueOf(v)) @@ -96,11 +110,7 @@ object LiteralValueProtoConverter { toLiteralProtoBuilderInternal(v.unsafeArray, options) case v: LocalDate => builder.setDate(v.toEpochDay.toInt) case v: Decimal => - builder.setDecimal( - builder.getDecimalBuilder - .setPrecision(Math.max(v.precision, v.scale)) - .setScale(v.scale) - .setValue(v.toString)) + builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString)) case v: Instant => builder.setTimestamp(SparkDateTimeUtils.instantToMicros(v)) case v: Timestamp => builder.setTimestamp(SparkDateTimeUtils.fromJavaTimestamp(v)) case v: LocalDateTime => @@ -113,24 +123,13 @@ object LiteralValueProtoConverter { builder.getTimeBuilder .setNano(SparkDateTimeUtils.localTimeToNanos(v)) .setPrecision(TimeType.DEFAULT_PRECISION)) - case v: Array[_] => - val ab = builder.getArrayBuilder - v.foreach { x => - ab.addElements(toLiteralProtoBuilderInternal(x, options).build()) - } - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(toDataType(v.getClass.getComponentType))) - } - builder.setArray(ab) + case v: Array[_] => builder.setArray(arrayBuilder(v)) case v: CalendarInterval => - builder.setCalendarInterval( - builder.getCalendarIntervalBuilder - .setMonths(v.months) - .setDays(v.days) - .setMicroseconds(v.microseconds)) + builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) case null => builder.setNull(ProtoDataTypes.NullType) case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).") } + } private def toLiteralProtoBuilderInternal( literal: Any, From 98668b1c7c85dcc961504a0b2e734aa826c98a5c Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Fri, 20 Mar 2026 21:49:49 +0000 Subject: [PATCH 8/8] minor last style hopefully --- .../spark/sql/connect/common/LiteralValueProtoConverter.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 8c8b9a3429d3..19e39befc8fb 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -242,6 +242,7 @@ object LiteralValueProtoConverter { .setPrecision(timeType.precision)) case _ => toLiteralProtoBuilderInternal(literal, options) } + } /**