diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala new file mode 100644 index 000000000000..5efe7fd33bdf --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/ClientTypeOps.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.ops + +import org.apache.arrow.vector.types.pojo.ArrowType + +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional client-side type operations for the Types Framework. + * + * This trait extends TypeApiOps with operations needed by client-facing infrastructure: Arrow + * conversion (ArrowUtils), JDBC mapping (JdbcUtils), Python interop (EvaluatePython), Hive + * formatting (HiveResult), and Thrift type mapping (SparkExecuteStatementOperation). + * + * Lives in sql/api so it's visible from sql/core and sql/hive-thriftserver. + * + * USAGE - integration points use ClientTypeOps(dt) which returns Option[ClientTypeOps]: + * {{{ + * // Forward lookup (most files): + * ClientTypeOps(dt).map(_.toArrowType(timeZoneId)).getOrElse { ... } + * + * // Reverse lookup (ArrowUtils.fromArrowType): + * ClientTypeOps.fromArrowType(at).getOrElse { ... } + * }}} + * + * @see + * TimeTypeApiOps for a reference implementation + * @since 4.2.0 + */ +trait ClientTypeOps { self: TypeApiOps => + + // ==================== Utilities ==================== + + /** + * Null-safe conversion helper. Returns null for null input, applies the partial function for + * non-null input, and returns null for unmatched values. + */ + protected def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, (_: Any) => null) + } + } + + // ==================== Arrow Conversion ==================== + + /** + * Converts this DataType to its Arrow representation. + * + * Used by ArrowUtils.toArrowType. + * + * @param timeZoneId + * the session timezone (needed by some temporal types) + * @return + * the corresponding ArrowType + */ + def toArrowType(timeZoneId: String): ArrowType + + // ==================== JDBC Mapping ==================== + + /** + * Returns the java.sql.Types constant for this type. + * + * Used by JdbcUtils.getCommonJDBCType for JDBC write path. + * + * @return + * java.sql.Types constant (e.g., java.sql.Types.TIME) + */ + def getJdbcType: Int + + /** + * Returns the DDL type name string for this type. + * + * Used by JdbcUtils for CREATE TABLE DDL generation. + * + * @return + * DDL type string (e.g., "TIME") + */ + def jdbcTypeName: String + + // ==================== Python Interop ==================== + + /** + * Returns true if values of this type need conversion when passed to/from Python. + * + * Used by EvaluatePython.needConversionInPython. + */ + def needConversionInPython: Boolean + + /** + * Creates a converter function for Python/Py4J interop. + * + * Used by EvaluatePython.makeFromJava. The returned function handles null-safe conversion of + * Java/Py4J values to the internal Catalyst representation. + * + * @return + * a function that converts a Java value to the internal representation + */ + def makeFromJava: Any => Any + + // ==================== Hive Formatting ==================== + + /** + * Formats an external-type value for Hive output. + * + * Used by HiveResult.toHiveString. The input is an external-type value (e.g., + * java.time.LocalTime for TimeType), NOT the internal representation. + * + * @param value + * the external-type value to format + * @return + * formatted string representation + */ + def formatExternal(value: Any): String + + // ==================== Thrift Mapping ==================== + + /** + * Returns the Thrift TTypeId name for this type. + * + * Used by SparkExecuteStatementOperation.toTTypeId. Returns a String that maps to a TTypeId + * enum value (e.g., "STRING_TYPE") since TTypeId is only available in the hive-thriftserver + * module. + * + * @return + * TTypeId enum name (e.g., "STRING_TYPE") + */ + def thriftTypeName: String +} + +/** + * Factory object for ClientTypeOps lookup. + * + * Delegates to TypeApiOps and narrows via collect to find implementations that mix in + * ClientTypeOps. + */ +object ClientTypeOps { + + /** + * Returns a ClientTypeOps instance for the given DataType, if available. + * + * @param dt + * the DataType to get operations for + * @return + * Some(ClientTypeOps) if supported, None otherwise + */ + // Delegates to TypeApiOps and narrows: a type must implement TypeApiOps AND mix in + // ClientTypeOps to be found here. No separate registration needed — the collect + // filter handles incremental trait adoption automatically. + def apply(dt: DataType): Option[ClientTypeOps] = + TypeApiOps(dt).collect { case co: ClientTypeOps => co } + + /** + * Reverse lookup: converts an Arrow type to a Spark DataType, if it belongs to a + * framework-managed type. + * + * Used by ArrowUtils.fromArrowType. Returns None if the Arrow type doesn't correspond to any + * framework-managed type, or the framework is disabled. + * + * @param at + * the ArrowType to convert + * @return + * Some(DataType) if recognized, None otherwise + */ + def fromArrowType(at: ArrowType): Option[DataType] = { + import org.apache.arrow.vector.types.TimeUnit + if (!SqlApiConf.get.typesFrameworkEnabled) return None + at match { + case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 => + Some(TimeType(TimeType.MICROS_PRECISION)) + // Add new framework types here + case _ => None + } + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala index 581ffffff2f9..1ddceabd61b2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.types.ops +import java.time.LocalTime + +import org.apache.arrow.vector.types.TimeUnit +import org.apache.arrow.vector.types.pojo.ArrowType + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder -import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, TimeFormatter} +import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, SparkDateTimeUtils, TimeFormatter} import org.apache.spark.sql.types.{DataType, TimeType} /** @@ -29,6 +34,13 @@ import org.apache.spark.sql.types.{DataType, TimeType} * - String formatting: uses FractionTimeFormatter for consistent output * - Row encoding: uses LocalTimeEncoder for java.time.LocalTime * + * Additionally, it implements ClientTypeOps for: + * - Arrow conversion (ArrowUtils) + * - JDBC mapping (JdbcUtils) + * - Python interop (EvaluatePython) + * - Hive formatting (HiveResult) + * - Thrift type mapping (SparkExecuteStatementOperation) + * * RELATIONSHIP TO TimeTypeOps: TimeTypeOps (in catalyst package) extends this class to inherit * client-side operations while adding server-side operations (physical type, literals, etc.). * @@ -36,7 +48,7 @@ import org.apache.spark.sql.types.{DataType, TimeType} * The TimeType with precision information * @since 4.2.0 */ -class TimeTypeApiOps(val t: TimeType) extends TypeApiOps { +class TimeTypeApiOps(val t: TimeType) extends TypeApiOps with ClientTypeOps { override def dataType: DataType = t @@ -56,4 +68,30 @@ class TimeTypeApiOps(val t: TimeType) extends TypeApiOps { // ==================== Row Encoding ==================== override def getEncoder: AgnosticEncoder[_] = LocalTimeEncoder + + // ==================== Client Type Operations (ClientTypeOps) ==================== + + override def toArrowType(timeZoneId: String): ArrowType = { + new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8) + } + + override def getJdbcType: Int = java.sql.Types.TIME + + override def jdbcTypeName: String = "TIME" + + override def needConversionInPython: Boolean = true + + override def makeFromJava: Any => Any = (obj: Any) => + nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + } + + override def formatExternal(value: Any): String = { + val nanos = SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime]) + timeFormatter.format(nanos) + } + + override def thriftTypeName: String = "STRING_TYPE" } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 92b52d4ae634..257f209dee9d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.SparkException import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.util.ArrayImplicits._ private[sql] object ArrowUtils { @@ -39,6 +40,14 @@ private[sql] object ArrowUtils { /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = + ClientTypeOps(dt) + .map(_.toArrowType(timeZoneId)) + .getOrElse(toArrowTypeDefault(dt, timeZoneId, largeVarTypes)) + + private def toArrowTypeDefault( + dt: DataType, + timeZoneId: String, + largeVarTypes: Boolean): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) @@ -67,7 +76,10 @@ private[sql] object ArrowUtils { throw ExecutionErrors.unsupportedDataTypeError(dt) } - def fromArrowType(dt: ArrowType): DataType = dt match { + def fromArrowType(dt: ArrowType): DataType = + ClientTypeOps.fromArrowType(dt).getOrElse(fromArrowTypeDefault(dt)) + + private def fromArrowTypeDefault(dt: ArrowType): DataType = dt match { case ArrowType.Bool.INSTANCE => BooleanType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 080794643fa0..875898952691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} +import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -289,6 +290,14 @@ object DeserializerBuildHelper { * @param isTopLevel true if we are creating a deserializer for the top level value. */ private def createDeserializer( + enc: AgnosticEncoder[_], + path: Expression, + walkedTypePath: WalkedTypePath, + isTopLevel: Boolean = false): Expression = + CatalystTypeOps(enc.dataType).map(_.createDeserializer(path)) + .getOrElse(createDeserializerDefault(enc, path, walkedTypePath, isTopLevel)) + + private def createDeserializerDefault( enc: AgnosticEncoder[_], path: Expression, walkedTypePath: WalkedTypePath, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index b8b2406a5813..d1e263d98276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -332,7 +333,12 @@ object SerializerBuildHelper { * representation. The mapping between the external and internal representations is described * by encoder `enc`. */ - private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { + private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = + CatalystTypeOps(enc.dataType).map(_.createSerializer(input)) + .getOrElse(createSerializerDefault(enc, input)) + + private def createSerializerDefault( + enc: AgnosticEncoder[_], input: Expression): Expression = enc match { case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input) case _ if isNativeEncoder(enc) => input case BoxedBooleanEncoder => createSerializerForBoolean(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala new file mode 100644 index 000000000000..e366436fea39 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/CatalystTypeOps.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.ops + +import org.apache.arrow.vector.ValueVector + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.arrow.ArrowFieldWriter +import org.apache.spark.sql.types.DataType + +/** + * Optional catalyst-layer type operations for the Types Framework. + * + * This trait extends TypeOps with operations needed by catalyst-level client infrastructure: + * serializer/deserializer expression building (SerializerBuildHelper, DeserializerBuildHelper) + * and Arrow field writer creation (ArrowWriter). + * + * USAGE - integration points use CatalystTypeOps(dt) which returns Option[CatalystTypeOps]: + * {{{ + * // DataType-keyed (ArrowWriter): + * CatalystTypeOps(dt).map(_.createArrowFieldWriter(vector)).getOrElse { ... } + * + * // Encoder-keyed (SerializerBuildHelper): use enc.dataType to get the DataType + * CatalystTypeOps(enc.dataType).map(_.createSerializer(input)).getOrElse { ... } + * }}} + * + * @see + * TimeTypeOps for a reference implementation + * @since 4.2.0 + */ +trait CatalystTypeOps { self: TypeOps => + + /** + * Creates a serializer expression that converts an external value to its internal + * Catalyst representation. + * + * Used by SerializerBuildHelper for Dataset[T] serialization. + * + * @param input + * the input expression representing the external value + * @return + * an Expression that performs the conversion + */ + def createSerializer(input: Expression): Expression + + /** + * Creates a deserializer expression that converts an internal Catalyst value + * to its external representation. + * + * Used by DeserializerBuildHelper for Dataset[T] deserialization. + * + * @param path + * the expression representing the internal value + * @return + * an Expression that performs the conversion + */ + def createDeserializer(path: Expression): Expression + + /** + * Creates an ArrowFieldWriter for writing values of this type to an Arrow vector. + * + * Used by ArrowWriter for Arrow-based data exchange. + * + * @param vector + * the Arrow ValueVector to write to (must be the correct vector type for this data type) + * @return + * an ArrowFieldWriter configured for this type + */ + def createArrowFieldWriter(vector: ValueVector): ArrowFieldWriter +} + +/** + * Factory object for CatalystTypeOps lookup. + * + * Delegates to TypeOps and narrows via collect to find implementations that mix in + * CatalystTypeOps. Returns None if the type is not supported, the framework is disabled, + * or the type's TypeOps does not implement CatalystTypeOps. + */ +object CatalystTypeOps { + + /** + * Returns a CatalystTypeOps instance for the given DataType, if available. + * + * @param dt + * the DataType to get operations for + * @return + * Some(CatalystTypeOps) if supported, None otherwise + */ + // Delegates to TypeOps and narrows: a type must implement TypeOps AND mix in + // CatalystTypeOps to be found here. + def apply(dt: DataType): Option[CatalystTypeOps] = + TypeOps(dt).collect { case co: CatalystTypeOps => co } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala index 74198c956edc..5464ae5b1d2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.catalyst.types.ops import java.time.LocalTime +import org.apache.arrow.vector.{TimeNanoVector, ValueVector} + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, MutableLong, MutableValue} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, MutableLong, MutableValue} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.TimeType +import org.apache.spark.sql.execution.arrow.{ArrowFieldWriter, TimeWriter} +import org.apache.spark.sql.types.{ObjectType, TimeType} import org.apache.spark.sql.types.ops.TimeTypeApiOps /** @@ -38,6 +42,10 @@ import org.apache.spark.sql.types.ops.TimeTypeApiOps * - String formatting (FractionTimeFormatter) * - Row encoding (LocalTimeEncoder) * + * Additionally, it implements CatalystTypeOps for: + * - Serializer/deserializer expression building (SerializerBuildHelper, DeserializerBuildHelper) + * - Arrow field writer creation (ArrowWriter) + * * INTERNAL REPRESENTATION: * - Values stored as Long nanoseconds since midnight * - Range: 0 to 86,399,999,999,999 @@ -48,7 +56,8 @@ import org.apache.spark.sql.types.ops.TimeTypeApiOps * The TimeType with precision information * @since 4.2.0 */ -case class TimeTypeOps(override val t: TimeType) extends TimeTypeApiOps(t) with TypeOps { +case class TimeTypeOps(override val t: TimeType) + extends TimeTypeApiOps(t) with TypeOps with CatalystTypeOps { // ==================== Physical Type Representation ==================== @@ -81,4 +90,28 @@ case class TimeTypeOps(override val t: TimeType) extends TimeTypeApiOps(t) with override def toScalaImpl(row: InternalRow, column: Int): Any = { DateTimeUtils.nanosToLocalTime(row.getLong(column)) } + + // ==================== Catalyst Type Operations (CatalystTypeOps) ==================== + + override def createSerializer(input: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + t, + "localTimeToNanos", + input :: Nil, + returnNullable = false) + } + + override def createDeserializer(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.time.LocalTime]), + "nanosToLocalTime", + path :: Nil, + returnNullable = false) + } + + override def createArrowFieldWriter(vector: ValueVector): ArrowFieldWriter = { + new TimeWriter(vector.asInstanceOf[TimeNanoVector]) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index b5269da035f3..0f6edbf60f76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps import org.apache.spark.sql.catalyst.util.STUtils import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ @@ -52,7 +53,14 @@ object ArrowWriter { private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() - (ArrowUtils.fromArrowField(field), vector) match { + val dt = ArrowUtils.fromArrowField(field) + CatalystTypeOps(dt).map(_.createArrowFieldWriter(vector)) + .getOrElse(createFieldWriterDefault(dt, vector)) + } + + private[sql] def createFieldWriterDefault( + dt: DataType, vector: ValueVector): ArrowFieldWriter = { + (dt, vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) @@ -146,7 +154,7 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { } } -private[arrow] abstract class ArrowFieldWriter { +private[sql] abstract class ArrowFieldWriter { def valueVector: ValueVector @@ -371,7 +379,7 @@ private[arrow] class TimestampNTZWriter( } } -private[arrow] class TimeWriter( +private[sql] class TimeWriter( val valueVector: TimeNanoVector) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 82029025a7f0..a05d6f80a9cf 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.connect.common.types.ops.ConnectArrowTypeOps import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors} import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.{CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} @@ -88,6 +89,14 @@ object ArrowDeserializers { } private[arrow] def deserializerFor( + encoder: AgnosticEncoder[_], + data: AnyRef, + timeZoneId: String): Deserializer[Any] = + ConnectArrowTypeOps(encoder) + .map(_.createArrowDeserializer(encoder, data, timeZoneId).asInstanceOf[Deserializer[Any]]) + .getOrElse(deserializerForDefault(encoder, data, timeZoneId)) + + private def deserializerForDefault( encoder: AgnosticEncoder[_], data: AnyRef, timeZoneId: String): Deserializer[Any] = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index d547c81afe5a..17a3da981265 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} +import org.apache.spark.sql.connect.common.types.ops.ConnectArrowTypeOps import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator} @@ -239,7 +240,12 @@ object ArrowSerializer { } // TODO throw better errors on class cast exceptions. - private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { + private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = + ConnectArrowTypeOps(encoder) + .map(_.createArrowSerializer(v).asInstanceOf[Serializer]) + .getOrElse(serializerForDefault(encoder, v)) + + private def serializerForDefault[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { (encoder, v) match { case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => new FieldSerializer[Boolean, BitVector](v) { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index ea57e0e1c77f..d483d1bb5de7 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, Ti import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ +import org.apache.spark.sql.connect.common.types.ops.ConnectArrowTypeOps import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, UpCastRule, YearMonthIntervalType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.SparkStringUtils @@ -66,6 +67,14 @@ private[arrow] abstract class ArrowVectorReader { object ArrowVectorReader { def apply( + targetDataType: DataType, + vector: FieldVector, + timeZoneId: String): ArrowVectorReader = + ConnectArrowTypeOps(targetDataType) + .map(_.createArrowVectorReader(vector).asInstanceOf[ArrowVectorReader]) + .getOrElse(applyDefault(targetDataType, vector, timeZoneId)) + + private def applyDefault( targetDataType: DataType, vector: FieldVector, timeZoneId: String): ArrowVectorReader = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala new file mode 100644 index 000000000000..958423978539 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/types/ops/TimeTypeConnectOps.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.arrow + +import java.time.LocalTime + +import org.apache.arrow.vector.{FieldVector, TimeNanoVector} + +import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils +import org.apache.spark.sql.connect.common.types.ops.{ConnectArrowTypeOps, ProtoTypeOps} +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Combined Connect operations for TimeType. + * + * Implements both ProtoTypeOps (proto DataType/Literal conversions) and ConnectArrowTypeOps + * (Arrow serialization/deserialization) in a single class. + * + * Lives in the arrow package to access arrow-private types (ArrowVectorReader, TimeVectorReader, + * ArrowSerializer.Serializer, ArrowDeserializers.LeafFieldDeserializer). Placed under types/ops + * subdirectory to separate ops implementations from core arrow infrastructure. + * + * @param t + * The TimeType with precision information + * @since 4.2.0 + */ +private[connect] class TimeTypeConnectOps(val t: TimeType) + extends ProtoTypeOps + with ConnectArrowTypeOps { + + override def dataType: DataType = t + + override def encoder: AgnosticEncoder[_] = LocalTimeEncoder + + // ==================== ProtoTypeOps ==================== + + override def toConnectProtoType: proto.DataType = { + proto.DataType + .newBuilder() + .setTime(proto.DataType.Time.newBuilder().setPrecision(t.precision).build()) + .build() + } + + override def toLiteralProto( + value: Any, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder = { + val v = value.asInstanceOf[LocalTime] + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(TimeType.DEFAULT_PRECISION)) + } + + override def toLiteralProtoWithType( + value: Any, + dt: DataType, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder = { + val v = value.asInstanceOf[LocalTime] + val timeType = dt.asInstanceOf[TimeType] + builder.setTime( + builder.getTimeBuilder + .setNano(SparkDateTimeUtils.localTimeToNanos(v)) + .setPrecision(timeType.precision)) + } + + override def getScalaConverter: proto.Expression.Literal => Any = { v => + SparkDateTimeUtils.nanosToLocalTime(v.getTime.getNano) + } + + override def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): proto.DataType = { + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + proto.DataType.newBuilder().setTime(timeBuilder.build()).build() + } + + // ==================== ConnectArrowTypeOps ==================== + + override def createArrowSerializer(vector: Any): Any = { + val v = vector.asInstanceOf[TimeNanoVector] + new ArrowSerializer.Serializer { + override def write(index: Int, value: Any): Unit = { + if (value != null) { + v.setSafe(index, SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime])) + } else { + v.setNull(index) + } + } + } + } + + override def createArrowDeserializer( + enc: AgnosticEncoder[_], + vector: Any, + timeZoneId: String): Any = { + val v = vector.asInstanceOf[FieldVector] + new ArrowDeserializers.LeafFieldDeserializer[LocalTime](enc, v, timeZoneId) { + override def value(i: Int): LocalTime = reader.getLocalTime(i) + } + } + + override def createArrowVectorReader(vector: Any): Any = { + new TimeVectorReader(vector.asInstanceOf[TimeNanoVector]) + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index ceccf780f586..f5948d9b1463 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.common.types.ops.ProtoTypeOps import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils @@ -29,7 +30,10 @@ import org.apache.spark.util.SparkClassUtils * Helper class for conversions between [[DataType]] and [[proto.DataType]]. */ object DataTypeProtoConverter { - def toCatalystType(t: proto.DataType): DataType = { + def toCatalystType(t: proto.DataType): DataType = + ProtoTypeOps.toCatalystType(t).getOrElse(toCatalystTypeDefault(t)) + + private def toCatalystTypeDefault(t: proto.DataType): DataType = { t.getKindCase match { case proto.DataType.KindCase.NULL => NullType @@ -174,7 +178,12 @@ object DataTypeProtoConverter { toConnectProtoTypeInternal(t, bytesToBinary) } - private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = { + private def toConnectProtoTypeInternal(t: DataType, bytesToBinary: Boolean): proto.DataType = + ProtoTypeOps(t) + .map(_.toConnectProtoType) + .getOrElse(toConnectProtoTypeDefault(t, bytesToBinary)) + + private def toConnectProtoTypeDefault(t: DataType, bytesToBinary: Boolean): proto.DataType = { t match { case NullType => ProtoDataTypes.NullType diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 026f5441c6ca..19e39befc8fb 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ +import org.apache.spark.sql.connect.common.types.ops.ProtoTypeOps import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -57,7 +58,15 @@ object LiteralValueProtoConverter { literal: Any, options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() + ProtoTypeOps.toLiteralProtoForValue(literal, builder).getOrElse { + toLiteralProtoBuilderDefault(literal, builder, options) + } + } + private def toLiteralProtoBuilderDefault( + literal: Any, + builder: proto.Expression.Literal.Builder, + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { def decimalBuilder(precision: Int, scale: Int, value: String) = { builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value) } @@ -127,6 +136,16 @@ object LiteralValueProtoConverter { dataType: DataType, options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() + ProtoTypeOps(dataType) + .map(_.toLiteralProtoWithType(literal, dataType, builder)) + .getOrElse(toLiteralProtoWithTypeDefault(literal, dataType, builder, options)) + } + + private def toLiteralProtoWithTypeDefault( + literal: Any, + dataType: DataType, + builder: proto.Expression.Literal.Builder, + options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { def arrayBuilder(scalaValue: Any, elementType: DataType) = { val ab = builder.getArrayBuilder @@ -384,7 +403,8 @@ object LiteralValueProtoConverter { getScalaConverter(getProtoDataType(literal))(literal) } - private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { + private def getScalaConverterDefault( + dataType: proto.DataType): proto.Expression.Literal => Any = { val converter: proto.Expression.Literal => Any = dataType.getKindCase match { case proto.DataType.KindCase.NULL => v => @@ -428,6 +448,14 @@ object LiteralValueProtoConverter { "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", Map("typeInfo" -> dataType.getKindCase.toString)) } + converter + } + + private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { + val converter: proto.Expression.Literal => Any = + ProtoTypeOps.getScalaConverterForKind(dataType.getKindCase).getOrElse { + getScalaConverterDefault(dataType) + } v => if (v.hasNull) null else converter(v) } @@ -500,102 +528,9 @@ object LiteralValueProtoConverter { if (literal.getLiteralTypeCase == proto.Expression.Literal.LiteralTypeCase.NULL) { literal.getNull } else { - val builder = proto.DataType.newBuilder() - literal.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.BINARY => - builder.setBinary(proto.DataType.Binary.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.BYTE => - builder.setByte(proto.DataType.Byte.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.SHORT => - builder.setShort(proto.DataType.Short.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - builder.setInteger(proto.DataType.Integer.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.LONG => - builder.setLong(proto.DataType.Long.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - builder.setFloat(proto.DataType.Float.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - builder.setDouble(proto.DataType.Double.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - val decimal = Decimal.apply(literal.getDecimal.getValue) - var precision = decimal.precision - if (literal.getDecimal.hasPrecision) { - precision = math.max(precision, literal.getDecimal.getPrecision) - } - var scale = decimal.scale - if (literal.getDecimal.hasScale) { - scale = math.max(scale, literal.getDecimal.getScale) - } - builder.setDecimal( - proto.DataType.Decimal - .newBuilder() - .setPrecision(math.max(precision, scale)) - .setScale(scale) - .build()) - case proto.Expression.Literal.LiteralTypeCase.STRING => - builder.setString(proto.DataType.String.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DATE => - builder.setDate(proto.DataType.Date.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) - case proto.Expression.Literal.LiteralTypeCase.TIME => - val timeBuilder = proto.DataType.Time.newBuilder() - if (literal.getTime.hasPrecision) { - timeBuilder.setPrecision(literal.getTime.getPrecision) - } - builder.setTime(timeBuilder.build()) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - if (literal.getArray.hasElementType) { - builder.setArray( - proto.DataType.Array - .newBuilder() - .setElementType(literal.getArray.getElementType) - .setContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.MAP => - if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { - builder.setMap( - proto.DataType.Map - .newBuilder() - .setKeyType(literal.getMap.getKeyType) - .setValueType(literal.getMap.getValueType) - .setValueContainsNull(true) - .build()) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - if (literal.getStruct.hasStructType) { - builder.setStruct(literal.getStruct.getStructType.getStruct) - } else { - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", - Map.empty) - } - case _ => - val literalCase = literal.getLiteralTypeCase - throw InvalidPlanInput( - "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", - Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) - } - builder.build() + ProtoTypeOps + .getProtoDataTypeFromLiteral(literal) + .getOrElse(getProtoDataTypeDefault(literal)) } } @@ -610,6 +545,103 @@ object LiteralValueProtoConverter { dataType } + private def getProtoDataTypeDefault(literal: proto.Expression.Literal): proto.DataType = { + val builder = proto.DataType.newBuilder() + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.BINARY => + builder.setBinary(proto.DataType.Binary.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => + builder.setBoolean(proto.DataType.Boolean.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.BYTE => + builder.setByte(proto.DataType.Byte.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.SHORT => + builder.setShort(proto.DataType.Short.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.INTEGER => + builder.setInteger(proto.DataType.Integer.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.LONG => + builder.setLong(proto.DataType.Long.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.FLOAT => + builder.setFloat(proto.DataType.Float.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DOUBLE => + builder.setDouble(proto.DataType.Double.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DECIMAL => + val decimal = Decimal.apply(literal.getDecimal.getValue) + var precision = decimal.precision + if (literal.getDecimal.hasPrecision) { + precision = math.max(precision, literal.getDecimal.getPrecision) + } + var scale = decimal.scale + if (literal.getDecimal.hasScale) { + scale = math.max(scale, literal.getDecimal.getScale) + } + builder.setDecimal( + proto.DataType.Decimal + .newBuilder() + .setPrecision(math.max(precision, scale)) + .setScale(scale) + .build()) + case proto.Expression.Literal.LiteralTypeCase.STRING => + builder.setString(proto.DataType.String.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DATE => + builder.setDate(proto.DataType.Date.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => + builder.setTimestamp(proto.DataType.Timestamp.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => + builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => + builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => + builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => + builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder().build()) + case proto.Expression.Literal.LiteralTypeCase.TIME => + val timeBuilder = proto.DataType.Time.newBuilder() + if (literal.getTime.hasPrecision) { + timeBuilder.setPrecision(literal.getTime.getPrecision) + } + builder.setTime(timeBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + if (literal.getArray.hasElementType) { + builder.setArray( + proto.DataType.Array + .newBuilder() + .setElementType(literal.getArray.getElementType) + .setContainsNull(true) + .build()) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.ARRAY_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.MAP => + if (literal.getMap.hasKeyType && literal.getMap.hasValueType) { + builder.setMap( + proto.DataType.Map + .newBuilder() + .setKeyType(literal.getMap.getKeyType) + .setValueType(literal.getMap.getValueType) + .setValueContainsNull(true) + .build()) + } else { + throw InvalidPlanInput("CONNECT_INVALID_PLAN.MAP_LITERAL_MISSING_DATA_TYPE", Map.empty) + } + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + if (literal.getStruct.hasStructType) { + builder.setStruct(literal.getStruct.getStructType.getStruct) + } else { + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.STRUCT_LITERAL_MISSING_DATA_TYPE", + Map.empty) + } + case _ => + val literalCase = literal.getLiteralTypeCase + throw InvalidPlanInput( + "CONNECT_INVALID_PLAN.UNSUPPORTED_LITERAL_TYPE", + Map("typeInfo" -> s"${literalCase.name}(${literalCase.getNumber})")) + } + builder.build() + } + private def toScalaArrayInternal( literal: proto.Expression.Literal, arrayType: proto.DataType.Array): Array[_] = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala new file mode 100644 index 000000000000..da1357d579c1 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ConnectArrowTypeOps.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.common.types.ops + +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder +import org.apache.spark.sql.connect.client.arrow.TimeTypeConnectOps +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional type operations for Spark Connect Arrow serialization/deserialization. + * + * Handles Arrow-based data exchange on the Connect client side for framework-managed types. Used + * by ArrowSerializer, ArrowDeserializer, and ArrowVectorReader. + * + * NOTE: No feature flag check -the Connect client must handle whatever types the server sends. + * The feature flag controls server-side engine behavior; the client always needs to handle types + * that exist in the encoder. + * + * Methods return Any to avoid referencing arrow-private types from the types.ops package. Call + * sites in the arrow package cast to the expected types. + * + * @since 4.2.0 + */ +trait ConnectArrowTypeOps extends Serializable { + + def encoder: AgnosticEncoder[_] + + /** Creates an Arrow serializer for writing values to a vector. Returns a Serializer. */ + def createArrowSerializer(vector: Any): Any + + /** Creates an Arrow deserializer for reading values from a vector. Returns a Deserializer. */ + def createArrowDeserializer(enc: AgnosticEncoder[_], vector: Any, timeZoneId: String): Any + + /** Creates an ArrowVectorReader for this type's vector. Returns an ArrowVectorReader. */ + def createArrowVectorReader(vector: Any): Any +} + +/** + * Factory object for ConnectArrowTypeOps lookup. + * + * No feature flag check -the Connect client always handles registered types. + */ +object ConnectArrowTypeOps { + + /** Encoder-keyed dispatch (for ArrowSerializer, ArrowDeserializer). */ + def apply(enc: AgnosticEncoder[_]): Option[ConnectArrowTypeOps] = + enc match { + case LocalTimeEncoder => Some(new TimeTypeConnectOps(TimeType())) + case _ => None + } + + /** DataType-keyed dispatch (for ArrowVectorReader which doesn't have an encoder). */ + def apply(dt: DataType): Option[ConnectArrowTypeOps] = + dt match { + case tt: TimeType => Some(new TimeTypeConnectOps(tt)) + case _ => None + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala new file mode 100644 index 000000000000..c4e24597dbf1 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/types/ops/ProtoTypeOps.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.common.types.ops + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.arrow.TimeTypeConnectOps +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.types.{DataType, TimeType} + +/** + * Optional type operations for Spark Connect protobuf conversions. + * + * Handles bidirectional DataType <-> proto and Literal <-> proto conversions for + * framework-managed types in DataTypeProtoConverter and LiteralValueProtoConverter. + * + * @since 4.2.0 + */ +trait ProtoTypeOps extends Serializable { + + def dataType: DataType + + /** Converts this DataType to its Connect proto representation. */ + def toConnectProtoType: proto.DataType + + /** Converts a value to a proto literal builder (generic, no DataType context). */ + def toLiteralProto( + value: Any, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder + + /** Converts a value to a proto literal builder (with DataType context). */ + def toLiteralProtoWithType( + value: Any, + dt: DataType, + builder: proto.Expression.Literal.Builder): proto.Expression.Literal.Builder + + /** Returns a converter from proto literal to Scala value. */ + def getScalaConverter: proto.Expression.Literal => Any + + /** Returns a proto DataType inferred from a proto literal (for type inference). */ + def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): proto.DataType +} + +/** + * Factory object for ProtoTypeOps lookup. + */ +object ProtoTypeOps { + + def apply(dt: DataType): Option[ProtoTypeOps] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + dt match { + case tt: TimeType => Some(new TimeTypeConnectOps(tt)) + case _ => None + } + } + + /** Reverse lookup by value class for the generic literal builder. */ + def toLiteralProtoForValue( + value: Any, + builder: proto.Expression.Literal.Builder): Option[proto.Expression.Literal.Builder] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + value match { + case v: java.time.LocalTime => + Some(new TimeTypeConnectOps(TimeType()).toLiteralProto(v, builder)) + case _ => None + } + } + + /** + * Reverse lookup: converts a proto DataType to a Spark DataType, if it belongs to a + * framework-managed type. + */ + def toCatalystType(t: proto.DataType): Option[DataType] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + t.getKindCase match { + case proto.DataType.KindCase.TIME => + val time = t.getTime + if (time.hasPrecision) Some(TimeType(time.getPrecision)) + else Some(TimeType()) + case _ => None + } + } + + /** + * Reverse lookup: returns a Scala converter for a proto literal KindCase. + */ + def getScalaConverterForKind( + kindCase: proto.DataType.KindCase): Option[proto.Expression.Literal => Any] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + kindCase match { + case proto.DataType.KindCase.TIME => + Some(new TimeTypeConnectOps(TimeType()).getScalaConverter) + case _ => None + } + } + + /** + * Reverse lookup: returns the proto DataType inferred from a proto literal's type case, if the + * literal type belongs to a framework-managed type. + */ + def getProtoDataTypeFromLiteral(literal: proto.Expression.Literal): Option[proto.DataType] = { + if (!SqlApiConf.get.typesFrameworkEnabled) return None + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.TIME => + Some(new TimeTypeConnectOps(TimeType()).getProtoDataTypeFromLiteral(literal)) + case _ => None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 8a666bbb9dad..f748592e9071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTab import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ @@ -112,6 +113,17 @@ object HiveResult extends SQLConfHelper { formatters: TimeFormatters, binaryFormatter: BinaryFormatter): String = a match { case (null, _) => if (nested) "null" else "NULL" + case (value, dt) => + ClientTypeOps(dt).map(_.formatExternal(value)).getOrElse { + toHiveStringDefault(a, nested, formatters, binaryFormatter) + } + } + + private def toHiveStringDefault( + a: (Any, DataType), + nested: Boolean, + formatters: TimeFormatters, + binaryFormatter: BinaryFormatter): String = a match { case (b, BooleanType) => b.toString case (d: Date, DateType) => formatters.date.format(d) case (ld: LocalDate, DateType) => formatters.date.format(ld) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 28882a45b7f8..1c28bc2e105d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -145,7 +146,11 @@ object JdbcUtils extends Logging with SQLConfHelper { * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) * @return The default JdbcType for this DataType */ - def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + def getCommonJDBCType(dt: DataType): Option[JdbcType] = + ClientTypeOps(dt).map(ops => JdbcType(ops.jdbcTypeName, ops.getJdbcType)) + .orElse(getCommonJDBCTypeDefault(dt)) + + private def getCommonJDBCTypeDefault(dt: DataType): Option[JdbcType] = { dt match { case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 33622ca7349a..d329ca8494d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal} object EvaluatePython { @@ -41,7 +42,11 @@ object EvaluatePython { */ private[python] class BytesWrapper(val data: Array[Byte]) - def needConversionInPython(dt: DataType): Boolean = dt match { + def needConversionInPython(dt: DataType): Boolean = + ClientTypeOps(dt).map(_.needConversionInPython) + .getOrElse(needConversionInPythonDefault(dt)) + + private def needConversionInPythonDefault(dt: DataType): Boolean = dt match { case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType | _: TimeType | _: GeometryType | _: GeographyType => true case _: StructType => true @@ -111,7 +116,10 @@ object EvaluatePython { * Make a converter that converts `obj` to the type specified by the data type, or returns * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def makeFromJava(dataType: DataType): Any => Any = dataType match { + def makeFromJava(dataType: DataType): Any => Any = + ClientTypeOps(dataType).map(_.makeFromJava).getOrElse(makeFromJavaDefault(dataType)) + + private def makeFromJavaDefault(dataType: DataType): Any => Any = dataType match { case BooleanType => (obj: Any) => nullSafeConvert(obj) { case b: Boolean => b } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index cd2bdefcc306..12d16467101d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution} import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.ops.ClientTypeOps import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( @@ -326,7 +327,11 @@ private[hive] class SparkExecuteStatementOperation( object SparkExecuteStatementOperation { - def toTTypeId(typ: DataType): TTypeId = typ match { + def toTTypeId(typ: DataType): TTypeId = + ClientTypeOps(typ).map(ops => TTypeId.valueOf(ops.thriftTypeName)) + .getOrElse(toTTypeIdDefault(typ)) + + private def toTTypeIdDefault(typ: DataType): TTypeId = typ match { case NullType => TTypeId.NULL_TYPE case BooleanType => TTypeId.BOOLEAN_TYPE case ByteType => TTypeId.TINYINT_TYPE