package sql import ( "expvar" "fmt" "math" "math/rand/v2" "strconv" "strings" "time" "github.com/rqlite/rqlite/v9/command/proto" "github.com/rqlite/rqlite/v9/internal/random" "github.com/rqlite/sql" rsql "github.com/rqlite/sql" ) const ( numRewrittenStmts = "num_rewritten_stmts" numParserPanics = "num_parser_panics" ) // stats captures stats for the SQL processor. var stats *expvar.Map func init() { stats = expvar.NewMap("sql-processor") ResetStats() } // ResetStats resets the expvar stats for this module. Mostly for test purposes. func ResetStats() { stats.Init() stats.Add(numRewrittenStmts, 0) stats.Add(numParserPanics, 0) } // Process processes the given SQL statements, rewriting them if necessary. If // random-rewriting is enabled, calls to the RANDOM() function are replaced with // an actual random value. If a statement contains a RETURNING clause, the // statement is marked as a query, so that the result set can be returned to the // client. func Process(stmts []*proto.Statement, rwrand, rwtime bool) (retErr error) { defer func() { if r := recover(); r != nil { stats.Add(numParserPanics, 1) retErr = fmt.Errorf("panic during SQL processing: %v", r) } }() for i := range stmts { lowered := strings.ToLower(stmts[i].Sql) if (!rwtime || !ContainsTime(lowered)) && (!rwrand || !ContainsRandom(lowered)) && !ContainsReturning(lowered) && !ContainsExplain(lowered) { continue } parsed, err := rsql.NewParser(strings.NewReader(stmts[i].Sql)).ParseStatement() if err != nil { continue } _, stmts[i].SqlExplain = parsed.(*sql.ExplainStatement) rewriter := NewRewriter() rewriter.RewriteRand = rwrand rewriter.RewriteTime = rwtime rwStmt, rewritten, ret, err := rewriter.Do(parsed) if err != nil { continue } if rewritten { stats.Add(numRewrittenStmts, 1) stmts[i].Sql = rwStmt.String() } stmts[i].ForceQuery = ret } return nil } // ContainsTime returns true if the statement contains a time-related function. // The function performs a lower-case comparison so it is up to the caller to // ensure the statement is lower-cased. func ContainsTime(stmt string) bool { // Since this is a simple substring search, it also matches datetime( // and strftime(. targets := []string{"time(", "date(", "julianday(", "unixepoch(", "timediff("} for _, target := range targets { if strings.Contains(stmt, target) { return true } } return false } // ContainsRandom returns true if the statement contains a random-related function. // The function performs a lower-case comparison so it is up to the caller to // ensure the statement is lower-cased. func ContainsRandom(stmt string) bool { targets := []string{"random(", "randomblob("} for _, target := range targets { if strings.Contains(stmt, target) { return true } } return false } // ContainsReturning returns true if the statement contains a RETURNING clause. // The function performs a lower-case comparison so it is up to the caller to // ensure the statement is lower-cased. func ContainsReturning(stmt string) bool { return strings.Contains(stmt, "returning ") } // ContainsExplain returns true if the statement contains an EXPLAIN clause. // The function performs a lower-case comparison so it is up to the caller to // ensure the statement is lower-cased. func ContainsExplain(stmt string) bool { return strings.Contains(stmt, "explain ") } // Rewriter rewrites SQL statements. type Rewriter struct { RewriteRand bool RewriteTime bool randFn func() int64 nowFn func() time.Time orderedBy bool modified bool returning bool } // NewRewriter returns a new Rewriter. This object is not thread // safe, and should not be shared between goroutines. func NewRewriter() *Rewriter { return &Rewriter{ RewriteRand: true, RewriteTime: true, randFn: rand.Int64, nowFn: time.Now, } } // Do rewrites the provided statement. If the statement is rewritten, the second return value is true. func (rw *Rewriter) Do(stmt sql.Statement) (sql.Statement, bool, bool, error) { rw.modified = false node, err := sql.Walk(rw, stmt) if err != nil { return nil, false, false, err } return node.(sql.Statement), rw.modified, rw.returning, nil } func (rw *Rewriter) Visit(node sql.Node) (w sql.Visitor, n sql.Node, err error) { retNode := node switch n := retNode.(type) { case *sql.ReturningClause: rw.returning = true case *sql.OrderingTerm: // NO random() rewriting past this point. rw.orderedBy = true return rw, node, nil case *sql.Call: // If used, ensure the value is same for the duration of the statement jd := julianDayAsNumberLit(rw.nowFn()) if rw.RewriteTime && len(n.Args) > 0 && (strings.EqualFold(n.Name.Name, "date") || strings.EqualFold(n.Name.Name, "time") || strings.EqualFold(n.Name.Name, "datetime") || strings.EqualFold(n.Name.Name, "julianday") || strings.EqualFold(n.Name.Name, "unixepoch")) { if isNow(n.Args[0]) { n.Args[0] = jd } rw.modified = true } else if rw.RewriteTime && len(n.Args) > 1 && strings.EqualFold(n.Name.Name, "strftime") { if isNow(n.Args[1]) { n.Args[1] = jd } rw.modified = true } else if rw.RewriteTime && len(n.Args) > 1 && strings.EqualFold(n.Name.Name, "timediff") { if isNow(n.Args[0]) { n.Args[0] = jd } if isNow(n.Args[1]) { n.Args[1] = jd } rw.modified = true } else if !rw.orderedBy && rw.RewriteRand && strings.EqualFold(n.Name.Name, "random") { retNode = &sql.NumberLit{Value: strconv.Itoa(int(rw.randFn()))} rw.modified = true } else if !rw.orderedBy && rw.RewriteRand && strings.EqualFold(n.Name.Name, "randomblob") { if len(n.Args) == 1 { lit, ok := n.Args[0].(*sql.NumberLit) if !ok { break } n, err := strconv.Atoi(lit.Value) if err != nil { break } retNode = &sql.BlobLit{Value: fmt.Sprintf(`%X`, random.Bytes(max(n, 1)))} rw.modified = true } } } return rw, retNode, nil } func (rw *Rewriter) VisitEnd(node sql.Node) (sql.Node, error) { switch node.(type) { case *sql.OrderingTerm: rw.orderedBy = false } return node, nil } func isNow(e sql.Expr) bool { if i, ok := e.(*sql.Ident); ok { return strings.EqualFold(i.Name, "now") } else if s, ok := e.(*sql.StringLit); ok { return strings.EqualFold(s.Value, "now") } return false } func julianDayAsNumberLit(t time.Time) *sql.NumberLit { return &sql.NumberLit{Value: fmt.Sprintf("%f", julianDay(t))} } func julianDay(t time.Time) float64 { year := t.Year() month := int(t.Month()) day := t.Day() hour := t.Hour() minute := t.Minute() second := t.Second() nanosecond := t.Nanosecond() // Adjust for months January and February if month <= 2 { year-- month += 12 } // Calculate the Julian Day Number A := year / 100 B := 2 - A + A/4 // Convert time to fractional day fractionalDay := (float64(hour) + float64(minute)/60 + (float64(second)+float64(nanosecond)/1e9)/3600) / 24.0 // Use math.Floor to correctly handle the integer parts jd := math.Floor(365.25*float64(year+4716)) + math.Floor(30.6001*float64(month+1)) + float64(day) + float64(B) - 1524.5 + fractionalDay return jd }