diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs index 48e93c04bb034..bab8b9beb5930 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -86,22 +86,15 @@ pub async fn from_read_rel( match &read.read_type { Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, + // Normalize table names using DataFusion's identifier normalization + // (via TableReference::parse_str). Since Substrait has no concept of + // quoted identifiers, all names are treated as unquoted — this ensures + // interoperability with producers like Calcite/Isthmus that emit + // uppercase names (e.g. "LINEITEM" -> "lineitem"). + let table_reference = if nt.names.is_empty() { + return plan_err!("No table name found in NamedTable"); + } else { + TableReference::parse_str(&nt.names.join(".")) }; read_with_schema( diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index a92fc2957cae3..e9363fe38d035 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -53,13 +53,13 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER - Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST - Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]] - Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT - Filter: LINEITEM.L_SHIPDATE <= Date32("1998-12-01") - IntervalDayTime("IntervalDayTime { days: 0, milliseconds: 10368000 }") - TableScan: LINEITEM - "# + Projection: lineitem.L_RETURNFLAG, lineitem.L_LINESTATUS, sum(lineitem.L_QUANTITY) AS SUM_QTY, sum(lineitem.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS SUM_DISC_PRICE, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT * Int32(1) + lineitem.L_TAX) AS SUM_CHARGE, avg(lineitem.L_QUANTITY) AS AVG_QTY, avg(lineitem.L_EXTENDEDPRICE) AS AVG_PRICE, avg(lineitem.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER + Sort: lineitem.L_RETURNFLAG ASC NULLS LAST, lineitem.L_LINESTATUS ASC NULLS LAST + Aggregate: groupBy=[[lineitem.L_RETURNFLAG, lineitem.L_LINESTATUS]], aggr=[[sum(lineitem.L_QUANTITY), sum(lineitem.L_EXTENDEDPRICE), sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT), sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT * Int32(1) + lineitem.L_TAX), avg(lineitem.L_QUANTITY), avg(lineitem.L_EXTENDEDPRICE), avg(lineitem.L_DISCOUNT), count(Int64(1))]] + Projection: lineitem.L_RETURNFLAG, lineitem.L_LINESTATUS, lineitem.L_QUANTITY, lineitem.L_EXTENDEDPRICE, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT), lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + lineitem.L_TAX), lineitem.L_DISCOUNT + Filter: lineitem.L_SHIPDATE <= Date32("1998-12-01") - IntervalDayTime("IntervalDayTime { days: 0, milliseconds: 10368000 }") + TableScan: lineitem + "# ); Ok(()) } @@ -70,31 +70,31 @@ mod tests { assert_snapshot!( plan_str, @r#" - Limit: skip=0, fetch=100 - Sort: SUPPLIER.S_ACCTBAL DESC NULLS FIRST, NATION.N_NAME ASC NULLS LAST, SUPPLIER.S_NAME ASC NULLS LAST, PART.P_PARTKEY ASC NULLS LAST - Projection: SUPPLIER.S_ACCTBAL, SUPPLIER.S_NAME, NATION.N_NAME, PART.P_PARTKEY, PART.P_MFGR, SUPPLIER.S_ADDRESS, SUPPLIER.S_PHONE, SUPPLIER.S_COMMENT - Filter: PART.P_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND PART.P_SIZE = Int32(15) AND PART.P_TYPE LIKE CAST(Utf8("%BRASS") AS Utf8) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") AND PARTSUPP.PS_SUPPLYCOST = () - Subquery: - Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]] - Projection: PARTSUPP.PS_SUPPLYCOST - Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") + Limit: skip=0, fetch=100 + Sort: supplier.S_ACCTBAL DESC NULLS FIRST, nation.N_NAME ASC NULLS LAST, supplier.S_NAME ASC NULLS LAST, part.P_PARTKEY ASC NULLS LAST + Projection: supplier.S_ACCTBAL, supplier.S_NAME, nation.N_NAME, part.P_PARTKEY, part.P_MFGR, supplier.S_ADDRESS, supplier.S_PHONE, supplier.S_COMMENT + Filter: part.P_PARTKEY = partsupp.PS_PARTKEY AND supplier.S_SUPPKEY = partsupp.PS_SUPPKEY AND part.P_SIZE = Int32(15) AND part.P_TYPE LIKE CAST(Utf8("%BRASS") AS Utf8) AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_REGIONKEY = region.R_REGIONKEY AND region.R_NAME = Utf8("EUROPE") AND partsupp.PS_SUPPLYCOST = () + Subquery: + Aggregate: groupBy=[[]], aggr=[[min(partsupp.PS_SUPPLYCOST)]] + Projection: partsupp.PS_SUPPLYCOST + Filter: partsupp.PS_PARTKEY = partsupp.PS_PARTKEY AND supplier.S_SUPPKEY = partsupp.PS_SUPPKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_REGIONKEY = region.R_REGIONKEY AND region.R_NAME = Utf8("EUROPE") + Cross Join: + Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: PARTSUPP - TableScan: SUPPLIER - TableScan: NATION - TableScan: REGION + TableScan: partsupp + TableScan: supplier + TableScan: nation + TableScan: region + Cross Join: + Cross Join: Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: PART - TableScan: SUPPLIER - TableScan: PARTSUPP - TableScan: NATION - TableScan: REGION - "# + TableScan: part + TableScan: supplier + TableScan: partsupp + TableScan: nation + TableScan: region + "# ); Ok(()) } @@ -105,19 +105,19 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: LINEITEM.L_ORDERKEY, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY - Limit: skip=0, fetch=10 - Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST, ORDERS.O_ORDERDATE ASC NULLS LAST - Projection: LINEITEM.L_ORDERKEY, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY - Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: CUSTOMER.C_MKTSEGMENT = Utf8("BUILDING") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-03-15") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8("1995-03-15") AS Date32) - Cross Join: - Cross Join: - TableScan: LINEITEM - TableScan: CUSTOMER - TableScan: ORDERS - "# + Projection: lineitem.L_ORDERKEY, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE, orders.O_ORDERDATE, orders.O_SHIPPRIORITY + Limit: skip=0, fetch=10 + Sort: sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) DESC NULLS FIRST, orders.O_ORDERDATE ASC NULLS LAST + Projection: lineitem.L_ORDERKEY, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT), orders.O_ORDERDATE, orders.O_SHIPPRIORITY + Aggregate: groupBy=[[lineitem.L_ORDERKEY, orders.O_ORDERDATE, orders.O_SHIPPRIORITY]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: lineitem.L_ORDERKEY, orders.O_ORDERDATE, orders.O_SHIPPRIORITY, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: customer.C_MKTSEGMENT = Utf8("BUILDING") AND customer.C_CUSTKEY = orders.O_CUSTKEY AND lineitem.L_ORDERKEY = orders.O_ORDERKEY AND orders.O_ORDERDATE < CAST(Utf8("1995-03-15") AS Date32) AND lineitem.L_SHIPDATE > CAST(Utf8("1995-03-15") AS Date32) + Cross Join: + Cross Join: + TableScan: lineitem + TableScan: customer + TableScan: orders + "# ); Ok(()) } @@ -128,16 +128,16 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT - Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST - Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]] - Projection: ORDERS.O_ORDERPRIORITY - Filter: ORDERS.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS () - Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE - TableScan: LINEITEM - TableScan: ORDERS - "# + Projection: orders.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT + Sort: orders.O_ORDERPRIORITY ASC NULLS LAST + Aggregate: groupBy=[[orders.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]] + Projection: orders.O_ORDERPRIORITY + Filter: orders.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND orders.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS () + Subquery: + Filter: lineitem.L_ORDERKEY = lineitem.L_ORDERKEY AND lineitem.L_COMMITDATE < lineitem.L_RECEIPTDATE + TableScan: lineitem + TableScan: orders + "# ); Ok(()) } @@ -148,23 +148,23 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: NATION.N_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE - Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST - Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("ASIA") AND ORDERS.O_ORDERDATE >= CAST(Utf8("1994-01-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1995-01-01") AS Date32) + Projection: nation.N_NAME, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE + Sort: sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) DESC NULLS FIRST + Aggregate: groupBy=[[nation.N_NAME]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: nation.N_NAME, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: customer.C_CUSTKEY = orders.O_CUSTKEY AND lineitem.L_ORDERKEY = orders.O_ORDERKEY AND lineitem.L_SUPPKEY = supplier.S_SUPPKEY AND customer.C_NATIONKEY = supplier.S_NATIONKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_REGIONKEY = region.R_REGIONKEY AND region.R_NAME = Utf8("ASIA") AND orders.O_ORDERDATE >= CAST(Utf8("1994-01-01") AS Date32) AND orders.O_ORDERDATE < CAST(Utf8("1995-01-01") AS Date32) + Cross Join: + Cross Join: Cross Join: Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: CUSTOMER - TableScan: ORDERS - TableScan: LINEITEM - TableScan: SUPPLIER - TableScan: NATION - TableScan: REGION - "# + TableScan: customer + TableScan: orders + TableScan: lineitem + TableScan: supplier + TableScan: nation + TableScan: region + "# ); Ok(()) } @@ -175,11 +175,11 @@ mod tests { assert_snapshot!( plan_str, @r#" - Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT) AS REVENUE]] - Projection: LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT - Filter: LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) AND LINEITEM.L_DISCOUNT >= Decimal128(Some(5),3,2) AND LINEITEM.L_DISCOUNT <= Decimal128(Some(7),3,2) AND LINEITEM.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2)) - TableScan: LINEITEM - "# + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * lineitem.L_DISCOUNT) AS REVENUE]] + Projection: lineitem.L_EXTENDEDPRICE * lineitem.L_DISCOUNT + Filter: lineitem.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) AND lineitem.L_DISCOUNT >= Decimal128(Some(5),3,2) AND lineitem.L_DISCOUNT <= Decimal128(Some(7),3,2) AND lineitem.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2)) + TableScan: lineitem + "# ); Ok(()) } @@ -214,21 +214,21 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE, CUSTOMER.C_ACCTBAL, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_PHONE, CUSTOMER.C_COMMENT - Limit: skip=0, fetch=20 - Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST - Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), CUSTOMER.C_ACCTBAL, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_PHONE, CUSTOMER.C_COMMENT - Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8("1993-10-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8("R") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY + Projection: customer.C_CUSTKEY, customer.C_NAME, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE, customer.C_ACCTBAL, nation.N_NAME, customer.C_ADDRESS, customer.C_PHONE, customer.C_COMMENT + Limit: skip=0, fetch=20 + Sort: sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) DESC NULLS FIRST + Projection: customer.C_CUSTKEY, customer.C_NAME, sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT), customer.C_ACCTBAL, nation.N_NAME, customer.C_ADDRESS, customer.C_PHONE, customer.C_COMMENT + Aggregate: groupBy=[[customer.C_CUSTKEY, customer.C_NAME, customer.C_ACCTBAL, customer.C_PHONE, nation.N_NAME, customer.C_ADDRESS, customer.C_COMMENT]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: customer.C_CUSTKEY, customer.C_NAME, customer.C_ACCTBAL, customer.C_PHONE, nation.N_NAME, customer.C_ADDRESS, customer.C_COMMENT, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: customer.C_CUSTKEY = orders.O_CUSTKEY AND lineitem.L_ORDERKEY = orders.O_ORDERKEY AND orders.O_ORDERDATE >= CAST(Utf8("1993-10-01") AS Date32) AND orders.O_ORDERDATE < CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_RETURNFLAG = Utf8("R") AND customer.C_NATIONKEY = nation.N_NATIONKEY + Cross Join: + Cross Join: Cross Join: - Cross Join: - Cross Join: - TableScan: CUSTOMER - TableScan: ORDERS - TableScan: LINEITEM - TableScan: NATION - "# + TableScan: customer + TableScan: orders + TableScan: lineitem + TableScan: nation + "# ); Ok(()) } @@ -239,28 +239,28 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: PARTSUPP.PS_PARTKEY, sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) AS value - Sort: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) DESC NULLS FIRST - Filter: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) > () - Subquery: - Projection: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) * Decimal128(Some(1000000),11,10) - Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] - Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) - Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") - Cross Join: - Cross Join: - TableScan: PARTSUPP - TableScan: SUPPLIER - TableScan: NATION - Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]] - Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) - Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("JAPAN") + Projection: partsupp.PS_PARTKEY, sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) AS value + Sort: sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) DESC NULLS FIRST + Filter: sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) > () + Subquery: + Projection: sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY) * Decimal128(Some(1000000),11,10) + Aggregate: groupBy=[[]], aggr=[[sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY)]] + Projection: partsupp.PS_SUPPLYCOST * CAST(partsupp.PS_AVAILQTY AS Decimal128(19, 0)) + Filter: partsupp.PS_SUPPKEY = supplier.S_SUPPKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("JAPAN") Cross Join: Cross Join: - TableScan: PARTSUPP - TableScan: SUPPLIER - TableScan: NATION - "# + TableScan: partsupp + TableScan: supplier + TableScan: nation + Aggregate: groupBy=[[partsupp.PS_PARTKEY]], aggr=[[sum(partsupp.PS_SUPPLYCOST * partsupp.PS_AVAILQTY)]] + Projection: partsupp.PS_PARTKEY, partsupp.PS_SUPPLYCOST * CAST(partsupp.PS_AVAILQTY AS Decimal128(19, 0)) + Filter: partsupp.PS_SUPPKEY = supplier.S_SUPPKEY AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("JAPAN") + Cross Join: + Cross Join: + TableScan: partsupp + TableScan: supplier + TableScan: nation + "# ); Ok(()) } @@ -271,15 +271,15 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: LINEITEM.L_SHIPMODE, sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS HIGH_LINE_COUNT, sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS LOW_LINE_COUNT - Sort: LINEITEM.L_SHIPMODE ASC NULLS LAST - Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END)]] - Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8("1-URGENT") OR ORDERS.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8("1-URGENT") AND ORDERS.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END - Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8("MAIL") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("SHIP") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8("1995-01-01") AS Date32) - Cross Join: - TableScan: ORDERS - TableScan: LINEITEM - "# + Projection: lineitem.L_SHIPMODE, sum(CASE WHEN orders.O_ORDERPRIORITY = Utf8("1-URGENT") OR orders.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS HIGH_LINE_COUNT, sum(CASE WHEN orders.O_ORDERPRIORITY != Utf8("1-URGENT") AND orders.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END) AS LOW_LINE_COUNT + Sort: lineitem.L_SHIPMODE ASC NULLS LAST + Aggregate: groupBy=[[lineitem.L_SHIPMODE]], aggr=[[sum(CASE WHEN orders.O_ORDERPRIORITY = Utf8("1-URGENT") OR orders.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN orders.O_ORDERPRIORITY != Utf8("1-URGENT") AND orders.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END)]] + Projection: lineitem.L_SHIPMODE, CASE WHEN orders.O_ORDERPRIORITY = Utf8("1-URGENT") OR orders.O_ORDERPRIORITY = Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END, CASE WHEN orders.O_ORDERPRIORITY != Utf8("1-URGENT") AND orders.O_ORDERPRIORITY != Utf8("2-HIGH") THEN Int32(1) ELSE Int32(0) END + Filter: orders.O_ORDERKEY = lineitem.L_ORDERKEY AND (lineitem.L_SHIPMODE = CAST(Utf8("MAIL") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("SHIP") AS Utf8)) AND lineitem.L_COMMITDATE < lineitem.L_RECEIPTDATE AND lineitem.L_SHIPDATE < lineitem.L_COMMITDATE AND lineitem.L_RECEIPTDATE >= CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_RECEIPTDATE < CAST(Utf8("1995-01-01") AS Date32) + Cross Join: + TableScan: orders + TableScan: lineitem + "# ); Ok(()) } @@ -290,17 +290,17 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST - Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST - Projection: count(ORDERS.O_ORDERKEY), count(Int64(1)) - Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]] - Projection: count(ORDERS.O_ORDERKEY) - Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]] - Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY - Left Join: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY Filter: NOT ORDERS.O_COMMENT LIKE CAST(Utf8("%special%requests%") AS Utf8) - TableScan: CUSTOMER - TableScan: ORDERS - "# ); + Projection: count(orders.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST + Sort: count(Int64(1)) DESC NULLS FIRST, count(orders.O_ORDERKEY) DESC NULLS FIRST + Projection: count(orders.O_ORDERKEY), count(Int64(1)) + Aggregate: groupBy=[[count(orders.O_ORDERKEY)]], aggr=[[count(Int64(1))]] + Projection: count(orders.O_ORDERKEY) + Aggregate: groupBy=[[customer.C_CUSTKEY]], aggr=[[count(orders.O_ORDERKEY)]] + Projection: customer.C_CUSTKEY, orders.O_ORDERKEY + Left Join: customer.C_CUSTKEY = orders.O_CUSTKEY Filter: NOT orders.O_COMMENT LIKE CAST(Utf8("%special%requests%") AS Utf8) + TableScan: customer + TableScan: orders + "# ); Ok(()) } @@ -310,14 +310,14 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: Decimal128(Some(10000),5,2) * sum(CASE WHEN PART.P_TYPE LIKE Utf8("PROMO%") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END) / sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS PROMO_REVENUE - Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8("PROMO%") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]] - Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8("PROMO%") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32("1995-09-01") AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-10-01") AS Date32) - Cross Join: - TableScan: LINEITEM - TableScan: PART - "# + Projection: Decimal128(Some(10000),5,2) * sum(CASE WHEN part.P_TYPE LIKE Utf8("PROMO%") THEN lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END) / sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS PROMO_REVENUE + Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.P_TYPE LIKE Utf8("PROMO%") THEN lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT)]] + Projection: CASE WHEN part.P_TYPE LIKE CAST(Utf8("PROMO%") AS Utf8) THEN lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: lineitem.L_PARTKEY = part.P_PARTKEY AND lineitem.L_SHIPDATE >= Date32("1995-09-01") AND lineitem.L_SHIPDATE < CAST(Utf8("1995-10-01") AS Date32) + Cross Join: + TableScan: lineitem + TableScan: part + "# ); Ok(()) } @@ -336,19 +336,19 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: PART.P_BRAND, PART.P_TYPE, PART.P_SIZE, count(DISTINCT PARTSUPP.PS_SUPPKEY) AS SUPPLIER_CNT - Sort: count(DISTINCT PARTSUPP.PS_SUPPKEY) DESC NULLS FIRST, PART.P_BRAND ASC NULLS LAST, PART.P_TYPE ASC NULLS LAST, PART.P_SIZE ASC NULLS LAST - Aggregate: groupBy=[[PART.P_BRAND, PART.P_TYPE, PART.P_SIZE]], aggr=[[count(DISTINCT PARTSUPP.PS_SUPPKEY)]] - Projection: PART.P_BRAND, PART.P_TYPE, PART.P_SIZE, PARTSUPP.PS_SUPPKEY - Filter: PART.P_PARTKEY = PARTSUPP.PS_PARTKEY AND PART.P_BRAND != Utf8("Brand#45") AND NOT PART.P_TYPE LIKE CAST(Utf8("MEDIUM POLISHED%") AS Utf8) AND (PART.P_SIZE = Int32(49) OR PART.P_SIZE = Int32(14) OR PART.P_SIZE = Int32(23) OR PART.P_SIZE = Int32(45) OR PART.P_SIZE = Int32(19) OR PART.P_SIZE = Int32(3) OR PART.P_SIZE = Int32(36) OR PART.P_SIZE = Int32(9)) AND NOT PARTSUPP.PS_SUPPKEY IN () - Subquery: - Projection: SUPPLIER.S_SUPPKEY - Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8("%Customer%Complaints%") AS Utf8) - TableScan: SUPPLIER - Cross Join: - TableScan: PARTSUPP - TableScan: PART - "# + Projection: part.P_BRAND, part.P_TYPE, part.P_SIZE, count(DISTINCT partsupp.PS_SUPPKEY) AS SUPPLIER_CNT + Sort: count(DISTINCT partsupp.PS_SUPPKEY) DESC NULLS FIRST, part.P_BRAND ASC NULLS LAST, part.P_TYPE ASC NULLS LAST, part.P_SIZE ASC NULLS LAST + Aggregate: groupBy=[[part.P_BRAND, part.P_TYPE, part.P_SIZE]], aggr=[[count(DISTINCT partsupp.PS_SUPPKEY)]] + Projection: part.P_BRAND, part.P_TYPE, part.P_SIZE, partsupp.PS_SUPPKEY + Filter: part.P_PARTKEY = partsupp.PS_PARTKEY AND part.P_BRAND != Utf8("Brand#45") AND NOT part.P_TYPE LIKE CAST(Utf8("MEDIUM POLISHED%") AS Utf8) AND (part.P_SIZE = Int32(49) OR part.P_SIZE = Int32(14) OR part.P_SIZE = Int32(23) OR part.P_SIZE = Int32(45) OR part.P_SIZE = Int32(19) OR part.P_SIZE = Int32(3) OR part.P_SIZE = Int32(36) OR part.P_SIZE = Int32(9)) AND NOT partsupp.PS_SUPPKEY IN () + Subquery: + Projection: supplier.S_SUPPKEY + Filter: supplier.S_COMMENT LIKE CAST(Utf8("%Customer%Complaints%") AS Utf8) + TableScan: supplier + Cross Join: + TableScan: partsupp + TableScan: part + "# ); Ok(()) } @@ -366,25 +366,25 @@ mod tests { let plan_str = tpch_plan_to_string(18).await?; assert_snapshot!( plan_str, - @r#" - Projection: CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE, sum(LINEITEM.L_QUANTITY) AS EXPR$5 - Limit: skip=0, fetch=100 - Sort: ORDERS.O_TOTALPRICE DESC NULLS FIRST, ORDERS.O_ORDERDATE ASC NULLS LAST - Aggregate: groupBy=[[CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE]], aggr=[[sum(LINEITEM.L_QUANTITY)]] - Projection: CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE, LINEITEM.L_QUANTITY - Filter: ORDERS.O_ORDERKEY IN () AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY - Subquery: - Projection: LINEITEM.L_ORDERKEY - Filter: sum(LINEITEM.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2)) - Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]] - Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY - TableScan: LINEITEM - Cross Join: - Cross Join: - TableScan: CUSTOMER - TableScan: ORDERS - TableScan: LINEITEM - "# + @" + Projection: customer.C_NAME, customer.C_CUSTKEY, orders.O_ORDERKEY, orders.O_ORDERDATE, orders.O_TOTALPRICE, sum(lineitem.L_QUANTITY) AS EXPR$5 + Limit: skip=0, fetch=100 + Sort: orders.O_TOTALPRICE DESC NULLS FIRST, orders.O_ORDERDATE ASC NULLS LAST + Aggregate: groupBy=[[customer.C_NAME, customer.C_CUSTKEY, orders.O_ORDERKEY, orders.O_ORDERDATE, orders.O_TOTALPRICE]], aggr=[[sum(lineitem.L_QUANTITY)]] + Projection: customer.C_NAME, customer.C_CUSTKEY, orders.O_ORDERKEY, orders.O_ORDERDATE, orders.O_TOTALPRICE, lineitem.L_QUANTITY + Filter: orders.O_ORDERKEY IN () AND customer.C_CUSTKEY = orders.O_CUSTKEY AND orders.O_ORDERKEY = lineitem.L_ORDERKEY + Subquery: + Projection: lineitem.L_ORDERKEY + Filter: sum(lineitem.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2)) + Aggregate: groupBy=[[lineitem.L_ORDERKEY]], aggr=[[sum(lineitem.L_QUANTITY)]] + Projection: lineitem.L_ORDERKEY, lineitem.L_QUANTITY + TableScan: lineitem + Cross Join: + Cross Join: + TableScan: customer + TableScan: orders + TableScan: lineitem + " ); Ok(()) } @@ -394,13 +394,13 @@ mod tests { assert_snapshot!( plan_str, @r#" - Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]] - Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) - Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#12") AND (PART.P_CONTAINER = CAST(Utf8("SM CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("SM PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND (PART.P_CONTAINER = CAST(Utf8("MED BAG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PKG") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("MED PACK") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#34") AND (PART.P_CONTAINER = CAST(Utf8("LG CASE") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG BOX") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PACK") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8("LG PKG") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") - Cross Join: - TableScan: LINEITEM - TableScan: PART - "# + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.L_EXTENDEDPRICE * Int32(1) - lineitem.L_DISCOUNT) AS REVENUE]] + Projection: lineitem.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - lineitem.L_DISCOUNT) + Filter: part.P_PARTKEY = lineitem.L_PARTKEY AND part.P_BRAND = Utf8("Brand#12") AND (part.P_CONTAINER = CAST(Utf8("SM CASE") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("SM BOX") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("SM PACK") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("SM PKG") AS Utf8)) AND lineitem.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND lineitem.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND part.P_SIZE >= Int32(1) AND part.P_SIZE <= Int32(5) AND (lineitem.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND lineitem.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR part.P_PARTKEY = lineitem.L_PARTKEY AND part.P_BRAND = Utf8("Brand#23") AND (part.P_CONTAINER = CAST(Utf8("MED BAG") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("MED BOX") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("MED PKG") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("MED PACK") AS Utf8)) AND lineitem.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND lineitem.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND part.P_SIZE >= Int32(1) AND part.P_SIZE <= Int32(10) AND (lineitem.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND lineitem.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") OR part.P_PARTKEY = lineitem.L_PARTKEY AND part.P_BRAND = Utf8("Brand#34") AND (part.P_CONTAINER = CAST(Utf8("LG CASE") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("LG BOX") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("LG PACK") AS Utf8) OR part.P_CONTAINER = CAST(Utf8("LG PKG") AS Utf8)) AND lineitem.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND lineitem.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND part.P_SIZE >= Int32(1) AND part.P_SIZE <= Int32(15) AND (lineitem.L_SHIPMODE = CAST(Utf8("AIR") AS Utf8) OR lineitem.L_SHIPMODE = CAST(Utf8("AIR REG") AS Utf8)) AND lineitem.L_SHIPINSTRUCT = Utf8("DELIVER IN PERSON") + Cross Join: + TableScan: lineitem + TableScan: part + "# ); Ok(()) } @@ -411,27 +411,27 @@ mod tests { assert_snapshot!( plan_str, @r#" - Sort: SUPPLIER.S_NAME ASC NULLS LAST - Projection: SUPPLIER.S_NAME, SUPPLIER.S_ADDRESS - Filter: SUPPLIER.S_SUPPKEY IN () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("CANADA") - Subquery: - Projection: PARTSUPP.PS_SUPPKEY - Filter: PARTSUPP.PS_PARTKEY IN () AND CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) > () - Subquery: - Projection: PART.P_PARTKEY - Filter: PART.P_NAME LIKE CAST(Utf8("forest%") AS Utf8) - TableScan: PART - Subquery: - Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY) - Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]] - Projection: LINEITEM.L_QUANTITY - Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) - TableScan: LINEITEM - TableScan: PARTSUPP - Cross Join: - TableScan: SUPPLIER - TableScan: NATION - "# + Sort: supplier.S_NAME ASC NULLS LAST + Projection: supplier.S_NAME, supplier.S_ADDRESS + Filter: supplier.S_SUPPKEY IN () AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("CANADA") + Subquery: + Projection: partsupp.PS_SUPPKEY + Filter: partsupp.PS_PARTKEY IN () AND CAST(partsupp.PS_AVAILQTY AS Decimal128(19, 0)) > () + Subquery: + Projection: part.P_PARTKEY + Filter: part.P_NAME LIKE CAST(Utf8("forest%") AS Utf8) + TableScan: part + Subquery: + Projection: Decimal128(Some(5),2,1) * sum(lineitem.L_QUANTITY) + Aggregate: groupBy=[[]], aggr=[[sum(lineitem.L_QUANTITY)]] + Projection: lineitem.L_QUANTITY + Filter: lineitem.L_PARTKEY = lineitem.L_ORDERKEY AND lineitem.L_SUPPKEY = lineitem.L_PARTKEY AND lineitem.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND lineitem.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) + TableScan: lineitem + TableScan: partsupp + Cross Join: + TableScan: supplier + TableScan: nation + "# ); Ok(()) } @@ -442,25 +442,25 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT + Projection: supplier.S_NAME, count(Int64(1)) AS NUMWAIT Limit: skip=0, fetch=100 - Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST - Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]] - Projection: SUPPLIER.S_NAME - Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8("F") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("SAUDI ARABIA") + Sort: count(Int64(1)) DESC NULLS FIRST, supplier.S_NAME ASC NULLS LAST + Aggregate: groupBy=[[supplier.S_NAME]], aggr=[[count(Int64(1))]] + Projection: supplier.S_NAME + Filter: supplier.S_SUPPKEY = lineitem.L_SUPPKEY AND orders.O_ORDERKEY = lineitem.L_ORDERKEY AND orders.O_ORDERSTATUS = Utf8("F") AND lineitem.L_RECEIPTDATE > lineitem.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND supplier.S_NATIONKEY = nation.N_NATIONKEY AND nation.N_NAME = Utf8("SAUDI ARABIA") Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS - TableScan: LINEITEM + Filter: lineitem.L_ORDERKEY = lineitem.L_TAX AND lineitem.L_SUPPKEY != lineitem.L_LINESTATUS + TableScan: lineitem Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE - TableScan: LINEITEM + Filter: lineitem.L_ORDERKEY = lineitem.L_TAX AND lineitem.L_SUPPKEY != lineitem.L_LINESTATUS AND lineitem.L_RECEIPTDATE > lineitem.L_COMMITDATE + TableScan: lineitem Cross Join: Cross Join: Cross Join: - TableScan: SUPPLIER - TableScan: LINEITEM - TableScan: ORDERS - TableScan: NATION + TableScan: supplier + TableScan: lineitem + TableScan: orders + TableScan: nation "# ); Ok(()) @@ -472,20 +472,20 @@ mod tests { assert_snapshot!( plan_str, @r#" - Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL - Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST - Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]] - Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL - Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) AND CUSTOMER.C_ACCTBAL > () AND NOT EXISTS () + Projection: substr(customer.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(customer.C_ACCTBAL) AS TOTACCTBAL + Sort: substr(customer.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST + Aggregate: groupBy=[[substr(customer.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(customer.C_ACCTBAL)]] + Projection: substr(customer.C_PHONE, Int32(1), Int32(2)), customer.C_ACCTBAL + Filter: (substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) AND customer.C_ACCTBAL > () AND NOT EXISTS () Subquery: - Aggregate: groupBy=[[]], aggr=[[avg(CUSTOMER.C_ACCTBAL)]] - Projection: CUSTOMER.C_ACCTBAL - Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) - TableScan: CUSTOMER + Aggregate: groupBy=[[]], aggr=[[avg(customer.C_ACCTBAL)]] + Projection: customer.C_ACCTBAL + Filter: customer.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(customer.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) + TableScan: customer Subquery: - Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY - TableScan: ORDERS - TableScan: CUSTOMER + Filter: orders.O_CUSTKEY = orders.O_ORDERKEY + TableScan: orders + TableScan: customer "# ); Ok(()) @@ -639,11 +639,11 @@ mod tests { assert_snapshot!( plan_str, - @r#" - Projection: count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR - WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] - TableScan: DATA - "# + @" + Projection: count(Int64(1)) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR + WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: data + " ); Ok(()) } @@ -708,4 +708,73 @@ mod tests { Ok(()) } + + /// Verifies that uppercase table names from Substrait (e.g. Calcite/Isthmus) + /// are normalized to lowercase, matching DataFusion's default catalog behavior. + #[tokio::test] + async fn test_uppercase_table_name_resolves_to_lowercase() -> Result<()> { + // The simple_select plan references table "DATA" (uppercase) + let path = "tests/testdata/test_plans/simple_select.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = SessionContext::new(); + let schema = datafusion::arrow::datatypes::Schema::new(vec![ + datafusion::arrow::datatypes::Field::new( + "a", + datafusion::arrow::datatypes::DataType::Int32, + true, + ), + datafusion::arrow::datatypes::Field::new( + "b", + datafusion::arrow::datatypes::DataType::Int32, + true, + ), + ]); + let table = + datafusion::datasource::empty::EmptyTable::new(std::sync::Arc::new(schema)); + ctx.register_table("data", std::sync::Arc::new(table))?; + + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + let plan_str = format!("{plan}"); + + assert_snapshot!( + plan_str, + @r" + Projection: data.a, data.b + TableScan: data + " + ); + + ctx.state().create_physical_plan(&plan).await?; + Ok(()) + } + + /// Same as above but uses add_plan_schemas_to_ctx which also normalizes. + #[tokio::test] + async fn test_uppercase_table_name_with_plan_schemas() -> Result<()> { + let path = "tests/testdata/test_plans/simple_select.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + let plan_str = format!("{plan}"); + + assert_snapshot!( + plan_str, + @r" + Projection: data.a, data.b + TableScan: data + " + ); + + ctx.state().create_physical_plan(&plan).await?; + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index e916b4cb0e1a9..ad238093e8ce4 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -38,10 +38,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.A AS a, DATA.B AS b, DATA.A + Int64(1) AS add1 - TableScan: DATA - "# + @" + Projection: data.A AS a, data.B AS b, data.A + Int64(1) AS add1 + TableScan: data + " ); Ok(()) } @@ -57,11 +57,11 @@ mod tests { assert_snapshot!( plan, // Note that duplicate references in the remap are aliased - @r#" - Projection: DATA.B, DATA.A AS A1, DATA.A AS DATA.A__temp__0 AS A2 - Filter: DATA.B = Int64(2) - TableScan: DATA - "# + @" + Projection: data.B, data.A AS A1, data.A AS data.A__temp__0 AS A2 + Filter: data.B = Int64(2) + TableScan: data + " ); Ok(()) } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 426f3c12e5a15..4a3bb32e1e4d1 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -43,10 +43,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: NOT DATA.D AS EXPR$0 - TableScan: DATA - "# + @" + Projection: NOT data.D AS EXPR$0 + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -74,11 +74,11 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR - WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] - TableScan: DATA - "# + @" + Projection: sum(data.D) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR + WindowAggr: windowExpr=[[sum(data.D) PARTITION BY [data.PART] ORDER BY [data.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -101,11 +101,11 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW__temp__0 AS ALIASED - WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - TableScan: DATA - "# + @" + Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW__temp__0 AS ALIASED + WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -130,12 +130,12 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() PARTITION BY [DATA.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$1 - WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - WindowAggr: windowExpr=[[row_number() PARTITION BY [DATA.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - TableScan: DATA - "# + @" + Projection: row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$0, row_number() PARTITION BY [data.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS EXPR$1 + WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[row_number() PARTITION BY [data.A] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: data + " ); // Trigger execution to ensure plan validity @@ -164,18 +164,18 @@ mod tests { settings.bind(|| { assert_snapshot!( plan, - @r#" + @" Projection: left.A, left.[UUID] AS C, right.D, Utf8(NULL) AS [UUID] AS E Left Join: left.A = right.A SubqueryAlias: left Union - Projection: A.A, Utf8(NULL) AS [UUID] - TableScan: A - Projection: B.A, CAST(B.C AS Utf8) - TableScan: B + Projection: a.A, Utf8(NULL) AS [UUID] + TableScan: a + Projection: b.A, CAST(b.C AS Utf8) + TableScan: b SubqueryAlias: right - TableScan: C - "# + TableScan: c + " ); }); diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index c8cc3fe9940ce..4a95580cce932 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -34,7 +34,7 @@ mod tests { table_name: &str, fields: Vec<(&str, DataType, bool)>, ) -> Result { - let table_ref = TableReference::bare(table_name); + let table_ref = TableReference::parse_str(table_name); let fields: Vec<(Option, Arc)> = fields .into_iter() .map(|pair| { @@ -69,10 +69,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.a, DATA.b - TableScan: DATA - "# + @" + Projection: data.a, data.b + TableScan: data + " ); Ok(()) } @@ -92,10 +92,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.a, DATA.b - TableScan: DATA projection=[a, b] - "# + @" + Projection: data.a, data.b + TableScan: data projection=[a, b] + " ); Ok(()) } @@ -117,10 +117,10 @@ mod tests { assert_snapshot!( plan, - @r#" - Projection: DATA.a, DATA.b - TableScan: DATA projection=[a, b] - "# + @" + Projection: data.a, data.b + TableScan: data projection=[a, b] + " ); Ok(()) } diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index f84594312b634..db582f221148b 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -124,22 +124,10 @@ pub mod test { } fn collect_named_table(&mut self, read: &ReadRel, nt: &NamedTable) -> Result<()> { - let table_reference = match nt.names.len() { - 0 => { - panic!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, + let table_reference = if nt.names.is_empty() { + panic!("No table name found in NamedTable"); + } else { + TableReference::parse_str(&nt.names.join(".")) }; let substrait_schema =