From 870c478648f35a2cf1d7fa72265e5357086e1f8c Mon Sep 17 00:00:00 2001 From: bvolpato Date: Tue, 17 Mar 2026 10:16:11 -0400 Subject: [PATCH] fix(substrait): normalize table names from Substrait NamedTable for Calcite interop Normalize Substrait NamedTable names using TableReference::parse_str, matching how DataFusion's SQL planner normalizes identifiers at parse time. Since Substrait has no concept of quoted identifiers, all names are treated as unquoted and lowercased. This fixes interoperability with producers like Apache Calcite/Isthmus which emit uppercase table names (e.g. LINEITEM) while DataFusion's catalog stores names in lowercase (e.g. lineitem). This addresses 118 out of 120 failing consumer-testing plans. --- .../src/logical_plan/consumer/rel/read_rel.rs | 25 +- .../tests/cases/consumer_integration.rs | 515 ++++++++++-------- .../substrait/tests/cases/emit_kind_tests.rs | 18 +- .../substrait/tests/cases/logical_plans.rs | 54 +- .../tests/cases/substrait_validations.rs | 26 +- datafusion/substrait/tests/utils.rs | 20 +- 6 files changed, 354 insertions(+), 304 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs index 48e93c04bb034..bab8b9beb5930 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -86,22 +86,15 @@ pub async fn from_read_rel( match &read.read_type { Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, + // Normalize table names using DataFusion's identifier normalization + // (via TableReference::parse_str). Since Substrait has no concept of + // quoted identifiers, all names are treated as unquoted — this ensures + // interoperability with producers like Calcite/Isthmus that emit + // uppercase names (e.g. "LINEITEM" -> "lineitem"). + let table_reference = if nt.names.is_empty() { + return plan_err!("No table name found in NamedTable"); + } else { + TableReference::parse_str(&nt.names.join(".")) }; read_with_schema( diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index a92fc2957cae3..e9363fe38d035 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -53,13 +53,13 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER - Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST - Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]] - Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT - Filter: LINEITEM.L_SHIPDATE <= Date32("1998-12-01") - IntervalDayTime("IntervalDayTime { days: 0, milliseconds: 10368000 }") - TableScan: LINEITEM - "# + Projection: lineitem.L_RETURNFLAG, lineitem.L_LINESTATUS, sum(lineitem.L_QUANTITY) AS SUM_QTY, sum(lineitem.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS SUM_DISC_PRICE, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT * Int32(1) + lineitem.L_TAX) AS SUM_CHARGE, avg(lineitem.L_QUANTITY) AS AVG_QTY, avg(lineitem.L_EXTENDEDPRICE) AS AVG_PRICE, avg(lineitem.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER + Sort: lineitem.L_RETURNFLAG ASC NULLS LAST, lineitem.L_LINESTATUS ASC NULLS LAST + Aggregate: groupBy=[[lineitem.L_RETURNFLAG, lineitem.L_LINESTATUS]], aggr=[[sum(lineitem.L_QUANTITY), sum(lineitem.L_EXTENDEDPRICE), sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT), sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT * Int32(1) + lineitem.L_TAX), avg(lineitem.L_QUANTITY), avg(lineitem.L_EXTENDEDPRICE), avg(lineitem.L_DISCOUNT), count(Int64(1))]] + Projection: lineitem.L_RETURNFLAG, lineitem.L_LINESTATUS, lineitem.L_QUANTITY, lineitem.L_EXTENDEDPRICE, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT), lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + lineitem.L_TAX), lineitem.L_DISCOUNT + Filter: lineitem.L_SHIPDATE <= Date32("1998-12-01") - IntervalDayTime("IntervalDayTime { days: 0, milliseconds: 10368000 }") + TableScan: lineitem + "# ); Ok(()) } @@ -70,31 +70,31 @@ mod tests { assert_snapshot!( plan_str, @r#" - Limit: skip=0, fetch=100 - Sort: SUPPLIER.S_ACCTBAL DESC NULLS FIRST, NATION.N_NAME ASC NULLS LAST, SUPPLIER.S_NAME ASC NULLS LAST, PART.P_PARTKEY ASC NULLS LAST - Projection: SUPPLIER.S_ACCTBAL, SUPPLIER.S_NAME, NATION.N_NAME, PART.P_PARTKEY, PART.P_MFGR, SUPPLIER.S_ADDRESS, SUPPLIER.S_PHONE, SUPPLIER.S_COMMENT - Filter: PART.P_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND PART.P_SIZE = Int32(15) AND PART.P_TYPE LIKE CAST(Utf8("%BRASS") AS Utf8) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") AND PARTSUPP.PS_SUPPLYCOST = () - Subquery: - Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]] - Projection: PARTSUPP.PS_SUPPLYCOST - Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") + Limit: skip=0, fetch=100 + Sort: supplier.S_ACCTBAL DESC NULLS FIRST, nation.N_NAME ASC NULLS LAST, supplier.S_NAME ASC NULLS LAST, part.P_PARTKEY ASC NULLS LAST + Projection: supplier.S_ACCTBAL, supplier.S_NAME, nation.N_NAME, part.P_PARTKEY, part.P_MFGR, supplier.S_ADDRESS, supplier.S_PHONE, supplier.S_COMMENT + Filter: part.P_PARTKEY = partsupp.PS_PARTKEY AND supplier.S_SUPPKEY = partsupp.PS_SUPPKEY AND part.P_SIZE = Int32(15) AND part.P_TYPE LIKE CAST(Utf8("%BRASS") AS Utf8) AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_REGIONKEY = region.R_REGIONKEY AND region.R_NAME = Utf8("EUROPE") AND partsupp.PS_SUPPLYCOST = () + Subquery: + Aggregate: groupBy=[[]], aggr=[[min(partsupp.PS_SUPPLYCOST)]] + Projection: partsupp.PS_SUPPLYCOST + Filter: partsupp.PS_PARTKEY = partsupp.PS_PARTKEY AND supplier.S_SUPPKEY = partsupp.PS_SUPPKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_REGIONKEY = region.R_REGIONKEY AND region.R_NAME = Utf8("EUROPE") + Cross Join: + Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: PARTSUPP - TableScan: SUPPLIER - TableScan: NATION - TableScan: REGION + TableScan: partsupp + TableScan: supplier + TableScan: nation + TableScan: region + Cross Join: + Cross Join: Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: PART - TableScan: SUPPLIER - TableScan: PARTSUPP - TableScan: NATION - TableScan: REGION - "# + TableScan: part + TableScan: supplier + TableScan: partsupp + TableScan: nation + TableScan: region + "# ); Ok(()) } @@ -105,19 +105,19 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: LINEITEM.L_ORDERKEY, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY - Limit: skip=0, fetch=10 - Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST, ORDERS.O_ORDERDATE ASC NULLS LAST - Projection: LINEITEM.L_ORDERKEY, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY - Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: CUSTOMER.C_MKTSEGMENT = Utf8("BUILDING") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-03-15") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8("1995-03-15") AS Date32) - Cross Join: - Cross Join: - TableScan: LINEITEM - TableScan: CUSTOMER - TableScan: ORDERS - "# + Projection: lineitem.L_ORDERKEY, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE, orders.O_ORDERDATE, orders.O_SHIPPRIORITY + Limit: skip=0, fetch=10 + Sort: sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) DESC NULLS FIRST, orders.O_ORDERDATE ASC NULLS LAST + Projection: lineitem.L_ORDERKEY, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT), orders.O_ORDERDATE, orders.O_SHIPPRIORITY + Aggregate: groupBy=[[lineitem.L_ORDERKEY, orders.O_ORDERDATE, orders.O_SHIPPRIORITY]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: lineitem.L_ORDERKEY, orders.O_ORDERDATE, orders.O_SHIPPRIORITY, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: customer.C_MKTSEGMENT = Utf8("BUILDING") AND customer.C_CUSTKEY = orders.O_CUSTKEY AND lineitem.L_ORDERKEY = orders.O_ORDERKEY AND orders.O_ORDERDATE < CAST(Utf8("1995-03-15") AS Date32) AND lineitem.L_SHIPDATE > CAST(Utf8("1995-03-15") AS Date32) + Cross Join: + Cross Join: + TableScan: lineitem + TableScan: customer + TableScan: orders + "# ); Ok(()) } @@ -128,16 +128,16 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT - Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST - Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]] - Projection: ORDERS.O_ORDERPRIORITY - Filter: ORDERS.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS () - Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE - TableScan: LINEITEM - TableScan: ORDERS - "# + Projection: orders.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT + Sort: orders.O_ORDERPRIORITY ASC NULLS LAST + Aggregate: groupBy=[[orders.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]] + Projection: orders.O_ORDERPRIORITY + Filter: orders.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND orders.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS () + Subquery: + Filter: lineitem.L_ORDERKEY = lineitem.L_ORDERKEY AND lineitem.L_COMMITDATE < lineitem.L_RECEIPTDATE + TableScan: lineitem + TableScan: orders + "# ); Ok(()) } @@ -148,23 +148,23 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: NATION.N_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE - Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST - Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("ASIA") AND ORDERS.O_ORDERDATE >= CAST(Utf8("1994-01-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-01-01") AS Date32) + Projection: nation.N_NAME, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE + Sort: sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) DESC NULLS FIRST + Aggregate: groupBy=[[nation.N_NAME]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: nation.N_NAME, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: customer.C_CUSTKEY = orders.O_CUSTKEY AND lineitem.L_ORDERKEY = orders.O_ORDERKEY AND lineitem.L_SUPPKEY = supplier.S_SUPPKEY AND customer.C_NATIONKEY = supplier.S_NATIONKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_REGIONKEY = region.R_REGIONKEY AND region.R_NAME = Utf8("ASIA") AND orders.O_ORDERDATE >= CAST(Utf8("1994-01-01") AS Date32) AND orders.O_ORDERDATE < CAST(Utf8("1995-01-01") AS Date32) + Cross Join: + Cross Join: Cross Join: Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: CUSTOMER - TableScan: ORDERS - TableScan: LINEITEM - TableScan: SUPPLIER - TableScan: NATION - TableScan: REGION - "# + TableScan: customer + TableScan: orders + TableScan: lineitem + TableScan: supplier + TableScan: nation + TableScan: region + "# ); Ok(()) } @@ -175,11 +175,11 @@ mod tests { assert_snapshot!( plan_str, @r#" - Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT) AS REVENUE]] - Projection: LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT - Filter: LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) AND LINEITEM.L_DISCOUNT >= Decimal128(Some(5),3,2) AND LINEITEM.L_DISCOUNT <= Decimal128(Some(7),3,2) AND LINEITEM.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2)) - TableScan: LINEITEM - "# + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * lineitem.L_DISCOUNT) AS REVENUE]] + Projection: lineitem.L_EXTENDEDPRICE * lineitem.L_DISCOUNT + Filter: lineitem.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) AND lineitem.L_DISCOUNT >= Decimal128(Some(5),3,2) AND lineitem.L_DISCOUNT <= Decimal128(Some(7),3,2) AND lineitem.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2)) + TableScan: lineitem + "# ); Ok(()) } @@ -214,21 +214,21 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE, CUSTOMER.C_ACCTBAL, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_PHONE, CUSTOMER.C_COMMENT - Limit: skip=0, fetch=20 - Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST - Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), CUSTOMER.C_ACCTBAL, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_PHONE, CUSTOMER.C_COMMENT - Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8("1993-10-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8("R") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY + Projection: customer.C_CUSTKEY, customer.C_NAME, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE, customer.C_ACCTBAL, nation.N_NAME, customer.C_ADDRESS, customer.C_PHONE, customer.C_COMMENT + Limit: skip=0, fetch=20 + Sort: sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) DESC NULLS FIRST + Projection: customer.C_CUSTKEY, customer.C_NAME, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT), customer.C_ACCTBAL, nation.N_NAME, customer.C_ADDRESS, customer.C_PHONE, customer.C_COMMENT + Aggregate: groupBy=[[customer.C_CUSTKEY, customer.C_NAME, customer.C_ACCTBAL, customer.C_PHONE, nation.N_NAME, customer.C_ADDRESS, customer.C_COMMENT]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: customer.C_CUSTKEY, customer.C_NAME, customer.C_ACCTBAL, customer.C_PHONE, nation.N_NAME, customer.C_ADDRESS, customer.C_COMMENT, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: customer.C_CUSTKEY = orders.O_CUSTKEY AND lineitem.L_ORDERKEY = orders.O_ORDERKEY AND orders.O_ORDERDATE >= CAST(Utf8("1993-10-01") AS Date32) AND orders.O_ORDERDATE < CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_RETURNFLAG = Utf8("R") AND customer.C_NATIONKEY = nation.N_NATIONKEY + Cross Join: + Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: CUSTOMER - TableScan: ORDERS - TableScan: LINEITEM - TableScan: NATION - "# + TableScan: customer + TableScan: orders + TableScan: lineitem + TableScan: nation + "# ); Ok(()) } @@ -239,28 +239,28 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: PARTSUPP.PS_PARTKEY, sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) AS value - Sort: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) DESC NULLS FIRST - Filter: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) > () - Subquery: - Projection: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) * Decimal128(Some(1000000),11,10) - Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] - Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) - Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") - Cross Join: - Cross Join: - TableScan: PARTSUPP - TableScan: SUPPLIER - TableScan: NATION - Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] - Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) - Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") + Projection: partsupp.PS_PARTKEY, sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) AS value + Sort: sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) DESC NULLS FIRST + Filter: sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) > () + Subquery: + Projection: sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) * Decimal128(Some(1000000),11,10) + Aggregate: groupBy=[[]], aggr=[[sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY)]] + Projection: partsupp.PS_SUPPLYCOST * CAST(partsupp.PS_AVAILQTY AS Decimal128(19, 0)) + Filter: partsupp.PS_SUPPKEY = supplier.S_SUPPKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("JAPAN") Cross Join: Cross Join: - TableScan: PARTSUPP - TableScan: SUPPLIER - TableScan: NATION - "# + TableScan: partsupp + TableScan: supplier + TableScan: nation + Aggregate: groupBy=[[partsupp.PS_PARTKEY]], aggr=[[sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY)]] + Projection: partsupp.PS_PARTKEY, partsupp.PS_SUPPLYCOST * CAST(partsupp.PS_AVAILQTY AS Decimal128(19, 0)) + Filter: partsupp.PS_SUPPKEY = supplier.S_SUPPKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("JAPAN") + Cross Join: + Cross Join: + TableScan: partsupp + TableScan: supplier + TableScan: nation + "# ); Ok(()) } @@ -271,15 +271,15 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: LINEITEM.L_SHIPMODE, sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS HIGH_LINE_COUNT, sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS LOW_LINE_COUNT - Sort: LINEITEM.L_SHIPMODE ASC NULLS LAST - Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END)]] - Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END - Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8("MAIL") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("SHIP") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8("1995-01-01") AS Date32) - Cross Join: - TableScan: ORDERS - TableScan: LINEITEM - "# + Projection: lineitem.L_SHIPMODE, sum(CASE WHEN orders.O_ORDERPRIORITY = Utf8("1-URGENT") OR orders.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS HIGH_LINE_COUNT, sum(CASE WHEN orders.O_ORDERPRIORITY != Utf8("1-URGENT") AND orders.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS LOW_LINE_COUNT + Sort: lineitem.L_SHIPMODE ASC NULLS LAST + Aggregate: groupBy=[[lineitem.L_SHIPMODE]], aggr=[[sum(CASE WHEN orders.O_ORDERPRIORITY = Utf8("1-URGENT") OR orders.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN orders.O_ORDERPRIORITY != Utf8("1-URGENT") AND orders.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END)]] + Projection: lineitem.L_SHIPMODE, CASE WHEN orders.O_ORDERPRIORITY = Utf8("1-URGENT") OR orders.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END, CASE WHEN orders.O_ORDERPRIORITY != Utf8("1-URGENT") AND orders.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END + Filter: orders.O_ORDERKEY = lineitem.L_ORDERKEY AND (lineitem.L_SHIPMODE = CAST(Utf8("MAIL") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("SHIP") AS Utf8)) AND lineitem.L_COMMITDATE < lineitem.L_RECEIPTDATE AND lineitem.L_SHIPDATE < lineitem.L_COMMITDATE AND lineitem.L_RECEIPTDATE >= CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_RECEIPTDATE < CAST(Utf8("1995-01-01") AS Date32) + Cross Join: + TableScan: orders + TableScan: lineitem + "# ); Ok(()) } @@ -290,17 +290,17 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST - Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST - Projection: count(ORDERS.O_ORDERKEY), count(Int64(1)) - Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]] - Projection: count(ORDERS.O_ORDERKEY) - Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]] - Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY - Left Join: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY Filter: NOT ORDERS.O_COMMENT LIKE CAST(Utf8("%special%requests%") AS Utf8) - TableScan: CUSTOMER - TableScan: ORDERS - "# ); + Projection: count(orders.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST + Sort: count(Int64(1)) DESC NULLS FIRST, count(orders.O_ORDERKEY) DESC NULLS FIRST + Projection: count(orders.O_ORDERKEY), count(Int64(1)) + Aggregate: groupBy=[[count(orders.O_ORDERKEY)]], aggr=[[count(Int64(1))]] + Projection: count(orders.O_ORDERKEY) + Aggregate: groupBy=[[customer.C_CUSTKEY]], aggr=[[count(orders.O_ORDERKEY)]] + Projection: customer.C_CUSTKEY, orders.O_ORDERKEY + Left Join: customer.C_CUSTKEY = orders.O_CUSTKEY Filter: NOT orders.O_COMMENT LIKE CAST(Utf8("%special%requests%") AS Utf8) + TableScan: customer + TableScan: orders + "# ); Ok(()) } @@ -310,14 +310,14 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: Decimal128(Some(10000),5,2) * sum(CASE WHEN PART.P_TYPE LIKE Utf8("PROMO%") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END) / sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS PROMO_REVENUE - Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8("PROMO%") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8("PROMO%") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32("1995-09-01") AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-10-01") AS Date32) - Cross Join: - TableScan: LINEITEM - TableScan: PART - "# + Projection: Decimal128(Some(10000),5,2) * sum(CASE WHEN part.P_TYPE LIKE Utf8("PROMO%") THEN lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END) / sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS PROMO_REVENUE + Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.P_TYPE LIKE Utf8("PROMO%") THEN lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: CASE WHEN part.P_TYPE LIKE CAST(Utf8("PROMO%") AS Utf8) THEN lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: lineitem.L_PARTKEY = part.P_PARTKEY AND lineitem.L_SHIPDATE >= Date32("1995-09-01") AND lineitem.L_SHIPDATE < CAST(Utf8("1995-10-01") AS Date32) + Cross Join: + TableScan: lineitem + TableScan: part + "# ); Ok(()) } @@ -336,19 +336,19 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: PART.P_BRAND, PART.P_TYPE, PART.P_SIZE, count(DISTINCT PARTSUPP.PS_SUPPKEY) AS SUPPLIER_CNT - Sort: count(DISTINCT PARTSUPP.PS_SUPPKEY) DESC NULLS FIRST, PART.P_BRAND ASC NULLS LAST, PART.P_TYPE ASC NULLS LAST, PART.P_SIZE ASC NULLS LAST - Aggregate: groupBy=[[PART.P_BRAND, PART.P_TYPE, PART.P_SIZE]], aggr=[[count(DISTINCT PARTSUPP.PS_SUPPKEY)]] - Projection: PART.P_BRAND, PART.P_TYPE, PART.P_SIZE, PARTSUPP.PS_SUPPKEY - Filter: PART.P_PARTKEY = PARTSUPP.PS_PARTKEY AND PART.P_BRAND != Utf8("Brand#45") AND NOT PART.P_TYPE LIKE CAST(Utf8("MEDIUM POLISHED%") AS Utf8) AND (PART.P_SIZE = Int32(49) OR PART.P_SIZE = Int32(14) OR PART.P_SIZE = Int32(23) OR PART.P_SIZE = Int32(45) OR PART.P_SIZE = Int32(19) OR PART.P_SIZE = Int32(3) OR PART.P_SIZE = Int32(36) OR PART.P_SIZE = Int32(9)) AND NOT PARTSUPP.PS_SUPPKEY IN () - Subquery: - Projection: SUPPLIER.S_SUPPKEY - Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8("%Customer%Complaints%") AS Utf8) - TableScan: SUPPLIER - Cross Join: - TableScan: PARTSUPP - TableScan: PART - "# + Projection: part.P_BRAND, part.P_TYPE, part.P_SIZE, count(DISTINCT partsupp.PS_SUPPKEY) AS SUPPLIER_CNT + Sort: count(DISTINCT partsupp.PS_SUPPKEY) DESC NULLS FIRST, part.P_BRAND ASC NULLS LAST, part.P_TYPE ASC NULLS LAST, part.P_SIZE ASC NULLS LAST + Aggregate: groupBy=[[part.P_BRAND, part.P_TYPE, part.P_SIZE]], aggr=[[count(DISTINCT partsupp.PS_SUPPKEY)]] + Projection: part.P_BRAND, part.P_TYPE, part.P_SIZE, partsupp.PS_SUPPKEY + Filter: part.P_PARTKEY = partsupp.PS_PARTKEY AND part.P_BRAND != Utf8("Brand#45") AND NOT part.P_TYPE LIKE CAST(Utf8("MEDIUM POLISHED%") AS Utf8) AND (part.P_SIZE = Int32(49) OR part.P_SIZE = Int32(14) OR part.P_SIZE = Int32(23) OR part.P_SIZE = Int32(45) OR part.P_SIZE = Int32(19) OR part.P_SIZE = Int32(3) OR part.P_SIZE = Int32(36) OR part.P_SIZE = Int32(9)) AND NOT partsupp.PS_SUPPKEY IN () + Subquery: + Projection: supplier.S_SUPPKEY + Filter: supplier.S_COMMENT LIKE CAST(Utf8("%Customer%Complaints%") AS Utf8) + TableScan: supplier + Cross Join: + TableScan: partsupp + TableScan: part + "# ); Ok(()) } @@ -366,25 +366,25 @@ mod tests { let plan_str = tpch_plan_to_string(18).await?; assert_snapshot!( plan_str, - @r#" - Projection: CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE, sum(LINEITEM.L_QUANTITY) AS EXPR$5 - Limit: skip=0, fetch=100 - Sort: ORDERS.O_TOTALPRICE DESC NULLS FIRST, ORDERS.O_ORDERDATE ASC NULLS LAST - Aggregate: groupBy=[[CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE]], aggr=[[sum(LINEITEM.L_QUANTITY)]] - Projection: CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE, LINEITEM.L_QUANTITY - Filter: ORDERS.O_ORDERKEY IN () AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY - Subquery: - Projection: LINEITEM.L_ORDERKEY - Filter: sum(LINEITEM.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2)) - Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]] - Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY - TableScan: LINEITEM - Cross Join: - Cross Join: - TableScan: CUSTOMER - TableScan: ORDERS - TableScan: LINEITEM - "# + @" + Projection: customer.C_NAME, customer.C_CUSTKEY, orders.O_ORDERKEY, orders.O_ORDERDATE, orders.O_TOTALPRICE, sum(lineitem.L_QUANTITY) AS EXPR$5 + Limit: skip=0, fetch=100 + Sort: orders.O_TOTALPRICE DESC NULLS FIRST, orders.O_ORDERDATE ASC NULLS LAST + Aggregate: groupBy=[[customer.C_NAME, customer.C_CUSTKEY, orders.O_ORDERKEY, orders.O_ORDERDATE, orders.O_TOTALPRICE]], aggr=[[sum(lineitem.L_QUANTITY)]] + Projection: customer.C_NAME, customer.C_CUSTKEY, orders.O_ORDERKEY, orders.O_ORDERDATE, orders.O_TOTALPRICE, lineitem.L_QUANTITY + Filter: orders.O_ORDERKEY IN () AND customer.C_CUSTKEY = orders.O_CUSTKEY AND orders.O_ORDERKEY = lineitem.L_ORDERKEY + Subquery: + Projection: lineitem.L_ORDERKEY + Filter: sum(lineitem.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2)) + Aggregate: groupBy=[[lineitem.L_ORDERKEY]], aggr=[[sum(lineitem.L_QUANTITY)]] + Projection: lineitem.L_ORDERKEY, lineitem.L_QUANTITY + TableScan: lineitem + Cross Join: + Cross Join: + TableScan: customer + TableScan: orders + TableScan: lineitem + " ); Ok(()) } @@ -394,13 +394,13 @@ mod tests { assert_snapshot!( plan_str, @r#" - Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]] - Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#12") AND (PART.P_CONTAINER = CAST(Utf8("SM CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND (PART.P_CONTAINER = CAST(Utf8("MED BAG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PKG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PACK") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#34") AND (PART.P_CONTAINER = CAST(Utf8("LG CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") - Cross Join: - TableScan: LINEITEM - TableScan: PART - "# + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE]] + Projection: lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: part.P_PARTKEY = lineitem.L_PARTKEY AND part.P_BRAND = Utf8("Brand#12") AND (part.P_CONTAINER = CAST(Utf8("SM CASE") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("SM BOX") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("SM PACK") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("SM PKG") AS Utf8)) AND lineitem.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND lineitem.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND part.P_SIZE >= Int32(1) AND part.P_SIZE <= Int32(5) AND (lineitem.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND lineitem.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR part.P_PARTKEY = lineitem.L_PARTKEY AND part.P_BRAND = Utf8("Brand#23") AND (part.P_CONTAINER = CAST(Utf8("MED BAG") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("MED BOX") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("MED PKG") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("MED PACK") AS Utf8)) AND lineitem.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND lineitem.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND part.P_SIZE >= Int32(1) AND part.P_SIZE <= Int32(10) AND (lineitem.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND lineitem.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR part.P_PARTKEY = lineitem.L_PARTKEY AND part.P_BRAND = Utf8("Brand#34") AND (part.P_CONTAINER = CAST(Utf8("LG CASE") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("LG BOX") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("LG PACK") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("LG PKG") AS Utf8)) AND lineitem.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND lineitem.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND part.P_SIZE >= Int32(1) AND part.P_SIZE <= Int32(15) AND (lineitem.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND lineitem.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") + Cross Join: + TableScan: lineitem + TableScan: part + "# ); Ok(()) } @@ -411,27 +411,27 @@ mod tests { assert_snapshot!( plan_str, @r#" - Sort: SUPPLIER.S_NAME ASC NULLS LAST - Projection: SUPPLIER.S_NAME, SUPPLIER.S_ADDRESS - Filter: SUPPLIER.S_SUPPKEY IN () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("CANADA") - Subquery: - Projection: PARTSUPP.PS_SUPPKEY - Filter: PARTSUPP.PS_PARTKEY IN () AND CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) > () - Subquery: - Projection: PART.P_PARTKEY - Filter: PART.P_NAME LIKE CAST(Utf8("forest%") AS Utf8) - TableScan: PART - Subquery: - Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY) - Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]] - Projection: LINEITEM.L_QUANTITY - Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) - TableScan: LINEITEM - TableScan: PARTSUPP - Cross Join: - TableScan: SUPPLIER - TableScan: NATION - "# + Sort: supplier.S_NAME ASC NULLS LAST + Projection: supplier.S_NAME, supplier.S_ADDRESS + Filter: supplier.S_SUPPKEY IN () AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("CANADA") + Subquery: + Projection: partsupp.PS_SUPPKEY + Filter: partsupp.PS_PARTKEY IN () AND CAST(partsupp.PS_AVAILQTY AS Decimal128(19, 0)) > () + Subquery: + Projection: part.P_PARTKEY + Filter: part.P_NAME LIKE CAST(Utf8("forest%") AS Utf8) + TableScan: part + Subquery: + Projection: Decimal128(Some(5),2,1) * sum(lineitem.L_QUANTITY) + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.L_QUANTITY)]] + Projection: lineitem.L_QUANTITY + Filter: lineitem.L_PARTKEY = lineitem.L_ORDERKEY AND lineitem.L_SUPPKEY = lineitem.L_PARTKEY AND lineitem.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) + TableScan: lineitem + TableScan: partsupp + Cross Join: + TableScan: supplier + TableScan: nation + "# ); Ok(()) } @@ -442,25 +442,25 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT + Projection: supplier.S_NAME, count(Int64(1)) AS NUMWAIT Limit: skip=0, fetch=100 - Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST - Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]] - Projection: SUPPLIER.S_NAME - Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8("F") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("SAUDI ARABIA") + Sort: count(Int64(1)) DESC NULLS FIRST, supplier.S_NAME ASC NULLS LAST + Aggregate: groupBy=[[supplier.S_NAME]], aggr=[[count(Int64(1))]] + Projection: supplier.S_NAME + Filter: supplier.S_SUPPKEY = lineitem.L_SUPPKEY AND orders.O_ORDERKEY = lineitem.L_ORDERKEY AND orders.O_ORDERSTATUS = Utf8("F") AND lineitem.L_RECEIPTDATE > lineitem.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("SAUDI ARABIA") Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS - TableScan: LINEITEM + Filter: lineitem.L_ORDERKEY = lineitem.L_TAX AND lineitem.L_SUPPKEY != lineitem.L_LINESTATUS + TableScan: lineitem Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE - TableScan: LINEITEM + Filter: lineitem.L_ORDERKEY = lineitem.L_TAX AND lineitem.L_SUPPKEY != lineitem.L_LINESTATUS AND lineitem.L_RECEIPTDATE > lineitem.L_COMMITDATE + TableScan: lineitem Cross Join: Cross Join: Cross Join: - TableScan: SUPPLIER - TableScan: LINEITEM - TableScan: ORDERS - TableScan: NATION + TableScan: supplier + TableScan: lineitem + TableScan: orders + TableScan: nation "# ); Ok(()) @@ -472,20 +472,20 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL - Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST - Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]] - Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL - Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) AND CUSTOMER.C_ACCTBAL > () AND NOT EXISTS () + Projection: substr(customer.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(customer.C_ACCTBAL) AS TOTACCTBAL + Sort: substr(customer.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST + Aggregate: groupBy=[[substr(customer.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(customer.C_ACCTBAL)]] + Projection: substr(customer.C_PHONE, Int32(1), Int32(2)), customer.C_ACCTBAL + Filter: (substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) AND customer.C_ACCTBAL > () AND NOT EXISTS () Subquery: - Aggregate: groupBy=[[]], aggr=[[avg(CUSTOMER.C_ACCTBAL)]] - Projection: CUSTOMER.C_ACCTBAL - Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) - TableScan: CUSTOMER + Aggregate: groupBy=[[]], aggr=[[avg(customer.C_ACCTBAL)]] + Projection: customer.C_ACCTBAL + Filter: customer.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) + TableScan: customer Subquery: - Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY - TableScan: ORDERS - TableScan: CUSTOMER + Filter: orders.O_CUSTKEY = orders.O_ORDERKEY + TableScan: orders + TableScan: customer "# ); Ok(()) @@ -639,11 +639,11 @@ mod tests { assert_snapshot!( plan_str, - @r#" - Projection: count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR - WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] - TableScan: DATA - "# + @" + Projection: count(Int64(1)) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR + WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: data + " ); Ok(()) } @@ -708,4 +708,73 @@ mod tests { Ok(()) } + + /// Verifies that uppercase table names from Substrait (e.g. Calcite/Isthmus) + /// are normalized to lowercase, matching DataFusion's default catalog behavior. + #[tokio::test] + async fn test_uppercase_table_name_resolves_to_lowercase() -> Result<()> { + // The simple_select plan references table "DATA" (uppercase) + let path = "tests/testdata/test_plans/simple_select.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = SessionContext::new(); + let schema = datafusion::arrow::datatypes::Schema::new(vec![ + datafusion::arrow::datatypes::Field::new( + "a", + datafusion::arrow::datatypes::DataType::Int32, + true, + ), + datafusion::arrow::datatypes::Field::new( + "b", + datafusion::arrow::datatypes::DataType::Int32, + true, + ), + ]); + let table = + datafusion::datasource::empty::EmptyTable::new(std::sync::Arc::new(schema)); + ctx.register_table("data", std::sync::Arc::new(table))?; + + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + let plan_str = format!("{plan}"); + + assert_snapshot!( + plan_str, + @r" + Projection: data.a, data.b + TableScan: data + " + ); + + ctx.state().create_physical_plan(&plan).await?; + Ok(()) + } + + /// Same as above but uses add_plan_schemas_to_ctx which also normalizes. + #[tokio::test] + async fn test_uppercase_table_name_with_plan_schemas() -> Result<()> { + let path = "tests/testdata/test_plans/simple_select.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + let plan_str = format!("{plan}"); + + assert_snapshot!( + plan_str, + @r" + Projection: data.a, data.b + TableScan: data + " + ); + + ctx.state().create_physical_plan(&plan).await?; + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index e916b4cb0e1a9..ad238093e8ce4 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -38,10 +38,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.A AS a, DATA.B AS b, DATA.A + Int64(1) AS add1 - TableScan: DATA - "# + @" + Projection: data.A AS a, data.B AS b, data.A + Int64(1) AS add1 + TableScan: data + " ); Ok(()) } @@ -57,11 +57,11 @@ mod tests { assert_snapshot!( plan, // Note that duplicate references in the remap are aliased - @r#" - Projection: DATA.B, DATA.A AS A1, DATA.A AS DATA.A__temp__0 AS A2 - Filter: DATA.B = Int64(2) - TableScan: DATA - "# + @" + Projection: data.B, data.A AS A1, data.A AS data.A__temp__0 AS A2 + Filter: data.B = Int64(2) + TableScan: data + " ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 426f3c12e5a15..4a3bb32e1e4d1 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -43,10 +43,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: NOT DATA.D AS EXPR$0 - TableScan: DATA - "# + @" + Projection: NOT data.D AS EXPR$0 + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -74,11 +74,11 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR - WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] - TableScan: DATA - "# + @" + Projection: sum(data.D) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR + WindowAggr: windowExpr=[[sum(data.D) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -101,11 +101,11 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW__temp__0 AS ALIASED - WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - TableScan: DATA - "# + @" + Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW__temp__0 AS ALIASED + WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -130,12 +130,12 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() PARTITION BY [DATA.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$1 - WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - WindowAggr: windowExpr=[[row_number() PARTITION BY [DATA.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - TableScan: DATA - "# + @" + Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() PARTITION BY [data.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$1 + WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[row_number() PARTITION BY [data.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -164,18 +164,18 @@ mod tests { settings.bind(|| { assert_snapshot!( plan, - @r#" + @" Projection: left.A, left.[UUID] AS C, right.D, Utf8(NULL) AS [UUID] AS E Left Join: left.A = right.A SubqueryAlias: left Union - Projection: A.A, Utf8(NULL) AS [UUID] - TableScan: A - Projection: B.A, CAST(B.C AS Utf8) - TableScan: B + Projection: a.A, Utf8(NULL) AS [UUID] + TableScan: a + Projection: b.A, CAST(b.C AS Utf8) + TableScan: b SubqueryAlias: right - TableScan: C - "# + TableScan: c + " ); }); diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index c8cc3fe9940ce..4a95580cce932 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -34,7 +34,7 @@ mod tests { table_name: &str, fields: Vec<(&str, DataType, bool)>, ) -> Result { - let table_ref = TableReference::bare(table_name); + let table_ref = TableReference::parse_str(table_name); let fields: Vec<(Option, Arc)> = fields .into_iter() .map(|pair| { @@ -69,10 +69,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.a, DATA.b - TableScan: DATA - "# + @" + Projection: data.a, data.b + TableScan: data + " ); Ok(()) } @@ -92,10 +92,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.a, DATA.b - TableScan: DATA projection=[a, b] - "# + @" + Projection: data.a, data.b + TableScan: data projection=[a, b] + " ); Ok(()) } @@ -117,10 +117,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.a, DATA.b - TableScan: DATA projection=[a, b] - "# + @" + Projection: data.a, data.b + TableScan: data projection=[a, b] + " ); Ok(()) } diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index f84594312b634..db582f221148b 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -124,22 +124,10 @@ pub mod test { } fn collect_named_table(&mut self, read: &ReadRel, nt: &NamedTable) -> Result<()> { - let table_reference = match nt.names.len() { - 0 => { - panic!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, + let table_reference = if nt.names.is_empty() { + panic!("No table name found in NamedTable"); + } else { + TableReference::parse_str(&nt.names.join(".")) }; let substrait_schema =