@@ -32,7 +32,7 @@ use datafusion_expr::{Expr, ExpressionPlacement, Projection};
3232
3333use crate :: optimizer:: ApplyOrder ;
3434use crate :: push_down_filter:: replace_cols_by_name;
35- use crate :: utils:: has_all_column_refs;
35+ use crate :: utils:: { ColumnReference , has_all_column_refs, schema_columns } ;
3636use crate :: { OptimizerConfig , OptimizerRule } ;
3737
3838/// Prefix for aliases generated by the extraction optimizer passes.
@@ -213,10 +213,11 @@ fn extract_from_plan(
213213 . collect ( ) ;
214214
215215 // Build per-input column sets for routing expressions to the correct input
216- let input_column_sets: Vec < std:: collections:: HashSet < Column > > = input_schemas
217- . iter ( )
218- . map ( |schema| schema_columns ( schema. as_ref ( ) ) )
219- . collect ( ) ;
216+ let input_column_sets: Vec < std:: collections:: HashSet < ColumnReference > > =
217+ input_schemas
218+ . iter ( )
219+ . map ( |schema| schema_columns ( schema. as_ref ( ) ) )
220+ . collect ( ) ;
220221
221222 // Transform expressions via map_expressions with routing
222223 let transformed = plan. map_expressions ( |expr| {
@@ -272,7 +273,7 @@ fn extract_from_plan(
272273/// in both sides of a join).
273274fn find_owning_input (
274275 expr : & Expr ,
275- input_column_sets : & [ std:: collections:: HashSet < Column > ] ,
276+ input_column_sets : & [ std:: collections:: HashSet < ColumnReference > ] ,
276277) -> Option < usize > {
277278 let mut found = None ;
278279 for ( idx, cols) in input_column_sets. iter ( ) . enumerate ( ) {
@@ -292,7 +293,7 @@ fn find_owning_input(
292293fn routing_extract (
293294 expr : Expr ,
294295 extractors : & mut [ LeafExpressionExtractor ] ,
295- input_column_sets : & [ std:: collections:: HashSet < Column > ] ,
296+ input_column_sets : & [ std:: collections:: HashSet < ColumnReference > ] ,
296297) -> Result < Transformed < Expr > > {
297298 expr. transform_down ( |e| {
298299 // Skip expressions already aliased with extracted expression pattern
@@ -340,19 +341,6 @@ fn routing_extract(
340341 } )
341342}
342343
343- /// Returns all columns in the schema (both qualified and unqualified forms)
344- fn schema_columns ( schema : & DFSchema ) -> std:: collections:: HashSet < Column > {
345- schema
346- . iter ( )
347- . flat_map ( |( qualifier, field) | {
348- [
349- Column :: new ( qualifier. cloned ( ) , field. name ( ) ) ,
350- Column :: new_unqualified ( field. name ( ) ) ,
351- ]
352- } )
353- . collect ( )
354- }
355-
356344/// Rewrites extraction pairs and column references from one qualifier
357345/// space to another.
358346///
@@ -1072,7 +1060,7 @@ fn route_to_inputs(
10721060 pairs : & [ ( Expr , String ) ] ,
10731061 columns : & IndexSet < Column > ,
10741062 node : & LogicalPlan ,
1075- input_column_sets : & [ std:: collections:: HashSet < Column > ] ,
1063+ input_column_sets : & [ std:: collections:: HashSet < ColumnReference > ] ,
10761064 input_schemas : & [ Arc < DFSchema > ] ,
10771065) -> Result < Option < Vec < ExtractionTarget > > > {
10781066 let num_inputs = input_schemas. len ( ) ;
@@ -1173,7 +1161,7 @@ fn try_push_into_inputs(
11731161 // Build per-input schemas and column sets for routing
11741162 let input_schemas: Vec < Arc < DFSchema > > =
11751163 inputs. iter ( ) . map ( |i| Arc :: clone ( i. schema ( ) ) ) . collect ( ) ;
1176- let input_column_sets: Vec < std:: collections:: HashSet < Column > > =
1164+ let input_column_sets: Vec < std:: collections:: HashSet < ColumnReference > > =
11771165 input_schemas. iter ( ) . map ( |s| schema_columns ( s) ) . collect ( ) ;
11781166
11791167 // Route pairs and columns to the appropriate inputs
@@ -2436,16 +2424,18 @@ mod tests {
24362424 // Simulate schema_columns output for two sides of a join where both
24372425 // have a "user" column — each set contains the qualified and
24382426 // unqualified form.
2439- let left_cols: HashSet < Column > = [
2440- Column :: new ( Some ( "test" ) , "user" ) ,
2441- Column :: new_unqualified ( "user" ) ,
2427+ let relation = "test" . into ( ) ;
2428+ let left_cols: HashSet < ColumnReference > = [
2429+ ColumnReference :: new ( Some ( & relation) , "user" ) ,
2430+ ColumnReference :: new_unqualified ( "user" ) ,
24422431 ]
24432432 . into_iter ( )
24442433 . collect ( ) ;
24452434
2446- let right_cols: HashSet < Column > = [
2447- Column :: new ( Some ( "right" ) , "user" ) ,
2448- Column :: new_unqualified ( "user" ) ,
2435+ let relation = "right" . into ( ) ;
2436+ let right_cols: HashSet < ColumnReference > = [
2437+ ColumnReference :: new ( Some ( & relation) , "user" ) ,
2438+ ColumnReference :: new_unqualified ( "user" ) ,
24492439 ]
24502440 . into_iter ( )
24512441 . collect ( ) ;
0 commit comments