Skip to content

Commit 6210816

Browse files
committed
Add plan assert as well
1 parent ea8e467 commit 6210816

2 files changed

Lines changed: 44 additions & 40 deletions

File tree

datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
101101
}) => match (left.as_ref(), right.as_ref()) {
102102
(Expr::Column(l), Expr::Column(r)) => match op {
103103
Operator::Eq => eq_keys.push((l.clone(), r.clone())),
104-
Operator::IsNotDistinctFrom => indistinct_keys.push((l.clone(), r.clone())),
104+
Operator::IsNotDistinctFrom => {
105+
indistinct_keys.push((l.clone(), r.clone()))
106+
}
105107
_ => unreachable!(),
106108
},
107109
_ => accum_filters.push(expr.clone()),
@@ -110,26 +112,26 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
110112
}
111113
}
112114

113-
let (join_keys, null_equality) = match (eq_keys.is_empty(), indistinct_keys.is_empty())
114-
{
115-
// Mixed: use eq_keys as equijoin keys, demote indistinct keys to filter
116-
(false, false) => {
117-
for (l, r) in &indistinct_keys {
118-
accum_filters.push(Expr::BinaryExpr(BinaryExpr {
119-
left: Box::new(Expr::Column(l.clone())),
120-
op: Operator::IsNotDistinctFrom,
121-
right: Box::new(Expr::Column(r.clone())),
122-
}));
115+
let (join_keys, null_equality) =
116+
match (eq_keys.is_empty(), indistinct_keys.is_empty()) {
117+
// Mixed: use eq_keys as equijoin keys, demote indistinct keys to filter
118+
(false, false) => {
119+
for (l, r) in &indistinct_keys {
120+
accum_filters.push(Expr::BinaryExpr(BinaryExpr {
121+
left: Box::new(Expr::Column(l.clone())),
122+
op: Operator::IsNotDistinctFrom,
123+
right: Box::new(Expr::Column(r.clone())),
124+
}));
125+
}
126+
(eq_keys, NullEquality::NullEqualsNothing)
123127
}
124-
(eq_keys, NullEquality::NullEqualsNothing)
125-
}
126-
// Only eq keys
127-
(false, true) => (eq_keys, NullEquality::NullEqualsNothing),
128-
// Only indistinct keys
129-
(true, false) => (indistinct_keys, NullEquality::NullEqualsNull),
130-
// No keys at all
131-
(true, true) => (vec![], NullEquality::NullEqualsNothing),
132-
};
128+
// Only eq keys
129+
(false, true) => (eq_keys, NullEquality::NullEqualsNothing),
130+
// Only indistinct keys
131+
(true, false) => (indistinct_keys, NullEquality::NullEqualsNull),
132+
// No keys at all
133+
(true, true) => (vec![], NullEquality::NullEqualsNothing),
134+
};
133135

134136
let join_filter = accum_filters.into_iter().reduce(Expr::and);
135137
(join_keys, null_equality, join_filter)
@@ -224,7 +226,8 @@ mod tests {
224226
assert_eq!(null_eq, NullEquality::NullEqualsNothing);
225227

226228
// The IsNotDistinctFrom predicate should be demoted to the filter.
227-
let filter = filter.expect("filter should contain the demoted indistinct predicate");
229+
let filter =
230+
filter.expect("filter should contain the demoted indistinct predicate");
228231
match &filter {
229232
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
230233
assert_eq!(*op, Operator::IsNotDistinctFrom);

datafusion/substrait/tests/cases/consumer_integration.rs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -763,37 +763,38 @@ mod tests {
763763
Ok(())
764764
}
765765

766-
/// Regression test: a Substrait join expression containing both `equal` and
767-
/// `is_not_distinct_from` (as Spark can produce) must preserve the
768-
/// null-safe semantics of `IS NOT DISTINCT FROM` by demoting it to the
769-
/// join filter when mixed with regular equality keys.
770-
///
771-
/// The plan is loaded from a JSON-encoded Substrait protobuf to exercise
772-
/// the full consumer path (`from_substrait_plan` → `from_join_rel`).
766+
/// Regression: a Substrait join with both `equal` and `is_not_distinct_from`
767+
/// must demote `IS NOT DISTINCT FROM` to the join filter (matching the SQL
768+
/// planner behavior tested in `join_is_not_distinct_from.slt:179-205`).
773769
#[tokio::test]
774770
async fn test_mixed_join_equal_and_indistinct_from_substrait_plan() -> Result<()> {
771+
let plan_str =
772+
test_plan_to_string("mixed_join_equal_and_indistinct.json").await?;
773+
// Eq becomes the equijoin key; IS NOT DISTINCT FROM is demoted to filter.
774+
assert_snapshot!(
775+
plan_str,
776+
@r#"
777+
Projection: left.id, left.val, left.comment, right.id AS id0, right.val AS val0, right.comment AS comment0
778+
Inner Join: left.id = right.id Filter: left.val IS NOT DISTINCT FROM right.val
779+
SubqueryAlias: left
780+
Values: (Utf8("1"), Utf8("a"), Utf8("c1")), (Utf8("2"), Utf8("b"), Utf8("c2")), (Utf8("3"), Utf8(NULL), Utf8("c3")), (Utf8("4"), Utf8(NULL), Utf8("c4")), (Utf8("5"), Utf8("e"), Utf8("c5"))...
781+
SubqueryAlias: right
782+
Values: (Utf8("1"), Utf8("a"), Utf8("c1")), (Utf8("2"), Utf8("b"), Utf8("c2")), (Utf8("3"), Utf8(NULL), Utf8("c3")), (Utf8("4"), Utf8(NULL), Utf8("c4")), (Utf8("5"), Utf8("e"), Utf8("c5"))...
783+
"#
784+
);
785+
786+
// Also execute to verify NULL=NULL rows (ids 3,4) are preserved.
775787
let path = "tests/testdata/test_plans/mixed_join_equal_and_indistinct.json";
776788
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
777789
File::open(path).expect("file not found"),
778790
))
779791
.expect("failed to parse json");
780-
781792
let ctx = SessionContext::new();
782793
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
783-
784-
// Execute and count rows.
785-
// Both tables have 6 identical rows; rows 3 and 4 have val=NULL.
786-
// With correct handling, IS NOT DISTINCT FROM is demoted to the join
787-
// filter, so NULL=NULL matches and all 6 rows appear in the output.
788794
let df = ctx.execute_logical_plan(plan).await?;
789795
let results = df.collect().await?;
790796
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
791-
792-
assert_eq!(
793-
total_rows, 6,
794-
"Expected 6 rows (including NULL=NULL matches via IS NOT DISTINCT FROM), \
795-
got {total_rows}. Mixed equal/is_not_distinct_from lost null-safe semantics."
796-
);
797+
assert_eq!(total_rows, 6);
797798

798799
Ok(())
799800
}

0 commit comments

Comments
 (0)