diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 2f7d50fe764ae..3c210ca7d985b 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6541,6 +6541,14 @@ ], "sqlState" : "42601" }, + "STORAGE_PARTITION_JOIN_INCOMPATIBLE_REDUCED_TYPES" : { + "message" : [ + "Storage-partition join partition transforms produced incompatible reduced types,", + "left reducers: returned: ,", + "right reducers: returned: ." + ], + "sqlState" : "42K09" + }, "STREAMING_CHECKPOINT_MISSING_METADATA_FILE" : { "message" : [ "Checkpoint location is in an inconsistent state: the metadata file is missing but offset and/or commit logs contain data. Please restore the metadata file or create a new checkpoint directory." diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index 54d22057b45a2..28520aa56258c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.DataType; /** * A 'reducer' for output of user-defined functions. @@ -31,9 +32,10 @@ *
  • More generally, there exists reducer functions r1(x) and r2(x) such that * r1(f_source(x)) = r2(f_target(x)) for all input x.
  • * + * where = means both value and data type match. * - * @param reducer input type - * @param reducer output type + * @param the physical Java type of the input + * @param the physical Java type of the output * @since 4.0.0 */ @Evolving @@ -47,4 +49,11 @@ public interface Reducer { default String displayName() { return getClass().getSimpleName(); } + + /** + * Returns the {@link DataType data type} of values produced by this reducer. + * + * @return the data type of values produced by this reducer. + */ + DataType resultType(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 28a9225b6ce23..8b11860d0eb87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -465,10 +465,11 @@ case class KeyedPartitioning( /** * Reduces this partitioning's partition keys by applying the given reducers. - * Returns the distinct reduced keys. + * Returns the reduced keys and their data types. */ - def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = - KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers).distinct + def reduceKeys( + reducers: Seq[Option[Reducer[_, _]]]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = + KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers) override def satisfies0(required: Distribution): Boolean = { nonGroupedSatisfies(required) || groupedSatisfies(required) @@ -586,10 +587,14 @@ object KeyedPartitioning { def reduceKeys( keys: Seq[InternalRowComparableWrapper], dataTypes: Seq[DataType], - reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = { + reducers: Seq[Option[Reducer[_, _]]]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = { + val reducedDataTypes = dataTypes.zip(reducers).map { + case (_, Some(reducer: Reducer[Any, Any])) => reducer.resultType() + case (t, _) => t + } val comparableKeyWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) - keys.map { key => + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(reducedDataTypes) + val reducedKeys = keys.map { key => val keyValues = key.row.toSeq(dataTypes) val reducedKey = keyValues.zip(reducers).map { case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) @@ -597,6 +602,8 @@ object KeyedPartitioning { }.toArray comparableKeyWrapperFactory(new GenericInternalRow(reducedKey)) } + + (reducedDataTypes, reducedKeys) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 8985bdb519d19..d32d42f20004c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, CharsetProvider, DateTimeUtils, FailFastMode, IntervalUtils, MapData} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -3128,6 +3129,30 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ) } + def storagePartitionJoinIncompatibleReducedTypesError( + leftReducers: Option[Seq[Option[Reducer[_, _]]]], + leftReducedDataTypes: Seq[DataType], + rightReducers: Option[Seq[Option[Reducer[_, _]]]], + rightReducedDataTypes: Seq[DataType]): Throwable = { + def reducersNames(reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + reducers.toSeq.flatMap(_.map(_.map(_.displayName()).getOrElse("identity"))) + .mkString("[", ", ", "]") + } + + def dataTypeNames(dataTypes: Seq[DataType]) = { + dataTypes.map(toSQLType).mkString("[", ", ", "]") + } + + new SparkException( + errorClass = "STORAGE_PARTITION_JOIN_INCOMPATIBLE_REDUCED_TYPES", + messageParameters = Map( + "leftReducers" -> reducersNames(leftReducers), + "leftReducedDataTypes" -> dataTypeNames(leftReducedDataTypes), + "rightReducers" -> reducersNames(rightReducers), + "rightReducedDataTypes" -> dataTypeNames(rightReducedDataTypes)), + cause = null) + } + def notAbsolutePathError(path: Path): SparkException = { SparkException.internalError(s"$path is not absolute path.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 407d592f82199..e7762565f47ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -203,10 +203,10 @@ abstract class InMemoryBaseTable( case YearsTransform(ref) => extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)).toInt case (micros: Long, TimestampType) => val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } @@ -225,7 +225,7 @@ abstract class InMemoryBaseTable( case (days, DateType) => days case (micros: Long, TimestampType) => - ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)).toInt case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 10e59c7b36114..455f7f85d2b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -141,7 +141,7 @@ case class GroupPartitionsExec( )(keyedPartitioning.projectKeys) // Reduce keys if reducers are specified - val reducedKeys = reducers.fold(projectedKeys)( + val (reducedDataTypes, reducedKeys) = reducers.fold((projectedDataTypes, projectedKeys))( KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, _)) val keyToPartitionIndices = reducedKeys.zipWithIndex.groupMap(_._1)(_._2) @@ -149,7 +149,7 @@ case class GroupPartitionsExec( if (expectedPartitionKeys.isDefined) { alignToExpectedKeys(keyToPartitionIndices) } else { - (groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes), true) + (groupAndSortByKeys(keyToPartitionIndices, reducedDataTypes), true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index cca37558584f0..66cc7c90b63e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.GroupPartitionsExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -509,16 +510,24 @@ case class EnsureRequirements( // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) - val leftReducedKeys = - leftReducers.fold(leftPartitioning.partitionKeys)(leftPartitioning.reduceKeys) val rightReducers = rightSpec.reducers(leftSpec) - val rightReducedKeys = - rightReducers.fold(rightPartitioning.partitionKeys)(rightPartitioning.reduceKeys) + val (leftReducedDataTypes, leftReducedKeys) = leftReducers.fold( + (leftPartitioning.expressionDataTypes, leftPartitioning.partitionKeys) + )(leftPartitioning.reduceKeys) + val (rightReducedDataTypes, rightReducedKeys) = rightReducers.fold( + (rightPartitioning.expressionDataTypes, rightPartitioning.partitionKeys) + )(rightPartitioning.reduceKeys) + if (leftReducedDataTypes != rightReducedDataTypes) { + throw QueryExecutionErrors.storagePartitionJoinIncompatibleReducedTypesError( + leftReducers = leftReducers, + leftReducedDataTypes = leftReducedDataTypes, + rightReducers = rightReducers, + rightReducedDataTypes = rightReducedDataTypes) + } // merge values on both sides - var mergedPartitionKeys = - mergePartitions(leftReducedKeys, rightReducedKeys, joinType, leftPartitioning.keyOrdering) - .map((_, 1)) + var mergedPartitionKeys = mergeAndDedupPartitions(leftReducedKeys, rightReducedKeys, + joinType, leftPartitioning.keyOrdering).map((_, 1)) logInfo(log"After merging, there are " + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions") @@ -752,7 +761,7 @@ case class EnsureRequirements( } /** - * Merge and sort partitions keys for SPJ and optionally enable partition filtering. + * Merge, dedup and sort partitions keys for SPJ and optionally enable partition filtering. * Both sides must have matching partition expressions. * @param leftPartitionKeys left side partition keys * @param rightPartitionKeys right side partition keys @@ -760,20 +769,21 @@ case class EnsureRequirements( * @keyOrdering ordering to sort partition keys * @return merged and sorted partition values */ - def mergePartitions( + def mergeAndDedupPartitions( leftPartitionKeys: Seq[InternalRowComparableWrapper], rightPartitionKeys: Seq[InternalRowComparableWrapper], joinType: JoinType, keyOrdering: Ordering[InternalRowComparableWrapper]): Seq[InternalRowComparableWrapper] = { val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { - case Inner => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys, intersect = true) - case LeftOuter => leftPartitionKeys - case RightOuter => rightPartitionKeys - case _ => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys) + case Inner => + mergeAndDedupPartitionKeys(leftPartitionKeys, rightPartitionKeys, intersect = true) + case LeftOuter => leftPartitionKeys.distinct + case RightOuter => rightPartitionKeys.distinct + case _ => mergeAndDedupPartitionKeys(leftPartitionKeys, rightPartitionKeys) } } else { - mergePartitionKeys(leftPartitionKeys, rightPartitionKeys) + mergeAndDedupPartitionKeys(leftPartitionKeys, rightPartitionKeys) } // SPARK-41471: We keep to order of partitions to make sure the order of @@ -781,7 +791,7 @@ case class EnsureRequirements( merged.sorted(keyOrdering) } - private def mergePartitionKeys( + private def mergeAndDedupPartitionKeys( leftPartitionKeys: Seq[InternalRowComparableWrapper], rightPartitionKeys: Seq[InternalRowComparableWrapper], intersect: Boolean = false) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 4116de1b89783..19c00b58a15a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector import java.sql.Timestamp import java.util.Collections -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} @@ -75,6 +75,20 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Column.create("dept_id", IntegerType), Column.create("data", StringType)) + def withFunction[T](fn: UnboundFunction)(f: => T): T = { + val id = Identifier.of(Array.empty, fn.name()) + val oldFn = Option.when(catalog.listFunctions(Array.empty).contains(id)) { + val fn = catalog.loadFunction(id) + catalog.dropFunction(id) + fn + } + catalog.createFunction(id, fn) + try f finally { + catalog.dropFunction(id) + oldFn.foreach(catalog.createFunction(id, _)) + } + } + test("clustered distribution: output partitioning should be KeyedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) @@ -88,7 +102,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts") val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) - val partitionKeys = Seq(50L, 51L, 52L).map(v => InternalRow.fromSeq(Seq(v))) + val partitionKeys = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, catalystDistribution, physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys)) @@ -3385,4 +3399,83 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with checkKeywordsExistsInExplain(df, FormattedMode, formattedKeyword) } } + + test("SPARK-56046: Reducers with same result types") { + val items_partitions = Array(days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2021-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + Seq( + s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.time = i.arrive_time", + s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.arrive_time = p.time" + ).foreach { joinSting => + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} id, item_id + |FROM $joinSting + |ORDER BY id, item_id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) + + checkAnswer(df, Seq(Row(0, 1), Row(1, 1))) + } + } + } + + test("SPARK-56046: Reducers with different result types") { + withFunction(UnboundDaysFunctionWithIncompatibleResultTypeReducer) { + val items_partitions = Array(days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2021-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + Seq( + s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.time = i.arrive_time", + s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.arrive_time = p.time" + ).foreach { joinSting => + val e = intercept[SparkException] { + sql( + s""" + |${selectWithMergeJoinHint("i", "p")} id, item_id + |FROM $joinSting + |ORDER BY id, item_id + |""".stripMargin).collect() + } + assert(e.getMessage.contains( + "Storage-partition join partition transforms produced incompatible reduced types")) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 7c4852c5e22d5..588490e07dfd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -1174,7 +1174,7 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase Invoke( Literal.create(YearsFunction, ObjectType(YearsFunction.getClass)), "invoke", - LongType, + IntegerType, Seq(Cast(attr("day"), TimestampType, Some("America/Los_Angeles"))), Seq(TimestampType), propagateNull = false), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index e7d05ab25e2dd..13e84e32d1dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -40,42 +40,97 @@ object UnboundYearsFunction extends UnboundFunction { override def name(): String = "years" } -object YearsFunction extends ScalarFunction[Long] { +object YearsFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(TimestampType) - override def resultType(): DataType = LongType + override def resultType(): DataType = IntegerType override def name(): String = "years" override def canonicalName(): String = name() val UTC: ZoneId = ZoneId.of("UTC") val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate - def invoke(ts: Long): Long = { + def invoke(ts: Long): Int = { val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt } + + override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Int] = null } -object DaysFunction extends BoundFunction { - override def inputTypes(): Array[DataType] = Array(TimestampType) - override def resultType(): DataType = LongType +abstract class UnboundDaysFunctionBase extends UnboundFunction { + protected def isValidType(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _ => false + } + + override def description(): String = name() override def name(): String = "days" - override def canonicalName(): String = name() } -object UnboundDaysFunction extends UnboundFunction { +object UnboundDaysFunction extends UnboundDaysFunctionBase { override def bind(inputType: StructType): BoundFunction = { if (inputType.size == 1 && isValidType(inputType.head.dataType)) DaysFunction else throw new UnsupportedOperationException( "'days' only take date or timestamp as input type") } +} - private def isValidType(dt: DataType): Boolean = dt match { - case DateType | TimestampType => true - case _ => false +object UnboundDaysFunctionWithIncompatibleResultTypeReducer extends UnboundDaysFunctionBase { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) { + DaysFunctionWithIncompatibleResultTypeReducer + } else throw new UnsupportedOperationException( + "'days' only take date or timestamp as input type") } +} - override def description(): String = name() +abstract class DaysFunctionBase extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { + override def inputTypes(): Array[DataType] = Array(TimestampType) + override def resultType(): DataType = DateType override def name(): String = "days" + override def canonicalName(): String = name() +} + +// This `days` function reduces `DateType` partitions keys to `IntegerType` partitions keys when +// partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. +object DaysFunction extends DaysFunctionBase { + override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Int] = { + if (otherFunc == YearsFunction) { + DaysToYearsReducer() + } else { + null + } + } +} + +// This `days` function reduces `DateType` partition keys to `DateType` partition keys when +// partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. +object DaysFunctionWithIncompatibleResultTypeReducer extends DaysFunctionBase { + override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Int] = { + if (otherFunc == YearsFunction) { + DaysToYearsReducerWithIncompatibleResultType() + } else { + null + } + } +} + +abstract class DaysToYearsReducerBase extends Reducer[Int, Int] { + val UTC: ZoneId = ZoneId.of("UTC") + val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate + + override def reduce(days: Int): Int = { + val localDate = EPOCH_LOCAL_DATE.plusDays(days) + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt + } +} + +case class DaysToYearsReducer() extends DaysToYearsReducerBase { + override def resultType(): DataType = IntegerType +} + +case class DaysToYearsReducerWithIncompatibleResultType() extends DaysToYearsReducerBase { + override def resultType(): DataType = DateType } object UnboundBucketFunction extends UnboundFunction { @@ -114,6 +169,7 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In case class BucketReducer(divisor: Int) extends Reducer[Int, Int] { override def reduce(bucket: Int): Int = bucket % divisor + override def resultType(): DataType = IntegerType override def displayName(): String = toString }