332 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			332 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package sanitize
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"encoding/hex"
 | |
| 	"fmt"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 	"unicode/utf8"
 | |
| )
 | |
| 
 | |
| // Part is either a string or an int. A string is raw SQL. An int is a
 | |
| // argument placeholder.
 | |
| type Part any
 | |
| 
 | |
| type Query struct {
 | |
| 	Parts []Part
 | |
| }
 | |
| 
 | |
| // utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
 | |
| // character. utf8.RuneError is not an error if it is also width 3.
 | |
| //
 | |
| // https://github.com/jackc/pgx/issues/1380
 | |
| const replacementcharacterwidth = 3
 | |
| 
 | |
| func (q *Query) Sanitize(args ...any) (string, error) {
 | |
| 	argUse := make([]bool, len(args))
 | |
| 	buf := &bytes.Buffer{}
 | |
| 
 | |
| 	for _, part := range q.Parts {
 | |
| 		var str string
 | |
| 		switch part := part.(type) {
 | |
| 		case string:
 | |
| 			str = part
 | |
| 		case int:
 | |
| 			argIdx := part - 1
 | |
| 
 | |
| 			if argIdx < 0 {
 | |
| 				return "", fmt.Errorf("first sql argument must be > 0")
 | |
| 			}
 | |
| 
 | |
| 			if argIdx >= len(args) {
 | |
| 				return "", fmt.Errorf("insufficient arguments")
 | |
| 			}
 | |
| 			arg := args[argIdx]
 | |
| 			switch arg := arg.(type) {
 | |
| 			case nil:
 | |
| 				str = "null"
 | |
| 			case int64:
 | |
| 				str = strconv.FormatInt(arg, 10)
 | |
| 			case float64:
 | |
| 				str = strconv.FormatFloat(arg, 'f', -1, 64)
 | |
| 			case bool:
 | |
| 				str = strconv.FormatBool(arg)
 | |
| 			case []byte:
 | |
| 				str = QuoteBytes(arg)
 | |
| 			case string:
 | |
| 				str = QuoteString(arg)
 | |
| 			case time.Time:
 | |
| 				str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
 | |
| 			default:
 | |
| 				return "", fmt.Errorf("invalid arg type: %T", arg)
 | |
| 			}
 | |
| 			argUse[argIdx] = true
 | |
| 
 | |
| 			// Prevent SQL injection via Line Comment Creation
 | |
| 			// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
 | |
| 			str = " " + str + " "
 | |
| 		default:
 | |
| 			return "", fmt.Errorf("invalid Part type: %T", part)
 | |
| 		}
 | |
| 		buf.WriteString(str)
 | |
| 	}
 | |
| 
 | |
| 	for i, used := range argUse {
 | |
| 		if !used {
 | |
| 			return "", fmt.Errorf("unused argument: %d", i)
 | |
| 		}
 | |
| 	}
 | |
| 	return buf.String(), nil
 | |
| }
 | |
| 
 | |
| func NewQuery(sql string) (*Query, error) {
 | |
| 	l := &sqlLexer{
 | |
| 		src:     sql,
 | |
| 		stateFn: rawState,
 | |
| 	}
 | |
| 
 | |
| 	for l.stateFn != nil {
 | |
| 		l.stateFn = l.stateFn(l)
 | |
| 	}
 | |
| 
 | |
| 	query := &Query{Parts: l.parts}
 | |
| 
 | |
| 	return query, nil
 | |
| }
 | |
| 
 | |
| func QuoteString(str string) string {
 | |
| 	return "'" + strings.ReplaceAll(str, "'", "''") + "'"
 | |
| }
 | |
| 
 | |
| func QuoteBytes(buf []byte) string {
 | |
| 	return `'\x` + hex.EncodeToString(buf) + "'"
 | |
| }
 | |
| 
 | |
| type sqlLexer struct {
 | |
| 	src     string
 | |
| 	start   int
 | |
| 	pos     int
 | |
| 	nested  int // multiline comment nesting level.
 | |
| 	stateFn stateFn
 | |
| 	parts   []Part
 | |
| }
 | |
| 
 | |
| type stateFn func(*sqlLexer) stateFn
 | |
| 
 | |
| func rawState(l *sqlLexer) stateFn {
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		switch r {
 | |
| 		case 'e', 'E':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune == '\'' {
 | |
| 				l.pos += width
 | |
| 				return escapeStringState
 | |
| 			}
 | |
| 		case '\'':
 | |
| 			return singleQuoteState
 | |
| 		case '"':
 | |
| 			return doubleQuoteState
 | |
| 		case '$':
 | |
| 			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if '0' <= nextRune && nextRune <= '9' {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos-width])
 | |
| 				}
 | |
| 				l.start = l.pos
 | |
| 				return placeholderState
 | |
| 			}
 | |
| 		case '-':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune == '-' {
 | |
| 				l.pos += width
 | |
| 				return oneLineCommentState
 | |
| 			}
 | |
| 		case '/':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune == '*' {
 | |
| 				l.pos += width
 | |
| 				return multilineCommentState
 | |
| 			}
 | |
| 		case utf8.RuneError:
 | |
| 			if width != replacementcharacterwidth {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos])
 | |
| 					l.start = l.pos
 | |
| 				}
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func singleQuoteState(l *sqlLexer) stateFn {
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		switch r {
 | |
| 		case '\'':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune != '\'' {
 | |
| 				return rawState
 | |
| 			}
 | |
| 			l.pos += width
 | |
| 		case utf8.RuneError:
 | |
| 			if width != replacementcharacterwidth {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos])
 | |
| 					l.start = l.pos
 | |
| 				}
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func doubleQuoteState(l *sqlLexer) stateFn {
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		switch r {
 | |
| 		case '"':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune != '"' {
 | |
| 				return rawState
 | |
| 			}
 | |
| 			l.pos += width
 | |
| 		case utf8.RuneError:
 | |
| 			if width != replacementcharacterwidth {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos])
 | |
| 					l.start = l.pos
 | |
| 				}
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // placeholderState consumes a placeholder value. The $ must have already has
 | |
| // already been consumed. The first rune must be a digit.
 | |
| func placeholderState(l *sqlLexer) stateFn {
 | |
| 	num := 0
 | |
| 
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		if '0' <= r && r <= '9' {
 | |
| 			num *= 10
 | |
| 			num += int(r - '0')
 | |
| 		} else {
 | |
| 			l.parts = append(l.parts, num)
 | |
| 			l.pos -= width
 | |
| 			l.start = l.pos
 | |
| 			return rawState
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func escapeStringState(l *sqlLexer) stateFn {
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		switch r {
 | |
| 		case '\\':
 | |
| 			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			l.pos += width
 | |
| 		case '\'':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune != '\'' {
 | |
| 				return rawState
 | |
| 			}
 | |
| 			l.pos += width
 | |
| 		case utf8.RuneError:
 | |
| 			if width != replacementcharacterwidth {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos])
 | |
| 					l.start = l.pos
 | |
| 				}
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func oneLineCommentState(l *sqlLexer) stateFn {
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		switch r {
 | |
| 		case '\\':
 | |
| 			_, width = utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			l.pos += width
 | |
| 		case '\n', '\r':
 | |
| 			return rawState
 | |
| 		case utf8.RuneError:
 | |
| 			if width != replacementcharacterwidth {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos])
 | |
| 					l.start = l.pos
 | |
| 				}
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func multilineCommentState(l *sqlLexer) stateFn {
 | |
| 	for {
 | |
| 		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 		l.pos += width
 | |
| 
 | |
| 		switch r {
 | |
| 		case '/':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune == '*' {
 | |
| 				l.pos += width
 | |
| 				l.nested++
 | |
| 			}
 | |
| 		case '*':
 | |
| 			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
 | |
| 			if nextRune != '/' {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			l.pos += width
 | |
| 			if l.nested == 0 {
 | |
| 				return rawState
 | |
| 			}
 | |
| 			l.nested--
 | |
| 
 | |
| 		case utf8.RuneError:
 | |
| 			if width != replacementcharacterwidth {
 | |
| 				if l.pos-l.start > 0 {
 | |
| 					l.parts = append(l.parts, l.src[l.start:l.pos])
 | |
| 					l.start = l.pos
 | |
| 				}
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
 | |
| // as necessary. This function is only safe when standard_conforming_strings is
 | |
| // on.
 | |
| func SanitizeSQL(sql string, args ...any) (string, error) {
 | |
| 	query, err := NewQuery(sql)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return query.Sanitize(args...)
 | |
| }
 |