From e028a0318c0d6cda78cb6c58929d0df0c9cbce7d Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Tue, 31 Mar 2026 22:25:34 +0530 Subject: [PATCH 1/6] feat(parser): add SQL Server PIVOT/UNPIVOT clause parsing (#456) Add support for SQL Server and Oracle PIVOT/UNPIVOT operators in FROM clauses. PIVOT transforms rows to columns via an aggregate function, while UNPIVOT performs the reverse column-to-row transformation. - Add PivotClause and UnpivotClause AST node types - Add Pivot/Unpivot fields to TableReference struct - Implement parsePivotClause/parseUnpivotClause in new pivot.go - Wire parsing into parseFromTableReference and parseJoinedTableRef - Add PIVOT/UNPIVOT to tokenizer keyword map for correct token typing - Update formatter to render PIVOT/UNPIVOT clauses - Enable testdata/mssql/11_pivot.sql and 12_unpivot.sql - Add 4 dedicated tests covering subquery+alias, plain table, AS alias Closes #456 Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/formatter/render.go | 18 +++ pkg/sql/ast/ast.go | 52 ++++++++- pkg/sql/parser/pivot.go | 184 ++++++++++++++++++++++++++++++ pkg/sql/parser/select_subquery.go | 78 +++++++++++++ pkg/sql/parser/tsql_test.go | 141 ++++++++++++++++++++++- pkg/sql/tokenizer/tokenizer.go | 3 + 6 files changed, 473 insertions(+), 3 deletions(-) create mode 100644 pkg/sql/parser/pivot.go diff --git a/pkg/formatter/render.go b/pkg/formatter/render.go index 77baa0f5..c1ea1501 100644 --- a/pkg/formatter/render.go +++ b/pkg/formatter/render.go @@ -1186,6 +1186,24 @@ func tableRefSQL(t *ast.TableReference) string { } else { sb.WriteString(t.Name) } + if t.Pivot != nil { + sb.WriteString(" PIVOT (") + sb.WriteString(exprSQL(t.Pivot.AggregateFunction)) + sb.WriteString(" FOR ") + sb.WriteString(t.Pivot.PivotColumn) + sb.WriteString(" IN (") + sb.WriteString(strings.Join(t.Pivot.InValues, ", ")) + sb.WriteString("))") + } + if t.Unpivot != nil { + sb.WriteString(" UNPIVOT (") + sb.WriteString(t.Unpivot.ValueColumn) + sb.WriteString(" FOR ") + sb.WriteString(t.Unpivot.NameColumn) + sb.WriteString(" IN (") + sb.WriteString(strings.Join(t.Unpivot.InColumns, ", ")) + sb.WriteString("))") + } if t.Alias != "" { sb.WriteString(" ") sb.WriteString(t.Alias) diff --git a/pkg/sql/ast/ast.go b/pkg/sql/ast/ast.go index a1b43060..c050427b 100644 --- a/pkg/sql/ast/ast.go +++ b/pkg/sql/ast/ast.go @@ -231,6 +231,12 @@ type TableReference struct { // ForSystemTime is the MariaDB temporal table clause (10.3.4+). // Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01' ForSystemTime *ForSystemTimeClause // MariaDB temporal query + // Pivot is the SQL Server / Oracle PIVOT clause for row-to-column transformation. + // Example: SELECT * FROM t PIVOT (SUM(sales) FOR region IN ([North], [South])) AS pvt + Pivot *PivotClause + // Unpivot is the SQL Server / Oracle UNPIVOT clause for column-to-row transformation. + // Example: SELECT * FROM t UNPIVOT (sales FOR region IN (north_sales, south_sales)) AS unpvt + Unpivot *UnpivotClause } func (t *TableReference) statementNode() {} @@ -244,10 +250,17 @@ func (t TableReference) TokenLiteral() string { return "subquery" } func (t TableReference) Children() []Node { + var nodes []Node if t.Subquery != nil { - return []Node{t.Subquery} + nodes = append(nodes, t.Subquery) } - return nil + if t.Pivot != nil { + nodes = append(nodes, t.Pivot) + } + if t.Unpivot != nil { + nodes = append(nodes, t.Unpivot) + } + return nodes } // OrderByExpression represents an ORDER BY clause element with direction and NULL ordering @@ -1969,6 +1982,41 @@ func (c ForSystemTimeClause) Children() []Node { return nodes } +// PivotClause represents the SQL Server / Oracle PIVOT operator for row-to-column +// transformation in a FROM clause. +// +// PIVOT (SUM(sales) FOR region IN ([North], [South], [East], [West])) AS pvt +type PivotClause struct { + AggregateFunction Expression // The aggregate function, e.g. SUM(sales) + PivotColumn string // The column used for pivoting, e.g. region + InValues []string // The values to pivot on, e.g. [North], [South] + Pos models.Location // Source position of the PIVOT keyword +} + +func (p *PivotClause) expressionNode() {} +func (p PivotClause) TokenLiteral() string { return "PIVOT" } +func (p PivotClause) Children() []Node { + if p.AggregateFunction != nil { + return []Node{p.AggregateFunction} + } + return nil +} + +// UnpivotClause represents the SQL Server / Oracle UNPIVOT operator for column-to-row +// transformation in a FROM clause. +// +// UNPIVOT (sales FOR region IN (north_sales, south_sales, east_sales)) AS unpvt +type UnpivotClause struct { + ValueColumn string // The target value column, e.g. sales + NameColumn string // The target name column, e.g. region + InColumns []string // The source columns to unpivot, e.g. north_sales, south_sales + Pos models.Location // Source position of the UNPIVOT keyword +} + +func (u *UnpivotClause) expressionNode() {} +func (u UnpivotClause) TokenLiteral() string { return "UNPIVOT" } +func (u UnpivotClause) Children() []Node { return nil } + // PeriodDefinition represents a PERIOD FOR clause in CREATE TABLE. // // PERIOD FOR app_time (start_col, end_col) diff --git a/pkg/sql/parser/pivot.go b/pkg/sql/parser/pivot.go new file mode 100644 index 00000000..646d6da8 --- /dev/null +++ b/pkg/sql/parser/pivot.go @@ -0,0 +1,184 @@ +// Copyright 2026 GoSQLX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package parser - pivot.go +// SQL Server / Oracle PIVOT and UNPIVOT clause parsing. + +package parser + +import ( + "strings" + + "github.com/ajitpratap0/GoSQLX/pkg/models" + "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" +) + +// isPivotKeyword returns true if the current token is the PIVOT keyword. +func (p *Parser) isPivotKeyword() bool { + return p.isType(models.TokenTypeKeyword) && + strings.EqualFold(p.currentToken.Token.Value, "PIVOT") +} + +// isUnpivotKeyword returns true if the current token is the UNPIVOT keyword. +func (p *Parser) isUnpivotKeyword() bool { + return p.isType(models.TokenTypeKeyword) && + strings.EqualFold(p.currentToken.Token.Value, "UNPIVOT") +} + +// parsePivotClause parses PIVOT (aggregate FOR column IN (values)). +// The current token must be the PIVOT keyword. +func (p *Parser) parsePivotClause() (*ast.PivotClause, error) { + pos := p.currentLocation() + p.advance() // consume PIVOT + + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after PIVOT") + } + p.advance() // consume ( + + // Parse aggregate function expression (e.g. SUM(sales)) + aggFunc, err := p.parseExpression() + if err != nil { + return nil, err + } + + // Expect FOR keyword + if !p.isType(models.TokenTypeFor) { + return nil, p.expectedError("FOR in PIVOT clause") + } + p.advance() // consume FOR + + // Parse pivot column name + if !p.isIdentifier() { + return nil, p.expectedError("column name after FOR in PIVOT") + } + pivotCol := p.currentToken.Token.Value + p.advance() + + // Expect IN keyword + if !p.isType(models.TokenTypeIn) { + return nil, p.expectedError("IN in PIVOT clause") + } + p.advance() // consume IN + + // Expect opening parenthesis for value list + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after IN in PIVOT") + } + p.advance() // consume ( + + // Parse IN values — identifiers (possibly bracket-quoted in SQL Server) + var inValues []string + for !p.isType(models.TokenTypeRParen) && !p.isType(models.TokenTypeEOF) { + if !p.isIdentifier() && !p.isType(models.TokenTypeNumber) && !p.isStringLiteral() { + return nil, p.expectedError("value in PIVOT IN list") + } + inValues = append(inValues, p.currentToken.Token.Value) + p.advance() + if p.isType(models.TokenTypeComma) { + p.advance() + } + } + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close PIVOT IN list") + } + p.advance() // close IN list ) + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close PIVOT clause") + } + p.advance() // close PIVOT ) + + return &ast.PivotClause{ + AggregateFunction: aggFunc, + PivotColumn: pivotCol, + InValues: inValues, + Pos: pos, + }, nil +} + +// parseUnpivotClause parses UNPIVOT (value_col FOR name_col IN (columns)). +// The current token must be the UNPIVOT keyword. +func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) { + pos := p.currentLocation() + p.advance() // consume UNPIVOT + + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after UNPIVOT") + } + p.advance() // consume ( + + // Parse value column name + if !p.isIdentifier() { + return nil, p.expectedError("value column name in UNPIVOT") + } + valueCol := p.currentToken.Token.Value + p.advance() + + // Expect FOR keyword + if !p.isType(models.TokenTypeFor) { + return nil, p.expectedError("FOR in UNPIVOT clause") + } + p.advance() // consume FOR + + // Parse name column + if !p.isIdentifier() { + return nil, p.expectedError("name column after FOR in UNPIVOT") + } + nameCol := p.currentToken.Token.Value + p.advance() + + // Expect IN keyword + if !p.isType(models.TokenTypeIn) { + return nil, p.expectedError("IN in UNPIVOT clause") + } + p.advance() // consume IN + + // Expect opening parenthesis for column list + if !p.isType(models.TokenTypeLParen) { + return nil, p.expectedError("( after IN in UNPIVOT") + } + p.advance() // consume ( + + // Parse IN columns + var cols []string + for !p.isType(models.TokenTypeRParen) && !p.isType(models.TokenTypeEOF) { + if !p.isIdentifier() { + return nil, p.expectedError("column name in UNPIVOT IN list") + } + cols = append(cols, p.currentToken.Token.Value) + p.advance() + if p.isType(models.TokenTypeComma) { + p.advance() + } + } + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close UNPIVOT IN list") + } + p.advance() // close IN list ) + + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close UNPIVOT clause") + } + p.advance() // close UNPIVOT ) + + return &ast.UnpivotClause{ + ValueColumn: valueCol, + NameColumn: nameCol, + InColumns: cols, + Pos: pos, + }, nil +} diff --git a/pkg/sql/parser/select_subquery.go b/pkg/sql/parser/select_subquery.go index f1eff0ae..af20082d 100644 --- a/pkg/sql/parser/select_subquery.go +++ b/pkg/sql/parser/select_subquery.go @@ -125,6 +125,46 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { } } + // SQL Server / Oracle PIVOT clause + if p.isPivotKeyword() { + pivot, err := p.parsePivotClause() + if err != nil { + return tableRef, err + } + tableRef.Pivot = pivot + // PIVOT result often has its own alias: PIVOT (...) AS pvt + if p.isType(models.TokenTypeAs) { + p.advance() // consume AS + if p.isIdentifier() { + tableRef.Alias = p.currentToken.Token.Value + p.advance() + } + } else if p.isIdentifier() { + tableRef.Alias = p.currentToken.Token.Value + p.advance() + } + } + + // SQL Server / Oracle UNPIVOT clause + if p.isUnpivotKeyword() { + unpivot, err := p.parseUnpivotClause() + if err != nil { + return tableRef, err + } + tableRef.Unpivot = unpivot + // UNPIVOT result alias: UNPIVOT (...) AS unpvt + if p.isType(models.TokenTypeAs) { + p.advance() // consume AS + if p.isIdentifier() { + tableRef.Alias = p.currentToken.Token.Value + p.advance() + } + } else if p.isIdentifier() { + tableRef.Alias = p.currentToken.Token.Value + p.advance() + } + } + return tableRef, nil } @@ -217,6 +257,44 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error } } + // SQL Server / Oracle PIVOT clause + if p.isPivotKeyword() { + pivot, err := p.parsePivotClause() + if err != nil { + return ref, err + } + ref.Pivot = pivot + if p.isType(models.TokenTypeAs) { + p.advance() + if p.isIdentifier() { + ref.Alias = p.currentToken.Token.Value + p.advance() + } + } else if p.isIdentifier() { + ref.Alias = p.currentToken.Token.Value + p.advance() + } + } + + // SQL Server / Oracle UNPIVOT clause + if p.isUnpivotKeyword() { + unpivot, err := p.parseUnpivotClause() + if err != nil { + return ref, err + } + ref.Unpivot = unpivot + if p.isType(models.TokenTypeAs) { + p.advance() + if p.isIdentifier() { + ref.Alias = p.currentToken.Token.Value + p.advance() + } + } else if p.isIdentifier() { + ref.Alias = p.currentToken.Token.Value + p.advance() + } + } + return ref, nil } diff --git a/pkg/sql/parser/tsql_test.go b/pkg/sql/parser/tsql_test.go index 7a105e60..f815aed9 100644 --- a/pkg/sql/parser/tsql_test.go +++ b/pkg/sql/parser/tsql_test.go @@ -355,6 +355,8 @@ func TestTSQL_TestdataFiles(t *testing.T) { "08_window_row_number.sql": true, "09_window_rank.sql": true, "10_window_lag_lead.sql": true, + "11_pivot.sql": true, + "12_unpivot.sql": true, "13_cross_apply.sql": true, "14_outer_apply.sql": true, "15_try_convert.sql": true, @@ -391,9 +393,146 @@ func TestTSQL_TestdataFiles(t *testing.T) { t.Errorf("expected %s to parse, got: %v", name, parseErr) } } else { - // These are known to not yet be supported (PIVOT, UNPIVOT, OPTION) + // These are known to not yet be supported (OPTION) t.Logf("%s: %v (not yet supported)", name, parseErr) } }) } } + +func TestTSQL_PivotBasic(t *testing.T) { + sql := `SELECT * FROM ( + SELECT product, region, sales + FROM sales_data +) AS SourceTable +PIVOT ( + SUM(sales) FOR region IN ([North], [South], [East], [West]) +) AS PivotTable` + + result, err := ParseWithDialect(sql, keywords.DialectSQLServer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(result.Statements)) + } + stmt, ok := result.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", result.Statements[0]) + } + if len(stmt.From) == 0 { + t.Fatal("expected at least one FROM reference") + } + ref := stmt.From[0] + if ref.Pivot == nil { + t.Fatal("expected Pivot clause on table reference") + } + if ref.Pivot.PivotColumn != "region" { + t.Errorf("expected pivot column 'region', got %q", ref.Pivot.PivotColumn) + } + if len(ref.Pivot.InValues) != 4 { + t.Errorf("expected 4 IN values, got %d", len(ref.Pivot.InValues)) + } + expected := []string{"North", "South", "East", "West"} + for i, v := range expected { + if i < len(ref.Pivot.InValues) && ref.Pivot.InValues[i] != v { + t.Errorf("IN value [%d]: expected %q, got %q", i, v, ref.Pivot.InValues[i]) + } + } + if ref.Pivot.AggregateFunction == nil { + t.Error("expected aggregate function in PIVOT") + } + if ref.Alias != "PivotTable" { + t.Errorf("expected alias 'PivotTable', got %q", ref.Alias) + } +} + +func TestTSQL_UnpivotBasic(t *testing.T) { + sql := `SELECT product, region, sales FROM ( + SELECT product, north_sales, south_sales, east_sales, west_sales + FROM regional_sales +) AS SourceTable +UNPIVOT ( + sales FOR region IN (north_sales, south_sales, east_sales, west_sales) +) AS UnpivotTable` + + result, err := ParseWithDialect(sql, keywords.DialectSQLServer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(result.Statements)) + } + stmt, ok := result.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", result.Statements[0]) + } + if len(stmt.From) == 0 { + t.Fatal("expected at least one FROM reference") + } + ref := stmt.From[0] + if ref.Unpivot == nil { + t.Fatal("expected Unpivot clause on table reference") + } + if ref.Unpivot.ValueColumn != "sales" { + t.Errorf("expected value column 'sales', got %q", ref.Unpivot.ValueColumn) + } + if ref.Unpivot.NameColumn != "region" { + t.Errorf("expected name column 'region', got %q", ref.Unpivot.NameColumn) + } + if len(ref.Unpivot.InColumns) != 4 { + t.Errorf("expected 4 IN columns, got %d", len(ref.Unpivot.InColumns)) + } + expected := []string{"north_sales", "south_sales", "east_sales", "west_sales"} + for i, v := range expected { + if i < len(ref.Unpivot.InColumns) && ref.Unpivot.InColumns[i] != v { + t.Errorf("IN column [%d]: expected %q, got %q", i, v, ref.Unpivot.InColumns[i]) + } + } + if ref.Alias != "UnpivotTable" { + t.Errorf("expected alias 'UnpivotTable', got %q", ref.Alias) + } +} + +func TestTSQL_PivotWithoutAlias(t *testing.T) { + sql := `SELECT * FROM sales_data PIVOT (SUM(amount) FOR quarter IN (Q1, Q2, Q3, Q4))` + + result, err := ParseWithDialect(sql, keywords.DialectSQLServer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt, ok := result.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", result.Statements[0]) + } + ref := stmt.From[0] + if ref.Pivot == nil { + t.Fatal("expected Pivot clause") + } + if ref.Name != "sales_data" { + t.Errorf("expected table name 'sales_data', got %q", ref.Name) + } + if ref.Pivot.PivotColumn != "quarter" { + t.Errorf("expected pivot column 'quarter', got %q", ref.Pivot.PivotColumn) + } + if len(ref.Pivot.InValues) != 4 { + t.Errorf("expected 4 IN values, got %d", len(ref.Pivot.InValues)) + } +} + +func TestTSQL_PivotWithASAlias(t *testing.T) { + sql := `SELECT * FROM t PIVOT (COUNT(id) FOR status IN (active, inactive)) AS pvt` + + result, err := ParseWithDialect(sql, keywords.DialectSQLServer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt := result.Statements[0].(*ast.SelectStatement) + ref := stmt.From[0] + if ref.Pivot == nil { + t.Fatal("expected Pivot clause") + } + if ref.Alias != "pvt" { + t.Errorf("expected alias 'pvt', got %q", ref.Alias) + } +} diff --git a/pkg/sql/tokenizer/tokenizer.go b/pkg/sql/tokenizer/tokenizer.go index 39885055..8fbfda60 100644 --- a/pkg/sql/tokenizer/tokenizer.go +++ b/pkg/sql/tokenizer/tokenizer.go @@ -228,6 +228,9 @@ var keywordTokenTypes = map[string]models.TokenType{ // boundary detection. SETTINGS/FORMAT are common words and must NOT be here. "PREWHERE": models.TokenTypeKeyword, "FINAL": models.TokenTypeKeyword, + // SQL Server / Oracle PIVOT/UNPIVOT clause keywords + "PIVOT": models.TokenTypeKeyword, + "UNPIVOT": models.TokenTypeKeyword, } // Tokenizer provides high-performance SQL tokenization with zero-copy operations. From b8c058f0c779928462b975ce7ca95e34360c156e Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Wed, 1 Apr 2026 03:16:28 +0530 Subject: [PATCH 2/6] security: add CVE-2026-32285 to .trivyignore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CVE-2026-32285 affects github.com/buger/jsonparser v1.1.1, which is a transitive dependency via mark3labs/mcp-go → invopop/jsonschema → wk8/go-ordered-map → buger/jsonparser. No fixed version is available upstream. The package is not called directly by any GoSQLX code and risk is scoped to MCP JSON schema generation. Added to .trivyignore until a patched version is released. Fixes Trivy Repository Scan CI failures in PR #475 and #477. --- .trivyignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.trivyignore b/.trivyignore index f8e6bc1d..87b50db3 100644 --- a/.trivyignore +++ b/.trivyignore @@ -2,6 +2,13 @@ # Format: [expiry-date] [comment] # See: https://aquasecurity.github.io/trivy/latest/docs/configuration/filtering/#trivyignore +# CVE-2026-32285 — github.com/buger/jsonparser v1.1.1 +# Severity: HIGH/MEDIUM | No fixed version available (latest is v1.1.1, released 2021-01-08) +# Transitive dependency: mark3labs/mcp-go → invopop/jsonschema → wk8/go-ordered-map → buger/jsonparser +# Not called directly by any GoSQLX code. Risk is scoped to MCP JSON schema generation. +# Re-evaluate when buger/jsonparser releases a patched version or when mcp-go updates its dependency. +CVE-2026-32285 + # GHSA-6g7g-w4f8-9c9x — buger/jsonparser v1.1.1 # Severity: MEDIUM | No fixed version available (latest is v1.1.1, released 2021-01-08) # Transitive dependency: mark3labs/mcp-go → invopop/jsonschema → wk8/go-ordered-map → buger/jsonparser From a1e78c06894b20d7d2d99049bb002773be9a0bea Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Tue, 7 Apr 2026 11:03:00 +0530 Subject: [PATCH 3/6] fix(parser): scope PIVOT/UNPIVOT to SQL Server/Oracle dialects PIVOT and UNPIVOT were registered in the global tokenizer keyword map, which made them reserved words across every dialect. Queries like `SELECT pivot FROM users` then failed in PostgreSQL/MySQL/SQLite/ ClickHouse where these identifiers are perfectly legal. Changes: - Remove PIVOT/UNPIVOT from tokenizer keywordTokenTypes (they are now tokenized as identifiers in all dialects). - Gate isPivotKeyword/isUnpivotKeyword on the parser dialect (TSQL/ Oracle only) and accept identifier-typed tokens by value match. - Skip alias consumption in parseFromTableReference / parseJoinedTableRef when the upcoming identifier is a contextual PIVOT/UNPIVOT keyword, so the pivot-clause parser can claim it. - Fix unsafe single-value type assertion in TestTSQL_PivotWithASAlias to comply with the project's mandatory two-value form. - Add TestPivotIdentifierInNonTSQLDialects regression covering pivot/ unpivot as identifiers in PostgreSQL, MySQL, and SQLite. All parser/tokenizer/formatter tests pass with -race. --- pkg/sql/parser/pivot.go | 35 +++++++++++++++++++++++++------ pkg/sql/parser/select_subquery.go | 8 +++++-- pkg/sql/parser/tsql_test.go | 27 +++++++++++++++++++++++- pkg/sql/tokenizer/tokenizer.go | 7 ++++--- 4 files changed, 65 insertions(+), 12 deletions(-) diff --git a/pkg/sql/parser/pivot.go b/pkg/sql/parser/pivot.go index 646d6da8..b60bbc76 100644 --- a/pkg/sql/parser/pivot.go +++ b/pkg/sql/parser/pivot.go @@ -22,18 +22,41 @@ import ( "github.com/ajitpratap0/GoSQLX/pkg/models" "github.com/ajitpratap0/GoSQLX/pkg/sql/ast" + "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" ) -// isPivotKeyword returns true if the current token is the PIVOT keyword. +// pivotDialectAllowed reports whether PIVOT/UNPIVOT is a recognized clause +// for the parser's current dialect. PIVOT/UNPIVOT are SQL Server / Oracle +// extensions; in other dialects the words must remain valid identifiers. +func (p *Parser) pivotDialectAllowed() bool { + return p.dialect == string(keywords.DialectSQLServer) || + p.dialect == string(keywords.DialectOracle) +} + +// isPivotKeyword returns true if the current token is the contextual PIVOT +// keyword in a dialect that supports it. PIVOT is non-reserved, so it may +// arrive as either an identifier or a keyword token. func (p *Parser) isPivotKeyword() bool { - return p.isType(models.TokenTypeKeyword) && - strings.EqualFold(p.currentToken.Token.Value, "PIVOT") + if !p.pivotDialectAllowed() { + return false + } + t := p.currentToken.Token.Type + if t != models.TokenTypeKeyword && t != models.TokenTypeIdentifier { + return false + } + return strings.EqualFold(p.currentToken.Token.Value, "PIVOT") } -// isUnpivotKeyword returns true if the current token is the UNPIVOT keyword. +// isUnpivotKeyword mirrors isPivotKeyword for UNPIVOT. func (p *Parser) isUnpivotKeyword() bool { - return p.isType(models.TokenTypeKeyword) && - strings.EqualFold(p.currentToken.Token.Value, "UNPIVOT") + if !p.pivotDialectAllowed() { + return false + } + t := p.currentToken.Token.Type + if t != models.TokenTypeKeyword && t != models.TokenTypeIdentifier { + return false + } + return strings.EqualFold(p.currentToken.Token.Value, "UNPIVOT") } // parsePivotClause parses PIVOT (aggregate FOR column IN (values)). diff --git a/pkg/sql/parser/select_subquery.go b/pkg/sql/parser/select_subquery.go index af20082d..f4297053 100644 --- a/pkg/sql/parser/select_subquery.go +++ b/pkg/sql/parser/select_subquery.go @@ -87,7 +87,9 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { // Check for table alias (required for derived tables, optional for regular tables). // Guard: in MariaDB, CONNECT followed by BY is a hierarchical query clause, not an alias. // Similarly, START followed by WITH is a hierarchical query seed, not an alias. - if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() { + // Don't consume PIVOT/UNPIVOT as a table alias — they are contextual + // keywords in SQL Server/Oracle and must reach the pivot-clause parser below. + if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() && !p.isPivotKeyword() && !p.isUnpivotKeyword() { if p.isType(models.TokenTypeAs) { p.advance() // Consume AS if !p.isIdentifier() { @@ -219,7 +221,9 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error // Optional alias. // Guard: in MariaDB, CONNECT followed by BY is a hierarchical query clause, not an alias. // Similarly, START followed by WITH is a hierarchical query seed, not an alias. - if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() { + // Don't consume PIVOT/UNPIVOT as a table alias — they are contextual + // keywords in SQL Server/Oracle and must reach the pivot-clause parser below. + if (p.isIdentifier() || p.isType(models.TokenTypeAs)) && !p.isMariaDBClauseStart() && !p.isPivotKeyword() && !p.isUnpivotKeyword() { if p.isType(models.TokenTypeAs) { p.advance() if !p.isIdentifier() { diff --git a/pkg/sql/parser/tsql_test.go b/pkg/sql/parser/tsql_test.go index f815aed9..a62298d0 100644 --- a/pkg/sql/parser/tsql_test.go +++ b/pkg/sql/parser/tsql_test.go @@ -527,7 +527,10 @@ func TestTSQL_PivotWithASAlias(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - stmt := result.Statements[0].(*ast.SelectStatement) + stmt, ok := result.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", result.Statements[0]) + } ref := stmt.From[0] if ref.Pivot == nil { t.Fatal("expected Pivot clause") @@ -536,3 +539,25 @@ func TestTSQL_PivotWithASAlias(t *testing.T) { t.Errorf("expected alias 'pvt', got %q", ref.Alias) } } + +// TestPivotIdentifierInNonTSQLDialects verifies PIVOT/UNPIVOT remain valid +// identifiers in dialects that don't recognize the contextual clause. +// Regression for global-tokenizer-keyword leak. +func TestPivotIdentifierInNonTSQLDialects(t *testing.T) { + cases := []struct { + dialect keywords.SQLDialect + sql string + }{ + {keywords.DialectPostgreSQL, "SELECT pivot FROM users"}, + {keywords.DialectPostgreSQL, "SELECT unpivot FROM users"}, + {keywords.DialectMySQL, "SELECT pivot, unpivot FROM t"}, + {keywords.DialectSQLite, "SELECT pivot FROM t AS pivot"}, + } + for _, tc := range cases { + t.Run(string(tc.dialect)+"_"+tc.sql, func(t *testing.T) { + if _, err := ParseWithDialect(tc.sql, tc.dialect); err != nil { + t.Fatalf("expected pivot/unpivot to parse as identifier in %s, got: %v", tc.dialect, err) + } + }) + } +} diff --git a/pkg/sql/tokenizer/tokenizer.go b/pkg/sql/tokenizer/tokenizer.go index 8fbfda60..8b01ad51 100644 --- a/pkg/sql/tokenizer/tokenizer.go +++ b/pkg/sql/tokenizer/tokenizer.go @@ -228,9 +228,10 @@ var keywordTokenTypes = map[string]models.TokenType{ // boundary detection. SETTINGS/FORMAT are common words and must NOT be here. "PREWHERE": models.TokenTypeKeyword, "FINAL": models.TokenTypeKeyword, - // SQL Server / Oracle PIVOT/UNPIVOT clause keywords - "PIVOT": models.TokenTypeKeyword, - "UNPIVOT": models.TokenTypeKeyword, + // NOTE: PIVOT/UNPIVOT are intentionally NOT in this global map. + // They are non-reserved contextual keywords (legal identifiers in + // PostgreSQL/MySQL/SQLite/ClickHouse), so the parser matches them + // by value+dialect via isPivotKeyword/isUnpivotKeyword instead. } // Tokenizer provides high-performance SQL tokenization with zero-copy operations. From ec3f4d4a1679a1e5f4f3fdc642fc0b39fd920d7f Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Tue, 7 Apr 2026 11:31:32 +0530 Subject: [PATCH 4/6] fix(parser): polish PIVOT/UNPIVOT round-trip and validation - Formatter emits AS before PIVOT/UNPIVOT aliases for clean round-trip. - Tokenizer records Quote='[' on SQL Server bracket-quoted identifiers; pivot parser uses renderQuotedIdent to preserve [North] etc. in PivotClause.InValues and UnpivotClause.InColumns. - Reject empty IN lists for both PIVOT and UNPIVOT. - Extract parsePivotAlias helper, collapsing four duplicated alias blocks in select_subquery.go. - Add TestPivotNegativeCases (missing parens, missing FOR/IN, empty IN) and TestPivotBracketedInValuesPreserved. Full test suite passes with -race. --- pkg/formatter/render.go | 8 ++++- pkg/sql/parser/pivot.go | 47 +++++++++++++++++++++++++++-- pkg/sql/parser/select_subquery.go | 46 +++------------------------- pkg/sql/parser/tsql_test.go | 50 ++++++++++++++++++++++++++++++- pkg/sql/tokenizer/tokenizer.go | 2 +- 5 files changed, 106 insertions(+), 47 deletions(-) diff --git a/pkg/formatter/render.go b/pkg/formatter/render.go index c1ea1501..c9c9d2a3 100644 --- a/pkg/formatter/render.go +++ b/pkg/formatter/render.go @@ -1205,7 +1205,13 @@ func tableRefSQL(t *ast.TableReference) string { sb.WriteString("))") } if t.Alias != "" { - sb.WriteString(" ") + // PIVOT/UNPIVOT aliases conventionally use AS to avoid ambiguity + // with the closing paren of the clause. + if t.Pivot != nil || t.Unpivot != nil { + sb.WriteString(" AS ") + } else { + sb.WriteString(" ") + } sb.WriteString(t.Alias) } return sb.String() diff --git a/pkg/sql/parser/pivot.go b/pkg/sql/parser/pivot.go index b60bbc76..8cf9f5ef 100644 --- a/pkg/sql/parser/pivot.go +++ b/pkg/sql/parser/pivot.go @@ -25,6 +25,43 @@ import ( "github.com/ajitpratap0/GoSQLX/pkg/sql/keywords" ) +// renderQuotedIdent reproduces the original delimiters of a quoted identifier +// token so the parsed value round-trips through the formatter. The tokenizer +// strips delimiters but records them in QuoteStyle. +func renderQuotedIdent(tok models.Token) string { + q := tok.Quote + if q == 0 && tok.Word != nil { + q = tok.Word.QuoteStyle + } + switch q { + case '[': + return "[" + tok.Value + "]" + case '"': + return "\"" + tok.Value + "\"" + case '`': + return "`" + tok.Value + "`" + } + return tok.Value +} + +// parsePivotAlias consumes an optional alias (with or without AS) following a +// PIVOT/UNPIVOT clause. Extracted to avoid four copies of the same logic in +// the table-reference and join paths. +func (p *Parser) parsePivotAlias(ref *ast.TableReference) { + if p.isType(models.TokenTypeAs) { + p.advance() // consume AS + if p.isIdentifier() { + ref.Alias = p.currentToken.Token.Value + p.advance() + } + return + } + if p.isIdentifier() { + ref.Alias = p.currentToken.Token.Value + p.advance() + } +} + // pivotDialectAllowed reports whether PIVOT/UNPIVOT is a recognized clause // for the parser's current dialect. PIVOT/UNPIVOT are SQL Server / Oracle // extensions; in other dialects the words must remain valid identifiers. @@ -107,7 +144,7 @@ func (p *Parser) parsePivotClause() (*ast.PivotClause, error) { if !p.isIdentifier() && !p.isType(models.TokenTypeNumber) && !p.isStringLiteral() { return nil, p.expectedError("value in PIVOT IN list") } - inValues = append(inValues, p.currentToken.Token.Value) + inValues = append(inValues, renderQuotedIdent(p.currentToken.Token)) p.advance() if p.isType(models.TokenTypeComma) { p.advance() @@ -117,6 +154,9 @@ func (p *Parser) parsePivotClause() (*ast.PivotClause, error) { if !p.isType(models.TokenTypeRParen) { return nil, p.expectedError(") to close PIVOT IN list") } + if len(inValues) == 0 { + return nil, p.expectedError("at least one value in PIVOT IN list") + } p.advance() // close IN list ) if !p.isType(models.TokenTypeRParen) { @@ -181,7 +221,7 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) { if !p.isIdentifier() { return nil, p.expectedError("column name in UNPIVOT IN list") } - cols = append(cols, p.currentToken.Token.Value) + cols = append(cols, renderQuotedIdent(p.currentToken.Token)) p.advance() if p.isType(models.TokenTypeComma) { p.advance() @@ -191,6 +231,9 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) { if !p.isType(models.TokenTypeRParen) { return nil, p.expectedError(") to close UNPIVOT IN list") } + if len(cols) == 0 { + return nil, p.expectedError("at least one column in UNPIVOT IN list") + } p.advance() // close IN list ) if !p.isType(models.TokenTypeRParen) { diff --git a/pkg/sql/parser/select_subquery.go b/pkg/sql/parser/select_subquery.go index f4297053..2d4a2040 100644 --- a/pkg/sql/parser/select_subquery.go +++ b/pkg/sql/parser/select_subquery.go @@ -134,17 +134,7 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { return tableRef, err } tableRef.Pivot = pivot - // PIVOT result often has its own alias: PIVOT (...) AS pvt - if p.isType(models.TokenTypeAs) { - p.advance() // consume AS - if p.isIdentifier() { - tableRef.Alias = p.currentToken.Token.Value - p.advance() - } - } else if p.isIdentifier() { - tableRef.Alias = p.currentToken.Token.Value - p.advance() - } + p.parsePivotAlias(&tableRef) } // SQL Server / Oracle UNPIVOT clause @@ -154,17 +144,7 @@ func (p *Parser) parseFromTableReference() (ast.TableReference, error) { return tableRef, err } tableRef.Unpivot = unpivot - // UNPIVOT result alias: UNPIVOT (...) AS unpvt - if p.isType(models.TokenTypeAs) { - p.advance() // consume AS - if p.isIdentifier() { - tableRef.Alias = p.currentToken.Token.Value - p.advance() - } - } else if p.isIdentifier() { - tableRef.Alias = p.currentToken.Token.Value - p.advance() - } + p.parsePivotAlias(&tableRef) } return tableRef, nil @@ -268,16 +248,7 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error return ref, err } ref.Pivot = pivot - if p.isType(models.TokenTypeAs) { - p.advance() - if p.isIdentifier() { - ref.Alias = p.currentToken.Token.Value - p.advance() - } - } else if p.isIdentifier() { - ref.Alias = p.currentToken.Token.Value - p.advance() - } + p.parsePivotAlias(&ref) } // SQL Server / Oracle UNPIVOT clause @@ -287,16 +258,7 @@ func (p *Parser) parseJoinedTableRef(joinType string) (ast.TableReference, error return ref, err } ref.Unpivot = unpivot - if p.isType(models.TokenTypeAs) { - p.advance() - if p.isIdentifier() { - ref.Alias = p.currentToken.Token.Value - p.advance() - } - } else if p.isIdentifier() { - ref.Alias = p.currentToken.Token.Value - p.advance() - } + p.parsePivotAlias(&ref) } return ref, nil diff --git a/pkg/sql/parser/tsql_test.go b/pkg/sql/parser/tsql_test.go index a62298d0..675f65b3 100644 --- a/pkg/sql/parser/tsql_test.go +++ b/pkg/sql/parser/tsql_test.go @@ -433,7 +433,7 @@ PIVOT ( if len(ref.Pivot.InValues) != 4 { t.Errorf("expected 4 IN values, got %d", len(ref.Pivot.InValues)) } - expected := []string{"North", "South", "East", "West"} + expected := []string{"[North]", "[South]", "[East]", "[West]"} for i, v := range expected { if i < len(ref.Pivot.InValues) && ref.Pivot.InValues[i] != v { t.Errorf("IN value [%d]: expected %q, got %q", i, v, ref.Pivot.InValues[i]) @@ -540,6 +540,54 @@ func TestTSQL_PivotWithASAlias(t *testing.T) { } } +// TestPivotNegativeCases covers parser error paths for malformed PIVOT/UNPIVOT. +func TestPivotNegativeCases(t *testing.T) { + cases := []struct { + name string + sql string + }{ + {"missing_lparen", "SELECT * FROM t PIVOT SUM(x) FOR c IN (a))"}, + {"missing_for", "SELECT * FROM t PIVOT (SUM(x) c IN (a))"}, + {"missing_in", "SELECT * FROM t PIVOT (SUM(x) FOR c (a))"}, + {"missing_in_lparen", "SELECT * FROM t PIVOT (SUM(x) FOR c IN a)"}, + {"empty_in_list", "SELECT * FROM t PIVOT (SUM(x) FOR c IN ())"}, + {"unpivot_missing_for", "SELECT * FROM t UNPIVOT (v c IN (a))"}, + {"unpivot_empty_in_list", "SELECT * FROM t UNPIVOT (v FOR n IN ())"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := ParseWithDialect(tc.sql, keywords.DialectSQLServer) + if err == nil { + t.Fatalf("expected parse error, got nil for: %s", tc.sql) + } + }) + } +} + +// TestPivotBracketedInValuesPreserved verifies SQL Server bracket-quoted IN +// values survive parsing so the formatter can re-emit them. +func TestPivotBracketedInValuesPreserved(t *testing.T) { + sql := `SELECT * FROM sales PIVOT (SUM(amt) FOR region IN ([North], [South])) AS p` + result, err := ParseWithDialect(sql, keywords.DialectSQLServer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + stmt, ok := result.Statements[0].(*ast.SelectStatement) + if !ok { + t.Fatalf("expected SelectStatement, got %T", result.Statements[0]) + } + got := stmt.From[0].Pivot.InValues + want := []string{"[North]", "[South]"} + if len(got) != len(want) { + t.Fatalf("expected %d values, got %d (%v)", len(want), len(got), got) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("InValues[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + // TestPivotIdentifierInNonTSQLDialects verifies PIVOT/UNPIVOT remain valid // identifiers in dialects that don't recognize the contextual clause. // Regression for global-tokenizer-keyword leak. diff --git a/pkg/sql/tokenizer/tokenizer.go b/pkg/sql/tokenizer/tokenizer.go index 8b01ad51..f823cda9 100644 --- a/pkg/sql/tokenizer/tokenizer.go +++ b/pkg/sql/tokenizer/tokenizer.go @@ -1283,7 +1283,7 @@ func (t *Tokenizer) readPunctuation() (models.Token, error) { ch, chSize := utf8.DecodeRune(t.input[t.pos.Index:]) if ch == ']' { t.pos.AdvanceRune(ch, chSize) // Consume ] - return models.Token{Type: models.TokenTypeIdentifier, Value: string(ident)}, nil + return models.Token{Type: models.TokenTypeIdentifier, Value: string(ident), Quote: '['}, nil } ident = append(ident, t.input[t.pos.Index:t.pos.Index+chSize]...) t.pos.AdvanceRune(ch, chSize) From 865e8c6a5f6b8b7b8aea8f35174404193845eb90 Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Tue, 7 Apr 2026 11:46:19 +0530 Subject: [PATCH 5/6] fix(parser): escape embedded delimiters and fix empty-IN error order - renderQuotedIdent now doubles embedded `]`, `"`, and `` ` `` per dialect convention so identifiers like [foo]bar] round-trip unambiguously. - Empty PIVOT/UNPIVOT IN list check now runs before the closing-`)` check so the user-facing error is "at least one value/column..." instead of the misleading ") to close ... IN list". - Clarify renderQuotedIdent comment to reference Token.Quote and Word.QuoteStyle as the actual sources. --- pkg/sql/parser/pivot.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/pkg/sql/parser/pivot.go b/pkg/sql/parser/pivot.go index 8cf9f5ef..ab03fdf6 100644 --- a/pkg/sql/parser/pivot.go +++ b/pkg/sql/parser/pivot.go @@ -27,7 +27,9 @@ import ( // renderQuotedIdent reproduces the original delimiters of a quoted identifier // token so the parsed value round-trips through the formatter. The tokenizer -// strips delimiters but records them in QuoteStyle. +// strips delimiters but records the style in Token.Quote (or, for word +// tokens, Word.QuoteStyle). Embedded delimiters are escaped per dialect: +// SQL Server doubles `]`, ANSI doubles `"`, MySQL doubles “ ` “. func renderQuotedIdent(tok models.Token) string { q := tok.Quote if q == 0 && tok.Word != nil { @@ -35,11 +37,11 @@ func renderQuotedIdent(tok models.Token) string { } switch q { case '[': - return "[" + tok.Value + "]" + return "[" + strings.ReplaceAll(tok.Value, "]", "]]") + "]" case '"': - return "\"" + tok.Value + "\"" + return "\"" + strings.ReplaceAll(tok.Value, "\"", "\"\"") + "\"" case '`': - return "`" + tok.Value + "`" + return "`" + strings.ReplaceAll(tok.Value, "`", "``") + "`" } return tok.Value } @@ -151,12 +153,12 @@ func (p *Parser) parsePivotClause() (*ast.PivotClause, error) { } } - if !p.isType(models.TokenTypeRParen) { - return nil, p.expectedError(") to close PIVOT IN list") - } if len(inValues) == 0 { return nil, p.expectedError("at least one value in PIVOT IN list") } + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close PIVOT IN list") + } p.advance() // close IN list ) if !p.isType(models.TokenTypeRParen) { @@ -228,12 +230,12 @@ func (p *Parser) parseUnpivotClause() (*ast.UnpivotClause, error) { } } - if !p.isType(models.TokenTypeRParen) { - return nil, p.expectedError(") to close UNPIVOT IN list") - } if len(cols) == 0 { return nil, p.expectedError("at least one column in UNPIVOT IN list") } + if !p.isType(models.TokenTypeRParen) { + return nil, p.expectedError(") to close UNPIVOT IN list") + } p.advance() // close IN list ) if !p.isType(models.TokenTypeRParen) { From dd0215b6f64127d7fdcae366538e4426f5bbd925 Mon Sep 17 00:00:00 2001 From: Ajit Pratap Singh Date: Tue, 7 Apr 2026 12:26:05 +0530 Subject: [PATCH 6/6] refactor(formatter): thread nodeFormatter through tableRefSQL and joinSQL Previously these package-level renderers hardcoded keyword literals (PIVOT, UNPIVOT, FOR, IN, LATERAL, JOIN, ON) which bypassed the caller's case policy (f.kw). Thread *nodeFormatter into both functions and route every keyword through f.kw so uppercase/lowercase options apply uniformly across FROM, JOIN, MERGE, DELETE USING, and UPDATE FROM paths. Addresses claude-review feedback on PR #477. All formatter, parser, and tokenizer tests pass with -race. --- pkg/formatter/render.go | 62 +++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/pkg/formatter/render.go b/pkg/formatter/render.go index c9c9d2a3..29669012 100644 --- a/pkg/formatter/render.go +++ b/pkg/formatter/render.go @@ -227,7 +227,7 @@ func renderSelect(s *ast.SelectStatement, opts ast.FormatOptions) string { sb.WriteString(" ") froms := make([]string, len(s.From)) for i := range s.From { - froms[i] = tableRefSQL(&s.From[i]) + froms[i] = tableRefSQL(&s.From[i], f) } sb.WriteString(strings.Join(froms, ", ")) } @@ -235,7 +235,7 @@ func renderSelect(s *ast.SelectStatement, opts ast.FormatOptions) string { for _, j := range s.Joins { j := j sb.WriteString(f.clauseSep()) - sb.WriteString(joinSQL(&j)) + sb.WriteString(joinSQL(&j, f)) } if s.Sample != nil { @@ -400,7 +400,7 @@ func renderUpdate(u *ast.UpdateStatement, opts ast.FormatOptions) string { sb.WriteString(" ") froms := make([]string, len(u.From)) for i := range u.From { - froms[i] = tableRefSQL(&u.From[i]) + froms[i] = tableRefSQL(&u.From[i], f) } sb.WriteString(strings.Join(froms, ", ")) } @@ -452,7 +452,7 @@ func renderDelete(d *ast.DeleteStatement, opts ast.FormatOptions) string { sb.WriteString(" ") usings := make([]string, len(d.Using)) for i := range d.Using { - usings[i] = tableRefSQL(&d.Using[i]) + usings[i] = tableRefSQL(&d.Using[i], f) } sb.WriteString(strings.Join(usings, ", ")) } @@ -883,7 +883,7 @@ func renderMerge(m *ast.MergeStatement, opts ast.FormatOptions) string { sb.WriteString(f.kw("MERGE INTO")) sb.WriteString(" ") - sb.WriteString(tableRefSQL(&m.TargetTable)) + sb.WriteString(tableRefSQL(&m.TargetTable, f)) if m.TargetAlias != "" { sb.WriteString(" ") sb.WriteString(m.TargetAlias) @@ -892,7 +892,7 @@ func renderMerge(m *ast.MergeStatement, opts ast.FormatOptions) string { sb.WriteString(f.clauseSep()) sb.WriteString(f.kw("USING")) sb.WriteString(" ") - sb.WriteString(tableRefSQL(&m.SourceTable)) + sb.WriteString(tableRefSQL(&m.SourceTable, f)) if m.SourceAlias != "" { sb.WriteString(" ") sb.WriteString(m.SourceAlias) @@ -1173,11 +1173,13 @@ func orderBySQL(orders []ast.OrderByExpression) string { return strings.Join(parts, ", ") } -// tableRefSQL renders a TableReference. -func tableRefSQL(t *ast.TableReference) string { +// tableRefSQL renders a TableReference. The formatter is threaded through +// so PIVOT/UNPIVOT/FOR/IN/AS/LATERAL keywords honor the caller's case policy. +func tableRefSQL(t *ast.TableReference, f *nodeFormatter) string { var sb strings.Builder if t.Lateral { - sb.WriteString("LATERAL ") + sb.WriteString(f.kw("LATERAL")) + sb.WriteString(" ") } if t.Subquery != nil { sb.WriteString("(") @@ -1187,20 +1189,32 @@ func tableRefSQL(t *ast.TableReference) string { sb.WriteString(t.Name) } if t.Pivot != nil { - sb.WriteString(" PIVOT (") + sb.WriteString(" ") + sb.WriteString(f.kw("PIVOT")) + sb.WriteString(" (") sb.WriteString(exprSQL(t.Pivot.AggregateFunction)) - sb.WriteString(" FOR ") + sb.WriteString(" ") + sb.WriteString(f.kw("FOR")) + sb.WriteString(" ") sb.WriteString(t.Pivot.PivotColumn) - sb.WriteString(" IN (") + sb.WriteString(" ") + sb.WriteString(f.kw("IN")) + sb.WriteString(" (") sb.WriteString(strings.Join(t.Pivot.InValues, ", ")) sb.WriteString("))") } if t.Unpivot != nil { - sb.WriteString(" UNPIVOT (") + sb.WriteString(" ") + sb.WriteString(f.kw("UNPIVOT")) + sb.WriteString(" (") sb.WriteString(t.Unpivot.ValueColumn) - sb.WriteString(" FOR ") + sb.WriteString(" ") + sb.WriteString(f.kw("FOR")) + sb.WriteString(" ") sb.WriteString(t.Unpivot.NameColumn) - sb.WriteString(" IN (") + sb.WriteString(" ") + sb.WriteString(f.kw("IN")) + sb.WriteString(" (") sb.WriteString(strings.Join(t.Unpivot.InColumns, ", ")) sb.WriteString("))") } @@ -1208,7 +1222,9 @@ func tableRefSQL(t *ast.TableReference) string { // PIVOT/UNPIVOT aliases conventionally use AS to avoid ambiguity // with the closing paren of the clause. if t.Pivot != nil || t.Unpivot != nil { - sb.WriteString(" AS ") + sb.WriteString(" ") + sb.WriteString(f.kw("AS")) + sb.WriteString(" ") } else { sb.WriteString(" ") } @@ -1241,13 +1257,17 @@ func sampleSQL(s *ast.SampleClause, f *nodeFormatter) string { } // joinSQL renders a JOIN clause. -func joinSQL(j *ast.JoinClause) string { +func joinSQL(j *ast.JoinClause, f *nodeFormatter) string { var sb strings.Builder - sb.WriteString(j.Type) - sb.WriteString(" JOIN ") - sb.WriteString(tableRefSQL(&j.Right)) + sb.WriteString(f.kw(j.Type)) + sb.WriteString(" ") + sb.WriteString(f.kw("JOIN")) + sb.WriteString(" ") + sb.WriteString(tableRefSQL(&j.Right, f)) if j.Condition != nil { - sb.WriteString(" ON ") + sb.WriteString(" ") + sb.WriteString(f.kw("ON")) + sb.WriteString(" ") sb.WriteString(exprSQL(j.Condition)) } return sb.String()