@@ -54,6 +54,36 @@ const mysqlInListQuery = `/* name: FooByList :many */
5454SELECT a, b FROM foo WHERE foo.a IN (?, ?);
5555`
5656
57+ const starExpansionSeriesSchema = `
58+ CREATE TABLE alertreport (
59+ eventdate date
60+ );
61+ `
62+
63+ const starExpansionSeriesQuery = `-- name: CountAlertReportBy :many
64+ select DATE_TRUNC($1,ts)::text as datetime,coalesce(count,0) as count from
65+ (
66+ SELECT DATE_TRUNC($1,eventdate) as hr ,count(*)
67+ FROM alertreport
68+ where eventdate between $2 and $3
69+ GROUP BY 1
70+ ) AS cnt
71+ right outer join ( SELECT * FROM generate_series ( $2, $3, CONCAT('1 ',$1)::interval) AS ts ) as dte
72+ on DATE_TRUNC($1, ts ) = cnt.hr
73+ order by 1 asc;
74+ `
75+
76+ const lowerSwitchedOrderSchema = `
77+ CREATE TABLE foo (
78+ bar text not null,
79+ bat text not null
80+ );
81+ `
82+
83+ const lowerSwitchedOrderQuery = `-- name: LowerSwitchedOrder :many
84+ SELECT bar FROM foo WHERE bar = $1 AND bat = LOWER($2);
85+ `
86+
5787type stubAnalyzer struct {
5888 analyze func (context.Context , ast.Node , string , []string , * named.ParamSet ) (* analysispb.Analysis , error )
5989}
@@ -126,6 +156,66 @@ func newMySQLInListCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
126156 }, stmts [0 ].Raw
127157}
128158
159+ func newStarExpansionSeriesCompiler (t * testing.T ) (* Compiler , * ast.RawStmt ) {
160+ t .Helper ()
161+
162+ parser := postgresql .NewParser ()
163+ catalog := postgresql .NewCatalog ()
164+
165+ schema , err := parser .Parse (strings .NewReader (starExpansionSeriesSchema ))
166+ if err != nil {
167+ t .Fatal (err )
168+ }
169+ if err := catalog .Build (schema ); err != nil {
170+ t .Fatal (err )
171+ }
172+
173+ stmts , err := parser .Parse (strings .NewReader (starExpansionSeriesQuery ))
174+ if err != nil {
175+ t .Fatal (err )
176+ }
177+ if len (stmts ) != 1 {
178+ t .Fatalf ("expected 1 statement, got %d" , len (stmts ))
179+ }
180+
181+ return & Compiler {
182+ conf : config.SQL {Engine : config .EnginePostgreSQL },
183+ parser : parser ,
184+ catalog : catalog ,
185+ selector : newDefaultSelector (),
186+ }, stmts [0 ].Raw
187+ }
188+
189+ func newLowerSwitchedOrderCompiler (t * testing.T ) (* Compiler , * ast.RawStmt ) {
190+ t .Helper ()
191+
192+ parser := postgresql .NewParser ()
193+ catalog := postgresql .NewCatalog ()
194+
195+ schema , err := parser .Parse (strings .NewReader (lowerSwitchedOrderSchema ))
196+ if err != nil {
197+ t .Fatal (err )
198+ }
199+ if err := catalog .Build (schema ); err != nil {
200+ t .Fatal (err )
201+ }
202+
203+ stmts , err := parser .Parse (strings .NewReader (lowerSwitchedOrderQuery ))
204+ if err != nil {
205+ t .Fatal (err )
206+ }
207+ if len (stmts ) != 1 {
208+ t .Fatalf ("expected 1 statement, got %d" , len (stmts ))
209+ }
210+
211+ return & Compiler {
212+ conf : config.SQL {Engine : config .EnginePostgreSQL },
213+ parser : parser ,
214+ catalog : catalog ,
215+ selector : newDefaultSelector (),
216+ }, stmts [0 ].Raw
217+ }
218+
129219func assertBatchParameterNames (t * testing.T , params []Parameter ) {
130220 t .Helper ()
131221
@@ -168,6 +258,73 @@ func assertBatchParameterNames(t *testing.T, params []Parameter) {
168258 }
169259}
170260
261+ func assertStarExpansionSeriesParameterNames (t * testing.T , params []Parameter ) {
262+ t .Helper ()
263+
264+ checks := []struct {
265+ idx int
266+ number int
267+ name string
268+ typ string
269+ }{
270+ {idx : 0 , number : 1 , name : "date_trunc" , typ : "text" },
271+ {idx : 1 , number : 2 , name : "eventdate" , typ : "date" },
272+ {idx : 2 , number : 3 , name : "eventdate" , typ : "date" },
273+ }
274+ if len (params ) != len (checks ) {
275+ t .Fatalf ("expected %d params, got %d" , len (checks ), len (params ))
276+ }
277+
278+ for _ , check := range checks {
279+ param := params [check .idx ]
280+ if param .Number != check .number {
281+ t .Fatalf ("param %d number mismatch: got %d want %d" , check .idx , param .Number , check .number )
282+ }
283+ if param .Column == nil {
284+ t .Fatalf ("param %d column is nil" , check .idx )
285+ }
286+ if param .Column .Name != check .name {
287+ t .Fatalf ("param %d name mismatch: got %q want %q" , check .idx , param .Column .Name , check .name )
288+ }
289+ if param .Column .DataType != check .typ && param .Column .DataType != "pg_catalog." + check .typ {
290+ t .Fatalf ("param %d type mismatch: got %q want %q or %q" , check .idx , param .Column .DataType , check .typ , "pg_catalog." + check .typ )
291+ }
292+ }
293+ }
294+
295+ func assertLowerSwitchedOrderParams (t * testing.T , params []Parameter ) {
296+ t .Helper ()
297+
298+ checks := []struct {
299+ idx int
300+ number int
301+ name string
302+ typ string
303+ }{
304+ {idx : 0 , number : 1 , name : "bar" , typ : "text" },
305+ {idx : 1 , number : 2 , name : "lower" , typ : "text" },
306+ }
307+ if len (params ) != len (checks ) {
308+ t .Fatalf ("expected %d params, got %d" , len (checks ), len (params ))
309+ }
310+
311+ for _ , check := range checks {
312+ param := params [check .idx ]
313+ if param .Number != check .number {
314+ t .Fatalf ("param %d number mismatch: got %d want %d" , check .idx , param .Number , check .number )
315+ }
316+ if param .Column == nil {
317+ t .Fatalf ("param %d column is nil" , check .idx )
318+ }
319+ if param .Column .Name != check .name {
320+ t .Fatalf ("param %d name mismatch: got %q want %q" , check .idx , param .Column .Name , check .name )
321+ }
322+ if param .Column .DataType != check .typ && param .Column .DataType != "pg_catalog." + check .typ {
323+ t .Fatalf ("param %d type mismatch: got %q want %q or %q" , check .idx , param .Column .DataType , check .typ , "pg_catalog." + check .typ )
324+ }
325+ }
326+ }
327+
171328func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams (t * testing.T ) {
172329 t .Parallel ()
173330
@@ -247,3 +404,50 @@ func TestInferQueryPreservesDistinctMySQLInListParams(t *testing.T) {
247404 }
248405 }
249406}
407+
408+ func TestInferQueryPreservesStarExpansionSeriesParamNames (t * testing.T ) {
409+ t .Parallel ()
410+
411+ comp , raw := newStarExpansionSeriesCompiler (t )
412+ anlys , err := comp .inferQuery (raw , starExpansionSeriesQuery )
413+ if err != nil {
414+ t .Fatal (err )
415+ }
416+ if anlys == nil {
417+ t .Fatal ("expected non-nil analysis" )
418+ }
419+
420+ assertStarExpansionSeriesParameterNames (t , anlys .Parameters )
421+ }
422+
423+ func TestParseQueryManagedDBPreservesStarExpansionSeriesParamNames (t * testing.T ) {
424+ t .Parallel ()
425+
426+ comp , raw := newStarExpansionSeriesCompiler (t )
427+ comp .analyzer = stubAnalyzer {analyze : func (_ context.Context , _ ast.Node , _ string , _ []string , _ * named.ParamSet ) (* analysispb.Analysis , error ) {
428+ return & analysispb.Analysis {Params : []* analysispb.Parameter {
429+ {Number : 1 , Column : & analysispb.Column {DataType : "pg_catalog.text" }},
430+ {Number : 2 , Column : & analysispb.Column {DataType : "pg_catalog.date" }},
431+ {Number : 3 , Column : & analysispb.Column {DataType : "pg_catalog.date" }},
432+ }}, nil
433+ }}
434+
435+ query , err := comp .parseQuery (raw , starExpansionSeriesQuery , opts.Parser {})
436+ if err != nil {
437+ t .Fatal (err )
438+ }
439+
440+ assertStarExpansionSeriesParameterNames (t , query .Params )
441+ }
442+
443+ func TestParseQueryPreservesLowerSwitchedOrderParamTypes (t * testing.T ) {
444+ t .Parallel ()
445+
446+ comp , raw := newLowerSwitchedOrderCompiler (t )
447+ query , err := comp .parseQuery (raw , lowerSwitchedOrderQuery , opts.Parser {})
448+ if err != nil {
449+ t .Fatal (err )
450+ }
451+
452+ assertLowerSwitchedOrderParams (t , query .Params )
453+ }
0 commit comments