Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types.ops

import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{DataType, TimeType}

/**
* Optional client-side type operations for the Types Framework.
*
* This trait extends TypeApiOps with operations needed by client-facing infrastructure: Arrow
* conversion (ArrowUtils), JDBC mapping (JdbcUtils), Python interop (EvaluatePython), Hive
* formatting (HiveResult), and Thrift type mapping (SparkExecuteStatementOperation).
*
* Lives in sql/api so it's visible from sql/core and sql/hive-thriftserver.
*
* USAGE - integration points use ClientTypeOps(dt) which returns Option[ClientTypeOps]:
* {{{
* // Forward lookup (most files):
* ClientTypeOps(dt).map(_.toArrowType(timeZoneId)).getOrElse { ... }
*
* // Reverse lookup (ArrowUtils.fromArrowType):
* ClientTypeOps.fromArrowType(at).getOrElse { ... }
* }}}
*
* @see
* TimeTypeApiOps for a reference implementation
* @since 4.2.0
*/
trait ClientTypeOps { self: TypeApiOps =>

// ==================== Utilities ====================

/**
* Null-safe conversion helper. Returns null for null input, applies the partial function for
* non-null input, and returns null for unmatched values.
*/
protected def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = {
if (input == null) {
null
} else {
f.applyOrElse(input, (_: Any) => null)
}
}

// ==================== Arrow Conversion ====================

/**
* Converts this DataType to its Arrow representation.
*
* Used by ArrowUtils.toArrowType.
*
* @param timeZoneId
* the session timezone (needed by some temporal types)
* @return
* the corresponding ArrowType
*/
def toArrowType(timeZoneId: String): ArrowType

// ==================== JDBC Mapping ====================

/**
* Returns the java.sql.Types constant for this type.
*
* Used by JdbcUtils.getCommonJDBCType for JDBC write path.
*
* @return
* java.sql.Types constant (e.g., java.sql.Types.TIME)
*/
def getJdbcType: Int

/**
* Returns the DDL type name string for this type.
*
* Used by JdbcUtils for CREATE TABLE DDL generation.
*
* @return
* DDL type string (e.g., "TIME")
*/
def jdbcTypeName: String

// ==================== Python Interop ====================

/**
* Returns true if values of this type need conversion when passed to/from Python.
*
* Used by EvaluatePython.needConversionInPython.
*/
def needConversionInPython: Boolean

/**
* Creates a converter function for Python/Py4J interop.
*
* Used by EvaluatePython.makeFromJava. The returned function handles null-safe conversion of
* Java/Py4J values to the internal Catalyst representation.
*
* @return
* a function that converts a Java value to the internal representation
*/
def makeFromJava: Any => Any

// ==================== Hive Formatting ====================

/**
* Formats an external-type value for Hive output.
*
* Used by HiveResult.toHiveString. The input is an external-type value (e.g.,
* java.time.LocalTime for TimeType), NOT the internal representation.
*
* @param value
* the external-type value to format
* @return
* formatted string representation
*/
def formatExternal(value: Any): String

// ==================== Thrift Mapping ====================

/**
* Returns the Thrift TTypeId name for this type.
*
* Used by SparkExecuteStatementOperation.toTTypeId. Returns a String that maps to a TTypeId
* enum value (e.g., "STRING_TYPE") since TTypeId is only available in the hive-thriftserver
* module.
*
* @return
* TTypeId enum name (e.g., "STRING_TYPE")
*/
def thriftTypeName: String
}

/**
* Factory object for ClientTypeOps lookup.
*
* Delegates to TypeApiOps and narrows via collect to find implementations that mix in
* ClientTypeOps.
*/
object ClientTypeOps {

/**
* Returns a ClientTypeOps instance for the given DataType, if available.
*
* @param dt
* the DataType to get operations for
* @return
* Some(ClientTypeOps) if supported, None otherwise
*/
// Delegates to TypeApiOps and narrows: a type must implement TypeApiOps AND mix in
// ClientTypeOps to be found here. No separate registration needed — the collect
// filter handles incremental trait adoption automatically.
def apply(dt: DataType): Option[ClientTypeOps] =
TypeApiOps(dt).collect { case co: ClientTypeOps => co }

/**
* Reverse lookup: converts an Arrow type to a Spark DataType, if it belongs to a
* framework-managed type.
*
* Used by ArrowUtils.fromArrowType. Returns None if the Arrow type doesn't correspond to any
* framework-managed type, or the framework is disabled.
*
* @param at
* the ArrowType to convert
* @return
* Some(DataType) if recognized, None otherwise
*/
def fromArrowType(at: ArrowType): Option[DataType] = {
import org.apache.arrow.vector.types.TimeUnit
if (!SqlApiConf.get.typesFrameworkEnabled) return None
at match {
case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 8 * 8 =>
Some(TimeType(TimeType.MICROS_PRECISION))
// Add new framework types here
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.sql.types.ops

import java.time.LocalTime

import org.apache.arrow.vector.types.TimeUnit
import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.LocalTimeEncoder
import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, TimeFormatter}
import org.apache.spark.sql.catalyst.util.{FractionTimeFormatter, SparkDateTimeUtils, TimeFormatter}
import org.apache.spark.sql.types.{DataType, TimeType}

/**
Expand All @@ -29,14 +34,21 @@ import org.apache.spark.sql.types.{DataType, TimeType}
* - String formatting: uses FractionTimeFormatter for consistent output
* - Row encoding: uses LocalTimeEncoder for java.time.LocalTime
*
* Additionally, it implements ClientTypeOps for:
* - Arrow conversion (ArrowUtils)
* - JDBC mapping (JdbcUtils)
* - Python interop (EvaluatePython)
* - Hive formatting (HiveResult)
* - Thrift type mapping (SparkExecuteStatementOperation)
*
* RELATIONSHIP TO TimeTypeOps: TimeTypeOps (in catalyst package) extends this class to inherit
* client-side operations while adding server-side operations (physical type, literals, etc.).
*
* @param t
* The TimeType with precision information
* @since 4.2.0
*/
class TimeTypeApiOps(val t: TimeType) extends TypeApiOps {
class TimeTypeApiOps(val t: TimeType) extends TypeApiOps with ClientTypeOps {

override def dataType: DataType = t

Expand All @@ -56,4 +68,30 @@ class TimeTypeApiOps(val t: TimeType) extends TypeApiOps {
// ==================== Row Encoding ====================

override def getEncoder: AgnosticEncoder[_] = LocalTimeEncoder

// ==================== Client Type Operations (ClientTypeOps) ====================

override def toArrowType(timeZoneId: String): ArrowType = {
new ArrowType.Time(TimeUnit.NANOSECOND, 8 * 8)
}

override def getJdbcType: Int = java.sql.Types.TIME

override def jdbcTypeName: String = "TIME"

override def needConversionInPython: Boolean = true

override def makeFromJava: Any => Any = (obj: Any) =>
nullSafeConvert(obj) {
case c: Long => c
// Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
case c: Int => c.toLong
}

override def formatExternal(value: Any): String = {
val nanos = SparkDateTimeUtils.localTimeToNanos(value.asInstanceOf[LocalTime])
timeFormatter.format(nanos)
}

override def thriftTypeName: String = "STRING_TYPE"
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.SparkException
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ops.ClientTypeOps
import org.apache.spark.util.ArrayImplicits._

private[sql] object ArrowUtils {
Expand All @@ -39,6 +40,14 @@ private[sql] object ArrowUtils {

/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType =
ClientTypeOps(dt)
.map(_.toArrowType(timeZoneId))
.getOrElse(toArrowTypeDefault(dt, timeZoneId, largeVarTypes))

private def toArrowTypeDefault(
dt: DataType,
timeZoneId: String,
largeVarTypes: Boolean): ArrowType =
dt match {
case BooleanType => ArrowType.Bool.INSTANCE
case ByteType => new ArrowType.Int(8, true)
Expand Down Expand Up @@ -67,7 +76,10 @@ private[sql] object ArrowUtils {
throw ExecutionErrors.unsupportedDataTypeError(dt)
}

def fromArrowType(dt: ArrowType): DataType = dt match {
def fromArrowType(dt: ArrowType): DataType =
ClientTypeOps.fromArrowType(dt).getOrElse(fromArrowTypeDefault(dt))

private def fromArrowTypeDefault(dt: ArrowType): DataType = dt match {
case ArrowType.Bool.INSTANCE => BooleanType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -289,6 +290,14 @@ object DeserializerBuildHelper {
* @param isTopLevel true if we are creating a deserializer for the top level value.
*/
private def createDeserializer(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath,
isTopLevel: Boolean = false): Expression =
CatalystTypeOps(enc.dataType).map(_.createDeserializer(path))
.getOrElse(createDeserializerDefault(enc, path, walkedTypePath, isTopLevel))

private def createDeserializerDefault(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, Bo
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.types.ops.CatalystTypeOps
import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -332,7 +333,12 @@ object SerializerBuildHelper {
* representation. The mapping between the external and internal representations is described
* by encoder `enc`.
*/
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression =
CatalystTypeOps(enc.dataType).map(_.createSerializer(input))
.getOrElse(createSerializerDefault(enc, input))

private def createSerializerDefault(
enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
case _ if isNativeEncoder(enc) => input
case BoxedBooleanEncoder => createSerializerForBoolean(input)
Expand Down
Loading