Skip to content
74 changes: 59 additions & 15 deletions pkg/formatter/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ func renderSelect(s *ast.SelectStatement, opts ast.FormatOptions) string {
sb.WriteString(" ")
froms := make([]string, len(s.From))
for i := range s.From {
froms[i] = tableRefSQL(&s.From[i])
froms[i] = tableRefSQL(&s.From[i], f)
}
sb.WriteString(strings.Join(froms, ", "))
}

for _, j := range s.Joins {
j := j
sb.WriteString(f.clauseSep())
sb.WriteString(joinSQL(&j))
sb.WriteString(joinSQL(&j, f))
}

if s.Sample != nil {
Expand Down Expand Up @@ -400,7 +400,7 @@ func renderUpdate(u *ast.UpdateStatement, opts ast.FormatOptions) string {
sb.WriteString(" ")
froms := make([]string, len(u.From))
for i := range u.From {
froms[i] = tableRefSQL(&u.From[i])
froms[i] = tableRefSQL(&u.From[i], f)
}
sb.WriteString(strings.Join(froms, ", "))
}
Expand Down Expand Up @@ -452,7 +452,7 @@ func renderDelete(d *ast.DeleteStatement, opts ast.FormatOptions) string {
sb.WriteString(" ")
usings := make([]string, len(d.Using))
for i := range d.Using {
usings[i] = tableRefSQL(&d.Using[i])
usings[i] = tableRefSQL(&d.Using[i], f)
}
sb.WriteString(strings.Join(usings, ", "))
}
Expand Down Expand Up @@ -883,7 +883,7 @@ func renderMerge(m *ast.MergeStatement, opts ast.FormatOptions) string {

sb.WriteString(f.kw("MERGE INTO"))
sb.WriteString(" ")
sb.WriteString(tableRefSQL(&m.TargetTable))
sb.WriteString(tableRefSQL(&m.TargetTable, f))
if m.TargetAlias != "" {
sb.WriteString(" ")
sb.WriteString(m.TargetAlias)
Expand All @@ -892,7 +892,7 @@ func renderMerge(m *ast.MergeStatement, opts ast.FormatOptions) string {
sb.WriteString(f.clauseSep())
sb.WriteString(f.kw("USING"))
sb.WriteString(" ")
sb.WriteString(tableRefSQL(&m.SourceTable))
sb.WriteString(tableRefSQL(&m.SourceTable, f))
if m.SourceAlias != "" {
sb.WriteString(" ")
sb.WriteString(m.SourceAlias)
Expand Down Expand Up @@ -1173,11 +1173,13 @@ func orderBySQL(orders []ast.OrderByExpression) string {
return strings.Join(parts, ", ")
}

// tableRefSQL renders a TableReference.
func tableRefSQL(t *ast.TableReference) string {
// tableRefSQL renders a TableReference. The formatter is threaded through
// so PIVOT/UNPIVOT/FOR/IN/AS/LATERAL keywords honor the caller's case policy.
func tableRefSQL(t *ast.TableReference, f *nodeFormatter) string {
var sb strings.Builder
if t.Lateral {
sb.WriteString("LATERAL ")
sb.WriteString(f.kw("LATERAL"))
sb.WriteString(" ")
}
if t.Subquery != nil {
sb.WriteString("(")
Expand All @@ -1186,8 +1188,46 @@ func tableRefSQL(t *ast.TableReference) string {
} else {
sb.WriteString(t.Name)
}
if t.Alias != "" {
if t.Pivot != nil {
sb.WriteString(" ")
sb.WriteString(f.kw("PIVOT"))
sb.WriteString(" (")
sb.WriteString(exprSQL(t.Pivot.AggregateFunction))
sb.WriteString(" ")
sb.WriteString(f.kw("FOR"))
sb.WriteString(" ")
sb.WriteString(t.Pivot.PivotColumn)
sb.WriteString(" ")
sb.WriteString(f.kw("IN"))
sb.WriteString(" (")
sb.WriteString(strings.Join(t.Pivot.InValues, ", "))
sb.WriteString("))")
}
if t.Unpivot != nil {
sb.WriteString(" ")
sb.WriteString(f.kw("UNPIVOT"))
sb.WriteString(" (")
sb.WriteString(t.Unpivot.ValueColumn)
sb.WriteString(" ")
sb.WriteString(f.kw("FOR"))
sb.WriteString(" ")
sb.WriteString(t.Unpivot.NameColumn)
sb.WriteString(" ")
sb.WriteString(f.kw("IN"))
sb.WriteString(" (")
sb.WriteString(strings.Join(t.Unpivot.InColumns, ", "))
sb.WriteString("))")
}
if t.Alias != "" {
// PIVOT/UNPIVOT aliases conventionally use AS to avoid ambiguity
// with the closing paren of the clause.
if t.Pivot != nil || t.Unpivot != nil {
sb.WriteString(" ")
sb.WriteString(f.kw("AS"))
sb.WriteString(" ")
} else {
sb.WriteString(" ")
}
sb.WriteString(t.Alias)
}
return sb.String()
Expand Down Expand Up @@ -1217,13 +1257,17 @@ func sampleSQL(s *ast.SampleClause, f *nodeFormatter) string {
}

// joinSQL renders a JOIN clause.
func joinSQL(j *ast.JoinClause) string {
func joinSQL(j *ast.JoinClause, f *nodeFormatter) string {
var sb strings.Builder
sb.WriteString(j.Type)
sb.WriteString(" JOIN ")
sb.WriteString(tableRefSQL(&j.Right))
sb.WriteString(f.kw(j.Type))
sb.WriteString(" ")
sb.WriteString(f.kw("JOIN"))
sb.WriteString(" ")
sb.WriteString(tableRefSQL(&j.Right, f))
if j.Condition != nil {
sb.WriteString(" ON ")
sb.WriteString(" ")
sb.WriteString(f.kw("ON"))
sb.WriteString(" ")
sb.WriteString(exprSQL(j.Condition))
}
return sb.String()
Expand Down
52 changes: 50 additions & 2 deletions pkg/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ type TableReference struct {
// ForSystemTime is the MariaDB temporal table clause (10.3.4+).
// Example: SELECT * FROM t FOR SYSTEM_TIME AS OF '2024-01-01'
ForSystemTime *ForSystemTimeClause // MariaDB temporal query
// Pivot is the SQL Server / Oracle PIVOT clause for row-to-column transformation.
// Example: SELECT * FROM t PIVOT (SUM(sales) FOR region IN ([North], [South])) AS pvt
Pivot *PivotClause
// Unpivot is the SQL Server / Oracle UNPIVOT clause for column-to-row transformation.
// Example: SELECT * FROM t UNPIVOT (sales FOR region IN (north_sales, south_sales)) AS unpvt
Unpivot *UnpivotClause
}

func (t *TableReference) statementNode() {}
Expand All @@ -244,10 +250,17 @@ func (t TableReference) TokenLiteral() string {
return "subquery"
}
func (t TableReference) Children() []Node {
var nodes []Node
if t.Subquery != nil {
return []Node{t.Subquery}
nodes = append(nodes, t.Subquery)
}
return nil
if t.Pivot != nil {
nodes = append(nodes, t.Pivot)
}
if t.Unpivot != nil {
nodes = append(nodes, t.Unpivot)
}
return nodes
}

// OrderByExpression represents an ORDER BY clause element with direction and NULL ordering
Expand Down Expand Up @@ -1969,6 +1982,41 @@ func (c ForSystemTimeClause) Children() []Node {
return nodes
}

// PivotClause represents the SQL Server / Oracle PIVOT operator for row-to-column
// transformation in a FROM clause.
//
// PIVOT (SUM(sales) FOR region IN ([North], [South], [East], [West])) AS pvt
type PivotClause struct {
AggregateFunction Expression // The aggregate function, e.g. SUM(sales)
PivotColumn string // The column used for pivoting, e.g. region
InValues []string // The values to pivot on, e.g. [North], [South]
Pos models.Location // Source position of the PIVOT keyword
}

func (p *PivotClause) expressionNode() {}
func (p PivotClause) TokenLiteral() string { return "PIVOT" }
func (p PivotClause) Children() []Node {
if p.AggregateFunction != nil {
return []Node{p.AggregateFunction}
}
return nil
}

// UnpivotClause represents the SQL Server / Oracle UNPIVOT operator for column-to-row
// transformation in a FROM clause.
//
// UNPIVOT (sales FOR region IN (north_sales, south_sales, east_sales)) AS unpvt
type UnpivotClause struct {
ValueColumn string // The target value column, e.g. sales
NameColumn string // The target name column, e.g. region
InColumns []string // The source columns to unpivot, e.g. north_sales, south_sales
Pos models.Location // Source position of the UNPIVOT keyword
}

func (u *UnpivotClause) expressionNode() {}
func (u UnpivotClause) TokenLiteral() string { return "UNPIVOT" }
func (u UnpivotClause) Children() []Node { return nil }

// PeriodDefinition represents a PERIOD FOR clause in CREATE TABLE.
//
// PERIOD FOR app_time (start_col, end_col)
Expand Down
Loading
Loading