Skip to content

Commit 48dcbdf

Browse files
committed
fix: type bugs
1 parent e5dfa2b commit 48dcbdf

File tree

3 files changed

+365
-44
lines changed

3 files changed

+365
-44
lines changed

internal/compiler/parse_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,36 @@ const mysqlInListQuery = `/* name: FooByList :many */
5454
SELECT a, b FROM foo WHERE foo.a IN (?, ?);
5555
`
5656

57+
const starExpansionSeriesSchema = `
58+
CREATE TABLE alertreport (
59+
eventdate date
60+
);
61+
`
62+
63+
const starExpansionSeriesQuery = `-- name: CountAlertReportBy :many
64+
select DATE_TRUNC($1,ts)::text as datetime,coalesce(count,0) as count from
65+
(
66+
SELECT DATE_TRUNC($1,eventdate) as hr ,count(*)
67+
FROM alertreport
68+
where eventdate between $2 and $3
69+
GROUP BY 1
70+
) AS cnt
71+
right outer join ( SELECT * FROM generate_series ( $2, $3, CONCAT('1 ',$1)::interval) AS ts ) as dte
72+
on DATE_TRUNC($1, ts ) = cnt.hr
73+
order by 1 asc;
74+
`
75+
76+
const lowerSwitchedOrderSchema = `
77+
CREATE TABLE foo (
78+
bar text not null,
79+
bat text not null
80+
);
81+
`
82+
83+
const lowerSwitchedOrderQuery = `-- name: LowerSwitchedOrder :many
84+
SELECT bar FROM foo WHERE bar = $1 AND bat = LOWER($2);
85+
`
86+
5787
type stubAnalyzer struct {
5888
analyze func(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysispb.Analysis, error)
5989
}
@@ -126,6 +156,66 @@ func newMySQLInListCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
126156
}, stmts[0].Raw
127157
}
128158

159+
func newStarExpansionSeriesCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
160+
t.Helper()
161+
162+
parser := postgresql.NewParser()
163+
catalog := postgresql.NewCatalog()
164+
165+
schema, err := parser.Parse(strings.NewReader(starExpansionSeriesSchema))
166+
if err != nil {
167+
t.Fatal(err)
168+
}
169+
if err := catalog.Build(schema); err != nil {
170+
t.Fatal(err)
171+
}
172+
173+
stmts, err := parser.Parse(strings.NewReader(starExpansionSeriesQuery))
174+
if err != nil {
175+
t.Fatal(err)
176+
}
177+
if len(stmts) != 1 {
178+
t.Fatalf("expected 1 statement, got %d", len(stmts))
179+
}
180+
181+
return &Compiler{
182+
conf: config.SQL{Engine: config.EnginePostgreSQL},
183+
parser: parser,
184+
catalog: catalog,
185+
selector: newDefaultSelector(),
186+
}, stmts[0].Raw
187+
}
188+
189+
func newLowerSwitchedOrderCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
190+
t.Helper()
191+
192+
parser := postgresql.NewParser()
193+
catalog := postgresql.NewCatalog()
194+
195+
schema, err := parser.Parse(strings.NewReader(lowerSwitchedOrderSchema))
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
if err := catalog.Build(schema); err != nil {
200+
t.Fatal(err)
201+
}
202+
203+
stmts, err := parser.Parse(strings.NewReader(lowerSwitchedOrderQuery))
204+
if err != nil {
205+
t.Fatal(err)
206+
}
207+
if len(stmts) != 1 {
208+
t.Fatalf("expected 1 statement, got %d", len(stmts))
209+
}
210+
211+
return &Compiler{
212+
conf: config.SQL{Engine: config.EnginePostgreSQL},
213+
parser: parser,
214+
catalog: catalog,
215+
selector: newDefaultSelector(),
216+
}, stmts[0].Raw
217+
}
218+
129219
func assertBatchParameterNames(t *testing.T, params []Parameter) {
130220
t.Helper()
131221

@@ -168,6 +258,73 @@ func assertBatchParameterNames(t *testing.T, params []Parameter) {
168258
}
169259
}
170260

261+
func assertStarExpansionSeriesParameterNames(t *testing.T, params []Parameter) {
262+
t.Helper()
263+
264+
checks := []struct {
265+
idx int
266+
number int
267+
name string
268+
typ string
269+
}{
270+
{idx: 0, number: 1, name: "date_trunc", typ: "text"},
271+
{idx: 1, number: 2, name: "eventdate", typ: "date"},
272+
{idx: 2, number: 3, name: "eventdate", typ: "date"},
273+
}
274+
if len(params) != len(checks) {
275+
t.Fatalf("expected %d params, got %d", len(checks), len(params))
276+
}
277+
278+
for _, check := range checks {
279+
param := params[check.idx]
280+
if param.Number != check.number {
281+
t.Fatalf("param %d number mismatch: got %d want %d", check.idx, param.Number, check.number)
282+
}
283+
if param.Column == nil {
284+
t.Fatalf("param %d column is nil", check.idx)
285+
}
286+
if param.Column.Name != check.name {
287+
t.Fatalf("param %d name mismatch: got %q want %q", check.idx, param.Column.Name, check.name)
288+
}
289+
if param.Column.DataType != check.typ && param.Column.DataType != "pg_catalog."+check.typ {
290+
t.Fatalf("param %d type mismatch: got %q want %q or %q", check.idx, param.Column.DataType, check.typ, "pg_catalog."+check.typ)
291+
}
292+
}
293+
}
294+
295+
func assertLowerSwitchedOrderParams(t *testing.T, params []Parameter) {
296+
t.Helper()
297+
298+
checks := []struct {
299+
idx int
300+
number int
301+
name string
302+
typ string
303+
}{
304+
{idx: 0, number: 1, name: "bar", typ: "text"},
305+
{idx: 1, number: 2, name: "lower", typ: "text"},
306+
}
307+
if len(params) != len(checks) {
308+
t.Fatalf("expected %d params, got %d", len(checks), len(params))
309+
}
310+
311+
for _, check := range checks {
312+
param := params[check.idx]
313+
if param.Number != check.number {
314+
t.Fatalf("param %d number mismatch: got %d want %d", check.idx, param.Number, check.number)
315+
}
316+
if param.Column == nil {
317+
t.Fatalf("param %d column is nil", check.idx)
318+
}
319+
if param.Column.Name != check.name {
320+
t.Fatalf("param %d name mismatch: got %q want %q", check.idx, param.Column.Name, check.name)
321+
}
322+
if param.Column.DataType != check.typ && param.Column.DataType != "pg_catalog."+check.typ {
323+
t.Fatalf("param %d type mismatch: got %q want %q or %q", check.idx, param.Column.DataType, check.typ, "pg_catalog."+check.typ)
324+
}
325+
}
326+
}
327+
171328
func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams(t *testing.T) {
172329
t.Parallel()
173330

@@ -247,3 +404,50 @@ func TestInferQueryPreservesDistinctMySQLInListParams(t *testing.T) {
247404
}
248405
}
249406
}
407+
408+
func TestInferQueryPreservesStarExpansionSeriesParamNames(t *testing.T) {
409+
t.Parallel()
410+
411+
comp, raw := newStarExpansionSeriesCompiler(t)
412+
anlys, err := comp.inferQuery(raw, starExpansionSeriesQuery)
413+
if err != nil {
414+
t.Fatal(err)
415+
}
416+
if anlys == nil {
417+
t.Fatal("expected non-nil analysis")
418+
}
419+
420+
assertStarExpansionSeriesParameterNames(t, anlys.Parameters)
421+
}
422+
423+
func TestParseQueryManagedDBPreservesStarExpansionSeriesParamNames(t *testing.T) {
424+
t.Parallel()
425+
426+
comp, raw := newStarExpansionSeriesCompiler(t)
427+
comp.analyzer = stubAnalyzer{analyze: func(_ context.Context, _ ast.Node, _ string, _ []string, _ *named.ParamSet) (*analysispb.Analysis, error) {
428+
return &analysispb.Analysis{Params: []*analysispb.Parameter{
429+
{Number: 1, Column: &analysispb.Column{DataType: "pg_catalog.text"}},
430+
{Number: 2, Column: &analysispb.Column{DataType: "pg_catalog.date"}},
431+
{Number: 3, Column: &analysispb.Column{DataType: "pg_catalog.date"}},
432+
}}, nil
433+
}}
434+
435+
query, err := comp.parseQuery(raw, starExpansionSeriesQuery, opts.Parser{})
436+
if err != nil {
437+
t.Fatal(err)
438+
}
439+
440+
assertStarExpansionSeriesParameterNames(t, query.Params)
441+
}
442+
443+
func TestParseQueryPreservesLowerSwitchedOrderParamTypes(t *testing.T) {
444+
t.Parallel()
445+
446+
comp, raw := newLowerSwitchedOrderCompiler(t)
447+
query, err := comp.parseQuery(raw, lowerSwitchedOrderQuery, opts.Parser{})
448+
if err != nil {
449+
t.Fatal(err)
450+
}
451+
452+
assertLowerSwitchedOrderParams(t, query.Params)
453+
}

internal/compiler/resolve.go

Lines changed: 108 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,107 @@ func (comp *Compiler) incompatibleParamRefError(ref paramRef, existing, incoming
104104
}
105105
}
106106

107+
func sameTypeName(a, b *ast.TypeName) bool {
108+
if a == nil || b == nil {
109+
return a == nil && b == nil
110+
}
111+
return a.Catalog == b.Catalog &&
112+
a.Schema == b.Schema &&
113+
a.Name == b.Name &&
114+
arrayDims(a) == arrayDims(b)
115+
}
116+
117+
func isPolymorphicTypeName(t *ast.TypeName) bool {
118+
if t == nil {
119+
return false
120+
}
121+
122+
switch strings.ToLower(t.Name) {
123+
case "any",
124+
"anyarray",
125+
"anycompatible",
126+
"anycompatiblearray",
127+
"anycompatiblemultirange",
128+
"anycompatiblenonarray",
129+
"anycompatiblerange",
130+
"anyelement",
131+
"anyenum",
132+
"anymultirange",
133+
"anynonarray",
134+
"anyrange":
135+
return true
136+
default:
137+
return false
138+
}
139+
}
140+
141+
func funcCallArg(fn catalog.Function, idx int, namedArg string) *catalog.Argument {
142+
args := fn.InArgs()
143+
if namedArg != "" {
144+
for _, arg := range args {
145+
if arg.Name == namedArg {
146+
return arg
147+
}
148+
}
149+
return nil
150+
}
151+
if idx < 0 || idx >= len(args) {
152+
return nil
153+
}
154+
return args[idx]
155+
}
156+
157+
func funcCallArgMetadata(funcs []catalog.Function, idx int, namedArg string) (string, *ast.TypeName) {
158+
var (
159+
found bool
160+
name string
161+
nameConsistent = true
162+
typ *ast.TypeName
163+
typeConsistent = true
164+
concreteType *ast.TypeName
165+
concreteFound bool
166+
concreteAgree = true
167+
)
168+
169+
for _, fn := range funcs {
170+
arg := funcCallArg(fn, idx, namedArg)
171+
if arg == nil {
172+
continue
173+
}
174+
if !found {
175+
found = true
176+
name = arg.Name
177+
typ = arg.Type
178+
} else if name != arg.Name {
179+
nameConsistent = false
180+
}
181+
if found && !sameTypeName(typ, arg.Type) {
182+
typeConsistent = false
183+
}
184+
if !isPolymorphicTypeName(arg.Type) {
185+
if !concreteFound {
186+
concreteType = arg.Type
187+
concreteFound = true
188+
} else if !sameTypeName(concreteType, arg.Type) {
189+
concreteAgree = false
190+
}
191+
}
192+
}
193+
194+
if !found {
195+
return "", nil
196+
}
197+
if !nameConsistent {
198+
name = ""
199+
}
200+
if concreteFound && concreteAgree {
201+
typ = concreteType
202+
} else if !typeConsistent {
203+
typ = nil
204+
}
205+
return name, typ
206+
}
207+
107208
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
108209
c := comp.catalog
109210

@@ -519,7 +620,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
519620
}
520621

521622
case *ast.FuncCall:
522-
fun, err := c.ResolveFuncCall(n)
623+
funcs, err := c.ResolveFuncCalls(n)
624+
var fun *catalog.Function
523625
if err != nil {
524626
// Synthesize a function on the fly to avoid returning with an error
525627
// for an unknown Postgres function (e.g. defined in an extension)
@@ -534,6 +636,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
534636
Args: args,
535637
ReturnType: &ast.TypeName{Name: "any"},
536638
}
639+
funcs = []catalog.Function{*fun}
640+
} else {
641+
fun = &funcs[0]
537642
}
538643

539644
var added bool
@@ -592,24 +697,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
592697
continue
593698
}
594699

595-
var paramName string
596-
var paramType *ast.TypeName
597-
598-
if argName == "" {
599-
if i < len(fun.Args) {
600-
paramName = fun.Args[i].Name
601-
paramType = fun.Args[i].Type
602-
}
603-
} else {
700+
paramName, paramType := funcCallArgMetadata(funcs, i, argName)
701+
if argName != "" {
604702
paramName = argName
605-
for _, arg := range fun.Args {
606-
if arg.Name == argName {
607-
paramType = arg.Type
608-
}
609-
}
610-
if paramType == nil {
611-
panic(fmt.Sprintf("named argument %s has no type", paramName))
612-
}
613703
}
614704
if paramName == "" {
615705
paramName = funcName

0 commit comments

Comments
 (0)